diff --git a/README.md b/README.md index cca50a16e..39dc49fea 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,8 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L |------------------------------------------------------------------------------| | SWF | @mock_swf | basic endpoints done | |------------------------------------------------------------------------------| +| X-Ray | @mock_xray | core endpoints done | +|------------------------------------------------------------------------------| ``` ### Another Example diff --git a/moto/__init__.py b/moto/__init__.py index 728d8db71..871aab881 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -36,6 +36,7 @@ from .sts import mock_sts, mock_sts_deprecated # flake8: noqa from .ssm import mock_ssm # flake8: noqa from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa from .swf import mock_swf, mock_swf_deprecated # flake8: noqa +from .xray import mock_xray # flake8: noqa try: diff --git a/moto/xray/__init__.py b/moto/xray/__init__.py new file mode 100644 index 000000000..7b32ca0b0 --- /dev/null +++ b/moto/xray/__init__.py @@ -0,0 +1,6 @@ +from __future__ import unicode_literals +from .models import xray_backends +from ..core.models import base_decorator + +xray_backend = xray_backends['us-east-1'] +mock_xray = base_decorator(xray_backends) diff --git a/moto/xray/exceptions.py b/moto/xray/exceptions.py new file mode 100644 index 000000000..24f700178 --- /dev/null +++ b/moto/xray/exceptions.py @@ -0,0 +1,39 @@ +import json + + +class AWSError(Exception): + CODE = None + STATUS = 400 + + def __init__(self, message, code=None, status=None): + self.message = message + self.code = code if code is not None else self.CODE + self.status = status if status is not None else self.STATUS + + def response(self): + return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) + + +class InvalidRequestException(AWSError): + CODE = 'InvalidRequestException' + + +class BadSegmentException(Exception): + def __init__(self, seg_id=None, code=None, message=None): + self.id = seg_id + self.code = code + self.message = message + + def __repr__(self): + return ''.format('-'.join([self.id, self.code, self.message])) + + def to_dict(self): + result = {} + if self.id is not None: + result['Id'] = self.id + if self.code is not None: + result['ErrorCode'] = self.code + if self.message is not None: + result['Message'] = self.message + + return result diff --git a/moto/xray/models.py b/moto/xray/models.py new file mode 100644 index 000000000..f22edeb9f --- /dev/null +++ b/moto/xray/models.py @@ -0,0 +1,208 @@ +from __future__ import unicode_literals + +import bisect +import datetime +from collections import defaultdict +import json +from moto.core import BaseBackend, BaseModel +from moto.ec2 import ec2_backends +from .exceptions import BadSegmentException, AWSError + + +class TelemetryRecords(BaseModel): + def __init__(self, instance_id, hostname, resource_arn, records): + self.instance_id = instance_id + self.hostname = hostname + self.resource_arn = resource_arn + self.records = records + + @classmethod + def from_json(cls, json): + instance_id = json.get('EC2InstanceId', None) + hostname = json.get('Hostname') + resource_arn = json.get('ResourceARN') + telemetry_records = json['TelemetryRecords'] + + return cls(instance_id, hostname, resource_arn, telemetry_records) + + +# https://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html +class TraceSegment(BaseModel): + def __init__(self, name, segment_id, trace_id, start_time, end_time=None, in_progress=False, service=None, user=None, + origin=None, parent_id=None, http=None, aws=None, metadata=None, annotations=None, subsegments=None, **kwargs): + self.name = name + self.id = segment_id + self.trace_id = trace_id + self._trace_version = None + self._original_request_start_time = None + self._trace_identifier = None + self.start_time = start_time + self._start_date = None + self.end_time = end_time + self._end_date = None + self.in_progress = in_progress + self.service = service + self.user = user + self.origin = origin + self.parent_id = parent_id + self.http = http + self.aws = aws + self.metadata = metadata + self.annotations = annotations + self.subsegments = subsegments + self.misc = kwargs + + def __lt__(self, other): + return self.start_date < other.start_date + + @property + def trace_version(self): + if self._trace_version is None: + self._trace_version = int(self.trace_id.split('-', 1)[0]) + return self._trace_version + + @property + def request_start_date(self): + if self._original_request_start_time is None: + start_time = int(self.trace_id.split('-')[1], 16) + self._original_request_start_time = datetime.datetime.fromtimestamp(start_time) + return self._original_request_start_time + + @property + def start_date(self): + if self._start_date is None: + self._start_date = datetime.datetime.fromtimestamp(self.start_time) + return self._start_date + + @property + def end_date(self): + if self._end_date is None: + self._end_date = datetime.datetime.fromtimestamp(self.end_time) + return self._end_date + + @classmethod + def from_dict(cls, data): + # Check manditory args + if 'id' not in data: + raise BadSegmentException(code='MissingParam', message='Missing segment ID') + seg_id = data['id'] + data['segment_id'] = seg_id # Just adding this key for future convenience + + for arg in ('name', 'trace_id', 'start_time'): + if arg not in data: + raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing segment ID') + + if 'end_time' not in data and 'in_progress' not in data: + raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing end_time or in_progress') + if 'end_time' not in data and data['in_progress'] == 'false': + raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing end_time') + + return cls(**data) + + +class SegmentCollection(object): + def __init__(self): + self._segments = defaultdict(self._new_trace_item) + + @staticmethod + def _new_trace_item(): + return { + 'start_date': datetime.datetime(1970, 1, 1), + 'end_date': datetime.datetime(1970, 1, 1), + 'finished': False, + 'segments': [] + } + + def put_segment(self, segment): + # insert into a sorted list + bisect.insort_left(self._segments[segment.trace_id]['segments'], segment) + + # Get the last segment (takes into account incorrect ordering) + # and if its the last one, mark trace as complete + if self._segments[segment.trace_id]['segments'][-1].end_time is not None: + self._segments[segment.trace_id]['finished'] = True + + start_time = self._segments[segment.trace_id]['segments'][0].start_date + end_time = self._segments[segment.trace_id]['segments'][-1].end_date + self._segments[segment.trace_id]['start_date'] = start_time + self._segments[segment.trace_id]['end_date'] = end_time + + # Todo consolidate trace segments into a trace. + # not enough working knowledge of xray to do this + + def summary(self, start_time, end_time, filter_expression=None, sampling=False): + # This beast https://docs.aws.amazon.com/xray/latest/api/API_GetTraceSummaries.html#API_GetTraceSummaries_ResponseSyntax + if filter_expression is not None: + raise AWSError('Not implemented yet - moto', code='InternalFailure', status=500) + + summaries = [] + + for tid, trace in self._segments.items(): + if trace['finished'] and start_time < trace['start_date'] and trace['end_date'] < end_time: + duration = int((trace['end_date'] - trace['start_date']).total_seconds()) + # this stuff is mostly guesses, refer to TODO above + has_error = any(['error' in seg.misc for seg in trace['segments']]) + has_fault = any(['fault' in seg.misc for seg in trace['segments']]) + has_throttle = any(['throttle' in seg.misc for seg in trace['segments']]) + + # Apparently all of these options are optional + summary_part = { + 'Annotations': {}, # Not implemented yet + 'Duration': duration, + 'HasError': has_error, + 'HasFault': has_fault, + 'HasThrottle': has_throttle, + 'Http': {}, # Not implemented yet + 'Id': tid, + 'IsParital': False, # needs lots more work to work on partials + 'ResponseTime': 1, # definitely 1ms resposnetime + 'ServiceIds': [], # Not implemented yet + 'Users': {} # Not implemented yet + } + summaries.append(summary_part) + + result = { + "ApproximateTime": int((datetime.datetime.now() - datetime.datetime(1970, 1, 1)).total_seconds()), + "TracesProcessedCount": len(summaries), + "TraceSummaries": summaries + } + + return result + + +class XRayBackend(BaseBackend): + + def __init__(self): + self._telemetry_records = [] + self._segment_collection = SegmentCollection() + + def add_telemetry_records(self, json): + self._telemetry_records.append( + TelemetryRecords.from_json(json) + ) + + def process_segment(self, doc): + try: + data = json.loads(doc) + except ValueError: + raise BadSegmentException(code='JSONFormatError', message='Bad JSON data') + + try: + # Get Segment Object + segment = TraceSegment.from_dict(data) + except ValueError: + raise BadSegmentException(code='JSONFormatError', message='Bad JSON data') + + try: + # Store Segment Object + self._segment_collection.put_segment(segment) + except Exception as err: + raise BadSegmentException(seg_id=segment.id, code='InternalFailure', message=str(err)) + + def get_trace_summary(self, start_time, end_time, filter_expression, summaries): + return self._segment_collection.summary(start_time, end_time, filter_expression, summaries) + + +xray_backends = {} +for region, ec2_backend in ec2_backends.items(): + xray_backends[region] = XRayBackend() diff --git a/moto/xray/responses.py b/moto/xray/responses.py new file mode 100644 index 000000000..3c69e105c --- /dev/null +++ b/moto/xray/responses.py @@ -0,0 +1,133 @@ +from __future__ import unicode_literals +from urllib.parse import urlsplit +import json +import six +import datetime + +from moto.core.responses import BaseResponse +from moto.core.utils import camelcase_to_underscores, method_names_from_class +from werkzeug.exceptions import HTTPException + +from .models import xray_backends +from .exceptions import AWSError, BadSegmentException + + +class XRayResponse(BaseResponse): + + def _error(self, code, message): + return json.dumps({'__type': code, 'message': message}), dict(status=400) + + @property + def xray_backend(self): + return xray_backends[self.region] + + @property + def request_params(self): + try: + return json.loads(self.body) + except ValueError: + return {} + + def _get_param(self, param, default=None): + return self.request_params.get(param, default) + + def call_action(self): + # Amazon is just calling urls like /TelemetryRecords etc... + action = urlsplit(self.uri).path.lstrip('/') + action = camelcase_to_underscores(action) + headers = self.response_headers + method_names = method_names_from_class(self.__class__) + if action in method_names: + method = getattr(self, action) + try: + response = method() + except HTTPException as http_error: + response = http_error.description, dict(status=http_error.code) + if isinstance(response, six.string_types): + return 200, headers, response + else: + body, new_headers = response + status = new_headers.get('status', 200) + headers.update(new_headers) + # Cast status to string + if "status" in headers: + headers['status'] = str(headers['status']) + return status, headers, body + + raise NotImplementedError( + "The {0} action has not been implemented".format(action)) + + # PutTelemetryRecords + def telemetry_records(self): + try: + self.xray_backend.add_telemetry_records(self.request_params) + except AWSError as err: + return err.response() + + return '' + + # PutTraceSegments + def trace_segments(self): + docs = self._get_param('TraceSegmentDocuments') + + if docs is None: + msg = 'Parameter TraceSegmentDocuments is missing' + return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + + # Raises an exception that contains info about a bad segment, + # the object also has a to_dict() method + bad_segments = [] + for doc in docs: + try: + self.xray_backend.process_segment(doc) + except BadSegmentException as bad_seg: + bad_segments.append(bad_seg) + except Exception as err: + return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + + result = {'UnprocessedTraceSegments': [x.to_dict() for x in bad_segments]} + return json.dumps(result) + + # GetTraceSummaries + def trace_summaries(self): + start_time = self._get_param('StartTime') + end_time = self._get_param('EndTime') + if start_time is None: + msg = 'Parameter StartTime is missing' + return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + if end_time is None: + msg = 'Parameter EndTime is missing' + return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + + filter_expression = self._get_param('FilterExpression') + sampling = self._get_param('Sampling', 'false') == 'true' + + try: + start_time = datetime.datetime.fromtimestamp(int(start_time)) + end_time = datetime.datetime.fromtimestamp(int(end_time)) + except ValueError: + msg = 'start_time and end_time are not integers' + return json.dumps({'__type': 'InvalidParameterValue', 'message': msg}), dict(status=400) + except Exception as err: + return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + + try: + result = self.xray_backend.get_trace_summary(start_time, end_time, filter_expression, sampling) + except AWSError as err: + return err.response() + except Exception as err: + return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + + return json.dumps(result) + + # BatchGetTraces + def traces(self): + raise NotImplementedError() + + # GetServiceGraph + def service_graph(self): + raise NotImplementedError() + + # GetTraceGraph + def trace_graph(self): + raise NotImplementedError() diff --git a/moto/xray/urls.py b/moto/xray/urls.py new file mode 100644 index 000000000..c224e8d38 --- /dev/null +++ b/moto/xray/urls.py @@ -0,0 +1,10 @@ +from __future__ import unicode_literals +from .responses import XRayResponse + +url_bases = [ + "https?://xray.(.+).amazonaws.com", +] + +url_paths = { + '{0}/.+$': XRayResponse.dispatch, +} diff --git a/tests/test_xray/test_xray_boto3.py b/tests/test_xray/test_xray_boto3.py new file mode 100644 index 000000000..9da55ad1e --- /dev/null +++ b/tests/test_xray/test_xray_boto3.py @@ -0,0 +1,84 @@ +from __future__ import unicode_literals + +import boto3 +import json +import botocore.exceptions +import sure # noqa + +from moto import mock_xray + +import datetime + + +@mock_xray +def test_put_telemetry(): + client = boto3.client('xray', region_name='us-east-1') + + client.put_telemetry_records( + TelemetryRecords=[ + { + 'Timestamp': datetime.datetime(2015, 1, 1), + 'SegmentsReceivedCount': 123, + 'SegmentsSentCount': 123, + 'SegmentsSpilloverCount': 123, + 'SegmentsRejectedCount': 123, + 'BackendConnectionErrors': { + 'TimeoutCount': 123, + 'ConnectionRefusedCount': 123, + 'HTTPCode4XXCount': 123, + 'HTTPCode5XXCount': 123, + 'UnknownHostCount': 123, + 'OtherCount': 123 + } + }, + ], + EC2InstanceId='string', + Hostname='string', + ResourceARN='string' + ) + + +@mock_xray +def test_put_trace_segments(): + client = boto3.client('xray', region_name='us-east-1') + + client.put_trace_segments( + TraceSegmentDocuments=[ + json.dumps({ + 'name': 'example.com', + 'id': '70de5b6f19ff9a0a', + 'start_time': 1.478293361271E9, + 'trace_id': '1-581cf771-a006649127e371903a2de979', + 'end_time': 1.478293361449E9 + }) + ] + ) + + +@mock_xray +def test_trace_summary(): + client = boto3.client('xray', region_name='us-east-1') + + client.put_trace_segments( + TraceSegmentDocuments=[ + json.dumps({ + 'name': 'example.com', + 'id': '70de5b6f19ff9a0a', + 'start_time': 1.478293361271E9, + 'trace_id': '1-581cf771-a006649127e371903a2de979', + 'in_progress': True + }), + json.dumps({ + 'name': 'example.com', + 'id': '70de5b6f19ff9a0b', + 'start_time': 1478293365, + 'trace_id': '1-581cf771-a006649127e371903a2de979', + 'end_time': 1478293385 + }) + ] + ) + + client.get_trace_summaries( + StartTime=datetime.datetime(2014, 1, 1), + EndTime=datetime.datetime(2017, 1, 1) + )