diff --git a/moto/core/responses.py b/moto/core/responses.py index a97f66f6c..781a0b284 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -178,8 +178,7 @@ class BaseResponse(_TemplateEnvironmentMixin): self.setup_class(request, full_url, headers) return self.call_action() - def call_action(self): - headers = self.response_headers + def _get_action(self): action = self.querystring.get('Action', [""])[0] if not action: # Some services use a header for the action # Headers are case-insensitive. Probably a better way to do this. @@ -188,7 +187,11 @@ class BaseResponse(_TemplateEnvironmentMixin): if match: action = match.split(".")[-1] - action = camelcase_to_underscores(action) + return action + + def call_action(self): + headers = self.response_headers + action = camelcase_to_underscores(self._get_action()) method_names = method_names_from_class(self.__class__) if action in method_names: method = getattr(self, action) diff --git a/moto/server.py b/moto/server.py index 8d0103cc2..966cb1614 100644 --- a/moto/server.py +++ b/moto/server.py @@ -139,10 +139,13 @@ def create_backend_app(service): else: endpoint = None - if endpoint in backend_app.view_functions: + original_endpoint = endpoint + index = 2 + while endpoint in backend_app.view_functions: # HACK: Sometimes we map the same view to multiple url_paths. Flask # requries us to have different names. - endpoint += "2" + endpoint = original_endpoint + str(index) + index += 1 backend_app.add_url_rule( url_path, diff --git a/moto/xray/models.py b/moto/xray/models.py index f22edeb9f..b2d418232 100644 --- a/moto/xray/models.py +++ b/moto/xray/models.py @@ -28,7 +28,7 @@ class TelemetryRecords(BaseModel): # https://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html class TraceSegment(BaseModel): - def __init__(self, name, segment_id, trace_id, start_time, end_time=None, in_progress=False, service=None, user=None, + def __init__(self, name, segment_id, trace_id, start_time, raw, end_time=None, in_progress=False, service=None, user=None, origin=None, parent_id=None, http=None, aws=None, metadata=None, annotations=None, subsegments=None, **kwargs): self.name = name self.id = segment_id @@ -52,6 +52,9 @@ class TraceSegment(BaseModel): self.subsegments = subsegments self.misc = kwargs + # Raw json string + self.raw = raw + def __lt__(self, other): return self.start_date < other.start_date @@ -81,7 +84,7 @@ class TraceSegment(BaseModel): return self._end_date @classmethod - def from_dict(cls, data): + def from_dict(cls, data, raw): # Check manditory args if 'id' not in data: raise BadSegmentException(code='MissingParam', message='Missing segment ID') @@ -97,12 +100,12 @@ class TraceSegment(BaseModel): if 'end_time' not in data and data['in_progress'] == 'false': raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing end_time') - return cls(**data) + return cls(raw=raw, **data) class SegmentCollection(object): def __init__(self): - self._segments = defaultdict(self._new_trace_item) + self._traces = defaultdict(self._new_trace_item) @staticmethod def _new_trace_item(): @@ -110,23 +113,24 @@ class SegmentCollection(object): 'start_date': datetime.datetime(1970, 1, 1), 'end_date': datetime.datetime(1970, 1, 1), 'finished': False, + 'trace_id': None, 'segments': [] } def put_segment(self, segment): # insert into a sorted list - bisect.insort_left(self._segments[segment.trace_id]['segments'], segment) + bisect.insort_left(self._traces[segment.trace_id]['segments'], segment) # Get the last segment (takes into account incorrect ordering) # and if its the last one, mark trace as complete - if self._segments[segment.trace_id]['segments'][-1].end_time is not None: - self._segments[segment.trace_id]['finished'] = True - - start_time = self._segments[segment.trace_id]['segments'][0].start_date - end_time = self._segments[segment.trace_id]['segments'][-1].end_date - self._segments[segment.trace_id]['start_date'] = start_time - self._segments[segment.trace_id]['end_date'] = end_time + if self._traces[segment.trace_id]['segments'][-1].end_time is not None: + self._traces[segment.trace_id]['finished'] = True + start_time = self._traces[segment.trace_id]['segments'][0].start_date + end_time = self._traces[segment.trace_id]['segments'][-1].end_date + self._traces[segment.trace_id]['start_date'] = start_time + self._traces[segment.trace_id]['end_date'] = end_time + self._traces[segment.trace_id]['trace_id'] = segment.trace_id # Todo consolidate trace segments into a trace. # not enough working knowledge of xray to do this @@ -137,7 +141,7 @@ class SegmentCollection(object): summaries = [] - for tid, trace in self._segments.items(): + for tid, trace in self._traces.items(): if trace['finished'] and start_time < trace['start_date'] and trace['end_date'] < end_time: duration = int((trace['end_date'] - trace['start_date']).total_seconds()) # this stuff is mostly guesses, refer to TODO above @@ -169,6 +173,20 @@ class SegmentCollection(object): return result + def get_trace_ids(self, trace_ids): + traces = [] + unprocessed = [] + + # Its a default dict + existing_trace_ids = list(self._traces.keys()) + for trace_id in trace_ids: + if trace_id in existing_trace_ids: + traces.append(self._traces[trace_id]) + else: + unprocessed.append(trace_id) + + return traces, unprocessed + class XRayBackend(BaseBackend): @@ -189,7 +207,7 @@ class XRayBackend(BaseBackend): try: # Get Segment Object - segment = TraceSegment.from_dict(data) + segment = TraceSegment.from_dict(data, raw=doc) except ValueError: raise BadSegmentException(code='JSONFormatError', message='Bad JSON data') @@ -202,6 +220,31 @@ class XRayBackend(BaseBackend): def get_trace_summary(self, start_time, end_time, filter_expression, summaries): return self._segment_collection.summary(start_time, end_time, filter_expression, summaries) + def get_trace_ids(self, trace_ids, next_token): + traces, unprocessed_ids = self._segment_collection.get_trace_ids(trace_ids) + + result = { + 'Traces': [], + 'UnprocessedTraceIds': unprocessed_ids + + } + + for trace in traces: + segments = [] + for segment in trace['segments']: + segments.append({ + 'Id': segment.id, + 'Document': segment.raw + }) + + result['Traces'].append({ + 'Duration': int((trace['end_date'] - trace['start_date']).total_seconds()), + 'Id': trace['trace_id'], + 'Segments': segments + }) + + return result + xray_backends = {} for region, ec2_backend in ec2_backends.items(): diff --git a/moto/xray/responses.py b/moto/xray/responses.py index 89705fb5b..328a266bf 100644 --- a/moto/xray/responses.py +++ b/moto/xray/responses.py @@ -1,11 +1,8 @@ from __future__ import unicode_literals import json -import six import datetime from moto.core.responses import BaseResponse -from moto.core.utils import camelcase_to_underscores, method_names_from_class -from werkzeug.exceptions import HTTPException from six.moves.urllib.parse import urlsplit from .models import xray_backends @@ -31,31 +28,11 @@ class XRayResponse(BaseResponse): def _get_param(self, param, default=None): return self.request_params.get(param, default) - def call_action(self): + def _get_action(self): # Amazon is just calling urls like /TelemetryRecords etc... - action = urlsplit(self.uri).path.lstrip('/') - action = camelcase_to_underscores(action) - headers = self.response_headers - method_names = method_names_from_class(self.__class__) - if action in method_names: - method = getattr(self, action) - try: - response = method() - except HTTPException as http_error: - response = http_error.description, dict(status=http_error.code) - if isinstance(response, six.string_types): - return 200, headers, response - else: - body, new_headers = response - status = new_headers.get('status', 200) - headers.update(new_headers) - # Cast status to string - if "status" in headers: - headers['status'] = str(headers['status']) - return status, headers, body - - raise NotImplementedError( - "The {0} action has not been implemented".format(action)) + # This uses the value after / as the camalcase action, which then + # gets converted in call_action to find the following methods + return urlsplit(self.uri).path.lstrip('/') # PutTelemetryRecords def telemetry_records(self): @@ -122,12 +99,52 @@ class XRayResponse(BaseResponse): # BatchGetTraces def traces(self): - raise NotImplementedError() + trace_ids = self._get_param('TraceIds') + next_token = self._get_param('NextToken') # not implemented yet - # GetServiceGraph + if trace_ids is None: + msg = 'Parameter TraceIds is missing' + return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + + try: + result = self.xray_backend.get_trace_ids(trace_ids, next_token) + except AWSError as err: + return err.response() + except Exception as err: + return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + + return json.dumps(result) + + # GetServiceGraph - just a dummy response for now def service_graph(self): - raise NotImplementedError() + start_time = self._get_param('StartTime') + end_time = self._get_param('EndTime') + # next_token = self._get_param('NextToken') # not implemented yet - # GetTraceGraph + if start_time is None: + msg = 'Parameter StartTime is missing' + return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + if end_time is None: + msg = 'Parameter EndTime is missing' + return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + + result = { + 'StartTime': start_time, + 'EndTime': end_time, + 'Services': [] + } + return json.dumps(result) + + # GetTraceGraph - just a dummy response for now def trace_graph(self): - raise NotImplementedError() + trace_ids = self._get_param('TraceIds') + # next_token = self._get_param('NextToken') # not implemented yet + + if trace_ids is None: + msg = 'Parameter TraceIds is missing' + return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + + result = { + 'Services': [] + } + return json.dumps(result) diff --git a/moto/xray/urls.py b/moto/xray/urls.py index c224e8d38..b0f13a980 100644 --- a/moto/xray/urls.py +++ b/moto/xray/urls.py @@ -6,5 +6,10 @@ url_bases = [ ] url_paths = { - '{0}/.+$': XRayResponse.dispatch, + '{0}/TelemetryRecords$': XRayResponse.dispatch, + '{0}/TraceSegments$': XRayResponse.dispatch, + '{0}/Traces$': XRayResponse.dispatch, + '{0}/ServiceGraph$': XRayResponse.dispatch, + '{0}/TraceGraph$': XRayResponse.dispatch, + '{0}/TraceSummaries$': XRayResponse.dispatch, } diff --git a/tests/test_xray/test_xray_boto3.py b/tests/test_xray/test_xray_boto3.py index 9da55ad1e..5ad8f8bc7 100644 --- a/tests/test_xray/test_xray_boto3.py +++ b/tests/test_xray/test_xray_boto3.py @@ -82,3 +82,58 @@ def test_trace_summary(): StartTime=datetime.datetime(2014, 1, 1), EndTime=datetime.datetime(2017, 1, 1) ) + + +@mock_xray +def test_batch_get_trace(): + client = boto3.client('xray', region_name='us-east-1') + + client.put_trace_segments( + TraceSegmentDocuments=[ + json.dumps({ + 'name': 'example.com', + 'id': '70de5b6f19ff9a0a', + 'start_time': 1.478293361271E9, + 'trace_id': '1-581cf771-a006649127e371903a2de979', + 'in_progress': True + }), + json.dumps({ + 'name': 'example.com', + 'id': '70de5b6f19ff9a0b', + 'start_time': 1478293365, + 'trace_id': '1-581cf771-a006649127e371903a2de979', + 'end_time': 1478293385 + }) + ] + ) + + resp = client.batch_get_traces( + TraceIds=['1-581cf771-a006649127e371903a2de979', '1-581cf772-b006649127e371903a2de979'] + ) + len(resp['UnprocessedTraceIds']).should.equal(1) + len(resp['Traces']).should.equal(1) + + +# Following are not implemented, just testing it returns what boto expects +@mock_xray +def test_batch_get_service_graph(): + client = boto3.client('xray', region_name='us-east-1') + + client.get_service_graph( + StartTime=datetime.datetime(2014, 1, 1), + EndTime=datetime.datetime(2017, 1, 1) + ) + + +@mock_xray +def test_batch_get_trace_graph(): + client = boto3.client('xray', region_name='us-east-1') + + client.batch_get_traces( + TraceIds=['1-581cf771-a006649127e371903a2de979', '1-581cf772-b006649127e371903a2de979'] + ) + + + + +