Merge remote-tracking branch 'spulec/master'

This commit is contained in:
Alexander Mohr 2017-12-04 12:27:07 -08:00
commit f541ff932c
5 changed files with 233 additions and 53 deletions

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals
import json
import re
try:
from urllib import unquote
@ -198,7 +197,7 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}"
def _get_aws_region(self, full_url):
region = re.search(self.region_regex, full_url)
region = self.region_regex.search(full_url)
if region:
return region.group(1)
else:

View File

@ -1,13 +1,13 @@
import json
import json
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError
import boto.ec2.cloudwatch
import datetime
from datetime import datetime, timedelta
from dateutil.tz import tzutc
from .utils import make_arn_for_dashboard
DEFAULT_ACCOUNT_ID = 123456789012
@ -18,6 +18,34 @@ class Dimension(object):
self.value = value
def daterange(start, stop, step=timedelta(days=1), inclusive=False):
"""
This method will iterate from `start` to `stop` datetimes with a timedelta step of `step`
(supports iteration forwards or backwards in time)
:param start: start datetime
:param stop: end datetime
:param step: step size as a timedelta
:param inclusive: if True, last item returned will be as step closest to `end` (or `end` if no remainder).
"""
# inclusive=False to behave like range by default
total_step_secs = step.total_seconds()
assert total_step_secs != 0
if total_step_secs > 0:
while start < stop:
yield start
start = start + step
else:
while stop < start:
yield start
start = start + step
if inclusive and start == stop:
yield start
class FakeAlarm(BaseModel):
def __init__(self, name, namespace, metric_name, comparison_operator, evaluation_periods,
@ -38,14 +66,14 @@ class FakeAlarm(BaseModel):
self.ok_actions = ok_actions
self.insufficient_data_actions = insufficient_data_actions
self.unit = unit
self.configuration_updated_timestamp = datetime.datetime.utcnow()
self.configuration_updated_timestamp = datetime.utcnow()
self.history = []
self.state_reason = ''
self.state_reason_data = '{}'
self.state = 'OK'
self.state_updated_timestamp = datetime.datetime.utcnow()
self.state_updated_timestamp = datetime.utcnow()
def update_state(self, reason, reason_data, state_value):
# History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action
@ -56,17 +84,18 @@ class FakeAlarm(BaseModel):
self.state_reason = reason
self.state_reason_data = reason_data
self.state = state_value
self.state_updated_timestamp = datetime.datetime.utcnow()
self.state_updated_timestamp = datetime.utcnow()
class MetricDatum(BaseModel):
def __init__(self, namespace, name, value, dimensions):
def __init__(self, namespace, name, value, dimensions, timestamp):
self.namespace = namespace
self.name = name
self.value = value
self.dimensions = [Dimension(dimension['name'], dimension[
'value']) for dimension in dimensions]
self.timestamp = timestamp or datetime.utcnow().replace(tzinfo=tzutc())
self.dimensions = [Dimension(dimension['Name'], dimension[
'Value']) for dimension in dimensions]
class Dashboard(BaseModel):
@ -75,7 +104,7 @@ class Dashboard(BaseModel):
self.arn = make_arn_for_dashboard(DEFAULT_ACCOUNT_ID, name)
self.name = name
self.body = body
self.last_modified = datetime.datetime.now()
self.last_modified = datetime.now()
@property
def last_modified_iso(self):
@ -92,6 +121,53 @@ class Dashboard(BaseModel):
return '<CloudWatchDashboard {0}>'.format(self.name)
class Statistics:
def __init__(self, stats, dt):
self.timestamp = iso_8601_datetime_with_milliseconds(dt)
self.values = []
self.stats = stats
@property
def sample_count(self):
if 'SampleCount' not in self.stats:
return None
return len(self.values)
@property
def unit(self):
return None
@property
def sum(self):
if 'Sum' not in self.stats:
return None
return sum(self.values)
@property
def min(self):
if 'Minimum' not in self.stats:
return None
return min(self.values)
@property
def max(self):
if 'Maximum' not in self.stats:
return None
return max(self.values)
@property
def average(self):
if 'Average' not in self.stats:
return None
# when moto is 3.4+ we can switch to the statistics module
return sum(self.values) / len(self.values)
class CloudWatchBackend(BaseBackend):
def __init__(self):
@ -150,9 +226,34 @@ class CloudWatchBackend(BaseBackend):
self.alarms.pop(alarm_name, None)
def put_metric_data(self, namespace, metric_data):
for name, value, dimensions in metric_data:
for metric_member in metric_data:
self.metric_data.append(MetricDatum(
namespace, name, value, dimensions))
namespace, metric_member['MetricName'], float(metric_member['Value']), metric_member['Dimensions.member'], metric_member.get('Timestamp')))
def get_metric_statistics(self, namespace, metric_name, start_time, end_time, period, stats):
period_delta = timedelta(seconds=period)
filtered_data = [md for md in self.metric_data if
md.namespace == namespace and md.name == metric_name and start_time <= md.timestamp <= end_time]
# earliest to oldest
filtered_data = sorted(filtered_data, key=lambda x: x.timestamp)
if not filtered_data:
return []
idx = 0
data = list()
for dt in daterange(filtered_data[0].timestamp, filtered_data[-1].timestamp + period_delta, period_delta):
s = Statistics(stats, dt)
while idx < len(filtered_data) and filtered_data[idx].timestamp < (dt + period_delta):
s.values.append(filtered_data[idx].value)
idx += 1
if not s.values:
continue
data.append(s)
return data
def get_all_metrics(self):
return self.metric_data

View File

@ -2,6 +2,7 @@ import json
from moto.core.utils import amzn_request_id
from moto.core.responses import BaseResponse
from .models import cloudwatch_backends
from dateutil.parser import parse as dtparse
class CloudWatchResponse(BaseResponse):
@ -75,35 +76,36 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def put_metric_data(self):
namespace = self._get_param('Namespace')
metric_data = []
metric_index = 1
while True:
try:
metric_name = self.querystring[
'MetricData.member.{0}.MetricName'.format(metric_index)][0]
except KeyError:
break
value = self.querystring.get(
'MetricData.member.{0}.Value'.format(metric_index), [None])[0]
dimensions = []
dimension_index = 1
while True:
try:
dimension_name = self.querystring[
'MetricData.member.{0}.Dimensions.member.{1}.Name'.format(metric_index, dimension_index)][0]
except KeyError:
break
dimension_value = self.querystring[
'MetricData.member.{0}.Dimensions.member.{1}.Value'.format(metric_index, dimension_index)][0]
dimensions.append(
{'name': dimension_name, 'value': dimension_value})
dimension_index += 1
metric_data.append([metric_name, value, dimensions])
metric_index += 1
metric_data = self._get_multi_param('MetricData.member')
self.cloudwatch_backend.put_metric_data(namespace, metric_data)
template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
return template.render()
@amzn_request_id
def get_metric_statistics(self):
namespace = self._get_param('Namespace')
metric_name = self._get_param('MetricName')
start_time = dtparse(self._get_param('StartTime'))
end_time = dtparse(self._get_param('EndTime'))
period = int(self._get_param('Period'))
statistics = self._get_multi_param("Statistics.member")
# Unsupported Parameters (To Be Implemented)
unit = self._get_param('Unit')
extended_statistics = self._get_param('ExtendedStatistics')
dimensions = self._get_param('Dimensions')
if unit or extended_statistics or dimensions:
raise NotImplemented()
# TODO: this should instead throw InvalidParameterCombination
if not statistics:
raise NotImplemented("Must specify either Statistics or ExtendedStatistics")
datapoints = self.cloudwatch_backend.get_metric_statistics(namespace, metric_name, start_time, end_time, period, statistics)
template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE)
return template.render(label=metric_name, datapoints=datapoints)
@amzn_request_id
def list_metrics(self):
metrics = self.cloudwatch_backend.get_all_metrics()
@ -150,10 +152,6 @@ class CloudWatchResponse(BaseResponse):
template = self.response_template(GET_DASHBOARD_TEMPLATE)
return template.render(dashboard=dashboard)
@amzn_request_id
def get_metric_statistics(self):
raise NotImplementedError()
@amzn_request_id
def list_dashboards(self):
prefix = self._get_param('DashboardNamePrefix', '')
@ -266,6 +264,50 @@ PUT_METRIC_DATA_TEMPLATE = """<PutMetricDataResponse xmlns="http://monitoring.am
</ResponseMetadata>
</PutMetricDataResponse>"""
GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ResponseMetadata>
<RequestId>
{{ request_id }}
</RequestId>
</ResponseMetadata>
<GetMetricStatisticsResult>
<Label> {{ label }} </Label>
<Datapoints>
{% for datapoint in datapoints %}
<Datapoint>
{% if datapoint.sum %}
<Sum>{{ datapoint.sum }}</Sum>
{% endif %}
{% if datapoint.average %}
<Average>{{ datapoint.average }}</Average>
{% endif %}
{% if datapoint.maximum %}
<Maximum>{{ datapoint.maximum }}</Maximum>
{% endif %}
{% if datapoint.minimum %}
<Minimum>{{ datapoint.minimum }}</Minimum>
{% endif %}
{% if datapoint.sample_count %}
<SampleCount>{{ datapoint.sample_count }}</SampleCount>
{% endif %}
{% if datapoint.extended_statistics %}
<ExtendedStatistics>{{ datapoint.extended_statistics }}</ExtendedStatistics>
{% endif %}
<Timestamp>{{ datapoint.timestamp }}</Timestamp>
<Unit>{{ datapoint.unit }}</Unit>
</Datapoint>
{% endfor %}
</Datapoints>
</GetMetricStatisticsResult>
</GetMetricStatisticsResponse>"""
LIST_METRICS_TEMPLATE = """<ListMetricsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ListMetricsResult>
<Metrics>

View File

@ -106,7 +106,8 @@ class BaseResponse(_TemplateEnvironmentMixin):
default_region = 'us-east-1'
# to extract region, use [^.]
region_regex = r'\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com'
region_regex = re.compile(r'\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com')
param_list_regex = re.compile(r'(.*)\.(\d+)\.')
aws_service_spec = None
@classmethod
@ -167,7 +168,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
self.response_headers = {"server": "amazon.com"}
def get_region_from_url(self, request, full_url):
match = re.search(self.region_regex, full_url)
match = self.region_regex.search(full_url)
if match:
region = match.group(1)
elif 'Authorization' in request.headers and 'AWS4' in request.headers['Authorization']:
@ -311,6 +312,41 @@ class BaseResponse(_TemplateEnvironmentMixin):
return False
return if_none
def _get_multi_param_helper(self, param_prefix):
value_dict = dict()
tracked_prefixes = set() # prefixes which have already been processed
def is_tracked(name_param):
for prefix_loop in tracked_prefixes:
if name_param.startswith(prefix_loop):
return True
return False
for name, value in self.querystring.items():
if is_tracked(name) or not name.startswith(param_prefix):
continue
match = self.param_list_regex.search(name[len(param_prefix):]) if len(name) > len(param_prefix) else None
if match:
prefix = param_prefix + match.group(1)
value = self._get_multi_param(prefix)
tracked_prefixes.add(prefix)
name = prefix
value_dict[name] = value
else:
value_dict[name] = value[0]
if not value_dict:
return None
if len(value_dict) > 1:
# strip off period prefix
value_dict = {name[len(param_prefix) + 1:]: value for name, value in value_dict.items()}
else:
value_dict = list(value_dict.values())[0]
return value_dict
def _get_multi_param(self, param_prefix):
"""
Given a querystring of ?LaunchConfigurationNames.member.1=my-test-1&LaunchConfigurationNames.member.2=my-test-2
@ -323,12 +359,13 @@ class BaseResponse(_TemplateEnvironmentMixin):
values = []
index = 1
while True:
try:
values.append(self.querystring[prefix + str(index)][0])
except KeyError:
value_dict = self._get_multi_param_helper(prefix + str(index))
if not value_dict:
break
else:
index += 1
values.append(value_dict)
index += 1
return values
def _get_dict_param(self, param_prefix):

View File

@ -1,4 +1,6 @@
from __future__ import unicode_literals
import re
from six.moves.urllib.parse import urlparse
from moto.core.responses import BaseResponse
@ -15,12 +17,11 @@ from .exceptions import (
MAXIMUM_VISIBILTY_TIMEOUT = 43200
MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB
DEFAULT_RECEIVED_MESSAGES = 1
SQS_REGION_REGEX = r'://(.+?)\.queue\.amazonaws\.com'
class SQSResponse(BaseResponse):
region_regex = SQS_REGION_REGEX
region_regex = re.compile(r'://(.+?)\.queue\.amazonaws\.com')
@property
def sqs_backend(self):