This commit is contained in:
Steve Pulec 2017-02-23 21:37:43 -05:00
parent 1433f28846
commit f37bad0e00
260 changed files with 6370 additions and 3773 deletions

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import logging import logging
#logging.getLogger('boto').setLevel(logging.CRITICAL) # logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto' __title__ = 'moto'
__version__ = '0.4.31' __version__ = '0.4.31'

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import apigateway_backends from .models import apigateway_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
apigateway_backend = apigateway_backends['us-east-1'] apigateway_backend = apigateway_backends['us-east-1']
mock_apigateway = base_decorator(apigateway_backends) mock_apigateway = base_decorator(apigateway_backends)

View File

@ -4,9 +4,7 @@ from moto.core.exceptions import RESTError
class StageNotFoundException(RESTError): class StageNotFoundException(RESTError):
code = 404 code = 404
def __init__(self): def __init__(self):
super(StageNotFoundException, self).__init__( super(StageNotFoundException, self).__init__(
"NotFoundException", "Invalid stage identifier specified") "NotFoundException", "Invalid stage identifier specified")

View File

@ -14,15 +14,18 @@ STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_nam
class Deployment(dict): class Deployment(dict):
def __init__(self, deployment_id, name, description=""): def __init__(self, deployment_id, name, description=""):
super(Deployment, self).__init__() super(Deployment, self).__init__()
self['id'] = deployment_id self['id'] = deployment_id
self['stageName'] = name self['stageName'] = name
self['description'] = description self['description'] = description
self['createdDate'] = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) self['createdDate'] = iso_8601_datetime_with_milliseconds(
datetime.datetime.now())
class IntegrationResponse(dict): class IntegrationResponse(dict):
def __init__(self, status_code, selection_pattern=None): def __init__(self, status_code, selection_pattern=None):
self['responseTemplates'] = {"application/json": None} self['responseTemplates'] = {"application/json": None}
self['statusCode'] = status_code self['statusCode'] = status_code
@ -31,6 +34,7 @@ class IntegrationResponse(dict):
class Integration(dict): class Integration(dict):
def __init__(self, integration_type, uri, http_method, request_templates=None): def __init__(self, integration_type, uri, http_method, request_templates=None):
super(Integration, self).__init__() super(Integration, self).__init__()
self['type'] = integration_type self['type'] = integration_type
@ -42,7 +46,8 @@ class Integration(dict):
} }
def create_integration_response(self, status_code, selection_pattern): def create_integration_response(self, status_code, selection_pattern):
integration_response = IntegrationResponse(status_code, selection_pattern) integration_response = IntegrationResponse(
status_code, selection_pattern)
self["integrationResponses"][status_code] = integration_response self["integrationResponses"][status_code] = integration_response
return integration_response return integration_response
@ -54,12 +59,14 @@ class Integration(dict):
class MethodResponse(dict): class MethodResponse(dict):
def __init__(self, status_code): def __init__(self, status_code):
super(MethodResponse, self).__init__() super(MethodResponse, self).__init__()
self['statusCode'] = status_code self['statusCode'] = status_code
class Method(dict): class Method(dict):
def __init__(self, method_type, authorization_type): def __init__(self, method_type, authorization_type):
super(Method, self).__init__() super(Method, self).__init__()
self.update(dict( self.update(dict(
@ -86,6 +93,7 @@ class Method(dict):
class Resource(object): class Resource(object):
def __init__(self, id, region_name, api_id, path_part, parent_id): def __init__(self, id, region_name, api_id, path_part, parent_id):
self.id = id self.id = id
self.region_name = region_name self.region_name = region_name
@ -127,14 +135,17 @@ class Resource(object):
if integration_type == 'HTTP': if integration_type == 'HTTP':
uri = integration['uri'] uri = integration['uri']
requests_func = getattr(requests, integration['httpMethod'].lower()) requests_func = getattr(requests, integration[
'httpMethod'].lower())
response = requests_func(uri) response = requests_func(uri)
else: else:
raise NotImplementedError("The {0} type has not been implemented".format(integration_type)) raise NotImplementedError(
"The {0} type has not been implemented".format(integration_type))
return response.status_code, response.text return response.status_code, response.text
def add_method(self, method_type, authorization_type): def add_method(self, method_type, authorization_type):
method = Method(method_type=method_type, authorization_type=authorization_type) method = Method(method_type=method_type,
authorization_type=authorization_type)
self.resource_methods[method_type] = method self.resource_methods[method_type] = method
return method return method
@ -142,7 +153,8 @@ class Resource(object):
return self.resource_methods[method_type] return self.resource_methods[method_type]
def add_integration(self, method_type, integration_type, uri, request_templates=None): def add_integration(self, method_type, integration_type, uri, request_templates=None):
integration = Integration(integration_type, uri, method_type, request_templates=request_templates) integration = Integration(
integration_type, uri, method_type, request_templates=request_templates)
self.resource_methods[method_type]['methodIntegration'] = integration self.resource_methods[method_type]['methodIntegration'] = integration
return integration return integration
@ -155,9 +167,8 @@ class Resource(object):
class Stage(dict): class Stage(dict):
def __init__(self, name=None, deployment_id=None, variables=None, def __init__(self, name=None, deployment_id=None, variables=None,
description='',cacheClusterEnabled=False,cacheClusterSize=None): description='', cacheClusterEnabled=False, cacheClusterSize=None):
super(Stage, self).__init__() super(Stage, self).__init__()
if variables is None: if variables is None:
variables = {} variables = {}
@ -190,21 +201,24 @@ class Stage(dict):
elif op['op'] == 'replace': elif op['op'] == 'replace':
# Method Settings drop into here # Method Settings drop into here
# (e.g., path could be '/*/*/logging/loglevel') # (e.g., path could be '/*/*/logging/loglevel')
split_path = op['path'].split('/',3) split_path = op['path'].split('/', 3)
if len(split_path)!=4: if len(split_path) != 4:
continue continue
self._patch_method_setting('/'.join(split_path[1:3]),split_path[3],op['value']) self._patch_method_setting(
'/'.join(split_path[1:3]), split_path[3], op['value'])
else: else:
raise Exception('Patch operation "%s" not implemented' % op['op']) raise Exception(
'Patch operation "%s" not implemented' % op['op'])
return self return self
def _patch_method_setting(self,resource_path_and_method,key,value): def _patch_method_setting(self, resource_path_and_method, key, value):
updated_key = self._method_settings_translations(key) updated_key = self._method_settings_translations(key)
if updated_key is not None: if updated_key is not None:
if resource_path_and_method not in self['methodSettings']: if resource_path_and_method not in self['methodSettings']:
self['methodSettings'][resource_path_and_method] = self._get_default_method_settings() self['methodSettings'][
self['methodSettings'][resource_path_and_method][updated_key] = self._convert_to_type(updated_key,value) resource_path_and_method] = self._get_default_method_settings()
self['methodSettings'][resource_path_and_method][
updated_key] = self._convert_to_type(updated_key, value)
def _get_default_method_settings(self): def _get_default_method_settings(self):
return { return {
@ -219,18 +233,18 @@ class Stage(dict):
"requireAuthorizationForCacheControl": True "requireAuthorizationForCacheControl": True
} }
def _method_settings_translations(self,key): def _method_settings_translations(self, key):
mappings = { mappings = {
'metrics/enabled' :'metricsEnabled', 'metrics/enabled': 'metricsEnabled',
'logging/loglevel' : 'loggingLevel', 'logging/loglevel': 'loggingLevel',
'logging/dataTrace' : 'dataTraceEnabled' , 'logging/dataTrace': 'dataTraceEnabled',
'throttling/burstLimit' : 'throttlingBurstLimit', 'throttling/burstLimit': 'throttlingBurstLimit',
'throttling/rateLimit' : 'throttlingRateLimit', 'throttling/rateLimit': 'throttlingRateLimit',
'caching/enabled' : 'cachingEnabled', 'caching/enabled': 'cachingEnabled',
'caching/ttlInSeconds' : 'cacheTtlInSeconds', 'caching/ttlInSeconds': 'cacheTtlInSeconds',
'caching/dataEncrypted' : 'cacheDataEncrypted', 'caching/dataEncrypted': 'cacheDataEncrypted',
'caching/requireAuthorizationForCacheControl' : 'requireAuthorizationForCacheControl', 'caching/requireAuthorizationForCacheControl': 'requireAuthorizationForCacheControl',
'caching/unauthorizedCacheControlHeaderStrategy' : 'unauthorizedCacheControlHeaderStrategy' 'caching/unauthorizedCacheControlHeaderStrategy': 'unauthorizedCacheControlHeaderStrategy'
} }
if key in mappings: if key in mappings:
@ -238,21 +252,21 @@ class Stage(dict):
else: else:
None None
def _str2bool(self,v): def _str2bool(self, v):
return v.lower() == "true" return v.lower() == "true"
def _convert_to_type(self,key,val): def _convert_to_type(self, key, val):
type_mappings = { type_mappings = {
'metricsEnabled' : 'bool', 'metricsEnabled': 'bool',
'loggingLevel' : 'str', 'loggingLevel': 'str',
'dataTraceEnabled' : 'bool', 'dataTraceEnabled': 'bool',
'throttlingBurstLimit' : 'int', 'throttlingBurstLimit': 'int',
'throttlingRateLimit' : 'float', 'throttlingRateLimit': 'float',
'cachingEnabled' : 'bool', 'cachingEnabled': 'bool',
'cacheTtlInSeconds' : 'int', 'cacheTtlInSeconds': 'int',
'cacheDataEncrypted' : 'bool', 'cacheDataEncrypted': 'bool',
'requireAuthorizationForCacheControl' :'bool', 'requireAuthorizationForCacheControl': 'bool',
'unauthorizedCacheControlHeaderStrategy' : 'str' 'unauthorizedCacheControlHeaderStrategy': 'str'
} }
if key in type_mappings: if key in type_mappings:
@ -261,7 +275,7 @@ class Stage(dict):
if type_value == 'bool': if type_value == 'bool':
return self._str2bool(val) return self._str2bool(val)
elif type_value == 'int': elif type_value == 'int':
return int(val) return int(val)
elif type_value == 'float': elif type_value == 'float':
return float(val) return float(val)
else: else:
@ -269,10 +283,8 @@ class Stage(dict):
else: else:
return str(val) return str(val)
def _apply_operation_to_variables(self, op):
key = op['path'][op['path'].rindex("variables/") + 10:]
def _apply_operation_to_variables(self,op):
key = op['path'][op['path'].rindex("variables/")+10:]
if op['op'] == 'remove': if op['op'] == 'remove':
self['variables'].pop(key, None) self['variables'].pop(key, None)
elif op['op'] == 'replace': elif op['op'] == 'replace':
@ -281,8 +293,8 @@ class Stage(dict):
raise Exception('Patch operation "%s" not implemented' % op['op']) raise Exception('Patch operation "%s" not implemented' % op['op'])
class RestAPI(object): class RestAPI(object):
def __init__(self, id, region_name, name, description): def __init__(self, id, region_name, name, description):
self.id = id self.id = id
self.region_name = region_name self.region_name = region_name
@ -306,7 +318,8 @@ class RestAPI(object):
def add_child(self, path, parent_id=None): def add_child(self, path, parent_id=None):
child_id = create_id() child_id = create_id()
child = Resource(id=child_id, region_name=self.region_name, api_id=self.id, path_part=path, parent_id=parent_id) child = Resource(id=child_id, region_name=self.region_name,
api_id=self.id, path_part=path, parent_id=parent_id)
self.resources[child_id] = child self.resources[child_id] = child
return child return child
@ -326,25 +339,28 @@ class RestAPI(object):
return status_code, {}, response return status_code, {}, response
def update_integration_mocks(self, stage_name): def update_integration_mocks(self, stage_name):
stage_url = STAGE_URL.format(api_id=self.id.upper(), region_name=self.region_name, stage_name=stage_name) stage_url = STAGE_URL.format(api_id=self.id.upper(),
responses.add_callback(responses.GET, stage_url, callback=self.resource_callback) region_name=self.region_name, stage_name=stage_name)
responses.add_callback(responses.GET, stage_url,
callback=self.resource_callback)
def create_stage(self, name, deployment_id,variables=None,description='',cacheClusterEnabled=None,cacheClusterSize=None): def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None):
if variables is None: if variables is None:
variables = {} variables = {}
stage = Stage(name=name, deployment_id=deployment_id,variables=variables, stage = Stage(name=name, deployment_id=deployment_id, variables=variables,
description=description,cacheClusterSize=cacheClusterSize,cacheClusterEnabled=cacheClusterEnabled) description=description, cacheClusterSize=cacheClusterSize, cacheClusterEnabled=cacheClusterEnabled)
self.stages[name] = stage self.stages[name] = stage
self.update_integration_mocks(name) self.update_integration_mocks(name)
return stage return stage
def create_deployment(self, name, description="",stage_variables=None): def create_deployment(self, name, description="", stage_variables=None):
if stage_variables is None: if stage_variables is None:
stage_variables = {} stage_variables = {}
deployment_id = create_id() deployment_id = create_id()
deployment = Deployment(deployment_id, name, description) deployment = Deployment(deployment_id, name, description)
self.deployments[deployment_id] = deployment self.deployments[deployment_id] = deployment
self.stages[name] = Stage(name=name, deployment_id=deployment_id,variables=stage_variables) self.stages[name] = Stage(
name=name, deployment_id=deployment_id, variables=stage_variables)
self.update_integration_mocks(name) self.update_integration_mocks(name)
return deployment return deployment
@ -353,7 +369,7 @@ class RestAPI(object):
return self.deployments[deployment_id] return self.deployments[deployment_id]
def get_stages(self): def get_stages(self):
return list(self.stages.values()) return list(self.stages.values())
def get_deployments(self): def get_deployments(self):
return list(self.deployments.values()) return list(self.deployments.values())
@ -363,6 +379,7 @@ class RestAPI(object):
class APIGatewayBackend(BaseBackend): class APIGatewayBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name):
super(APIGatewayBackend, self).__init__() super(APIGatewayBackend, self).__init__()
self.apis = {} self.apis = {}
@ -429,19 +446,17 @@ class APIGatewayBackend(BaseBackend):
else: else:
return stage return stage
def get_stages(self, function_id): def get_stages(self, function_id):
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
return api.get_stages() return api.get_stages()
def create_stage(self, function_id, stage_name, deploymentId, def create_stage(self, function_id, stage_name, deploymentId,
variables=None,description='',cacheClusterEnabled=None,cacheClusterSize=None): variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None):
if variables is None: if variables is None:
variables = {} variables = {}
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
api.create_stage(stage_name,deploymentId,variables=variables, api.create_stage(stage_name, deploymentId, variables=variables,
description=description,cacheClusterEnabled=cacheClusterEnabled,cacheClusterSize=cacheClusterSize) description=description, cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize)
return api.stages.get(stage_name) return api.stages.get(stage_name)
def update_stage(self, function_id, stage_name, patch_operations): def update_stage(self, function_id, stage_name, patch_operations):
@ -467,10 +482,10 @@ class APIGatewayBackend(BaseBackend):
return method_response return method_response
def create_integration(self, function_id, resource_id, method_type, integration_type, uri, def create_integration(self, function_id, resource_id, method_type, integration_type, uri,
request_templates=None): request_templates=None):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
integration = resource.add_integration(method_type, integration_type, uri, integration = resource.add_integration(method_type, integration_type, uri,
request_templates=request_templates) request_templates=request_templates)
return integration return integration
def get_integration(self, function_id, resource_id, method_type): def get_integration(self, function_id, resource_id, method_type):
@ -482,25 +497,31 @@ class APIGatewayBackend(BaseBackend):
return resource.delete_integration(method_type) return resource.delete_integration(method_type)
def create_integration_response(self, function_id, resource_id, method_type, status_code, selection_pattern): def create_integration_response(self, function_id, resource_id, method_type, status_code, selection_pattern):
integration = self.get_integration(function_id, resource_id, method_type) integration = self.get_integration(
integration_response = integration.create_integration_response(status_code, selection_pattern) function_id, resource_id, method_type)
integration_response = integration.create_integration_response(
status_code, selection_pattern)
return integration_response return integration_response
def get_integration_response(self, function_id, resource_id, method_type, status_code): def get_integration_response(self, function_id, resource_id, method_type, status_code):
integration = self.get_integration(function_id, resource_id, method_type) integration = self.get_integration(
integration_response = integration.get_integration_response(status_code) function_id, resource_id, method_type)
integration_response = integration.get_integration_response(
status_code)
return integration_response return integration_response
def delete_integration_response(self, function_id, resource_id, method_type, status_code): def delete_integration_response(self, function_id, resource_id, method_type, status_code):
integration = self.get_integration(function_id, resource_id, method_type) integration = self.get_integration(
integration_response = integration.delete_integration_response(status_code) function_id, resource_id, method_type)
integration_response = integration.delete_integration_response(
status_code)
return integration_response return integration_response
def create_deployment(self, function_id, name, description ="", stage_variables=None): def create_deployment(self, function_id, name, description="", stage_variables=None):
if stage_variables is None: if stage_variables is None:
stage_variables = {} stage_variables = {}
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
deployment = api.create_deployment(name, description,stage_variables) deployment = api.create_deployment(name, description, stage_variables)
return deployment return deployment
def get_deployment(self, function_id, deployment_id): def get_deployment(self, function_id, deployment_id):
@ -515,6 +536,8 @@ class APIGatewayBackend(BaseBackend):
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
return api.delete_deployment(deployment_id) return api.delete_deployment(deployment_id)
apigateway_backends = {} apigateway_backends = {}
for region_name in ['us-east-1', 'us-west-2', 'eu-west-1', 'ap-northeast-1']: # Not available in boto yet # Not available in boto yet
for region_name in ['us-east-1', 'us-west-2', 'eu-west-1', 'ap-northeast-1']:
apigateway_backends[region_name] = APIGatewayBackend(region_name) apigateway_backends[region_name] = APIGatewayBackend(region_name)

View File

@ -12,7 +12,6 @@ class APIGatewayResponse(BaseResponse):
def _get_param(self, key): def _get_param(self, key):
return json.loads(self.body).get(key) return json.loads(self.body).get(key)
def _get_param_with_default_value(self, key, default): def _get_param_with_default_value(self, key, default):
jsonbody = json.loads(self.body) jsonbody = json.loads(self.body)
@ -69,7 +68,8 @@ class APIGatewayResponse(BaseResponse):
resource = self.backend.get_resource(function_id, resource_id) resource = self.backend.get_resource(function_id, resource_id)
elif self.method == 'POST': elif self.method == 'POST':
path_part = self._get_param("pathPart") path_part = self._get_param("pathPart")
resource = self.backend.create_resource(function_id, resource_id, path_part) resource = self.backend.create_resource(
function_id, resource_id, path_part)
elif self.method == 'DELETE': elif self.method == 'DELETE':
resource = self.backend.delete_resource(function_id, resource_id) resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict()) return 200, {}, json.dumps(resource.to_dict())
@ -82,11 +82,13 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6] method_type = url_path_parts[6]
if self.method == 'GET': if self.method == 'GET':
method = self.backend.get_method(function_id, resource_id, method_type) method = self.backend.get_method(
function_id, resource_id, method_type)
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)
elif self.method == 'PUT': elif self.method == 'PUT':
authorization_type = self._get_param("authorizationType") authorization_type = self._get_param("authorizationType")
method = self.backend.create_method(function_id, resource_id, method_type, authorization_type) method = self.backend.create_method(
function_id, resource_id, method_type, authorization_type)
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)
def resource_method_responses(self, request, full_url, headers): def resource_method_responses(self, request, full_url, headers):
@ -98,11 +100,14 @@ class APIGatewayResponse(BaseResponse):
response_code = url_path_parts[8] response_code = url_path_parts[8]
if self.method == 'GET': if self.method == 'GET':
method_response = self.backend.get_method_response(function_id, resource_id, method_type, response_code) method_response = self.backend.get_method_response(
function_id, resource_id, method_type, response_code)
elif self.method == 'PUT': elif self.method == 'PUT':
method_response = self.backend.create_method_response(function_id, resource_id, method_type, response_code) method_response = self.backend.create_method_response(
function_id, resource_id, method_type, response_code)
elif self.method == 'DELETE': elif self.method == 'DELETE':
method_response = self.backend.delete_method_response(function_id, resource_id, method_type, response_code) method_response = self.backend.delete_method_response(
function_id, resource_id, method_type, response_code)
return 200, {}, json.dumps(method_response) return 200, {}, json.dumps(method_response)
def restapis_stages(self, request, full_url, headers): def restapis_stages(self, request, full_url, headers):
@ -113,10 +118,13 @@ class APIGatewayResponse(BaseResponse):
if self.method == 'POST': if self.method == 'POST':
stage_name = self._get_param("stageName") stage_name = self._get_param("stageName")
deployment_id = self._get_param("deploymentId") deployment_id = self._get_param("deploymentId")
stage_variables = self._get_param_with_default_value('variables',{}) stage_variables = self._get_param_with_default_value(
description = self._get_param_with_default_value('description','') 'variables', {})
cacheClusterEnabled = self._get_param_with_default_value('cacheClusterEnabled',False) description = self._get_param_with_default_value('description', '')
cacheClusterSize = self._get_param_with_default_value('cacheClusterSize',None) cacheClusterEnabled = self._get_param_with_default_value(
'cacheClusterEnabled', False)
cacheClusterSize = self._get_param_with_default_value(
'cacheClusterSize', None)
stage_response = self.backend.create_stage(function_id, stage_name, deployment_id, stage_response = self.backend.create_stage(function_id, stage_name, deployment_id,
variables=stage_variables, description=description, variables=stage_variables, description=description,
@ -135,12 +143,14 @@ class APIGatewayResponse(BaseResponse):
if self.method == 'GET': if self.method == 'GET':
try: try:
stage_response = self.backend.get_stage(function_id, stage_name) stage_response = self.backend.get_stage(
function_id, stage_name)
except StageNotFoundException as error: except StageNotFoundException as error:
return error.code, {},'{{"message":"{0}","code":"{1}"}}'.format(error.message,error.error_type) return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type)
elif self.method == 'PATCH': elif self.method == 'PATCH':
patch_operations = self._get_param('patchOperations') patch_operations = self._get_param('patchOperations')
stage_response = self.backend.update_stage(function_id, stage_name, patch_operations) stage_response = self.backend.update_stage(
function_id, stage_name, patch_operations)
return 200, {}, json.dumps(stage_response) return 200, {}, json.dumps(stage_response)
def integrations(self, request, full_url, headers): def integrations(self, request, full_url, headers):
@ -151,14 +161,17 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6] method_type = url_path_parts[6]
if self.method == 'GET': if self.method == 'GET':
integration_response = self.backend.get_integration(function_id, resource_id, method_type) integration_response = self.backend.get_integration(
function_id, resource_id, method_type)
elif self.method == 'PUT': elif self.method == 'PUT':
integration_type = self._get_param('type') integration_type = self._get_param('type')
uri = self._get_param('uri') uri = self._get_param('uri')
request_templates = self._get_param('requestTemplates') request_templates = self._get_param('requestTemplates')
integration_response = self.backend.create_integration(function_id, resource_id, method_type, integration_type, uri, request_templates=request_templates) integration_response = self.backend.create_integration(
function_id, resource_id, method_type, integration_type, uri, request_templates=request_templates)
elif self.method == 'DELETE': elif self.method == 'DELETE':
integration_response = self.backend.delete_integration(function_id, resource_id, method_type) integration_response = self.backend.delete_integration(
function_id, resource_id, method_type)
return 200, {}, json.dumps(integration_response) return 200, {}, json.dumps(integration_response)
def integration_responses(self, request, full_url, headers): def integration_responses(self, request, full_url, headers):
@ -193,9 +206,11 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps({"item": deployments}) return 200, {}, json.dumps({"item": deployments})
elif self.method == 'POST': elif self.method == 'POST':
name = self._get_param("stageName") name = self._get_param("stageName")
description = self._get_param_with_default_value("description","") description = self._get_param_with_default_value("description", "")
stage_variables = self._get_param_with_default_value('variables',{}) stage_variables = self._get_param_with_default_value(
deployment = self.backend.create_deployment(function_id, name, description,stage_variables) 'variables', {})
deployment = self.backend.create_deployment(
function_id, name, description, stage_variables)
return 200, {}, json.dumps(deployment) return 200, {}, json.dumps(deployment)
def individual_deployment(self, request, full_url, headers): def individual_deployment(self, request, full_url, headers):
@ -205,7 +220,9 @@ class APIGatewayResponse(BaseResponse):
deployment_id = url_path_parts[4] deployment_id = url_path_parts[4]
if self.method == 'GET': if self.method == 'GET':
deployment = self.backend.get_deployment(function_id, deployment_id) deployment = self.backend.get_deployment(
function_id, deployment_id)
elif self.method == 'DELETE': elif self.method == 'DELETE':
deployment = self.backend.delete_deployment(function_id, deployment_id) deployment = self.backend.delete_deployment(
function_id, deployment_id)
return 200, {}, json.dumps(deployment) return 200, {}, json.dumps(deployment)

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import autoscaling_backends from .models import autoscaling_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
autoscaling_backend = autoscaling_backends['us-east-1'] autoscaling_backend = autoscaling_backends['us-east-1']
mock_autoscaling = base_decorator(autoscaling_backends) mock_autoscaling = base_decorator(autoscaling_backends)

View File

@ -10,12 +10,14 @@ DEFAULT_COOLDOWN = 300
class InstanceState(object): class InstanceState(object):
def __init__(self, instance, lifecycle_state="InService"): def __init__(self, instance, lifecycle_state="InService"):
self.instance = instance self.instance = instance
self.lifecycle_state = lifecycle_state self.lifecycle_state = lifecycle_state
class FakeScalingPolicy(object): class FakeScalingPolicy(object):
def __init__(self, name, policy_type, adjustment_type, as_name, scaling_adjustment, def __init__(self, name, policy_type, adjustment_type, as_name, scaling_adjustment,
cooldown, autoscaling_backend): cooldown, autoscaling_backend):
self.name = name self.name = name
@ -31,14 +33,18 @@ class FakeScalingPolicy(object):
def execute(self): def execute(self):
if self.adjustment_type == 'ExactCapacity': if self.adjustment_type == 'ExactCapacity':
self.autoscaling_backend.set_desired_capacity(self.as_name, self.scaling_adjustment) self.autoscaling_backend.set_desired_capacity(
self.as_name, self.scaling_adjustment)
elif self.adjustment_type == 'ChangeInCapacity': elif self.adjustment_type == 'ChangeInCapacity':
self.autoscaling_backend.change_capacity(self.as_name, self.scaling_adjustment) self.autoscaling_backend.change_capacity(
self.as_name, self.scaling_adjustment)
elif self.adjustment_type == 'PercentChangeInCapacity': elif self.adjustment_type == 'PercentChangeInCapacity':
self.autoscaling_backend.change_capacity_percent(self.as_name, self.scaling_adjustment) self.autoscaling_backend.change_capacity_percent(
self.as_name, self.scaling_adjustment)
class FakeLaunchConfiguration(object): class FakeLaunchConfiguration(object):
def __init__(self, name, image_id, key_name, ramdisk_id, kernel_id, security_groups, user_data, def __init__(self, name, image_id, key_name, ramdisk_id, kernel_id, security_groups, user_data,
instance_type, instance_monitoring, instance_profile_name, instance_type, instance_monitoring, instance_profile_name,
spot_price, ebs_optimized, associate_public_ip_address, block_device_mapping_dict): spot_price, ebs_optimized, associate_public_ip_address, block_device_mapping_dict):
@ -77,14 +83,16 @@ class FakeLaunchConfiguration(object):
instance_profile_name=instance_profile_name, instance_profile_name=instance_profile_name,
spot_price=properties.get("SpotPrice"), spot_price=properties.get("SpotPrice"),
ebs_optimized=properties.get("EbsOptimized"), ebs_optimized=properties.get("EbsOptimized"),
associate_public_ip_address=properties.get("AssociatePublicIpAddress"), associate_public_ip_address=properties.get(
"AssociatePublicIpAddress"),
block_device_mappings=properties.get("BlockDeviceMapping.member") block_device_mappings=properties.get("BlockDeviceMapping.member")
) )
return config return config
@classmethod @classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
cls.delete_from_cloudformation_json(original_resource.name, cloudformation_json, region_name) cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name)
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name)
@classmethod @classmethod
@ -126,7 +134,8 @@ class FakeLaunchConfiguration(object):
else: else:
block_type.volume_type = mapping.get('ebs._volume_type') block_type.volume_type = mapping.get('ebs._volume_type')
block_type.snapshot_id = mapping.get('ebs._snapshot_id') block_type.snapshot_id = mapping.get('ebs._snapshot_id')
block_type.delete_on_termination = mapping.get('ebs._delete_on_termination') block_type.delete_on_termination = mapping.get(
'ebs._delete_on_termination')
block_type.size = mapping.get('ebs._volume_size') block_type.size = mapping.get('ebs._volume_size')
block_type.iops = mapping.get('ebs._iops') block_type.iops = mapping.get('ebs._iops')
block_device_map[mount_point] = block_type block_device_map[mount_point] = block_type
@ -134,6 +143,7 @@ class FakeLaunchConfiguration(object):
class FakeAutoScalingGroup(object): class FakeAutoScalingGroup(object):
def __init__(self, name, availability_zones, desired_capacity, max_size, def __init__(self, name, availability_zones, desired_capacity, max_size,
min_size, launch_config_name, vpc_zone_identifier, min_size, launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period, health_check_type, default_cooldown, health_check_period, health_check_type,
@ -145,7 +155,8 @@ class FakeAutoScalingGroup(object):
self.max_size = max_size self.max_size = max_size
self.min_size = min_size self.min_size = min_size
self.launch_config = self.autoscaling_backend.launch_configurations[launch_config_name] self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name]
self.launch_config_name = launch_config_name self.launch_config_name = launch_config_name
self.vpc_zone_identifier = vpc_zone_identifier self.vpc_zone_identifier = vpc_zone_identifier
@ -175,7 +186,8 @@ class FakeAutoScalingGroup(object):
max_size=properties.get("MaxSize"), max_size=properties.get("MaxSize"),
min_size=properties.get("MinSize"), min_size=properties.get("MinSize"),
launch_config_name=launch_config_name, launch_config_name=launch_config_name,
vpc_zone_identifier=(','.join(properties.get("VPCZoneIdentifier", [])) or None), vpc_zone_identifier=(
','.join(properties.get("VPCZoneIdentifier", [])) or None),
default_cooldown=properties.get("Cooldown"), default_cooldown=properties.get("Cooldown"),
health_check_period=properties.get("HealthCheckGracePeriod"), health_check_period=properties.get("HealthCheckGracePeriod"),
health_check_type=properties.get("HealthCheckType"), health_check_type=properties.get("HealthCheckType"),
@ -188,7 +200,8 @@ class FakeAutoScalingGroup(object):
@classmethod @classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
cls.delete_from_cloudformation_json(original_resource.name, cloudformation_json, region_name) cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name)
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name)
@classmethod @classmethod
@ -219,7 +232,8 @@ class FakeAutoScalingGroup(object):
self.min_size = min_size self.min_size = min_size
if launch_config_name: if launch_config_name:
self.launch_config = self.autoscaling_backend.launch_configurations[launch_config_name] self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name]
self.launch_config_name = launch_config_name self.launch_config_name = launch_config_name
if vpc_zone_identifier is not None: if vpc_zone_identifier is not None:
self.vpc_zone_identifier = vpc_zone_identifier self.vpc_zone_identifier = vpc_zone_identifier
@ -244,7 +258,8 @@ class FakeAutoScalingGroup(object):
if self.desired_capacity > curr_instance_count: if self.desired_capacity > curr_instance_count:
# Need more instances # Need more instances
count_needed = int(self.desired_capacity) - int(curr_instance_count) count_needed = int(self.desired_capacity) - \
int(curr_instance_count)
reservation = self.autoscaling_backend.ec2_backend.add_instances( reservation = self.autoscaling_backend.ec2_backend.add_instances(
self.launch_config.image_id, self.launch_config.image_id,
count_needed, count_needed,
@ -259,8 +274,10 @@ class FakeAutoScalingGroup(object):
# Need to remove some instances # Need to remove some instances
count_to_remove = curr_instance_count - self.desired_capacity count_to_remove = curr_instance_count - self.desired_capacity
instances_to_remove = self.instance_states[:count_to_remove] instances_to_remove = self.instance_states[:count_to_remove]
instance_ids_to_remove = [instance.instance.id for instance in instances_to_remove] instance_ids_to_remove = [
self.autoscaling_backend.ec2_backend.terminate_instances(instance_ids_to_remove) instance.instance.id for instance in instances_to_remove]
self.autoscaling_backend.ec2_backend.terminate_instances(
instance_ids_to_remove)
self.instance_states = self.instance_states[count_to_remove:] self.instance_states = self.instance_states[count_to_remove:]
@ -419,8 +436,8 @@ class AutoScalingBackend(BaseBackend):
def describe_policies(self, autoscaling_group_name=None, policy_names=None, policy_types=None): def describe_policies(self, autoscaling_group_name=None, policy_names=None, policy_types=None):
return [policy for policy in self.policies.values() return [policy for policy in self.policies.values()
if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) and if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) and
(not policy_names or policy.name in policy_names) and (not policy_names or policy.name in policy_names) and
(not policy_types or policy.policy_type in policy_types)] (not policy_types or policy.policy_type in policy_types)]
def delete_policy(self, group_name): def delete_policy(self, group_name):
self.policies.pop(group_name, None) self.policies.pop(group_name, None)
@ -431,18 +448,22 @@ class AutoScalingBackend(BaseBackend):
def update_attached_elbs(self, group_name): def update_attached_elbs(self, group_name):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group_instance_ids = set(state.instance.id for state in group.instance_states) group_instance_ids = set(
state.instance.id for state in group.instance_states)
try: try:
elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers) elbs = self.elb_backend.describe_load_balancers(
names=group.load_balancers)
except LoadBalancerNotFoundError: except LoadBalancerNotFoundError:
# ELBs can be deleted before their autoscaling group # ELBs can be deleted before their autoscaling group
return return
for elb in elbs: for elb in elbs:
elb_instace_ids = set(elb.instance_ids) elb_instace_ids = set(elb.instance_ids)
self.elb_backend.register_instances(elb.name, group_instance_ids - elb_instace_ids) self.elb_backend.register_instances(
self.elb_backend.deregister_instances(elb.name, elb_instace_ids - group_instance_ids) elb.name, group_instance_ids - elb_instace_ids)
self.elb_backend.deregister_instances(
elb.name, elb_instace_ids - group_instance_ids)
def create_or_update_tags(self, tags): def create_or_update_tags(self, tags):
@ -452,19 +473,21 @@ class AutoScalingBackend(BaseBackend):
old_tags = group.tags old_tags = group.tags
new_tags = [] new_tags = []
#if key was in old_tags, update old tag # if key was in old_tags, update old tag
for old_tag in old_tags: for old_tag in old_tags:
if old_tag["key"] == tag["key"]: if old_tag["key"] == tag["key"]:
new_tags.append(tag) new_tags.append(tag)
else: else:
new_tags.append(old_tag) new_tags.append(old_tag)
#if key was never in old_tag's add it (create tag) # if key was never in old_tag's add it (create tag)
if not any(new_tag['key'] == tag['key'] for new_tag in new_tags): if not any(new_tag['key'] == tag['key'] for new_tag in new_tags):
new_tags.append(tag) new_tags.append(tag)
group.tags = new_tags group.tags = new_tags
autoscaling_backends = {} autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items(): for region, ec2_backend in ec2_backends.items():
autoscaling_backends[region] = AutoScalingBackend(ec2_backend, elb_backends[region]) autoscaling_backends[region] = AutoScalingBackend(
ec2_backend, elb_backends[region])

View File

@ -11,7 +11,8 @@ class AutoScalingResponse(BaseResponse):
return autoscaling_backends[self.region] return autoscaling_backends[self.region]
def create_launch_configuration(self): def create_launch_configuration(self):
instance_monitoring_string = self._get_param('InstanceMonitoring.Enabled') instance_monitoring_string = self._get_param(
'InstanceMonitoring.Enabled')
if instance_monitoring_string == 'true': if instance_monitoring_string == 'true':
instance_monitoring = True instance_monitoring = True
else: else:
@ -29,28 +30,35 @@ class AutoScalingResponse(BaseResponse):
instance_profile_name=self._get_param('IamInstanceProfile'), instance_profile_name=self._get_param('IamInstanceProfile'),
spot_price=self._get_param('SpotPrice'), spot_price=self._get_param('SpotPrice'),
ebs_optimized=self._get_param('EbsOptimized'), ebs_optimized=self._get_param('EbsOptimized'),
associate_public_ip_address=self._get_param("AssociatePublicIpAddress"), associate_public_ip_address=self._get_param(
block_device_mappings=self._get_list_prefix('BlockDeviceMappings.member') "AssociatePublicIpAddress"),
block_device_mappings=self._get_list_prefix(
'BlockDeviceMappings.member')
) )
template = self.response_template(CREATE_LAUNCH_CONFIGURATION_TEMPLATE) template = self.response_template(CREATE_LAUNCH_CONFIGURATION_TEMPLATE)
return template.render() return template.render()
def describe_launch_configurations(self): def describe_launch_configurations(self):
names = self._get_multi_param('LaunchConfigurationNames.member') names = self._get_multi_param('LaunchConfigurationNames.member')
launch_configurations = self.autoscaling_backend.describe_launch_configurations(names) launch_configurations = self.autoscaling_backend.describe_launch_configurations(
template = self.response_template(DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE) names)
template = self.response_template(
DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE)
return template.render(launch_configurations=launch_configurations) return template.render(launch_configurations=launch_configurations)
def delete_launch_configuration(self): def delete_launch_configuration(self):
launch_configurations_name = self.querystring.get('LaunchConfigurationName')[0] launch_configurations_name = self.querystring.get(
self.autoscaling_backend.delete_launch_configuration(launch_configurations_name) 'LaunchConfigurationName')[0]
self.autoscaling_backend.delete_launch_configuration(
launch_configurations_name)
template = self.response_template(DELETE_LAUNCH_CONFIGURATION_TEMPLATE) template = self.response_template(DELETE_LAUNCH_CONFIGURATION_TEMPLATE)
return template.render() return template.render()
def create_auto_scaling_group(self): def create_auto_scaling_group(self):
self.autoscaling_backend.create_autoscaling_group( self.autoscaling_backend.create_autoscaling_group(
name=self._get_param('AutoScalingGroupName'), name=self._get_param('AutoScalingGroupName'),
availability_zones=self._get_multi_param('AvailabilityZones.member'), availability_zones=self._get_multi_param(
'AvailabilityZones.member'),
desired_capacity=self._get_int_param('DesiredCapacity'), desired_capacity=self._get_int_param('DesiredCapacity'),
max_size=self._get_int_param('MaxSize'), max_size=self._get_int_param('MaxSize'),
min_size=self._get_int_param('MinSize'), min_size=self._get_int_param('MinSize'),
@ -61,7 +69,8 @@ class AutoScalingResponse(BaseResponse):
health_check_type=self._get_param('HealthCheckType'), health_check_type=self._get_param('HealthCheckType'),
load_balancers=self._get_multi_param('LoadBalancerNames.member'), load_balancers=self._get_multi_param('LoadBalancerNames.member'),
placement_group=self._get_param('PlacementGroup'), placement_group=self._get_param('PlacementGroup'),
termination_policies=self._get_multi_param('TerminationPolicies.member'), termination_policies=self._get_multi_param(
'TerminationPolicies.member'),
tags=self._get_list_prefix('Tags.member'), tags=self._get_list_prefix('Tags.member'),
) )
template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE)
@ -76,7 +85,8 @@ class AutoScalingResponse(BaseResponse):
def update_auto_scaling_group(self): def update_auto_scaling_group(self):
self.autoscaling_backend.update_autoscaling_group( self.autoscaling_backend.update_autoscaling_group(
name=self._get_param('AutoScalingGroupName'), name=self._get_param('AutoScalingGroupName'),
availability_zones=self._get_multi_param('AvailabilityZones.member'), availability_zones=self._get_multi_param(
'AvailabilityZones.member'),
desired_capacity=self._get_int_param('DesiredCapacity'), desired_capacity=self._get_int_param('DesiredCapacity'),
max_size=self._get_int_param('MaxSize'), max_size=self._get_int_param('MaxSize'),
min_size=self._get_int_param('MinSize'), min_size=self._get_int_param('MinSize'),
@ -87,7 +97,8 @@ class AutoScalingResponse(BaseResponse):
health_check_type=self._get_param('HealthCheckType'), health_check_type=self._get_param('HealthCheckType'),
load_balancers=self._get_multi_param('LoadBalancerNames.member'), load_balancers=self._get_multi_param('LoadBalancerNames.member'),
placement_group=self._get_param('PlacementGroup'), placement_group=self._get_param('PlacementGroup'),
termination_policies=self._get_multi_param('TerminationPolicies.member'), termination_policies=self._get_multi_param(
'TerminationPolicies.member'),
) )
template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE)
return template.render() return template.render()
@ -101,7 +112,8 @@ class AutoScalingResponse(BaseResponse):
def set_desired_capacity(self): def set_desired_capacity(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param('AutoScalingGroupName')
desired_capacity = self._get_int_param('DesiredCapacity') desired_capacity = self._get_int_param('DesiredCapacity')
self.autoscaling_backend.set_desired_capacity(group_name, desired_capacity) self.autoscaling_backend.set_desired_capacity(
group_name, desired_capacity)
template = self.response_template(SET_DESIRED_CAPACITY_TEMPLATE) template = self.response_template(SET_DESIRED_CAPACITY_TEMPLATE)
return template.render() return template.render()
@ -114,7 +126,8 @@ class AutoScalingResponse(BaseResponse):
def describe_auto_scaling_instances(self): def describe_auto_scaling_instances(self):
instance_states = self.autoscaling_backend.describe_autoscaling_instances() instance_states = self.autoscaling_backend.describe_autoscaling_instances()
template = self.response_template(DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE) template = self.response_template(
DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
return template.render(instance_states=instance_states) return template.render(instance_states=instance_states)
def put_scaling_policy(self): def put_scaling_policy(self):

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import lambda_backends from .models import lambda_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
lambda_backend = lambda_backends['us-east-1'] lambda_backend = lambda_backends['us-east-1']
mock_lambda = base_decorator(lambda_backends) mock_lambda = base_decorator(lambda_backends)

View File

@ -32,19 +32,22 @@ class LambdaFunction(object):
# optional # optional
self.description = spec.get('Description', '') self.description = spec.get('Description', '')
self.memory_size = spec.get('MemorySize', 128) self.memory_size = spec.get('MemorySize', 128)
self.publish = spec.get('Publish', False) # this is ignored currently self.publish = spec.get('Publish', False) # this is ignored currently
self.timeout = spec.get('Timeout', 3) self.timeout = spec.get('Timeout', 3)
# this isn't finished yet. it needs to find out the VpcId value # this isn't finished yet. it needs to find out the VpcId value
self._vpc_config = spec.get('VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []}) self._vpc_config = spec.get(
'VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []})
# auto-generated # auto-generated
self.version = '$LATEST' self.version = '$LATEST'
self.last_modified = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') self.last_modified = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
if 'ZipFile' in self.code: if 'ZipFile' in self.code:
# more hackery to handle unicode/bytes/str in python3 and python2 - argh! # more hackery to handle unicode/bytes/str in python3 and python2 -
# argh!
try: try:
to_unzip_code = base64.b64decode(bytes(self.code['ZipFile'], 'utf-8')) to_unzip_code = base64.b64decode(
bytes(self.code['ZipFile'], 'utf-8'))
except Exception: except Exception:
to_unzip_code = base64.b64decode(self.code['ZipFile']) to_unzip_code = base64.b64decode(self.code['ZipFile'])
@ -58,7 +61,8 @@ class LambdaFunction(object):
# validate s3 bucket # validate s3 bucket
try: try:
# FIXME: does not validate bucket region # FIXME: does not validate bucket region
key = s3_backend.get_key(self.code['S3Bucket'], self.code['S3Key']) key = s3_backend.get_key(
self.code['S3Bucket'], self.code['S3Key'])
except MissingBucket: except MissingBucket:
raise ValueError( raise ValueError(
"InvalidParameterValueException", "InvalidParameterValueException",
@ -72,7 +76,8 @@ class LambdaFunction(object):
else: else:
self.code_size = key.size self.code_size = key.size
self.code_sha_256 = hashlib.sha256(key.value).hexdigest() self.code_sha_256 = hashlib.sha256(key.value).hexdigest()
self.function_arn = 'arn:aws:lambda:123456789012:function:{0}'.format(self.function_name) self.function_arn = 'arn:aws:lambda:123456789012:function:{0}'.format(
self.function_name)
@property @property
def vpc_config(self): def vpc_config(self):
@ -130,7 +135,6 @@ class LambdaFunction(object):
self.convert(self.code), self.convert(self.code),
self.convert('print(json.dumps(lambda_handler(%s, %s)))' % (self.is_json(self.convert(event)), context))]) self.convert('print(json.dumps(lambda_handler(%s, %s)))' % (self.is_json(self.convert(event)), context))])
#print("moto_lambda_debug: ", mycode)
except Exception as ex: except Exception as ex:
print("Exception %s", ex) print("Exception %s", ex)
@ -182,7 +186,8 @@ class LambdaFunction(object):
'Runtime': properties['Runtime'], 'Runtime': properties['Runtime'],
} }
optional_properties = 'Description MemorySize Publish Timeout VpcConfig'.split() optional_properties = 'Description MemorySize Publish Timeout VpcConfig'.split()
# NOTE: Not doing `properties.get(k, DEFAULT)` to avoid duplicating the default logic # NOTE: Not doing `properties.get(k, DEFAULT)` to avoid duplicating the
# default logic
for prop in optional_properties: for prop in optional_properties:
if prop in properties: if prop in properties:
spec[prop] = properties[prop] spec[prop] = properties[prop]
@ -219,6 +224,6 @@ lambda_backends = {}
for region in boto.awslambda.regions(): for region in boto.awslambda.regions():
lambda_backends[region.name] = LambdaBackend() lambda_backends[region.name] = LambdaBackend()
# Handle us forgotten regions, unless Lambda truly only runs out of US and EU????? # Handle us forgotten regions, unless Lambda truly only runs out of US and
for region in ['ap-southeast-2']: for region in ['ap-southeast-2']:
lambda_backends[region] = LambdaBackend() lambda_backends[region] = LambdaBackend()

View File

@ -2,10 +2,8 @@ from __future__ import unicode_literals
import json import json
import re import re
import uuid
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import lambda_backends
class LambdaResponse(BaseResponse): class LambdaResponse(BaseResponse):

View File

@ -1,7 +1,8 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import cloudformation_backends from .models import cloudformation_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
cloudformation_backend = cloudformation_backends['us-east-1'] cloudformation_backend = cloudformation_backends['us-east-1']
mock_cloudformation = base_decorator(cloudformation_backends) mock_cloudformation = base_decorator(cloudformation_backends)
mock_cloudformation_deprecated = deprecated_base_decorator(cloudformation_backends) mock_cloudformation_deprecated = deprecated_base_decorator(
cloudformation_backends)

View File

@ -9,9 +9,10 @@ class UnformattedGetAttTemplateException(Exception):
class ValidationError(BadRequest): class ValidationError(BadRequest):
def __init__(self, name_or_id, message=None): def __init__(self, name_or_id, message=None):
if message is None: if message is None:
message="Stack with id {0} does not exist".format(name_or_id) message = "Stack with id {0} does not exist".format(name_or_id)
template = Template(ERROR_RESPONSE) template = Template(ERROR_RESPONSE)
super(ValidationError, self).__init__() super(ValidationError, self).__init__()
@ -22,6 +23,7 @@ class ValidationError(BadRequest):
class MissingParameterError(BadRequest): class MissingParameterError(BadRequest):
def __init__(self, parameter_name): def __init__(self, parameter_name):
template = Template(ERROR_RESPONSE) template = Template(ERROR_RESPONSE)
super(MissingParameterError, self).__init__() super(MissingParameterError, self).__init__()

View File

@ -11,6 +11,7 @@ from .exceptions import ValidationError
class FakeStack(object): class FakeStack(object):
def __init__(self, stack_id, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None): def __init__(self, stack_id, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None):
self.stack_id = stack_id self.stack_id = stack_id
self.name = name self.name = name
@ -22,7 +23,8 @@ class FakeStack(object):
self.role_arn = role_arn self.role_arn = role_arn
self.tags = tags if tags else {} self.tags = tags if tags else {}
self.events = [] self.events = []
self._add_stack_event("CREATE_IN_PROGRESS", resource_status_reason="User Initiated") self._add_stack_event("CREATE_IN_PROGRESS",
resource_status_reason="User Initiated")
self.description = self.template_dict.get('Description') self.description = self.template_dict.get('Description')
self.resource_map = self._create_resource_map() self.resource_map = self._create_resource_map()
@ -31,7 +33,8 @@ class FakeStack(object):
self.status = 'CREATE_COMPLETE' self.status = 'CREATE_COMPLETE'
def _create_resource_map(self): def _create_resource_map(self):
resource_map = ResourceMap(self.stack_id, self.name, self.parameters, self.tags, self.region_name, self.template_dict) resource_map = ResourceMap(
self.stack_id, self.name, self.parameters, self.tags, self.region_name, self.template_dict)
resource_map.create() resource_map.create()
return resource_map return resource_map
@ -79,7 +82,8 @@ class FakeStack(object):
return self.output_map.values() return self.output_map.values()
def update(self, template, role_arn=None): def update(self, template, role_arn=None):
self._add_stack_event("UPDATE_IN_PROGRESS", resource_status_reason="User Initiated") self._add_stack_event("UPDATE_IN_PROGRESS",
resource_status_reason="User Initiated")
self.template = template self.template = template
self.resource_map.update(json.loads(template)) self.resource_map.update(json.loads(template))
self.output_map = self._create_output_map() self.output_map = self._create_output_map()
@ -88,13 +92,15 @@ class FakeStack(object):
self.role_arn = role_arn self.role_arn = role_arn
def delete(self): def delete(self):
self._add_stack_event("DELETE_IN_PROGRESS", resource_status_reason="User Initiated") self._add_stack_event("DELETE_IN_PROGRESS",
resource_status_reason="User Initiated")
self.resource_map.delete() self.resource_map.delete()
self._add_stack_event("DELETE_COMPLETE") self._add_stack_event("DELETE_COMPLETE")
self.status = "DELETE_COMPLETE" self.status = "DELETE_COMPLETE"
class FakeEvent(object): class FakeEvent(object):
def __init__(self, stack_id, stack_name, logical_resource_id, physical_resource_id, resource_type, resource_status, resource_status_reason=None, resource_properties=None): def __init__(self, stack_id, stack_name, logical_resource_id, physical_resource_id, resource_type, resource_status, resource_status_reason=None, resource_properties=None):
self.stack_id = stack_id self.stack_id = stack_id
self.stack_name = stack_name self.stack_name = stack_name

View File

@ -94,6 +94,7 @@ logger = logging.getLogger("moto")
class LazyDict(dict): class LazyDict(dict):
def __getitem__(self, key): def __getitem__(self, key):
val = dict.__getitem__(self, key) val = dict.__getitem__(self, key)
if callable(val): if callable(val):
@ -133,7 +134,8 @@ def clean_json(resource_json, resources_map):
try: try:
return resource.get_cfn_attribute(resource_json['Fn::GetAtt'][1]) return resource.get_cfn_attribute(resource_json['Fn::GetAtt'][1])
except NotImplementedError as n: except NotImplementedError as n:
logger.warning(n.message.format(resource_json['Fn::GetAtt'][0])) logger.warning(n.message.format(
resource_json['Fn::GetAtt'][0]))
except UnformattedGetAttTemplateException: except UnformattedGetAttTemplateException:
raise BotoServerError( raise BotoServerError(
UnformattedGetAttTemplateException.status_code, UnformattedGetAttTemplateException.status_code,
@ -152,7 +154,8 @@ def clean_json(resource_json, resources_map):
join_list = [] join_list = []
for val in resource_json['Fn::Join'][1]: for val in resource_json['Fn::Join'][1]:
cleaned_val = clean_json(val, resources_map) cleaned_val = clean_json(val, resources_map)
join_list.append('{0}'.format(cleaned_val) if cleaned_val else '{0}'.format(val)) join_list.append('{0}'.format(cleaned_val)
if cleaned_val else '{0}'.format(val))
return resource_json['Fn::Join'][0].join(join_list) return resource_json['Fn::Join'][0].join(join_list)
cleaned_json = {} cleaned_json = {}
@ -215,14 +218,16 @@ def parse_and_create_resource(logical_id, resource_json, resources_map, region_n
if not resource_tuple: if not resource_tuple:
return None return None
resource_class, resource_json, resource_name = resource_tuple resource_class, resource_json, resource_name = resource_tuple
resource = resource_class.create_from_cloudformation_json(resource_name, resource_json, region_name) resource = resource_class.create_from_cloudformation_json(
resource_name, resource_json, region_name)
resource.type = resource_type resource.type = resource_type
resource.logical_resource_id = logical_id resource.logical_resource_id = logical_id
return resource return resource
def parse_and_update_resource(logical_id, resource_json, resources_map, region_name): def parse_and_update_resource(logical_id, resource_json, resources_map, region_name):
resource_class, new_resource_json, new_resource_name = parse_resource(logical_id, resource_json, resources_map) resource_class, new_resource_json, new_resource_name = parse_resource(
logical_id, resource_json, resources_map)
original_resource = resources_map[logical_id] original_resource = resources_map[logical_id]
new_resource = resource_class.update_from_cloudformation_json( new_resource = resource_class.update_from_cloudformation_json(
original_resource=original_resource, original_resource=original_resource,
@ -236,8 +241,10 @@ def parse_and_update_resource(logical_id, resource_json, resources_map, region_n
def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name): def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name):
resource_class, resource_json, resource_name = parse_resource(logical_id, resource_json, resources_map) resource_class, resource_json, resource_name = parse_resource(
resource_class.delete_from_cloudformation_json(resource_name, resource_json, region_name) logical_id, resource_json, resources_map)
resource_class.delete_from_cloudformation_json(
resource_name, resource_json, region_name)
def parse_condition(condition, resources_map, condition_map): def parse_condition(condition, resources_map, condition_map):
@ -312,7 +319,8 @@ class ResourceMap(collections.Mapping):
resource_json = self._resource_json_map.get(resource_logical_id) resource_json = self._resource_json_map.get(resource_logical_id)
if not resource_json: if not resource_json:
raise KeyError(resource_logical_id) raise KeyError(resource_logical_id)
new_resource = parse_and_create_resource(resource_logical_id, resource_json, self, self._region_name) new_resource = parse_and_create_resource(
resource_logical_id, resource_json, self, self._region_name)
if new_resource is not None: if new_resource is not None:
self._parsed_resources[resource_logical_id] = new_resource self._parsed_resources[resource_logical_id] = new_resource
return new_resource return new_resource
@ -343,7 +351,8 @@ class ResourceMap(collections.Mapping):
value = value.split(',') value = value.split(',')
self.resolved_parameters[key] = value self.resolved_parameters[key] = value
# Check if there are any non-default params that were not passed input params # Check if there are any non-default params that were not passed input
# params
for key, value in self.resolved_parameters.items(): for key, value in self.resolved_parameters.items():
if value is None: if value is None:
raise MissingParameterError(key) raise MissingParameterError(key)
@ -355,10 +364,11 @@ class ResourceMap(collections.Mapping):
lazy_condition_map = LazyDict() lazy_condition_map = LazyDict()
for condition_name, condition in conditions.items(): for condition_name, condition in conditions.items():
lazy_condition_map[condition_name] = functools.partial(parse_condition, lazy_condition_map[condition_name] = functools.partial(parse_condition,
condition, self._parsed_resources, lazy_condition_map) condition, self._parsed_resources, lazy_condition_map)
for condition_name in lazy_condition_map: for condition_name in lazy_condition_map:
self._parsed_resources[condition_name] = lazy_condition_map[condition_name] self._parsed_resources[
condition_name] = lazy_condition_map[condition_name]
def create(self): def create(self):
self.load_mapping() self.load_mapping()
@ -368,11 +378,12 @@ class ResourceMap(collections.Mapping):
# Since this is a lazy map, to create every object we just need to # Since this is a lazy map, to create every object we just need to
# iterate through self. # iterate through self.
self.tags.update({'aws:cloudformation:stack-name': self.get('AWS::StackName'), self.tags.update({'aws:cloudformation:stack-name': self.get('AWS::StackName'),
'aws:cloudformation:stack-id': self.get('AWS::StackId')}) 'aws:cloudformation:stack-id': self.get('AWS::StackId')})
for resource in self.resources: for resource in self.resources:
if isinstance(self[resource], ec2_models.TaggedEC2Resource): if isinstance(self[resource], ec2_models.TaggedEC2Resource):
self.tags['aws:cloudformation:logical-id'] = resource self.tags['aws:cloudformation:logical-id'] = resource
ec2_models.ec2_backends[self._region_name].create_tags([self[resource].physical_resource_id], self.tags) ec2_models.ec2_backends[self._region_name].create_tags(
[self[resource].physical_resource_id], self.tags)
def update(self, template): def update(self, template):
self.load_mapping() self.load_mapping()
@ -386,24 +397,29 @@ class ResourceMap(collections.Mapping):
new_resource_names = set(new_template) - set(old_template) new_resource_names = set(new_template) - set(old_template)
for resource_name in new_resource_names: for resource_name in new_resource_names:
resource_json = new_template[resource_name] resource_json = new_template[resource_name]
new_resource = parse_and_create_resource(resource_name, resource_json, self, self._region_name) new_resource = parse_and_create_resource(
resource_name, resource_json, self, self._region_name)
self._parsed_resources[resource_name] = new_resource self._parsed_resources[resource_name] = new_resource
removed_resource_nams = set(old_template) - set(new_template) removed_resource_nams = set(old_template) - set(new_template)
for resource_name in removed_resource_nams: for resource_name in removed_resource_nams:
resource_json = old_template[resource_name] resource_json = old_template[resource_name]
parse_and_delete_resource(resource_name, resource_json, self, self._region_name) parse_and_delete_resource(
resource_name, resource_json, self, self._region_name)
self._parsed_resources.pop(resource_name) self._parsed_resources.pop(resource_name)
resources_to_update = set(name for name in new_template if name in old_template and new_template[name] != old_template[name]) resources_to_update = set(name for name in new_template if name in old_template and new_template[
name] != old_template[name])
tries = 1 tries = 1
while resources_to_update and tries < 5: while resources_to_update and tries < 5:
for resource_name in resources_to_update.copy(): for resource_name in resources_to_update.copy():
resource_json = new_template[resource_name] resource_json = new_template[resource_name]
try: try:
changed_resource = parse_and_update_resource(resource_name, resource_json, self, self._region_name) changed_resource = parse_and_update_resource(
resource_name, resource_json, self, self._region_name)
except Exception as e: except Exception as e:
# skip over dependency violations, and try again in a second pass # skip over dependency violations, and try again in a
# second pass
last_exception = e last_exception = e
else: else:
self._parsed_resources[resource_name] = changed_resource self._parsed_resources[resource_name] = changed_resource
@ -422,7 +438,8 @@ class ResourceMap(collections.Mapping):
if parsed_resource and hasattr(parsed_resource, 'delete'): if parsed_resource and hasattr(parsed_resource, 'delete'):
parsed_resource.delete(self._region_name) parsed_resource.delete(self._region_name)
except Exception as e: except Exception as e:
# skip over dependency violations, and try again in a second pass # skip over dependency violations, and try again in a
# second pass
last_exception = e last_exception = e
else: else:
remaining_resources.remove(resource) remaining_resources.remove(resource)
@ -430,7 +447,9 @@ class ResourceMap(collections.Mapping):
if tries == 5: if tries == 5:
raise last_exception raise last_exception
class OutputMap(collections.Mapping): class OutputMap(collections.Mapping):
def __init__(self, resources, template): def __init__(self, resources, template):
self._template = template self._template = template
self._output_json_map = template.get('Outputs') self._output_json_map = template.get('Outputs')
@ -446,7 +465,8 @@ class OutputMap(collections.Mapping):
return self._parsed_outputs[output_logical_id] return self._parsed_outputs[output_logical_id]
else: else:
output_json = self._output_json_map.get(output_logical_id) output_json = self._output_json_map.get(output_logical_id)
new_output = parse_output(output_logical_id, output_json, self._resource_map) new_output = parse_output(
output_logical_id, output_json, self._resource_map)
self._parsed_outputs[output_logical_id] = new_output self._parsed_outputs[output_logical_id] = new_output
return new_output return new_output

View File

@ -18,7 +18,8 @@ class CloudFormationResponse(BaseResponse):
def _get_stack_from_s3_url(self, template_url): def _get_stack_from_s3_url(self, template_url):
template_url_parts = urlparse(template_url) template_url_parts = urlparse(template_url)
if "localhost" in template_url: if "localhost" in template_url:
bucket_name, key_name = template_url_parts.path.lstrip("/").split("/") bucket_name, key_name = template_url_parts.path.lstrip(
"/").split("/")
else: else:
bucket_name = template_url_parts.netloc.split(".")[0] bucket_name = template_url_parts.netloc.split(".")[0]
key_name = template_url_parts.path.lstrip("/") key_name = template_url_parts.path.lstrip("/")
@ -32,7 +33,8 @@ class CloudFormationResponse(BaseResponse):
template_url = self._get_param('TemplateURL') template_url = self._get_param('TemplateURL')
role_arn = self._get_param('RoleARN') role_arn = self._get_param('RoleARN')
parameters_list = self._get_list_prefix("Parameters.member") parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value']) for item in self._get_list_prefix("Tags.member")) tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
# Hack dict-comprehension # Hack dict-comprehension
parameters = dict([ parameters = dict([
@ -42,7 +44,8 @@ class CloudFormationResponse(BaseResponse):
]) ])
if template_url: if template_url:
stack_body = self._get_stack_from_s3_url(template_url) stack_body = self._get_stack_from_s3_url(template_url)
stack_notification_arns = self._get_multi_param('NotificationARNs.member') stack_notification_arns = self._get_multi_param(
'NotificationARNs.member')
stack = self.cloudformation_backend.create_stack( stack = self.cloudformation_backend.create_stack(
name=stack_name, name=stack_name,
@ -86,7 +89,8 @@ class CloudFormationResponse(BaseResponse):
else: else:
raise ValidationError(logical_resource_id) raise ValidationError(logical_resource_id)
template = self.response_template(DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE) template = self.response_template(
DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE)
return template.render(stack=stack, resource=resource) return template.render(stack=stack, resource=resource)
def describe_stack_resources(self): def describe_stack_resources(self):
@ -110,7 +114,8 @@ class CloudFormationResponse(BaseResponse):
def list_stack_resources(self): def list_stack_resources(self):
stack_name_or_id = self._get_param('StackName') stack_name_or_id = self._get_param('StackName')
resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id) resources = self.cloudformation_backend.list_stack_resources(
stack_name_or_id)
template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE) template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE)
return template.render(resources=resources) return template.render(resources=resources)
@ -138,13 +143,15 @@ class CloudFormationResponse(BaseResponse):
stack_name = self._get_param('StackName') stack_name = self._get_param('StackName')
role_arn = self._get_param('RoleARN') role_arn = self._get_param('RoleARN')
if self._get_param('UsePreviousTemplate') == "true": if self._get_param('UsePreviousTemplate') == "true":
stack_body = self.cloudformation_backend.get_stack(stack_name).template stack_body = self.cloudformation_backend.get_stack(
stack_name).template
else: else:
stack_body = self._get_param('TemplateBody') stack_body = self._get_param('TemplateBody')
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
if stack.status == 'ROLLBACK_COMPLETE': if stack.status == 'ROLLBACK_COMPLETE':
raise ValidationError(stack.stack_id, message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(stack.stack_id)) raise ValidationError(
stack.stack_id, message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(stack.stack_id))
stack = self.cloudformation_backend.update_stack( stack = self.cloudformation_backend.update_stack(
name=stack_name, name=stack_name,

View File

@ -1,5 +1,5 @@
from .models import cloudwatch_backends from .models import cloudwatch_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
cloudwatch_backend = cloudwatch_backends['us-east-1'] cloudwatch_backend = cloudwatch_backends['us-east-1']
mock_cloudwatch = base_decorator(cloudwatch_backends) mock_cloudwatch = base_decorator(cloudwatch_backends)

View File

@ -4,12 +4,14 @@ import datetime
class Dimension(object): class Dimension(object):
def __init__(self, name, value): def __init__(self, name, value):
self.name = name self.name = name
self.value = value self.value = value
class FakeAlarm(object): class FakeAlarm(object):
def __init__(self, name, namespace, metric_name, comparison_operator, evaluation_periods, def __init__(self, name, namespace, metric_name, comparison_operator, evaluation_periods,
period, threshold, statistic, description, dimensions, alarm_actions, period, threshold, statistic, description, dimensions, alarm_actions,
ok_actions, insufficient_data_actions, unit): ok_actions, insufficient_data_actions, unit):
@ -22,7 +24,8 @@ class FakeAlarm(object):
self.threshold = threshold self.threshold = threshold
self.statistic = statistic self.statistic = statistic
self.description = description self.description = description
self.dimensions = [Dimension(dimension['name'], dimension['value']) for dimension in dimensions] self.dimensions = [Dimension(dimension['name'], dimension[
'value']) for dimension in dimensions]
self.alarm_actions = alarm_actions self.alarm_actions = alarm_actions
self.ok_actions = ok_actions self.ok_actions = ok_actions
self.insufficient_data_actions = insufficient_data_actions self.insufficient_data_actions = insufficient_data_actions
@ -32,11 +35,13 @@ class FakeAlarm(object):
class MetricDatum(object): class MetricDatum(object):
def __init__(self, namespace, name, value, dimensions): def __init__(self, namespace, name, value, dimensions):
self.namespace = namespace self.namespace = namespace
self.name = name self.name = name
self.value = value self.value = value
self.dimensions = [Dimension(dimension['name'], dimension['value']) for dimension in dimensions] self.dimensions = [Dimension(dimension['name'], dimension[
'value']) for dimension in dimensions]
class CloudWatchBackend(BaseBackend): class CloudWatchBackend(BaseBackend):
@ -99,7 +104,8 @@ class CloudWatchBackend(BaseBackend):
def put_metric_data(self, namespace, metric_data): def put_metric_data(self, namespace, metric_data):
for name, value, dimensions in metric_data: for name, value, dimensions in metric_data:
self.metric_data.append(MetricDatum(namespace, name, value, dimensions)) self.metric_data.append(MetricDatum(
namespace, name, value, dimensions))
def get_all_metrics(self): def get_all_metrics(self):
return self.metric_data return self.metric_data

View File

@ -1,6 +1,5 @@
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import cloudwatch_backends from .models import cloudwatch_backends
import logging
class CloudWatchResponse(BaseResponse): class CloudWatchResponse(BaseResponse):
@ -18,7 +17,8 @@ class CloudWatchResponse(BaseResponse):
dimensions = self._get_list_prefix('Dimensions.member') dimensions = self._get_list_prefix('Dimensions.member')
alarm_actions = self._get_multi_param('AlarmActions.member') alarm_actions = self._get_multi_param('AlarmActions.member')
ok_actions = self._get_multi_param('OKActions.member') ok_actions = self._get_multi_param('OKActions.member')
insufficient_data_actions = self._get_multi_param("InsufficientDataActions.member") insufficient_data_actions = self._get_multi_param(
"InsufficientDataActions.member")
unit = self._get_param('Unit') unit = self._get_param('Unit')
cloudwatch_backend = cloudwatch_backends[self.region] cloudwatch_backend = cloudwatch_backends[self.region]
alarm = cloudwatch_backend.put_metric_alarm(name, namespace, metric_name, alarm = cloudwatch_backend.put_metric_alarm(name, namespace, metric_name,
@ -40,14 +40,16 @@ class CloudWatchResponse(BaseResponse):
cloudwatch_backend = cloudwatch_backends[self.region] cloudwatch_backend = cloudwatch_backends[self.region]
if action_prefix: if action_prefix:
alarms = cloudwatch_backend.get_alarms_by_action_prefix(action_prefix) alarms = cloudwatch_backend.get_alarms_by_action_prefix(
action_prefix)
elif alarm_name_prefix: elif alarm_name_prefix:
alarms = cloudwatch_backend.get_alarms_by_alarm_name_prefix(alarm_name_prefix) alarms = cloudwatch_backend.get_alarms_by_alarm_name_prefix(
alarm_name_prefix)
elif alarm_names: elif alarm_names:
alarms = cloudwatch_backend.get_alarms_by_alarm_names(alarm_names) alarms = cloudwatch_backend.get_alarms_by_alarm_names(alarm_names)
elif state_value: elif state_value:
alarms = cloudwatch_backend.get_alarms_by_state_value(state_value) alarms = cloudwatch_backend.get_alarms_by_state_value(state_value)
else : else:
alarms = cloudwatch_backend.get_all_alarms() alarms = cloudwatch_backend.get_all_alarms()
template = self.response_template(DESCRIBE_ALARMS_TEMPLATE) template = self.response_template(DESCRIBE_ALARMS_TEMPLATE)
@ -66,19 +68,24 @@ class CloudWatchResponse(BaseResponse):
metric_index = 1 metric_index = 1
while True: while True:
try: try:
metric_name = self.querystring['MetricData.member.{0}.MetricName'.format(metric_index)][0] metric_name = self.querystring[
'MetricData.member.{0}.MetricName'.format(metric_index)][0]
except KeyError: except KeyError:
break break
value = self.querystring.get('MetricData.member.{0}.Value'.format(metric_index), [None])[0] value = self.querystring.get(
'MetricData.member.{0}.Value'.format(metric_index), [None])[0]
dimensions = [] dimensions = []
dimension_index = 1 dimension_index = 1
while True: while True:
try: try:
dimension_name = self.querystring['MetricData.member.{0}.Dimensions.member.{1}.Name'.format(metric_index, dimension_index)][0] dimension_name = self.querystring[
'MetricData.member.{0}.Dimensions.member.{1}.Name'.format(metric_index, dimension_index)][0]
except KeyError: except KeyError:
break break
dimension_value = self.querystring['MetricData.member.{0}.Dimensions.member.{1}.Value'.format(metric_index, dimension_index)][0] dimension_value = self.querystring[
dimensions.append({'name': dimension_name, 'value': dimension_value}) 'MetricData.member.{0}.Dimensions.member.{1}.Value'.format(metric_index, dimension_index)][0]
dimensions.append(
{'name': dimension_name, 'value': dimension_value})
dimension_index += 1 dimension_index += 1
metric_data.append([metric_name, value, dimensions]) metric_data.append([metric_name, value, dimensions])
metric_index += 1 metric_index += 1

View File

@ -2,7 +2,6 @@ from __future__ import unicode_literals
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment from jinja2 import DictLoader, Environment
from six import text_type
SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?> SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
@ -33,6 +32,7 @@ ERROR_JSON_RESPONSE = u"""{
} }
""" """
class RESTError(HTTPException): class RESTError(HTTPException):
templates = { templates = {
'single_error': SINGLE_ERROR_RESPONSE, 'single_error': SINGLE_ERROR_RESPONSE,
@ -54,8 +54,10 @@ class DryRunClientError(RESTError):
class JsonRESTError(RESTError): class JsonRESTError(RESTError):
def __init__(self, error_type, message, template='error_json', **kwargs): def __init__(self, error_type, message, template='error_json', **kwargs):
super(JsonRESTError, self).__init__(error_type, message, template, **kwargs) super(JsonRESTError, self).__init__(
error_type, message, template, **kwargs)
def get_headers(self, *args, **kwargs): def get_headers(self, *args, **kwargs):
return [('Content-Type', 'application/json')] return [('Content-Type', 'application/json')]

View File

@ -3,7 +3,6 @@ from __future__ import absolute_import
import functools import functools
import inspect import inspect
import os
import re import re
from moto import settings from moto import settings
@ -15,6 +14,7 @@ from .utils import (
convert_flask_to_responses_response, convert_flask_to_responses_response,
) )
class BaseMockAWS(object): class BaseMockAWS(object):
nested_count = 0 nested_count = 0
@ -58,7 +58,6 @@ class BaseMockAWS(object):
if self.__class__.nested_count < 0: if self.__class__.nested_count < 0:
raise RuntimeError('Called stop() before start().') raise RuntimeError('Called stop() before start().')
if self.__class__.nested_count == 0: if self.__class__.nested_count == 0:
self.disable_patching() self.disable_patching()
@ -96,6 +95,7 @@ class BaseMockAWS(object):
class HttprettyMockAWS(BaseMockAWS): class HttprettyMockAWS(BaseMockAWS):
def reset(self): def reset(self):
HTTPretty.reset() HTTPretty.reset()
@ -118,10 +118,11 @@ class HttprettyMockAWS(BaseMockAWS):
RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD, RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD,
responses.OPTIONS, responses.PATCH, responses.POST, responses.PUT] responses.OPTIONS, responses.PATCH, responses.POST, responses.PUT]
class ResponsesMockAWS(BaseMockAWS): class ResponsesMockAWS(BaseMockAWS):
def reset(self): def reset(self):
responses.reset() responses.reset()
@ -146,6 +147,7 @@ class ResponsesMockAWS(BaseMockAWS):
pass pass
responses.reset() responses.reset()
MockAWS = ResponsesMockAWS MockAWS = ResponsesMockAWS
@ -167,12 +169,14 @@ class ServerModeMockAWS(BaseMockAWS):
if 'endpoint_url' not in kwargs: if 'endpoint_url' not in kwargs:
kwargs['endpoint_url'] = "http://localhost:8086" kwargs['endpoint_url'] = "http://localhost:8086"
return real_boto3_client(*args, **kwargs) return real_boto3_client(*args, **kwargs)
def fake_boto3_resource(*args, **kwargs): def fake_boto3_resource(*args, **kwargs):
if 'endpoint_url' not in kwargs: if 'endpoint_url' not in kwargs:
kwargs['endpoint_url'] = "http://localhost:8086" kwargs['endpoint_url'] = "http://localhost:8086"
return real_boto3_resource(*args, **kwargs) return real_boto3_resource(*args, **kwargs)
self._client_patcher = mock.patch('boto3.client', fake_boto3_client) self._client_patcher = mock.patch('boto3.client', fake_boto3_client)
self._resource_patcher = mock.patch('boto3.resource', fake_boto3_resource) self._resource_patcher = mock.patch(
'boto3.resource', fake_boto3_resource)
self._client_patcher.start() self._client_patcher.start()
self._resource_patcher.start() self._resource_patcher.start()
@ -181,7 +185,9 @@ class ServerModeMockAWS(BaseMockAWS):
self._client_patcher.stop() self._client_patcher.stop()
self._resource_patcher.stop() self._resource_patcher.stop()
class Model(type): class Model(type):
def __new__(self, clsname, bases, namespace): def __new__(self, clsname, bases, namespace):
cls = super(Model, self).__new__(self, clsname, bases, namespace) cls = super(Model, self).__new__(self, clsname, bases, namespace)
cls.__models__ = {} cls.__models__ = {}
@ -203,6 +209,7 @@ class Model(type):
class BaseBackend(object): class BaseBackend(object):
def reset(self): def reset(self):
self.__dict__ = {} self.__dict__ = {}
self.__init__() self.__init__()
@ -211,7 +218,8 @@ class BaseBackend(object):
def _url_module(self): def _url_module(self):
backend_module = self.__class__.__module__ backend_module = self.__class__.__module__
backend_urls_module_name = backend_module.replace("models", "urls") backend_urls_module_name = backend_module.replace("models", "urls")
backend_urls_module = __import__(backend_urls_module_name, fromlist=['url_bases', 'url_paths']) backend_urls_module = __import__(backend_urls_module_name, fromlist=[
'url_bases', 'url_paths'])
return backend_urls_module return backend_urls_module
@property @property
@ -306,6 +314,7 @@ class deprecated_base_decorator(base_decorator):
class MotoAPIBackend(BaseBackend): class MotoAPIBackend(BaseBackend):
def reset(self): def reset(self):
from moto.backends import BACKENDS from moto.backends import BACKENDS
for name, backends in BACKENDS.items(): for name, backends in BACKENDS.items():
@ -315,4 +324,5 @@ class MotoAPIBackend(BaseBackend):
backend.reset() backend.reset()
self.__init__() self.__init__()
moto_api_backend = MotoAPIBackend() moto_api_backend = MotoAPIBackend()

View File

@ -59,6 +59,7 @@ class DynamicDictLoader(DictLoader):
Including the fixed (current) method version here to ensure performance benefit Including the fixed (current) method version here to ensure performance benefit
even for those using older jinja versions. even for those using older jinja versions.
""" """
def get_source(self, environment, template): def get_source(self, environment, template):
if template in self.mapping: if template in self.mapping:
source = self.mapping[template] source = self.mapping[template]
@ -77,7 +78,8 @@ class _TemplateEnvironmentMixin(object):
def __init__(self): def __init__(self):
super(_TemplateEnvironmentMixin, self).__init__() super(_TemplateEnvironmentMixin, self).__init__()
self.loader = DynamicDictLoader({}) self.loader = DynamicDictLoader({})
self.environment = Environment(loader=self.loader, autoescape=self.should_autoescape) self.environment = Environment(
loader=self.loader, autoescape=self.should_autoescape)
@property @property
def should_autoescape(self): def should_autoescape(self):
@ -127,12 +129,14 @@ class BaseResponse(_TemplateEnvironmentMixin):
self.body = self.body.decode('utf-8') self.body = self.body.decode('utf-8')
if not querystring: if not querystring:
querystring.update(parse_qs(urlparse(full_url).query, keep_blank_values=True)) querystring.update(
parse_qs(urlparse(full_url).query, keep_blank_values=True))
if not querystring: if not querystring:
if 'json' in request.headers.get('content-type', []) and self.aws_service_spec: if 'json' in request.headers.get('content-type', []) and self.aws_service_spec:
decoded = json.loads(self.body) decoded = json.loads(self.body)
target = request.headers.get('x-amz-target') or request.headers.get('X-Amz-Target') target = request.headers.get(
'x-amz-target') or request.headers.get('X-Amz-Target')
service, method = target.split('.') service, method = target.split('.')
input_spec = self.aws_service_spec.input_spec(method) input_spec = self.aws_service_spec.input_spec(method)
flat = flatten_json_request_body('', decoded, input_spec) flat = flatten_json_request_body('', decoded, input_spec)
@ -161,7 +165,8 @@ class BaseResponse(_TemplateEnvironmentMixin):
if match: if match:
region = match.group(1) region = match.group(1)
elif 'Authorization' in request.headers: elif 'Authorization' in request.headers:
region = request.headers['Authorization'].split(",")[0].split("/")[2] region = request.headers['Authorization'].split(",")[
0].split("/")[2]
else: else:
region = self.default_region region = self.default_region
return region return region
@ -175,7 +180,8 @@ class BaseResponse(_TemplateEnvironmentMixin):
action = self.querystring.get('Action', [""])[0] action = self.querystring.get('Action', [""])[0]
if not action: # Some services use a header for the action if not action: # Some services use a header for the action
# Headers are case-insensitive. Probably a better way to do this. # Headers are case-insensitive. Probably a better way to do this.
match = self.headers.get('x-amz-target') or self.headers.get('X-Amz-Target') match = self.headers.get(
'x-amz-target') or self.headers.get('X-Amz-Target')
if match: if match:
action = match.split(".")[-1] action = match.split(".")[-1]
@ -198,7 +204,8 @@ class BaseResponse(_TemplateEnvironmentMixin):
headers['status'] = str(headers['status']) headers['status'] = str(headers['status'])
return status, headers, body return status, headers, body
raise NotImplementedError("The {0} action has not been implemented".format(action)) raise NotImplementedError(
"The {0} action has not been implemented".format(action))
def _get_param(self, param_name, if_none=None): def _get_param(self, param_name, if_none=None):
val = self.querystring.get(param_name) val = self.querystring.get(param_name)
@ -258,7 +265,8 @@ class BaseResponse(_TemplateEnvironmentMixin):
params = {} params = {}
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if key.startswith(param_prefix): if key.startswith(param_prefix):
params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[0] params[camelcase_to_underscores(
key.replace(param_prefix, ""))] = value[0]
return params return params
def _get_list_prefix(self, param_prefix): def _get_list_prefix(self, param_prefix):
@ -291,7 +299,8 @@ class BaseResponse(_TemplateEnvironmentMixin):
new_items = {} new_items = {}
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if key.startswith(index_prefix): if key.startswith(index_prefix):
new_items[camelcase_to_underscores(key.replace(index_prefix, ""))] = value[0] new_items[camelcase_to_underscores(
key.replace(index_prefix, ""))] = value[0]
if not new_items: if not new_items:
break break
results.append(new_items) results.append(new_items)
@ -327,7 +336,8 @@ class BaseResponse(_TemplateEnvironmentMixin):
def is_not_dryrun(self, action): def is_not_dryrun(self, action):
if 'true' in self.querystring.get('DryRun', ['false']): if 'true' in self.querystring.get('DryRun', ['false']):
message = 'An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set' % action message = 'An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set' % action
raise DryRunClientError(error_type="DryRunOperation", message=message) raise DryRunClientError(
error_type="DryRunOperation", message=message)
return True return True
@ -343,6 +353,7 @@ class MotoAPIResponse(BaseResponse):
class _RecursiveDictRef(object): class _RecursiveDictRef(object):
"""Store a recursive reference to dict.""" """Store a recursive reference to dict."""
def __init__(self): def __init__(self):
self.key = None self.key = None
self.dic = {} self.dic = {}
@ -502,12 +513,15 @@ def flatten_json_request_body(prefix, dict_body, spec):
if node_type == 'list': if node_type == 'list':
for idx, v in enumerate(value, 1): for idx, v in enumerate(value, 1):
pref = key + '.member.' + str(idx) pref = key + '.member.' + str(idx)
flat.update(flatten_json_request_body(pref, v, spec[key]['member'])) flat.update(flatten_json_request_body(
pref, v, spec[key]['member']))
elif node_type == 'map': elif node_type == 'map':
for idx, (k, v) in enumerate(value.items(), 1): for idx, (k, v) in enumerate(value.items(), 1):
pref = key + '.entry.' + str(idx) pref = key + '.entry.' + str(idx)
flat.update(flatten_json_request_body(pref + '.key', k, spec[key]['key'])) flat.update(flatten_json_request_body(
flat.update(flatten_json_request_body(pref + '.value', v, spec[key]['value'])) pref + '.key', k, spec[key]['key']))
flat.update(flatten_json_request_body(
pref + '.value', v, spec[key]['value']))
else: else:
flat.update(flatten_json_request_body(key, value, spec[key])) flat.update(flatten_json_request_body(key, value, spec[key]))
@ -542,7 +556,8 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
# this can happen when with an older version of # this can happen when with an older version of
# botocore for which the node in XML template is not # botocore for which the node in XML template is not
# defined in service spec. # defined in service spec.
log.warning('Field %s is not defined by the botocore version in use', k) log.warning(
'Field %s is not defined by the botocore version in use', k)
continue continue
if spec[k]['type'] == 'list': if spec[k]['type'] == 'list':
@ -554,7 +569,8 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
else: else:
od[k] = [transform(v['member'], spec[k]['member'])] od[k] = [transform(v['member'], spec[k]['member'])]
elif isinstance(v['member'], list): elif isinstance(v['member'], list):
od[k] = [transform(o, spec[k]['member']) for o in v['member']] od[k] = [transform(o, spec[k]['member'])
for o in v['member']]
elif isinstance(v['member'], OrderedDict): elif isinstance(v['member'], OrderedDict):
od[k] = [transform(v['member'], spec[k]['member'])] od[k] = [transform(v['member'], spec[k]['member'])]
else: else:

View File

@ -98,7 +98,7 @@ class convert_httpretty_response(object):
result = self.callback(request, url, headers) result = self.callback(request, url, headers)
status, headers, response = result status, headers, response = result
if 'server' not in headers: if 'server' not in headers:
headers["server"] = "amazon.com" headers["server"] = "amazon.com"
return status, headers, response return status, headers, response

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import datapipeline_backends from .models import datapipeline_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
datapipeline_backend = datapipeline_backends['us-east-1'] datapipeline_backend = datapipeline_backends['us-east-1']
mock_datapipeline = base_decorator(datapipeline_backends) mock_datapipeline = base_decorator(datapipeline_backends)

View File

@ -7,6 +7,7 @@ from .utils import get_random_pipeline_id, remove_capitalization_of_dict_keys
class PipelineObject(object): class PipelineObject(object):
def __init__(self, object_id, name, fields): def __init__(self, object_id, name, fields):
self.object_id = object_id self.object_id = object_id
self.name = name self.name = name
@ -21,6 +22,7 @@ class PipelineObject(object):
class Pipeline(object): class Pipeline(object):
def __init__(self, name, unique_id): def __init__(self, name, unique_id):
self.name = name self.name = name
self.unique_id = unique_id self.unique_id = unique_id
@ -82,7 +84,8 @@ class Pipeline(object):
def set_pipeline_objects(self, pipeline_objects): def set_pipeline_objects(self, pipeline_objects):
self.objects = [ self.objects = [
PipelineObject(pipeline_object['id'], pipeline_object['name'], pipeline_object['fields']) PipelineObject(pipeline_object['id'], pipeline_object[
'name'], pipeline_object['fields'])
for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects) for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects)
] ]
@ -95,8 +98,10 @@ class Pipeline(object):
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
cloudformation_unique_id = "cf-" + properties["Name"] cloudformation_unique_id = "cf-" + properties["Name"]
pipeline = datapipeline_backend.create_pipeline(properties["Name"], cloudformation_unique_id) pipeline = datapipeline_backend.create_pipeline(
datapipeline_backend.put_pipeline_definition(pipeline.pipeline_id, properties["PipelineObjects"]) properties["Name"], cloudformation_unique_id)
datapipeline_backend.put_pipeline_definition(
pipeline.pipeline_id, properties["PipelineObjects"])
if properties["Activate"]: if properties["Activate"]:
pipeline.activate() pipeline.activate()
@ -117,7 +122,8 @@ class DataPipelineBackend(BaseBackend):
return self.pipelines.values() return self.pipelines.values()
def describe_pipelines(self, pipeline_ids): def describe_pipelines(self, pipeline_ids):
pipelines = [pipeline for pipeline in self.pipelines.values() if pipeline.pipeline_id in pipeline_ids] pipelines = [pipeline for pipeline in self.pipelines.values(
) if pipeline.pipeline_id in pipeline_ids]
return pipelines return pipelines
def get_pipeline(self, pipeline_id): def get_pipeline(self, pipeline_id):

View File

@ -52,12 +52,14 @@ class DataPipelineResponse(BaseResponse):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
pipeline_objects = self.parameters["pipelineObjects"] pipeline_objects = self.parameters["pipelineObjects"]
self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects) self.datapipeline_backend.put_pipeline_definition(
pipeline_id, pipeline_objects)
return json.dumps({"errored": False}) return json.dumps({"errored": False})
def get_pipeline_definition(self): def get_pipeline_definition(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
pipeline_definition = self.datapipeline_backend.get_pipeline_definition(pipeline_id) pipeline_definition = self.datapipeline_backend.get_pipeline_definition(
pipeline_id)
return json.dumps({ return json.dumps({
"pipelineObjects": [pipeline_object.to_json() for pipeline_object in pipeline_definition] "pipelineObjects": [pipeline_object.to_json() for pipeline_object in pipeline_definition]
}) })
@ -66,7 +68,8 @@ class DataPipelineResponse(BaseResponse):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
object_ids = self.parameters["objectIds"] object_ids = self.parameters["objectIds"]
pipeline_objects = self.datapipeline_backend.describe_objects(object_ids, pipeline_id) pipeline_objects = self.datapipeline_backend.describe_objects(
object_ids, pipeline_id)
return json.dumps({ return json.dumps({
"hasMoreResults": False, "hasMoreResults": False,
"marker": None, "marker": None,

View File

@ -10,6 +10,7 @@ from .comparisons import get_comparison_func
class DynamoJsonEncoder(json.JSONEncoder): class DynamoJsonEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if hasattr(obj, 'to_json'): if hasattr(obj, 'to_json'):
return obj.to_json() return obj.to_json()
@ -53,6 +54,7 @@ class DynamoType(object):
class Item(object): class Item(object):
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs): def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
self.hash_key = hash_key self.hash_key = hash_key
self.hash_key_type = hash_key_type self.hash_key_type = hash_key_type
@ -157,7 +159,8 @@ class Table(object):
else: else:
range_value = None range_value = None
item = Item(hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs) item = Item(hash_value, self.hash_key_type, range_value,
self.range_key_type, item_attrs)
if range_value: if range_value:
self.items[hash_value][range_value] = item self.items[hash_value][range_value] = item
@ -167,7 +170,8 @@ class Table(object):
def get_item(self, hash_key, range_key): def get_item(self, hash_key, range_key):
if self.has_range_key and not range_key: if self.has_range_key and not range_key:
raise ValueError("Table has a range key, but no range key was passed into get_item") raise ValueError(
"Table has a range key, but no range key was passed into get_item")
try: try:
if range_key: if range_key:
return self.items[hash_key][range_key] return self.items[hash_key][range_key]
@ -222,7 +226,8 @@ class Table(object):
# Comparison is NULL and we don't have the attribute # Comparison is NULL and we don't have the attribute
continue continue
else: else:
# No attribute found and comparison is no NULL. This item fails # No attribute found and comparison is no NULL. This item
# fails
passes_all_conditions = False passes_all_conditions = False
break break
@ -283,7 +288,8 @@ class DynamoDBBackend(BaseBackend):
return None, None return None, None
hash_key = DynamoType(hash_key_dict) hash_key = DynamoType(hash_key_dict)
range_values = [DynamoType(range_value) for range_value in range_value_dicts] range_values = [DynamoType(range_value)
for range_value in range_value_dicts]
return table.query(hash_key, range_comparison, range_values) return table.query(hash_key, range_comparison, range_values)

View File

@ -130,7 +130,8 @@ class DynamoHandler(BaseResponse):
throughput = self.body["ProvisionedThroughput"] throughput = self.body["ProvisionedThroughput"]
new_read_units = throughput["ReadCapacityUnits"] new_read_units = throughput["ReadCapacityUnits"]
new_write_units = throughput["WriteCapacityUnits"] new_write_units = throughput["WriteCapacityUnits"]
table = dynamodb_backend.update_table_throughput(name, new_read_units, new_write_units) table = dynamodb_backend.update_table_throughput(
name, new_read_units, new_write_units)
return dynamo_json_dump(table.describe) return dynamo_json_dump(table.describe)
def describe_table(self): def describe_table(self):
@ -169,7 +170,8 @@ class DynamoHandler(BaseResponse):
key = request['Key'] key = request['Key']
hash_key = key['HashKeyElement'] hash_key = key['HashKeyElement']
range_key = key.get('RangeKeyElement') range_key = key.get('RangeKeyElement')
item = dynamodb_backend.delete_item(table_name, hash_key, range_key) item = dynamodb_backend.delete_item(
table_name, hash_key, range_key)
response = { response = {
"Responses": { "Responses": {
@ -221,11 +223,13 @@ class DynamoHandler(BaseResponse):
for key in keys: for key in keys:
hash_key = key["HashKeyElement"] hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement") range_key = key.get("RangeKeyElement")
item = dynamodb_backend.get_item(table_name, hash_key, range_key) item = dynamodb_backend.get_item(
table_name, hash_key, range_key)
if item: if item:
item_describe = item.describe_attrs(attributes_to_get) item_describe = item.describe_attrs(attributes_to_get)
items.append(item_describe) items.append(item_describe)
results["Responses"][table_name] = {"Items": items, "ConsumedCapacityUnits": 1} results["Responses"][table_name] = {
"Items": items, "ConsumedCapacityUnits": 1}
return dynamo_json_dump(results) return dynamo_json_dump(results)
def query(self): def query(self):
@ -239,7 +243,8 @@ class DynamoHandler(BaseResponse):
range_comparison = None range_comparison = None
range_values = [] range_values = []
items, last_page = dynamodb_backend.query(name, hash_key, range_comparison, range_values) items, last_page = dynamodb_backend.query(
name, hash_key, range_comparison, range_values)
if items is None: if items is None:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
@ -265,7 +270,8 @@ class DynamoHandler(BaseResponse):
filters = {} filters = {}
scan_filters = self.body.get('ScanFilter', {}) scan_filters = self.body.get('ScanFilter', {})
for attribute_name, scan_filter in scan_filters.items(): for attribute_name, scan_filter in scan_filters.items():
# Keys are attribute names. Values are tuples of (comparison, comparison_value) # Keys are attribute names. Values are tuples of (comparison,
# comparison_value)
comparison_operator = scan_filter["ComparisonOperator"] comparison_operator = scan_filter["ComparisonOperator"]
comparison_values = scan_filter.get("AttributeValueList", []) comparison_values = scan_filter.get("AttributeValueList", [])
filters[attribute_name] = (comparison_operator, comparison_values) filters[attribute_name] = (comparison_operator, comparison_values)

View File

@ -3,4 +3,4 @@ from .models import dynamodb_backend2
dynamodb_backends2 = {"global": dynamodb_backend2} dynamodb_backends2 = {"global": dynamodb_backend2}
mock_dynamodb2 = dynamodb_backend2.decorator mock_dynamodb2 = dynamodb_backend2.decorator
mock_dynamodb2_deprecated = dynamodb_backend2.deprecated_decorator mock_dynamodb2_deprecated = dynamodb_backend2.deprecated_decorator

View File

@ -1,12 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
# TODO add tests for all of these # TODO add tests for all of these
EQ_FUNCTION = lambda item_value, test_value: item_value == test_value EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa
NE_FUNCTION = lambda item_value, test_value: item_value != test_value NE_FUNCTION = lambda item_value, test_value: item_value != test_value # flake8: noqa
LE_FUNCTION = lambda item_value, test_value: item_value <= test_value LE_FUNCTION = lambda item_value, test_value: item_value <= test_value # flake8: noqa
LT_FUNCTION = lambda item_value, test_value: item_value < test_value LT_FUNCTION = lambda item_value, test_value: item_value < test_value # flake8: noqa
GE_FUNCTION = lambda item_value, test_value: item_value >= test_value GE_FUNCTION = lambda item_value, test_value: item_value >= test_value # flake8: noqa
GT_FUNCTION = lambda item_value, test_value: item_value > test_value GT_FUNCTION = lambda item_value, test_value: item_value > test_value # flake8: noqa
COMPARISON_FUNCS = { COMPARISON_FUNCS = {
'EQ': EQ_FUNCTION, 'EQ': EQ_FUNCTION,

View File

@ -11,6 +11,7 @@ from .comparisons import get_comparison_func
class DynamoJsonEncoder(json.JSONEncoder): class DynamoJsonEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if hasattr(obj, 'to_json'): if hasattr(obj, 'to_json'):
return obj.to_json() return obj.to_json()
@ -76,6 +77,7 @@ class DynamoType(object):
class Item(object): class Item(object):
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs): def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
self.hash_key = hash_key self.hash_key = hash_key
self.hash_key_type = hash_key_type self.hash_key_type = hash_key_type
@ -131,14 +133,15 @@ class Item(object):
elif action == 'SET' or action == 'set': elif action == 'SET' or action == 'set':
key, value = value.split("=") key, value = value.split("=")
if value in expression_attribute_values: if value in expression_attribute_values:
self.attrs[key] = DynamoType(expression_attribute_values[value]) self.attrs[key] = DynamoType(
expression_attribute_values[value])
else: else:
self.attrs[key] = DynamoType({"S": value}) self.attrs[key] = DynamoType({"S": value})
def update_with_attribute_updates(self, attribute_updates): def update_with_attribute_updates(self, attribute_updates):
for attribute_name, update_action in attribute_updates.items(): for attribute_name, update_action in attribute_updates.items():
action = update_action['Action'] action = update_action['Action']
if action == 'DELETE' and not 'Value' in update_action: if action == 'DELETE' and 'Value' not in update_action:
if attribute_name in self.attrs: if attribute_name in self.attrs:
del self.attrs[attribute_name] del self.attrs[attribute_name]
continue continue
@ -158,14 +161,16 @@ class Item(object):
self.attrs[attribute_name] = DynamoType({"S": new_value}) self.attrs[attribute_name] = DynamoType({"S": new_value})
elif action == 'ADD': elif action == 'ADD':
if set(update_action['Value'].keys()) == set(['N']): if set(update_action['Value'].keys()) == set(['N']):
existing = self.attrs.get(attribute_name, DynamoType({"N": '0'})) existing = self.attrs.get(
attribute_name, DynamoType({"N": '0'}))
self.attrs[attribute_name] = DynamoType({"N": str( self.attrs[attribute_name] = DynamoType({"N": str(
decimal.Decimal(existing.value) + decimal.Decimal(existing.value) +
decimal.Decimal(new_value) decimal.Decimal(new_value)
)}) )})
else: else:
# TODO: implement other data types # TODO: implement other data types
raise NotImplementedError('ADD not supported for %s' % ', '.join(update_action['Value'].keys())) raise NotImplementedError(
'ADD not supported for %s' % ', '.join(update_action['Value'].keys()))
class Table(object): class Table(object):
@ -186,7 +191,8 @@ class Table(object):
self.range_key_attr = elem["AttributeName"] self.range_key_attr = elem["AttributeName"]
self.range_key_type = elem["KeyType"] self.range_key_type = elem["KeyType"]
if throughput is None: if throughput is None:
self.throughput = {'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10} self.throughput = {
'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10}
else: else:
self.throughput = throughput self.throughput = throughput
self.throughput["NumberOfDecreasesToday"] = 0 self.throughput["NumberOfDecreasesToday"] = 0
@ -250,14 +256,16 @@ class Table(object):
else: else:
range_value = None range_value = None
item = Item(hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs) item = Item(hash_value, self.hash_key_type, range_value,
self.range_key_type, item_attrs)
if not overwrite: if not overwrite:
if expected is None: if expected is None:
expected = {} expected = {}
lookup_range_value = range_value lookup_range_value = range_value
else: else:
expected_range_value = expected.get(self.range_key_attr, {}).get("Value") expected_range_value = expected.get(
self.range_key_attr, {}).get("Value")
if(expected_range_value is None): if(expected_range_value is None):
lookup_range_value = range_value lookup_range_value = range_value
else: else:
@ -281,8 +289,10 @@ class Table(object):
elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value: elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value:
raise ValueError("The conditional request failed") raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val: elif 'ComparisonOperator' in val:
comparison_func = get_comparison_func(val['ComparisonOperator']) comparison_func = get_comparison_func(
dynamo_types = [DynamoType(ele) for ele in val["AttributeValueList"]] val['ComparisonOperator'])
dynamo_types = [DynamoType(ele) for ele in val[
"AttributeValueList"]]
for t in dynamo_types: for t in dynamo_types:
if not comparison_func(current_attr[key].value, t.value): if not comparison_func(current_attr[key].value, t.value):
raise ValueError('The conditional request failed') raise ValueError('The conditional request failed')
@ -304,7 +314,8 @@ class Table(object):
def get_item(self, hash_key, range_key=None): def get_item(self, hash_key, range_key=None):
if self.has_range_key and not range_key: if self.has_range_key and not range_key:
raise ValueError("Table has a range key, but no range key was passed into get_item") raise ValueError(
"Table has a range key, but no range key was passed into get_item")
try: try:
if range_key: if range_key:
return self.items[hash_key][range_key] return self.items[hash_key][range_key]
@ -339,9 +350,11 @@ class Table(object):
index = indexes_by_name[index_name] index = indexes_by_name[index_name]
try: try:
index_hash_key = [key for key in index['KeySchema'] if key['KeyType'] == 'HASH'][0] index_hash_key = [key for key in index[
'KeySchema'] if key['KeyType'] == 'HASH'][0]
except IndexError: except IndexError:
raise ValueError('Missing Hash Key. KeySchema: %s' % index['KeySchema']) raise ValueError('Missing Hash Key. KeySchema: %s' %
index['KeySchema'])
possible_results = [] possible_results = []
for item in self.all_items(): for item in self.all_items():
@ -351,17 +364,20 @@ class Table(object):
if item_hash_key and item_hash_key == hash_key: if item_hash_key and item_hash_key == hash_key:
possible_results.append(item) possible_results.append(item)
else: else:
possible_results = [item for item in list(self.all_items()) if isinstance(item, Item) and item.hash_key == hash_key] possible_results = [item for item in list(self.all_items()) if isinstance(
item, Item) and item.hash_key == hash_key]
if index_name: if index_name:
try: try:
index_range_key = [key for key in index['KeySchema'] if key['KeyType'] == 'RANGE'][0] index_range_key = [key for key in index[
'KeySchema'] if key['KeyType'] == 'RANGE'][0]
except IndexError: except IndexError:
index_range_key = None index_range_key = None
if range_comparison: if range_comparison:
if index_name and not index_range_key: if index_name and not index_range_key:
raise ValueError('Range Key comparison but no range key found for index: %s' % index_name) raise ValueError(
'Range Key comparison but no range key found for index: %s' % index_name)
elif index_name: elif index_name:
for result in possible_results: for result in possible_results:
@ -375,19 +391,21 @@ class Table(object):
if filter_kwargs: if filter_kwargs:
for result in possible_results: for result in possible_results:
for field, value in filter_kwargs.items(): for field, value in filter_kwargs.items():
dynamo_types = [DynamoType(ele) for ele in value["AttributeValueList"]] dynamo_types = [DynamoType(ele) for ele in value[
"AttributeValueList"]]
if result.attrs.get(field).compare(value['ComparisonOperator'], dynamo_types): if result.attrs.get(field).compare(value['ComparisonOperator'], dynamo_types):
results.append(result) results.append(result)
if not range_comparison and not filter_kwargs: if not range_comparison and not filter_kwargs:
# If we're not filtering on range key or on an index return all values # If we're not filtering on range key or on an index return all
# values
results = possible_results results = possible_results
if index_name: if index_name:
if index_range_key: if index_range_key:
results.sort(key=lambda item: item.attrs[index_range_key['AttributeName']].value results.sort(key=lambda item: item.attrs[index_range_key['AttributeName']].value
if item.attrs.get(index_range_key['AttributeName']) else None) if item.attrs.get(index_range_key['AttributeName']) else None)
else: else:
results.sort(key=lambda item: item.range_key) results.sort(key=lambda item: item.range_key)
@ -427,7 +445,8 @@ class Table(object):
# Comparison is NULL and we don't have the attribute # Comparison is NULL and we don't have the attribute
continue continue
else: else:
# No attribute found and comparison is no NULL. This item fails # No attribute found and comparison is no NULL. This item
# fails
passes_all_conditions = False passes_all_conditions = False
break break
@ -460,7 +479,6 @@ class Table(object):
return results, last_evaluated_key return results, last_evaluated_key
def lookup(self, *args, **kwargs): def lookup(self, *args, **kwargs):
if not self.schema: if not self.schema:
self.describe() self.describe()
@ -517,7 +535,8 @@ class DynamoDBBackend(BaseBackend):
if gsi_to_create: if gsi_to_create:
if gsi_to_create['IndexName'] in gsis_by_name: if gsi_to_create['IndexName'] in gsis_by_name:
raise ValueError('Global Secondary Index already exists: %s' % gsi_to_create['IndexName']) raise ValueError(
'Global Secondary Index already exists: %s' % gsi_to_create['IndexName'])
gsis_by_name[gsi_to_create['IndexName']] = gsi_to_create gsis_by_name[gsi_to_create['IndexName']] = gsi_to_create
@ -555,9 +574,11 @@ class DynamoDBBackend(BaseBackend):
def get_keys_value(self, table, keys): def get_keys_value(self, table, keys):
if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys): if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys):
raise ValueError("Table has a range key, but no range key was passed into get_item") raise ValueError(
"Table has a range key, but no range key was passed into get_item")
hash_key = DynamoType(keys[table.hash_key_attr]) hash_key = DynamoType(keys[table.hash_key_attr])
range_key = DynamoType(keys[table.range_key_attr]) if table.has_range_key else None range_key = DynamoType(
keys[table.range_key_attr]) if table.has_range_key else None
return hash_key, range_key return hash_key, range_key
def get_table(self, table_name): def get_table(self, table_name):
@ -577,7 +598,8 @@ class DynamoDBBackend(BaseBackend):
return None, None return None, None
hash_key = DynamoType(hash_key_dict) hash_key = DynamoType(hash_key_dict)
range_values = [DynamoType(range_value) for range_value in range_value_dicts] range_values = [DynamoType(range_value)
for range_value in range_value_dicts]
return table.query(hash_key, range_comparison, range_values, limit, return table.query(hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, index_name, **filter_kwargs) exclusive_start_key, scan_index_forward, index_name, **filter_kwargs)
@ -598,7 +620,8 @@ class DynamoDBBackend(BaseBackend):
table = self.get_table(table_name) table = self.get_table(table_name)
if all([table.hash_key_attr in key, table.range_key_attr in key]): if all([table.hash_key_attr in key, table.range_key_attr in key]):
# Covers cases where table has hash and range keys, ``key`` param will be a dict # Covers cases where table has hash and range keys, ``key`` param
# will be a dict
hash_value = DynamoType(key[table.hash_key_attr]) hash_value = DynamoType(key[table.hash_key_attr])
range_value = DynamoType(key[table.range_key_attr]) range_value = DynamoType(key[table.range_key_attr])
elif table.hash_key_attr in key: elif table.hash_key_attr in key:
@ -629,7 +652,8 @@ class DynamoDBBackend(BaseBackend):
item = table.get_item(hash_value, range_value) item = table.get_item(hash_value, range_value)
if update_expression: if update_expression:
item.update(update_expression, expression_attribute_names, expression_attribute_values) item.update(update_expression, expression_attribute_names,
expression_attribute_values)
else: else:
item.update_with_attribute_updates(attribute_updates) item.update_with_attribute_updates(attribute_updates)
return item return item

View File

@ -104,11 +104,11 @@ class DynamoHandler(BaseResponse):
local_secondary_indexes = body.get("LocalSecondaryIndexes", []) local_secondary_indexes = body.get("LocalSecondaryIndexes", [])
table = dynamodb_backend2.create_table(table_name, table = dynamodb_backend2.create_table(table_name,
schema=key_schema, schema=key_schema,
throughput=throughput, throughput=throughput,
attr=attr, attr=attr,
global_indexes=global_indexes, global_indexes=global_indexes,
indexes=local_secondary_indexes) indexes=local_secondary_indexes)
if table is not None: if table is not None:
return dynamo_json_dump(table.describe()) return dynamo_json_dump(table.describe())
else: else:
@ -127,7 +127,8 @@ class DynamoHandler(BaseResponse):
def update_table(self): def update_table(self):
name = self.body['TableName'] name = self.body['TableName']
if 'GlobalSecondaryIndexUpdates' in self.body: if 'GlobalSecondaryIndexUpdates' in self.body:
table = dynamodb_backend2.update_table_global_indexes(name, self.body['GlobalSecondaryIndexUpdates']) table = dynamodb_backend2.update_table_global_indexes(
name, self.body['GlobalSecondaryIndexUpdates'])
if 'ProvisionedThroughput' in self.body: if 'ProvisionedThroughput' in self.body:
throughput = self.body["ProvisionedThroughput"] throughput = self.body["ProvisionedThroughput"]
table = dynamodb_backend2.update_table_throughput(name, throughput) table = dynamodb_backend2.update_table_throughput(name, throughput)
@ -151,17 +152,20 @@ class DynamoHandler(BaseResponse):
else: else:
expected = None expected = None
# Attempt to parse simple ConditionExpressions into an Expected expression # Attempt to parse simple ConditionExpressions into an Expected
# expression
if not expected: if not expected:
condition_expression = self.body.get('ConditionExpression') condition_expression = self.body.get('ConditionExpression')
if condition_expression and 'OR' not in condition_expression: if condition_expression and 'OR' not in condition_expression:
cond_items = [c.strip() for c in condition_expression.split('AND')] cond_items = [c.strip()
for c in condition_expression.split('AND')]
if cond_items: if cond_items:
expected = {} expected = {}
overwrite = False overwrite = False
exists_re = re.compile('^attribute_exists\((.*)\)$') exists_re = re.compile('^attribute_exists\((.*)\)$')
not_exists_re = re.compile('^attribute_not_exists\((.*)\)$') not_exists_re = re.compile(
'^attribute_not_exists\((.*)\)$')
for cond in cond_items: for cond in cond_items:
exists_m = exists_re.match(cond) exists_m = exists_re.match(cond)
@ -172,7 +176,8 @@ class DynamoHandler(BaseResponse):
expected[not_exists_m.group(1)] = {'Exists': False} expected[not_exists_m.group(1)] = {'Exists': False}
try: try:
result = dynamodb_backend2.put_item(name, item, expected, overwrite) result = dynamodb_backend2.put_item(
name, item, expected, overwrite)
except Exception: except Exception:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er) return self.error(er)
@ -249,7 +254,8 @@ class DynamoHandler(BaseResponse):
item = dynamodb_backend2.get_item(table_name, key) item = dynamodb_backend2.get_item(table_name, key)
if item: if item:
item_describe = item.describe_attrs(attributes_to_get) item_describe = item.describe_attrs(attributes_to_get)
results["Responses"][table_name].append(item_describe["Item"]) results["Responses"][table_name].append(
item_describe["Item"])
results["ConsumedCapacity"].append({ results["ConsumedCapacity"].append({
"CapacityUnits": len(keys), "CapacityUnits": len(keys),
@ -268,8 +274,10 @@ class DynamoHandler(BaseResponse):
table = dynamodb_backend2.get_table(name) table = dynamodb_backend2.get_table(name)
index_name = self.body.get('IndexName') index_name = self.body.get('IndexName')
if index_name: if index_name:
all_indexes = (table.global_indexes or []) + (table.indexes or []) all_indexes = (table.global_indexes or []) + \
indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) (table.indexes or [])
indexes_by_name = dict((i['IndexName'], i)
for i in all_indexes)
if index_name not in indexes_by_name: if index_name not in indexes_by_name:
raise ValueError('Invalid index: %s for table: %s. Available indexes are: %s' % ( raise ValueError('Invalid index: %s for table: %s. Available indexes are: %s' % (
index_name, name, ', '.join(indexes_by_name.keys()) index_name, name, ', '.join(indexes_by_name.keys())
@ -279,16 +287,21 @@ class DynamoHandler(BaseResponse):
else: else:
index = table.schema index = table.schema
key_map = [column for _, column in sorted((k, v) for k, v in self.body['ExpressionAttributeNames'].items())] key_map = [column for _, column in sorted(
(k, v) for k, v in self.body['ExpressionAttributeNames'].items())]
if " AND " in key_condition_expression: if " AND " in key_condition_expression:
expressions = key_condition_expression.split(" AND ", 1) expressions = key_condition_expression.split(" AND ", 1)
index_hash_key = [key for key in index if key['KeyType'] == 'HASH'][0] index_hash_key = [
hash_key_index_in_key_map = key_map.index(index_hash_key['AttributeName']) key for key in index if key['KeyType'] == 'HASH'][0]
hash_key_index_in_key_map = key_map.index(
index_hash_key['AttributeName'])
hash_key_expression = expressions.pop(hash_key_index_in_key_map).strip('()') hash_key_expression = expressions.pop(
# TODO implement more than one range expression and OR operators hash_key_index_in_key_map).strip('()')
# TODO implement more than one range expression and OR
# operators
range_key_expression = expressions[0].strip('()') range_key_expression = expressions[0].strip('()')
range_key_expression_components = range_key_expression.split() range_key_expression_components = range_key_expression.split()
range_comparison = range_key_expression_components[1] range_comparison = range_key_expression_components[1]
@ -304,7 +317,8 @@ class DynamoHandler(BaseResponse):
value_alias_map[range_key_expression_components[1]], value_alias_map[range_key_expression_components[1]],
] ]
else: else:
range_values = [value_alias_map[range_key_expression_components[2]]] range_values = [value_alias_map[
range_key_expression_components[2]]]
else: else:
hash_key_expression = key_condition_expression hash_key_expression = key_condition_expression
range_comparison = None range_comparison = None
@ -316,14 +330,16 @@ class DynamoHandler(BaseResponse):
# 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}} # 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}}
key_conditions = self.body.get('KeyConditions') key_conditions = self.body.get('KeyConditions')
if key_conditions: if key_conditions:
hash_key_name, range_key_name = dynamodb_backend2.get_table_keys_name(name, key_conditions.keys()) hash_key_name, range_key_name = dynamodb_backend2.get_table_keys_name(
name, key_conditions.keys())
for key, value in key_conditions.items(): for key, value in key_conditions.items():
if key not in (hash_key_name, range_key_name): if key not in (hash_key_name, range_key_name):
filter_kwargs[key] = value filter_kwargs[key] = value
if hash_key_name is None: if hash_key_name is None:
er = "'com.amazonaws.dynamodb.v20120810#ResourceNotFoundException" er = "'com.amazonaws.dynamodb.v20120810#ResourceNotFoundException"
return self.error(er) return self.error(er)
hash_key = key_conditions[hash_key_name]['AttributeValueList'][0] hash_key = key_conditions[hash_key_name][
'AttributeValueList'][0]
if len(key_conditions) == 1: if len(key_conditions) == 1:
range_comparison = None range_comparison = None
range_values = [] range_values = []
@ -334,8 +350,10 @@ class DynamoHandler(BaseResponse):
else: else:
range_condition = key_conditions.get(range_key_name) range_condition = key_conditions.get(range_key_name)
if range_condition: if range_condition:
range_comparison = range_condition['ComparisonOperator'] range_comparison = range_condition[
range_values = range_condition['AttributeValueList'] 'ComparisonOperator']
range_values = range_condition[
'AttributeValueList']
else: else:
range_comparison = None range_comparison = None
range_values = [] range_values = []
@ -369,7 +387,8 @@ class DynamoHandler(BaseResponse):
filters = {} filters = {}
scan_filters = self.body.get('ScanFilter', {}) scan_filters = self.body.get('ScanFilter', {})
for attribute_name, scan_filter in scan_filters.items(): for attribute_name, scan_filter in scan_filters.items():
# Keys are attribute names. Values are tuples of (comparison, comparison_value) # Keys are attribute names. Values are tuples of (comparison,
# comparison_value)
comparison_operator = scan_filter["ComparisonOperator"] comparison_operator = scan_filter["ComparisonOperator"]
comparison_values = scan_filter.get("AttributeValueList", []) comparison_values = scan_filter.get("AttributeValueList", [])
filters[attribute_name] = (comparison_operator, comparison_values) filters[attribute_name] = (comparison_operator, comparison_values)
@ -416,16 +435,20 @@ class DynamoHandler(BaseResponse):
key = self.body['Key'] key = self.body['Key']
update_expression = self.body.get('UpdateExpression') update_expression = self.body.get('UpdateExpression')
attribute_updates = self.body.get('AttributeUpdates') attribute_updates = self.body.get('AttributeUpdates')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) expression_attribute_names = self.body.get(
expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) 'ExpressionAttributeNames', {})
expression_attribute_values = self.body.get(
'ExpressionAttributeValues', {})
existing_item = dynamodb_backend2.get_item(name, key) existing_item = dynamodb_backend2.get_item(name, key)
# Support spaces between operators in an update expression # Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c` # E.g. `a = b + c` -> `a=b+c`
if update_expression: if update_expression:
update_expression = re.sub('\s*([=\+-])\s*', '\\1', update_expression) update_expression = re.sub(
'\s*([=\+-])\s*', '\\1', update_expression)
item = dynamodb_backend2.update_item(name, key, update_expression, attribute_updates, expression_attribute_names, expression_attribute_values) item = dynamodb_backend2.update_item(
name, key, update_expression, attribute_updates, expression_attribute_names, expression_attribute_values)
item_dict = item.to_json() item_dict = item.to_json()
item_dict['ConsumedCapacityUnits'] = 0.5 item_dict['ConsumedCapacityUnits'] = 0.5

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import ec2_backends from .models import ec2_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
ec2_backend = ec2_backends['us-east-1'] ec2_backend = ec2_backends['us-east-1']
mock_ec2 = base_decorator(ec2_backends) mock_ec2 = base_decorator(ec2_backends)

View File

@ -7,12 +7,14 @@ class EC2ClientError(RESTError):
class DependencyViolationError(EC2ClientError): class DependencyViolationError(EC2ClientError):
def __init__(self, message): def __init__(self, message):
super(DependencyViolationError, self).__init__( super(DependencyViolationError, self).__init__(
"DependencyViolation", message) "DependencyViolation", message)
class MissingParameterError(EC2ClientError): class MissingParameterError(EC2ClientError):
def __init__(self, parameter): def __init__(self, parameter):
super(MissingParameterError, self).__init__( super(MissingParameterError, self).__init__(
"MissingParameter", "MissingParameter",
@ -21,6 +23,7 @@ class MissingParameterError(EC2ClientError):
class InvalidDHCPOptionsIdError(EC2ClientError): class InvalidDHCPOptionsIdError(EC2ClientError):
def __init__(self, dhcp_options_id): def __init__(self, dhcp_options_id):
super(InvalidDHCPOptionsIdError, self).__init__( super(InvalidDHCPOptionsIdError, self).__init__(
"InvalidDhcpOptionID.NotFound", "InvalidDhcpOptionID.NotFound",
@ -29,6 +32,7 @@ class InvalidDHCPOptionsIdError(EC2ClientError):
class MalformedDHCPOptionsIdError(EC2ClientError): class MalformedDHCPOptionsIdError(EC2ClientError):
def __init__(self, dhcp_options_id): def __init__(self, dhcp_options_id):
super(MalformedDHCPOptionsIdError, self).__init__( super(MalformedDHCPOptionsIdError, self).__init__(
"InvalidDhcpOptionsId.Malformed", "InvalidDhcpOptionsId.Malformed",
@ -37,6 +41,7 @@ class MalformedDHCPOptionsIdError(EC2ClientError):
class InvalidKeyPairNameError(EC2ClientError): class InvalidKeyPairNameError(EC2ClientError):
def __init__(self, key): def __init__(self, key):
super(InvalidKeyPairNameError, self).__init__( super(InvalidKeyPairNameError, self).__init__(
"InvalidKeyPair.NotFound", "InvalidKeyPair.NotFound",
@ -45,6 +50,7 @@ class InvalidKeyPairNameError(EC2ClientError):
class InvalidKeyPairDuplicateError(EC2ClientError): class InvalidKeyPairDuplicateError(EC2ClientError):
def __init__(self, key): def __init__(self, key):
super(InvalidKeyPairDuplicateError, self).__init__( super(InvalidKeyPairDuplicateError, self).__init__(
"InvalidKeyPair.Duplicate", "InvalidKeyPair.Duplicate",
@ -53,6 +59,7 @@ class InvalidKeyPairDuplicateError(EC2ClientError):
class InvalidVPCIdError(EC2ClientError): class InvalidVPCIdError(EC2ClientError):
def __init__(self, vpc_id): def __init__(self, vpc_id):
super(InvalidVPCIdError, self).__init__( super(InvalidVPCIdError, self).__init__(
"InvalidVpcID.NotFound", "InvalidVpcID.NotFound",
@ -61,6 +68,7 @@ class InvalidVPCIdError(EC2ClientError):
class InvalidSubnetIdError(EC2ClientError): class InvalidSubnetIdError(EC2ClientError):
def __init__(self, subnet_id): def __init__(self, subnet_id):
super(InvalidSubnetIdError, self).__init__( super(InvalidSubnetIdError, self).__init__(
"InvalidSubnetID.NotFound", "InvalidSubnetID.NotFound",
@ -69,6 +77,7 @@ class InvalidSubnetIdError(EC2ClientError):
class InvalidNetworkAclIdError(EC2ClientError): class InvalidNetworkAclIdError(EC2ClientError):
def __init__(self, network_acl_id): def __init__(self, network_acl_id):
super(InvalidNetworkAclIdError, self).__init__( super(InvalidNetworkAclIdError, self).__init__(
"InvalidNetworkAclID.NotFound", "InvalidNetworkAclID.NotFound",
@ -77,6 +86,7 @@ class InvalidNetworkAclIdError(EC2ClientError):
class InvalidVpnGatewayIdError(EC2ClientError): class InvalidVpnGatewayIdError(EC2ClientError):
def __init__(self, network_acl_id): def __init__(self, network_acl_id):
super(InvalidVpnGatewayIdError, self).__init__( super(InvalidVpnGatewayIdError, self).__init__(
"InvalidVpnGatewayID.NotFound", "InvalidVpnGatewayID.NotFound",
@ -85,6 +95,7 @@ class InvalidVpnGatewayIdError(EC2ClientError):
class InvalidVpnConnectionIdError(EC2ClientError): class InvalidVpnConnectionIdError(EC2ClientError):
def __init__(self, network_acl_id): def __init__(self, network_acl_id):
super(InvalidVpnConnectionIdError, self).__init__( super(InvalidVpnConnectionIdError, self).__init__(
"InvalidVpnConnectionID.NotFound", "InvalidVpnConnectionID.NotFound",
@ -93,6 +104,7 @@ class InvalidVpnConnectionIdError(EC2ClientError):
class InvalidCustomerGatewayIdError(EC2ClientError): class InvalidCustomerGatewayIdError(EC2ClientError):
def __init__(self, customer_gateway_id): def __init__(self, customer_gateway_id):
super(InvalidCustomerGatewayIdError, self).__init__( super(InvalidCustomerGatewayIdError, self).__init__(
"InvalidCustomerGatewayID.NotFound", "InvalidCustomerGatewayID.NotFound",
@ -101,6 +113,7 @@ class InvalidCustomerGatewayIdError(EC2ClientError):
class InvalidNetworkInterfaceIdError(EC2ClientError): class InvalidNetworkInterfaceIdError(EC2ClientError):
def __init__(self, eni_id): def __init__(self, eni_id):
super(InvalidNetworkInterfaceIdError, self).__init__( super(InvalidNetworkInterfaceIdError, self).__init__(
"InvalidNetworkInterfaceID.NotFound", "InvalidNetworkInterfaceID.NotFound",
@ -109,6 +122,7 @@ class InvalidNetworkInterfaceIdError(EC2ClientError):
class InvalidNetworkAttachmentIdError(EC2ClientError): class InvalidNetworkAttachmentIdError(EC2ClientError):
def __init__(self, attachment_id): def __init__(self, attachment_id):
super(InvalidNetworkAttachmentIdError, self).__init__( super(InvalidNetworkAttachmentIdError, self).__init__(
"InvalidAttachmentID.NotFound", "InvalidAttachmentID.NotFound",
@ -117,6 +131,7 @@ class InvalidNetworkAttachmentIdError(EC2ClientError):
class InvalidSecurityGroupDuplicateError(EC2ClientError): class InvalidSecurityGroupDuplicateError(EC2ClientError):
def __init__(self, name): def __init__(self, name):
super(InvalidSecurityGroupDuplicateError, self).__init__( super(InvalidSecurityGroupDuplicateError, self).__init__(
"InvalidGroup.Duplicate", "InvalidGroup.Duplicate",
@ -125,6 +140,7 @@ class InvalidSecurityGroupDuplicateError(EC2ClientError):
class InvalidSecurityGroupNotFoundError(EC2ClientError): class InvalidSecurityGroupNotFoundError(EC2ClientError):
def __init__(self, name): def __init__(self, name):
super(InvalidSecurityGroupNotFoundError, self).__init__( super(InvalidSecurityGroupNotFoundError, self).__init__(
"InvalidGroup.NotFound", "InvalidGroup.NotFound",
@ -133,6 +149,7 @@ class InvalidSecurityGroupNotFoundError(EC2ClientError):
class InvalidPermissionNotFoundError(EC2ClientError): class InvalidPermissionNotFoundError(EC2ClientError):
def __init__(self): def __init__(self):
super(InvalidPermissionNotFoundError, self).__init__( super(InvalidPermissionNotFoundError, self).__init__(
"InvalidPermission.NotFound", "InvalidPermission.NotFound",
@ -140,6 +157,7 @@ class InvalidPermissionNotFoundError(EC2ClientError):
class InvalidRouteTableIdError(EC2ClientError): class InvalidRouteTableIdError(EC2ClientError):
def __init__(self, route_table_id): def __init__(self, route_table_id):
super(InvalidRouteTableIdError, self).__init__( super(InvalidRouteTableIdError, self).__init__(
"InvalidRouteTableID.NotFound", "InvalidRouteTableID.NotFound",
@ -148,6 +166,7 @@ class InvalidRouteTableIdError(EC2ClientError):
class InvalidRouteError(EC2ClientError): class InvalidRouteError(EC2ClientError):
def __init__(self, route_table_id, cidr): def __init__(self, route_table_id, cidr):
super(InvalidRouteError, self).__init__( super(InvalidRouteError, self).__init__(
"InvalidRoute.NotFound", "InvalidRoute.NotFound",
@ -156,6 +175,7 @@ class InvalidRouteError(EC2ClientError):
class InvalidInstanceIdError(EC2ClientError): class InvalidInstanceIdError(EC2ClientError):
def __init__(self, instance_id): def __init__(self, instance_id):
super(InvalidInstanceIdError, self).__init__( super(InvalidInstanceIdError, self).__init__(
"InvalidInstanceID.NotFound", "InvalidInstanceID.NotFound",
@ -164,6 +184,7 @@ class InvalidInstanceIdError(EC2ClientError):
class InvalidAMIIdError(EC2ClientError): class InvalidAMIIdError(EC2ClientError):
def __init__(self, ami_id): def __init__(self, ami_id):
super(InvalidAMIIdError, self).__init__( super(InvalidAMIIdError, self).__init__(
"InvalidAMIID.NotFound", "InvalidAMIID.NotFound",
@ -172,6 +193,7 @@ class InvalidAMIIdError(EC2ClientError):
class InvalidAMIAttributeItemValueError(EC2ClientError): class InvalidAMIAttributeItemValueError(EC2ClientError):
def __init__(self, attribute, value): def __init__(self, attribute, value):
super(InvalidAMIAttributeItemValueError, self).__init__( super(InvalidAMIAttributeItemValueError, self).__init__(
"InvalidAMIAttributeItemValue", "InvalidAMIAttributeItemValue",
@ -180,6 +202,7 @@ class InvalidAMIAttributeItemValueError(EC2ClientError):
class MalformedAMIIdError(EC2ClientError): class MalformedAMIIdError(EC2ClientError):
def __init__(self, ami_id): def __init__(self, ami_id):
super(MalformedAMIIdError, self).__init__( super(MalformedAMIIdError, self).__init__(
"InvalidAMIID.Malformed", "InvalidAMIID.Malformed",
@ -188,6 +211,7 @@ class MalformedAMIIdError(EC2ClientError):
class InvalidSnapshotIdError(EC2ClientError): class InvalidSnapshotIdError(EC2ClientError):
def __init__(self, snapshot_id): def __init__(self, snapshot_id):
super(InvalidSnapshotIdError, self).__init__( super(InvalidSnapshotIdError, self).__init__(
"InvalidSnapshot.NotFound", "InvalidSnapshot.NotFound",
@ -195,6 +219,7 @@ class InvalidSnapshotIdError(EC2ClientError):
class InvalidVolumeIdError(EC2ClientError): class InvalidVolumeIdError(EC2ClientError):
def __init__(self, volume_id): def __init__(self, volume_id):
super(InvalidVolumeIdError, self).__init__( super(InvalidVolumeIdError, self).__init__(
"InvalidVolume.NotFound", "InvalidVolume.NotFound",
@ -203,6 +228,7 @@ class InvalidVolumeIdError(EC2ClientError):
class InvalidVolumeAttachmentError(EC2ClientError): class InvalidVolumeAttachmentError(EC2ClientError):
def __init__(self, volume_id, instance_id): def __init__(self, volume_id, instance_id):
super(InvalidVolumeAttachmentError, self).__init__( super(InvalidVolumeAttachmentError, self).__init__(
"InvalidAttachment.NotFound", "InvalidAttachment.NotFound",
@ -211,6 +237,7 @@ class InvalidVolumeAttachmentError(EC2ClientError):
class InvalidDomainError(EC2ClientError): class InvalidDomainError(EC2ClientError):
def __init__(self, domain): def __init__(self, domain):
super(InvalidDomainError, self).__init__( super(InvalidDomainError, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
@ -219,6 +246,7 @@ class InvalidDomainError(EC2ClientError):
class InvalidAddressError(EC2ClientError): class InvalidAddressError(EC2ClientError):
def __init__(self, ip): def __init__(self, ip):
super(InvalidAddressError, self).__init__( super(InvalidAddressError, self).__init__(
"InvalidAddress.NotFound", "InvalidAddress.NotFound",
@ -227,6 +255,7 @@ class InvalidAddressError(EC2ClientError):
class InvalidAllocationIdError(EC2ClientError): class InvalidAllocationIdError(EC2ClientError):
def __init__(self, allocation_id): def __init__(self, allocation_id):
super(InvalidAllocationIdError, self).__init__( super(InvalidAllocationIdError, self).__init__(
"InvalidAllocationID.NotFound", "InvalidAllocationID.NotFound",
@ -235,6 +264,7 @@ class InvalidAllocationIdError(EC2ClientError):
class InvalidAssociationIdError(EC2ClientError): class InvalidAssociationIdError(EC2ClientError):
def __init__(self, association_id): def __init__(self, association_id):
super(InvalidAssociationIdError, self).__init__( super(InvalidAssociationIdError, self).__init__(
"InvalidAssociationID.NotFound", "InvalidAssociationID.NotFound",
@ -243,6 +273,7 @@ class InvalidAssociationIdError(EC2ClientError):
class InvalidVPCPeeringConnectionIdError(EC2ClientError): class InvalidVPCPeeringConnectionIdError(EC2ClientError):
def __init__(self, vpc_peering_connection_id): def __init__(self, vpc_peering_connection_id):
super(InvalidVPCPeeringConnectionIdError, self).__init__( super(InvalidVPCPeeringConnectionIdError, self).__init__(
"InvalidVpcPeeringConnectionId.NotFound", "InvalidVpcPeeringConnectionId.NotFound",
@ -251,6 +282,7 @@ class InvalidVPCPeeringConnectionIdError(EC2ClientError):
class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError): class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError):
def __init__(self, vpc_peering_connection_id): def __init__(self, vpc_peering_connection_id):
super(InvalidVPCPeeringConnectionStateTransitionError, self).__init__( super(InvalidVPCPeeringConnectionStateTransitionError, self).__init__(
"InvalidStateTransition", "InvalidStateTransition",
@ -259,6 +291,7 @@ class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError):
class InvalidParameterValueError(EC2ClientError): class InvalidParameterValueError(EC2ClientError):
def __init__(self, parameter_value): def __init__(self, parameter_value):
super(InvalidParameterValueError, self).__init__( super(InvalidParameterValueError, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
@ -267,6 +300,7 @@ class InvalidParameterValueError(EC2ClientError):
class InvalidParameterValueErrorTagNull(EC2ClientError): class InvalidParameterValueErrorTagNull(EC2ClientError):
def __init__(self): def __init__(self):
super(InvalidParameterValueErrorTagNull, self).__init__( super(InvalidParameterValueErrorTagNull, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
@ -274,6 +308,7 @@ class InvalidParameterValueErrorTagNull(EC2ClientError):
class InvalidInternetGatewayIdError(EC2ClientError): class InvalidInternetGatewayIdError(EC2ClientError):
def __init__(self, internet_gateway_id): def __init__(self, internet_gateway_id):
super(InvalidInternetGatewayIdError, self).__init__( super(InvalidInternetGatewayIdError, self).__init__(
"InvalidInternetGatewayID.NotFound", "InvalidInternetGatewayID.NotFound",
@ -282,6 +317,7 @@ class InvalidInternetGatewayIdError(EC2ClientError):
class GatewayNotAttachedError(EC2ClientError): class GatewayNotAttachedError(EC2ClientError):
def __init__(self, internet_gateway_id, vpc_id): def __init__(self, internet_gateway_id, vpc_id):
super(GatewayNotAttachedError, self).__init__( super(GatewayNotAttachedError, self).__init__(
"Gateway.NotAttached", "Gateway.NotAttached",
@ -290,6 +326,7 @@ class GatewayNotAttachedError(EC2ClientError):
class ResourceAlreadyAssociatedError(EC2ClientError): class ResourceAlreadyAssociatedError(EC2ClientError):
def __init__(self, resource_id): def __init__(self, resource_id):
super(ResourceAlreadyAssociatedError, self).__init__( super(ResourceAlreadyAssociatedError, self).__init__(
"Resource.AlreadyAssociated", "Resource.AlreadyAssociated",
@ -298,6 +335,7 @@ class ResourceAlreadyAssociatedError(EC2ClientError):
class TagLimitExceeded(EC2ClientError): class TagLimitExceeded(EC2ClientError):
def __init__(self): def __init__(self):
super(TagLimitExceeded, self).__init__( super(TagLimitExceeded, self).__init__(
"TagLimitExceeded", "TagLimitExceeded",
@ -305,6 +343,7 @@ class TagLimitExceeded(EC2ClientError):
class InvalidID(EC2ClientError): class InvalidID(EC2ClientError):
def __init__(self, resource_id): def __init__(self, resource_id):
super(InvalidID, self).__init__( super(InvalidID, self).__init__(
"InvalidID", "InvalidID",
@ -313,6 +352,7 @@ class InvalidID(EC2ClientError):
class InvalidCIDRSubnetError(EC2ClientError): class InvalidCIDRSubnetError(EC2ClientError):
def __init__(self, cidr): def __init__(self, cidr):
super(InvalidCIDRSubnetError, self).__init__( super(InvalidCIDRSubnetError, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
@ -321,6 +361,7 @@ class InvalidCIDRSubnetError(EC2ClientError):
class RulesPerSecurityGroupLimitExceededError(EC2ClientError): class RulesPerSecurityGroupLimitExceededError(EC2ClientError):
def __init__(self): def __init__(self):
super(RulesPerSecurityGroupLimitExceededError, self).__init__( super(RulesPerSecurityGroupLimitExceededError, self).__init__(
"RulesPerSecurityGroupLimitExceeded", "RulesPerSecurityGroupLimitExceeded",

File diff suppressed because it is too large Load Diff

View File

@ -66,6 +66,7 @@ class EC2Response(
Windows, Windows,
NatGateways, NatGateways,
): ):
@property @property
def ec2_backend(self): def ec2_backend(self):
from moto.ec2.models import ec2_backends from moto.ec2.models import ec2_backends

View File

@ -3,5 +3,7 @@ from moto.core.responses import BaseResponse
class AmazonDevPay(BaseResponse): class AmazonDevPay(BaseResponse):
def confirm_product_instance(self): def confirm_product_instance(self):
raise NotImplementedError('AmazonDevPay.confirm_product_instance is not yet implemented') raise NotImplementedError(
'AmazonDevPay.confirm_product_instance is not yet implemented')

View File

@ -5,6 +5,7 @@ from moto.ec2.utils import instance_ids_from_querystring, image_ids_from_queryst
class AmisResponse(BaseResponse): class AmisResponse(BaseResponse):
def create_image(self): def create_image(self):
name = self.querystring.get('Name')[0] name = self.querystring.get('Name')[0]
if "Description" in self.querystring: if "Description" in self.querystring:
@ -14,17 +15,21 @@ class AmisResponse(BaseResponse):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0] instance_id = instance_ids[0]
if self.is_not_dryrun('CreateImage'): if self.is_not_dryrun('CreateImage'):
image = self.ec2_backend.create_image(instance_id, name, description) image = self.ec2_backend.create_image(
instance_id, name, description)
template = self.response_template(CREATE_IMAGE_RESPONSE) template = self.response_template(CREATE_IMAGE_RESPONSE)
return template.render(image=image) return template.render(image=image)
def copy_image(self): def copy_image(self):
source_image_id = self.querystring.get('SourceImageId')[0] source_image_id = self.querystring.get('SourceImageId')[0]
source_region = self.querystring.get('SourceRegion')[0] source_region = self.querystring.get('SourceRegion')[0]
name = self.querystring.get('Name')[0] if self.querystring.get('Name') else None name = self.querystring.get(
description = self.querystring.get('Description')[0] if self.querystring.get('Description') else None 'Name')[0] if self.querystring.get('Name') else None
description = self.querystring.get(
'Description')[0] if self.querystring.get('Description') else None
if self.is_not_dryrun('CopyImage'): if self.is_not_dryrun('CopyImage'):
image = self.ec2_backend.copy_image(source_image_id, source_region, name, description) image = self.ec2_backend.copy_image(
source_image_id, source_region, name, description)
template = self.response_template(COPY_IMAGE_RESPONSE) template = self.response_template(COPY_IMAGE_RESPONSE)
return template.render(image=image) return template.render(image=image)
@ -38,7 +43,8 @@ class AmisResponse(BaseResponse):
def describe_images(self): def describe_images(self):
ami_ids = image_ids_from_querystring(self.querystring) ami_ids = image_ids_from_querystring(self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
images = self.ec2_backend.describe_images(ami_ids=ami_ids, filters=filters) images = self.ec2_backend.describe_images(
ami_ids=ami_ids, filters=filters)
template = self.response_template(DESCRIBE_IMAGES_RESPONSE) template = self.response_template(DESCRIBE_IMAGES_RESPONSE)
return template.render(images=images) return template.render(images=images)
@ -56,18 +62,22 @@ class AmisResponse(BaseResponse):
user_ids = sequence_from_querystring('UserId', self.querystring) user_ids = sequence_from_querystring('UserId', self.querystring)
if self.is_not_dryrun('ModifyImageAttribute'): if self.is_not_dryrun('ModifyImageAttribute'):
if (operation_type == 'add'): if (operation_type == 'add'):
self.ec2_backend.add_launch_permission(ami_id, user_ids=user_ids, group=group) self.ec2_backend.add_launch_permission(
ami_id, user_ids=user_ids, group=group)
elif (operation_type == 'remove'): elif (operation_type == 'remove'):
self.ec2_backend.remove_launch_permission(ami_id, user_ids=user_ids, group=group) self.ec2_backend.remove_launch_permission(
ami_id, user_ids=user_ids, group=group)
return MODIFY_IMAGE_ATTRIBUTE_RESPONSE return MODIFY_IMAGE_ATTRIBUTE_RESPONSE
def register_image(self): def register_image(self):
if self.is_not_dryrun('RegisterImage'): if self.is_not_dryrun('RegisterImage'):
raise NotImplementedError('AMIs.register_image is not yet implemented') raise NotImplementedError(
'AMIs.register_image is not yet implemented')
def reset_image_attribute(self): def reset_image_attribute(self):
if self.is_not_dryrun('ResetImageAttribute'): if self.is_not_dryrun('ResetImageAttribute'):
raise NotImplementedError('AMIs.reset_image_attribute is not yet implemented') raise NotImplementedError(
'AMIs.reset_image_attribute is not yet implemented')
CREATE_IMAGE_RESPONSE = """<CreateImageResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> CREATE_IMAGE_RESPONSE = """<CreateImageResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
@ -80,7 +90,8 @@ COPY_IMAGE_RESPONSE = """<CopyImageResponse xmlns="http://ec2.amazonaws.com/doc/
<imageId>{{ image.id }}</imageId> <imageId>{{ image.id }}</imageId>
</CopyImageResponse>""" </CopyImageResponse>"""
# TODO almost all of these params should actually be templated based on the ec2 image # TODO almost all of these params should actually be templated based on
# the ec2 image
DESCRIBE_IMAGES_RESPONSE = """<DescribeImagesResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> DESCRIBE_IMAGES_RESPONSE = """<DescribeImagesResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<imagesSet> <imagesSet>

View File

@ -3,6 +3,7 @@ from moto.core.responses import BaseResponse
class AvailabilityZonesAndRegions(BaseResponse): class AvailabilityZonesAndRegions(BaseResponse):
def describe_availability_zones(self): def describe_availability_zones(self):
zones = self.ec2_backend.describe_availability_zones() zones = self.ec2_backend.describe_availability_zones()
template = self.response_template(DESCRIBE_ZONES_RESPONSE) template = self.response_template(DESCRIBE_ZONES_RESPONSE)
@ -13,6 +14,7 @@ class AvailabilityZonesAndRegions(BaseResponse):
template = self.response_template(DESCRIBE_REGIONS_RESPONSE) template = self.response_template(DESCRIBE_REGIONS_RESPONSE)
return template.render(regions=regions) return template.render(regions=regions)
DESCRIBE_REGIONS_RESPONSE = """<DescribeRegionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> DESCRIBE_REGIONS_RESPONSE = """<DescribeRegionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<regionInfo> <regionInfo>

View File

@ -10,13 +10,15 @@ class CustomerGateways(BaseResponse):
type = self.querystring.get('Type', None)[0] type = self.querystring.get('Type', None)[0]
ip_address = self.querystring.get('IpAddress', None)[0] ip_address = self.querystring.get('IpAddress', None)[0]
bgp_asn = self.querystring.get('BgpAsn', None)[0] bgp_asn = self.querystring.get('BgpAsn', None)[0]
customer_gateway = self.ec2_backend.create_customer_gateway(type, ip_address=ip_address, bgp_asn=bgp_asn) customer_gateway = self.ec2_backend.create_customer_gateway(
type, ip_address=ip_address, bgp_asn=bgp_asn)
template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE) template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=customer_gateway) return template.render(customer_gateway=customer_gateway)
def delete_customer_gateway(self): def delete_customer_gateway(self):
customer_gateway_id = self.querystring.get('CustomerGatewayId')[0] customer_gateway_id = self.querystring.get('CustomerGatewayId')[0]
delete_status = self.ec2_backend.delete_customer_gateway(customer_gateway_id) delete_status = self.ec2_backend.delete_customer_gateway(
customer_gateway_id)
template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE) template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=delete_status) return template.render(customer_gateway=delete_status)

View File

@ -7,6 +7,7 @@ from moto.ec2.utils import (
class DHCPOptions(BaseResponse): class DHCPOptions(BaseResponse):
def associate_dhcp_options(self): def associate_dhcp_options(self):
dhcp_opt_id = self.querystring.get("DhcpOptionsId", [None])[0] dhcp_opt_id = self.querystring.get("DhcpOptionsId", [None])[0]
vpc_id = self.querystring.get("VpcId", [None])[0] vpc_id = self.querystring.get("VpcId", [None])[0]
@ -48,9 +49,11 @@ class DHCPOptions(BaseResponse):
return template.render(delete_status=delete_status) return template.render(delete_status=delete_status)
def describe_dhcp_options(self): def describe_dhcp_options(self):
dhcp_opt_ids = sequence_from_querystring("DhcpOptionsId", self.querystring) dhcp_opt_ids = sequence_from_querystring(
"DhcpOptionsId", self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
dhcp_opts = self.ec2_backend.get_all_dhcp_options(dhcp_opt_ids, filters) dhcp_opts = self.ec2_backend.get_all_dhcp_options(
dhcp_opt_ids, filters)
template = self.response_template(DESCRIBE_DHCP_OPTIONS_RESPONSE) template = self.response_template(DESCRIBE_DHCP_OPTIONS_RESPONSE)
return template.render(dhcp_options=dhcp_opts) return template.render(dhcp_options=dhcp_opts)

View File

@ -10,13 +10,15 @@ class ElasticBlockStore(BaseResponse):
instance_id = self.querystring.get('InstanceId')[0] instance_id = self.querystring.get('InstanceId')[0]
device_path = self.querystring.get('Device')[0] device_path = self.querystring.get('Device')[0]
if self.is_not_dryrun('AttachVolume'): if self.is_not_dryrun('AttachVolume'):
attachment = self.ec2_backend.attach_volume(volume_id, instance_id, device_path) attachment = self.ec2_backend.attach_volume(
volume_id, instance_id, device_path)
template = self.response_template(ATTACHED_VOLUME_RESPONSE) template = self.response_template(ATTACHED_VOLUME_RESPONSE)
return template.render(attachment=attachment) return template.render(attachment=attachment)
def copy_snapshot(self): def copy_snapshot(self):
if self.is_not_dryrun('CopySnapshot'): if self.is_not_dryrun('CopySnapshot'):
raise NotImplementedError('ElasticBlockStore.copy_snapshot is not yet implemented') raise NotImplementedError(
'ElasticBlockStore.copy_snapshot is not yet implemented')
def create_snapshot(self): def create_snapshot(self):
description = self.querystring.get('Description', [None])[0] description = self.querystring.get('Description', [None])[0]
@ -32,7 +34,8 @@ class ElasticBlockStore(BaseResponse):
snapshot_id = self.querystring.get('SnapshotId', [None])[0] snapshot_id = self.querystring.get('SnapshotId', [None])[0]
encrypted = self.querystring.get('Encrypted', ['false'])[0] encrypted = self.querystring.get('Encrypted', ['false'])[0]
if self.is_not_dryrun('CreateVolume'): if self.is_not_dryrun('CreateVolume'):
volume = self.ec2_backend.create_volume(size, zone, snapshot_id, encrypted) volume = self.ec2_backend.create_volume(
size, zone, snapshot_id, encrypted)
template = self.response_template(CREATE_VOLUME_RESPONSE) template = self.response_template(CREATE_VOLUME_RESPONSE)
return template.render(volume=volume) return template.render(volume=volume)
@ -50,51 +53,64 @@ class ElasticBlockStore(BaseResponse):
def describe_snapshots(self): def describe_snapshots(self):
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
# querystring for multiple snapshotids results in SnapshotId.1, SnapshotId.2 etc # querystring for multiple snapshotids results in SnapshotId.1,
snapshot_ids = ','.join([','.join(s[1]) for s in self.querystring.items() if 'SnapshotId' in s[0]]) # SnapshotId.2 etc
snapshot_ids = ','.join(
[','.join(s[1]) for s in self.querystring.items() if 'SnapshotId' in s[0]])
snapshots = self.ec2_backend.describe_snapshots(filters=filters) snapshots = self.ec2_backend.describe_snapshots(filters=filters)
# Describe snapshots to handle filter on snapshot_ids # Describe snapshots to handle filter on snapshot_ids
snapshots = [s for s in snapshots if s.id in snapshot_ids] if snapshot_ids else snapshots snapshots = [
s for s in snapshots if s.id in snapshot_ids] if snapshot_ids else snapshots
template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE) template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE)
return template.render(snapshots=snapshots) return template.render(snapshots=snapshots)
def describe_volumes(self): def describe_volumes(self):
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
# querystring for multiple volumeids results in VolumeId.1, VolumeId.2 etc # querystring for multiple volumeids results in VolumeId.1, VolumeId.2
volume_ids = ','.join([','.join(v[1]) for v in self.querystring.items() if 'VolumeId' in v[0]]) # etc
volume_ids = ','.join(
[','.join(v[1]) for v in self.querystring.items() if 'VolumeId' in v[0]])
volumes = self.ec2_backend.describe_volumes(filters=filters) volumes = self.ec2_backend.describe_volumes(filters=filters)
# Describe volumes to handle filter on volume_ids # Describe volumes to handle filter on volume_ids
volumes = [v for v in volumes if v.id in volume_ids] if volume_ids else volumes volumes = [
v for v in volumes if v.id in volume_ids] if volume_ids else volumes
template = self.response_template(DESCRIBE_VOLUMES_RESPONSE) template = self.response_template(DESCRIBE_VOLUMES_RESPONSE)
return template.render(volumes=volumes) return template.render(volumes=volumes)
def describe_volume_attribute(self): def describe_volume_attribute(self):
raise NotImplementedError('ElasticBlockStore.describe_volume_attribute is not yet implemented') raise NotImplementedError(
'ElasticBlockStore.describe_volume_attribute is not yet implemented')
def describe_volume_status(self): def describe_volume_status(self):
raise NotImplementedError('ElasticBlockStore.describe_volume_status is not yet implemented') raise NotImplementedError(
'ElasticBlockStore.describe_volume_status is not yet implemented')
def detach_volume(self): def detach_volume(self):
volume_id = self.querystring.get('VolumeId')[0] volume_id = self.querystring.get('VolumeId')[0]
instance_id = self.querystring.get('InstanceId')[0] instance_id = self.querystring.get('InstanceId')[0]
device_path = self.querystring.get('Device')[0] device_path = self.querystring.get('Device')[0]
if self.is_not_dryrun('DetachVolume'): if self.is_not_dryrun('DetachVolume'):
attachment = self.ec2_backend.detach_volume(volume_id, instance_id, device_path) attachment = self.ec2_backend.detach_volume(
volume_id, instance_id, device_path)
template = self.response_template(DETATCH_VOLUME_RESPONSE) template = self.response_template(DETATCH_VOLUME_RESPONSE)
return template.render(attachment=attachment) return template.render(attachment=attachment)
def enable_volume_io(self): def enable_volume_io(self):
if self.is_not_dryrun('EnableVolumeIO'): if self.is_not_dryrun('EnableVolumeIO'):
raise NotImplementedError('ElasticBlockStore.enable_volume_io is not yet implemented') raise NotImplementedError(
'ElasticBlockStore.enable_volume_io is not yet implemented')
def import_volume(self): def import_volume(self):
if self.is_not_dryrun('ImportVolume'): if self.is_not_dryrun('ImportVolume'):
raise NotImplementedError('ElasticBlockStore.import_volume is not yet implemented') raise NotImplementedError(
'ElasticBlockStore.import_volume is not yet implemented')
def describe_snapshot_attribute(self): def describe_snapshot_attribute(self):
snapshot_id = self.querystring.get('SnapshotId')[0] snapshot_id = self.querystring.get('SnapshotId')[0]
groups = self.ec2_backend.get_create_volume_permission_groups(snapshot_id) groups = self.ec2_backend.get_create_volume_permission_groups(
template = self.response_template(DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE) snapshot_id)
template = self.response_template(
DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE)
return template.render(snapshot_id=snapshot_id, groups=groups) return template.render(snapshot_id=snapshot_id, groups=groups)
def modify_snapshot_attribute(self): def modify_snapshot_attribute(self):
@ -104,18 +120,22 @@ class ElasticBlockStore(BaseResponse):
user_id = self.querystring.get('UserId.1', [None])[0] user_id = self.querystring.get('UserId.1', [None])[0]
if self.is_not_dryrun('ModifySnapshotAttribute'): if self.is_not_dryrun('ModifySnapshotAttribute'):
if (operation_type == 'add'): if (operation_type == 'add'):
self.ec2_backend.add_create_volume_permission(snapshot_id, user_id=user_id, group=group) self.ec2_backend.add_create_volume_permission(
snapshot_id, user_id=user_id, group=group)
elif (operation_type == 'remove'): elif (operation_type == 'remove'):
self.ec2_backend.remove_create_volume_permission(snapshot_id, user_id=user_id, group=group) self.ec2_backend.remove_create_volume_permission(
snapshot_id, user_id=user_id, group=group)
return MODIFY_SNAPSHOT_ATTRIBUTE_RESPONSE return MODIFY_SNAPSHOT_ATTRIBUTE_RESPONSE
def modify_volume_attribute(self): def modify_volume_attribute(self):
if self.is_not_dryrun('ModifyVolumeAttribute'): if self.is_not_dryrun('ModifyVolumeAttribute'):
raise NotImplementedError('ElasticBlockStore.modify_volume_attribute is not yet implemented') raise NotImplementedError(
'ElasticBlockStore.modify_volume_attribute is not yet implemented')
def reset_snapshot_attribute(self): def reset_snapshot_attribute(self):
if self.is_not_dryrun('ResetSnapshotAttribute'): if self.is_not_dryrun('ResetSnapshotAttribute'):
raise NotImplementedError('ElasticBlockStore.reset_snapshot_attribute is not yet implemented') raise NotImplementedError(
'ElasticBlockStore.reset_snapshot_attribute is not yet implemented')
CREATE_VOLUME_RESPONSE = """<CreateVolumeResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> CREATE_VOLUME_RESPONSE = """<CreateVolumeResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
@ -272,4 +292,4 @@ MODIFY_SNAPSHOT_ATTRIBUTE_RESPONSE = """
<requestId>666d2944-9276-4d6a-be12-1f4ada972fd8</requestId> <requestId>666d2944-9276-4d6a-be12-1f4ada972fd8</requestId>
<return>true</return> <return>true</return>
</ModifySnapshotAttributeResponse> </ModifySnapshotAttributeResponse>
""" """

View File

@ -4,6 +4,7 @@ from moto.ec2.utils import sequence_from_querystring
class ElasticIPAddresses(BaseResponse): class ElasticIPAddresses(BaseResponse):
def allocate_address(self): def allocate_address(self):
if "Domain" in self.querystring: if "Domain" in self.querystring:
domain = self.querystring.get('Domain')[0] domain = self.querystring.get('Domain')[0]
@ -18,11 +19,14 @@ class ElasticIPAddresses(BaseResponse):
instance = eni = None instance = eni = None
if "InstanceId" in self.querystring: if "InstanceId" in self.querystring:
instance = self.ec2_backend.get_instance(self.querystring['InstanceId'][0]) instance = self.ec2_backend.get_instance(
self.querystring['InstanceId'][0])
elif "NetworkInterfaceId" in self.querystring: elif "NetworkInterfaceId" in self.querystring:
eni = self.ec2_backend.get_network_interface(self.querystring['NetworkInterfaceId'][0]) eni = self.ec2_backend.get_network_interface(
self.querystring['NetworkInterfaceId'][0])
else: else:
self.ec2_backend.raise_error("MissingParameter", "Invalid request, expect InstanceId/NetworkId parameter.") self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect InstanceId/NetworkId parameter.")
reassociate = False reassociate = False
if "AllowReassociation" in self.querystring: if "AllowReassociation" in self.querystring:
@ -31,13 +35,17 @@ class ElasticIPAddresses(BaseResponse):
if self.is_not_dryrun('AssociateAddress'): if self.is_not_dryrun('AssociateAddress'):
if instance or eni: if instance or eni:
if "PublicIp" in self.querystring: if "PublicIp" in self.querystring:
eip = self.ec2_backend.associate_address(instance=instance, eni=eni, address=self.querystring['PublicIp'][0], reassociate=reassociate) eip = self.ec2_backend.associate_address(instance=instance, eni=eni, address=self.querystring[
'PublicIp'][0], reassociate=reassociate)
elif "AllocationId" in self.querystring: elif "AllocationId" in self.querystring:
eip = self.ec2_backend.associate_address(instance=instance, eni=eni, allocation_id=self.querystring['AllocationId'][0], reassociate=reassociate) eip = self.ec2_backend.associate_address(instance=instance, eni=eni, allocation_id=self.querystring[
'AllocationId'][0], reassociate=reassociate)
else: else:
self.ec2_backend.raise_error("MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.") self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.")
else: else:
self.ec2_backend.raise_error("MissingParameter", "Invalid request, expect either instance or ENI.") self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect either instance or ENI.")
template = self.response_template(ASSOCIATE_ADDRESS_RESPONSE) template = self.response_template(ASSOCIATE_ADDRESS_RESPONSE)
return template.render(address=eip) return template.render(address=eip)
@ -46,17 +54,23 @@ class ElasticIPAddresses(BaseResponse):
template = self.response_template(DESCRIBE_ADDRESS_RESPONSE) template = self.response_template(DESCRIBE_ADDRESS_RESPONSE)
if "Filter.1.Name" in self.querystring: if "Filter.1.Name" in self.querystring:
filter_by = sequence_from_querystring("Filter.1.Name", self.querystring)[0] filter_by = sequence_from_querystring(
filter_value = sequence_from_querystring("Filter.1.Value", self.querystring) "Filter.1.Name", self.querystring)[0]
filter_value = sequence_from_querystring(
"Filter.1.Value", self.querystring)
if filter_by == 'instance-id': if filter_by == 'instance-id':
addresses = filter(lambda x: x.instance.id == filter_value[0], self.ec2_backend.describe_addresses()) addresses = filter(lambda x: x.instance.id == filter_value[
0], self.ec2_backend.describe_addresses())
else: else:
raise NotImplementedError("Filtering not supported in describe_address.") raise NotImplementedError(
"Filtering not supported in describe_address.")
elif "PublicIp.1" in self.querystring: elif "PublicIp.1" in self.querystring:
public_ips = sequence_from_querystring("PublicIp", self.querystring) public_ips = sequence_from_querystring(
"PublicIp", self.querystring)
addresses = self.ec2_backend.address_by_ip(public_ips) addresses = self.ec2_backend.address_by_ip(public_ips)
elif "AllocationId.1" in self.querystring: elif "AllocationId.1" in self.querystring:
allocation_ids = sequence_from_querystring("AllocationId", self.querystring) allocation_ids = sequence_from_querystring(
"AllocationId", self.querystring)
addresses = self.ec2_backend.address_by_allocation(allocation_ids) addresses = self.ec2_backend.address_by_allocation(allocation_ids)
else: else:
addresses = self.ec2_backend.describe_addresses() addresses = self.ec2_backend.describe_addresses()
@ -65,22 +79,28 @@ class ElasticIPAddresses(BaseResponse):
def disassociate_address(self): def disassociate_address(self):
if self.is_not_dryrun('DisAssociateAddress'): if self.is_not_dryrun('DisAssociateAddress'):
if "PublicIp" in self.querystring: if "PublicIp" in self.querystring:
self.ec2_backend.disassociate_address(address=self.querystring['PublicIp'][0]) self.ec2_backend.disassociate_address(
address=self.querystring['PublicIp'][0])
elif "AssociationId" in self.querystring: elif "AssociationId" in self.querystring:
self.ec2_backend.disassociate_address(association_id=self.querystring['AssociationId'][0]) self.ec2_backend.disassociate_address(
association_id=self.querystring['AssociationId'][0])
else: else:
self.ec2_backend.raise_error("MissingParameter", "Invalid request, expect PublicIp/AssociationId parameter.") self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AssociationId parameter.")
return self.response_template(DISASSOCIATE_ADDRESS_RESPONSE).render() return self.response_template(DISASSOCIATE_ADDRESS_RESPONSE).render()
def release_address(self): def release_address(self):
if self.is_not_dryrun('ReleaseAddress'): if self.is_not_dryrun('ReleaseAddress'):
if "PublicIp" in self.querystring: if "PublicIp" in self.querystring:
self.ec2_backend.release_address(address=self.querystring['PublicIp'][0]) self.ec2_backend.release_address(
address=self.querystring['PublicIp'][0])
elif "AllocationId" in self.querystring: elif "AllocationId" in self.querystring:
self.ec2_backend.release_address(allocation_id=self.querystring['AllocationId'][0]) self.ec2_backend.release_address(
allocation_id=self.querystring['AllocationId'][0])
else: else:
self.ec2_backend.raise_error("MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.") self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.")
return self.response_template(RELEASE_ADDRESS_RESPONSE).render() return self.response_template(RELEASE_ADDRESS_RESPONSE).render()

View File

@ -4,28 +4,35 @@ from moto.ec2.utils import sequence_from_querystring, filters_from_querystring
class ElasticNetworkInterfaces(BaseResponse): class ElasticNetworkInterfaces(BaseResponse):
def create_network_interface(self): def create_network_interface(self):
subnet_id = self.querystring.get('SubnetId')[0] subnet_id = self.querystring.get('SubnetId')[0]
private_ip_address = self.querystring.get('PrivateIpAddress', [None])[0] private_ip_address = self.querystring.get(
'PrivateIpAddress', [None])[0]
groups = sequence_from_querystring('SecurityGroupId', self.querystring) groups = sequence_from_querystring('SecurityGroupId', self.querystring)
subnet = self.ec2_backend.get_subnet(subnet_id) subnet = self.ec2_backend.get_subnet(subnet_id)
if self.is_not_dryrun('CreateNetworkInterface'): if self.is_not_dryrun('CreateNetworkInterface'):
eni = self.ec2_backend.create_network_interface(subnet, private_ip_address, groups) eni = self.ec2_backend.create_network_interface(
template = self.response_template(CREATE_NETWORK_INTERFACE_RESPONSE) subnet, private_ip_address, groups)
template = self.response_template(
CREATE_NETWORK_INTERFACE_RESPONSE)
return template.render(eni=eni) return template.render(eni=eni)
def delete_network_interface(self): def delete_network_interface(self):
eni_id = self.querystring.get('NetworkInterfaceId')[0] eni_id = self.querystring.get('NetworkInterfaceId')[0]
if self.is_not_dryrun('DeleteNetworkInterface'): if self.is_not_dryrun('DeleteNetworkInterface'):
self.ec2_backend.delete_network_interface(eni_id) self.ec2_backend.delete_network_interface(eni_id)
template = self.response_template(DELETE_NETWORK_INTERFACE_RESPONSE) template = self.response_template(
DELETE_NETWORK_INTERFACE_RESPONSE)
return template.render() return template.render()
def describe_network_interface_attribute(self): def describe_network_interface_attribute(self):
raise NotImplementedError('ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented') raise NotImplementedError(
'ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented')
def describe_network_interfaces(self): def describe_network_interfaces(self):
eni_ids = sequence_from_querystring('NetworkInterfaceId', self.querystring) eni_ids = sequence_from_querystring(
'NetworkInterfaceId', self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
enis = self.ec2_backend.get_all_network_interfaces(eni_ids, filters) enis = self.ec2_backend.get_all_network_interfaces(eni_ids, filters)
template = self.response_template(DESCRIBE_NETWORK_INTERFACES_RESPONSE) template = self.response_template(DESCRIBE_NETWORK_INTERFACES_RESPONSE)
@ -36,15 +43,18 @@ class ElasticNetworkInterfaces(BaseResponse):
instance_id = self.querystring.get('InstanceId')[0] instance_id = self.querystring.get('InstanceId')[0]
device_index = self.querystring.get('DeviceIndex')[0] device_index = self.querystring.get('DeviceIndex')[0]
if self.is_not_dryrun('AttachNetworkInterface'): if self.is_not_dryrun('AttachNetworkInterface'):
attachment_id = self.ec2_backend.attach_network_interface(eni_id, instance_id, device_index) attachment_id = self.ec2_backend.attach_network_interface(
template = self.response_template(ATTACH_NETWORK_INTERFACE_RESPONSE) eni_id, instance_id, device_index)
template = self.response_template(
ATTACH_NETWORK_INTERFACE_RESPONSE)
return template.render(attachment_id=attachment_id) return template.render(attachment_id=attachment_id)
def detach_network_interface(self): def detach_network_interface(self):
attachment_id = self.querystring.get('AttachmentId')[0] attachment_id = self.querystring.get('AttachmentId')[0]
if self.is_not_dryrun('DetachNetworkInterface'): if self.is_not_dryrun('DetachNetworkInterface'):
self.ec2_backend.detach_network_interface(attachment_id) self.ec2_backend.detach_network_interface(attachment_id)
template = self.response_template(DETACH_NETWORK_INTERFACE_RESPONSE) template = self.response_template(
DETACH_NETWORK_INTERFACE_RESPONSE)
return template.render() return template.render()
def modify_network_interface_attribute(self): def modify_network_interface_attribute(self):
@ -52,12 +62,15 @@ class ElasticNetworkInterfaces(BaseResponse):
eni_id = self.querystring.get('NetworkInterfaceId')[0] eni_id = self.querystring.get('NetworkInterfaceId')[0]
group_id = self.querystring.get('SecurityGroupId.1')[0] group_id = self.querystring.get('SecurityGroupId.1')[0]
if self.is_not_dryrun('ModifyNetworkInterface'): if self.is_not_dryrun('ModifyNetworkInterface'):
self.ec2_backend.modify_network_interface_attribute(eni_id, group_id) self.ec2_backend.modify_network_interface_attribute(
eni_id, group_id)
return MODIFY_NETWORK_INTERFACE_ATTRIBUTE_RESPONSE return MODIFY_NETWORK_INTERFACE_ATTRIBUTE_RESPONSE
def reset_network_interface_attribute(self): def reset_network_interface_attribute(self):
if self.is_not_dryrun('ResetNetworkInterface'): if self.is_not_dryrun('ResetNetworkInterface'):
raise NotImplementedError('ElasticNetworkInterfaces(AmazonVPC).reset_network_interface_attribute is not yet implemented') raise NotImplementedError(
'ElasticNetworkInterfaces(AmazonVPC).reset_network_interface_attribute is not yet implemented')
CREATE_NETWORK_INTERFACE_RESPONSE = """ CREATE_NETWORK_INTERFACE_RESPONSE = """
<CreateNetworkInterfaceResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <CreateNetworkInterfaceResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">

View File

@ -4,6 +4,7 @@ from moto.ec2.utils import instance_ids_from_querystring
class General(BaseResponse): class General(BaseResponse):
def get_console_output(self): def get_console_output(self):
self.instance_ids = instance_ids_from_querystring(self.querystring) self.instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = self.instance_ids[0] instance_id = self.instance_ids[0]

View File

@ -5,14 +5,18 @@ from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import instance_ids_from_querystring, filters_from_querystring, \ from moto.ec2.utils import instance_ids_from_querystring, filters_from_querystring, \
dict_from_querystring, optional_from_querystring dict_from_querystring, optional_from_querystring
class InstanceResponse(BaseResponse): class InstanceResponse(BaseResponse):
def describe_instances(self): def describe_instances(self):
filter_dict = filters_from_querystring(self.querystring) filter_dict = filters_from_querystring(self.querystring)
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
if instance_ids: if instance_ids:
reservations = self.ec2_backend.get_reservations_by_instance_ids(instance_ids, filters=filter_dict) reservations = self.ec2_backend.get_reservations_by_instance_ids(
instance_ids, filters=filter_dict)
else: else:
reservations = self.ec2_backend.all_reservations(make_copy=True, filters=filter_dict) reservations = self.ec2_backend.all_reservations(
make_copy=True, filters=filter_dict)
template = self.response_template(EC2_DESCRIBE_INSTANCES) template = self.response_template(EC2_DESCRIBE_INSTANCES)
return template.render(reservations=reservations) return template.render(reservations=reservations)
@ -25,10 +29,12 @@ class InstanceResponse(BaseResponse):
security_group_ids = self._get_multi_param('SecurityGroupId') security_group_ids = self._get_multi_param('SecurityGroupId')
nics = dict_from_querystring("NetworkInterface", self.querystring) nics = dict_from_querystring("NetworkInterface", self.querystring)
instance_type = self.querystring.get("InstanceType", ["m1.small"])[0] instance_type = self.querystring.get("InstanceType", ["m1.small"])[0]
placement = self.querystring.get("Placement.AvailabilityZone", [None])[0] placement = self.querystring.get(
"Placement.AvailabilityZone", [None])[0]
subnet_id = self.querystring.get("SubnetId", [None])[0] subnet_id = self.querystring.get("SubnetId", [None])[0]
private_ip = self.querystring.get("PrivateIpAddress", [None])[0] private_ip = self.querystring.get("PrivateIpAddress", [None])[0]
associate_public_ip = self.querystring.get("AssociatePublicIpAddress", [None])[0] associate_public_ip = self.querystring.get(
"AssociatePublicIpAddress", [None])[0]
key_name = self.querystring.get("KeyName", [None])[0] key_name = self.querystring.get("KeyName", [None])[0]
if self.is_not_dryrun('RunInstance'): if self.is_not_dryrun('RunInstance'):
@ -72,10 +78,11 @@ class InstanceResponse(BaseResponse):
def describe_instance_status(self): def describe_instance_status(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
include_all_instances = optional_from_querystring('IncludeAllInstances', include_all_instances = optional_from_querystring('IncludeAllInstances',
self.querystring) == 'true' self.querystring) == 'true'
if instance_ids: if instance_ids:
instances = self.ec2_backend.get_multi_instances_by_id(instance_ids) instances = self.ec2_backend.get_multi_instances_by_id(
instance_ids)
elif include_all_instances: elif include_all_instances:
instances = self.ec2_backend.all_instances() instances = self.ec2_backend.all_instances()
else: else:
@ -85,7 +92,8 @@ class InstanceResponse(BaseResponse):
return template.render(instances=instances) return template.render(instances=instances)
def describe_instance_types(self): def describe_instance_types(self):
instance_types = [InstanceType(name='t1.micro', cores=1, memory=644874240, disk=0)] instance_types = [InstanceType(
name='t1.micro', cores=1, memory=644874240, disk=0)]
template = self.response_template(EC2_DESCRIBE_INSTANCE_TYPES) template = self.response_template(EC2_DESCRIBE_INSTANCE_TYPES)
return template.render(instance_types=instance_types) return template.render(instance_types=instance_types)
@ -96,10 +104,12 @@ class InstanceResponse(BaseResponse):
key = camelcase_to_underscores(attribute) key = camelcase_to_underscores(attribute)
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0] instance_id = instance_ids[0]
instance, value = self.ec2_backend.describe_instance_attribute(instance_id, key) instance, value = self.ec2_backend.describe_instance_attribute(
instance_id, key)
if key == "group_set": if key == "group_set":
template = self.response_template(EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE) template = self.response_template(
EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE)
else: else:
template = self.response_template(EC2_DESCRIBE_INSTANCE_ATTRIBUTE) template = self.response_template(EC2_DESCRIBE_INSTANCE_ATTRIBUTE)
@ -152,7 +162,8 @@ class InstanceResponse(BaseResponse):
instance = self.ec2_backend.get_instance(instance_id) instance = self.ec2_backend.get_instance(instance_id)
if self.is_not_dryrun('ModifyInstanceAttribute'): if self.is_not_dryrun('ModifyInstanceAttribute'):
block_device_type = instance.block_device_mapping[device_name_value] block_device_type = instance.block_device_mapping[
device_name_value]
block_device_type.delete_on_termination = del_on_term_value block_device_type.delete_on_termination = del_on_term_value
# +1 for the next device # +1 for the next device
@ -171,24 +182,27 @@ class InstanceResponse(BaseResponse):
if not attribute_key: if not attribute_key:
return return
if self.is_not_dryrun('Modify'+attribute_key.split(".")[0]): if self.is_not_dryrun('Modify' + attribute_key.split(".")[0]):
value = self.querystring.get(attribute_key)[0] value = self.querystring.get(attribute_key)[0]
normalized_attribute = camelcase_to_underscores(attribute_key.split(".")[0]) normalized_attribute = camelcase_to_underscores(
attribute_key.split(".")[0])
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0] instance_id = instance_ids[0]
self.ec2_backend.modify_instance_attribute(instance_id, normalized_attribute, value) self.ec2_backend.modify_instance_attribute(
instance_id, normalized_attribute, value)
return EC2_MODIFY_INSTANCE_ATTRIBUTE return EC2_MODIFY_INSTANCE_ATTRIBUTE
def _security_grp_instance_attribute_handler(self): def _security_grp_instance_attribute_handler(self):
new_security_grp_list = [] new_security_grp_list = []
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if 'GroupId.' in key: if 'GroupId.' in key:
new_security_grp_list.append(self.querystring.get(key)[0]) new_security_grp_list.append(self.querystring.get(key)[0])
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0] instance_id = instance_ids[0]
if self.is_not_dryrun('ModifyInstanceSecurityGroups'): if self.is_not_dryrun('ModifyInstanceSecurityGroups'):
self.ec2_backend.modify_instance_security_groups(instance_id, new_security_grp_list) self.ec2_backend.modify_instance_security_groups(
instance_id, new_security_grp_list)
return EC2_MODIFY_INSTANCE_ATTRIBUTE return EC2_MODIFY_INSTANCE_ATTRIBUTE
@ -630,4 +644,4 @@ EC2_DESCRIBE_INSTANCE_TYPES = """<?xml version="1.0" encoding="UTF-8"?>
</item> </item>
{% endfor %} {% endfor %}
</instanceTypeSet> </instanceTypeSet>
</DescribeInstanceTypesResponse>""" </DescribeInstanceTypesResponse>"""

View File

@ -7,6 +7,7 @@ from moto.ec2.utils import (
class InternetGateways(BaseResponse): class InternetGateways(BaseResponse):
def attach_internet_gateway(self): def attach_internet_gateway(self):
igw_id = self.querystring.get("InternetGatewayId", [None])[0] igw_id = self.querystring.get("InternetGatewayId", [None])[0]
vpc_id = self.querystring.get("VpcId", [None])[0] vpc_id = self.querystring.get("VpcId", [None])[0]
@ -33,9 +34,11 @@ class InternetGateways(BaseResponse):
if "InternetGatewayId.1" in self.querystring: if "InternetGatewayId.1" in self.querystring:
igw_ids = sequence_from_querystring( igw_ids = sequence_from_querystring(
"InternetGatewayId", self.querystring) "InternetGatewayId", self.querystring)
igws = self.ec2_backend.describe_internet_gateways(igw_ids, filters=filter_dict) igws = self.ec2_backend.describe_internet_gateways(
igw_ids, filters=filter_dict)
else: else:
igws = self.ec2_backend.describe_internet_gateways(filters=filter_dict) igws = self.ec2_backend.describe_internet_gateways(
filters=filter_dict)
template = self.response_template(DESCRIBE_INTERNET_GATEWAYS_RESPONSE) template = self.response_template(DESCRIBE_INTERNET_GATEWAYS_RESPONSE)
return template.render(internet_gateways=igws) return template.render(internet_gateways=igws)

View File

@ -4,10 +4,13 @@ from moto.core.responses import BaseResponse
class IPAddresses(BaseResponse): class IPAddresses(BaseResponse):
def assign_private_ip_addresses(self): def assign_private_ip_addresses(self):
if self.is_not_dryrun('AssignPrivateIPAddress'): if self.is_not_dryrun('AssignPrivateIPAddress'):
raise NotImplementedError('IPAddresses.assign_private_ip_addresses is not yet implemented') raise NotImplementedError(
'IPAddresses.assign_private_ip_addresses is not yet implemented')
def unassign_private_ip_addresses(self): def unassign_private_ip_addresses(self):
if self.is_not_dryrun('UnAssignPrivateIPAddress'): if self.is_not_dryrun('UnAssignPrivateIPAddress'):
raise NotImplementedError('IPAddresses.unassign_private_ip_addresses is not yet implemented') raise NotImplementedError(
'IPAddresses.unassign_private_ip_addresses is not yet implemented')

View File

@ -16,14 +16,16 @@ class KeyPairs(BaseResponse):
def delete_key_pair(self): def delete_key_pair(self):
name = self.querystring.get('KeyName')[0] name = self.querystring.get('KeyName')[0]
if self.is_not_dryrun('DeleteKeyPair'): if self.is_not_dryrun('DeleteKeyPair'):
success = six.text_type(self.ec2_backend.delete_key_pair(name)).lower() success = six.text_type(
self.ec2_backend.delete_key_pair(name)).lower()
return self.response_template(DELETE_KEY_PAIR_RESPONSE).render(success=success) return self.response_template(DELETE_KEY_PAIR_RESPONSE).render(success=success)
def describe_key_pairs(self): def describe_key_pairs(self):
names = keypair_names_from_querystring(self.querystring) names = keypair_names_from_querystring(self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
if len(filters) > 0: if len(filters) > 0:
raise NotImplementedError('Using filters in KeyPairs.describe_key_pairs is not yet implemented') raise NotImplementedError(
'Using filters in KeyPairs.describe_key_pairs is not yet implemented')
keypairs = self.ec2_backend.describe_key_pairs(names) keypairs = self.ec2_backend.describe_key_pairs(names)
template = self.response_template(DESCRIBE_KEY_PAIRS_RESPONSE) template = self.response_template(DESCRIBE_KEY_PAIRS_RESPONSE)

View File

@ -3,10 +3,13 @@ from moto.core.responses import BaseResponse
class Monitoring(BaseResponse): class Monitoring(BaseResponse):
def monitor_instances(self): def monitor_instances(self):
if self.is_not_dryrun('MonitorInstances'): if self.is_not_dryrun('MonitorInstances'):
raise NotImplementedError('Monitoring.monitor_instances is not yet implemented') raise NotImplementedError(
'Monitoring.monitor_instances is not yet implemented')
def unmonitor_instances(self): def unmonitor_instances(self):
if self.is_not_dryrun('UnMonitorInstances'): if self.is_not_dryrun('UnMonitorInstances'):
raise NotImplementedError('Monitoring.unmonitor_instances is not yet implemented') raise NotImplementedError(
'Monitoring.unmonitor_instances is not yet implemented')

View File

@ -8,7 +8,8 @@ class NatGateways(BaseResponse):
def create_nat_gateway(self): def create_nat_gateway(self):
subnet_id = self._get_param('SubnetId') subnet_id = self._get_param('SubnetId')
allocation_id = self._get_param('AllocationId') allocation_id = self._get_param('AllocationId')
nat_gateway = self.ec2_backend.create_nat_gateway(subnet_id=subnet_id, allocation_id=allocation_id) nat_gateway = self.ec2_backend.create_nat_gateway(
subnet_id=subnet_id, allocation_id=allocation_id)
template = self.response_template(CREATE_NAT_GATEWAY) template = self.response_template(CREATE_NAT_GATEWAY)
return template.render(nat_gateway=nat_gateway) return template.render(nat_gateway=nat_gateway)

View File

@ -45,7 +45,8 @@ class NetworkACLs(BaseResponse):
def describe_network_acls(self): def describe_network_acls(self):
network_acl_ids = network_acl_ids_from_querystring(self.querystring) network_acl_ids = network_acl_ids_from_querystring(self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
network_acls = self.ec2_backend.get_all_network_acls(network_acl_ids, filters) network_acls = self.ec2_backend.get_all_network_acls(
network_acl_ids, filters)
template = self.response_template(DESCRIBE_NETWORK_ACL_RESPONSE) template = self.response_template(DESCRIBE_NETWORK_ACL_RESPONSE)
return template.render(network_acls=network_acls) return template.render(network_acls=network_acls)

View File

@ -3,13 +3,17 @@ from moto.core.responses import BaseResponse
class PlacementGroups(BaseResponse): class PlacementGroups(BaseResponse):
def create_placement_group(self): def create_placement_group(self):
if self.is_not_dryrun('CreatePlacementGroup'): if self.is_not_dryrun('CreatePlacementGroup'):
raise NotImplementedError('PlacementGroups.create_placement_group is not yet implemented') raise NotImplementedError(
'PlacementGroups.create_placement_group is not yet implemented')
def delete_placement_group(self): def delete_placement_group(self):
if self.is_not_dryrun('DeletePlacementGroup'): if self.is_not_dryrun('DeletePlacementGroup'):
raise NotImplementedError('PlacementGroups.delete_placement_group is not yet implemented') raise NotImplementedError(
'PlacementGroups.delete_placement_group is not yet implemented')
def describe_placement_groups(self): def describe_placement_groups(self):
raise NotImplementedError('PlacementGroups.describe_placement_groups is not yet implemented') raise NotImplementedError(
'PlacementGroups.describe_placement_groups is not yet implemented')

View File

@ -3,23 +3,30 @@ from moto.core.responses import BaseResponse
class ReservedInstances(BaseResponse): class ReservedInstances(BaseResponse):
def cancel_reserved_instances_listing(self): def cancel_reserved_instances_listing(self):
if self.is_not_dryrun('CancelReservedInstances'): if self.is_not_dryrun('CancelReservedInstances'):
raise NotImplementedError('ReservedInstances.cancel_reserved_instances_listing is not yet implemented') raise NotImplementedError(
'ReservedInstances.cancel_reserved_instances_listing is not yet implemented')
def create_reserved_instances_listing(self): def create_reserved_instances_listing(self):
if self.is_not_dryrun('CreateReservedInstances'): if self.is_not_dryrun('CreateReservedInstances'):
raise NotImplementedError('ReservedInstances.create_reserved_instances_listing is not yet implemented') raise NotImplementedError(
'ReservedInstances.create_reserved_instances_listing is not yet implemented')
def describe_reserved_instances(self): def describe_reserved_instances(self):
raise NotImplementedError('ReservedInstances.describe_reserved_instances is not yet implemented') raise NotImplementedError(
'ReservedInstances.describe_reserved_instances is not yet implemented')
def describe_reserved_instances_listings(self): def describe_reserved_instances_listings(self):
raise NotImplementedError('ReservedInstances.describe_reserved_instances_listings is not yet implemented') raise NotImplementedError(
'ReservedInstances.describe_reserved_instances_listings is not yet implemented')
def describe_reserved_instances_offerings(self): def describe_reserved_instances_offerings(self):
raise NotImplementedError('ReservedInstances.describe_reserved_instances_offerings is not yet implemented') raise NotImplementedError(
'ReservedInstances.describe_reserved_instances_offerings is not yet implemented')
def purchase_reserved_instances_offering(self): def purchase_reserved_instances_offering(self):
if self.is_not_dryrun('PurchaseReservedInstances'): if self.is_not_dryrun('PurchaseReservedInstances'):
raise NotImplementedError('ReservedInstances.purchase_reserved_instances_offering is not yet implemented') raise NotImplementedError(
'ReservedInstances.purchase_reserved_instances_offering is not yet implemented')

View File

@ -8,24 +8,28 @@ class RouteTables(BaseResponse):
def associate_route_table(self): def associate_route_table(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self.querystring.get('RouteTableId')[0]
subnet_id = self.querystring.get('SubnetId')[0] subnet_id = self.querystring.get('SubnetId')[0]
association_id = self.ec2_backend.associate_route_table(route_table_id, subnet_id) association_id = self.ec2_backend.associate_route_table(
route_table_id, subnet_id)
template = self.response_template(ASSOCIATE_ROUTE_TABLE_RESPONSE) template = self.response_template(ASSOCIATE_ROUTE_TABLE_RESPONSE)
return template.render(association_id=association_id) return template.render(association_id=association_id)
def create_route(self): def create_route(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self.querystring.get('RouteTableId')[0]
destination_cidr_block = self.querystring.get('DestinationCidrBlock')[0] destination_cidr_block = self.querystring.get(
'DestinationCidrBlock')[0]
gateway_id = optional_from_querystring('GatewayId', self.querystring) gateway_id = optional_from_querystring('GatewayId', self.querystring)
instance_id = optional_from_querystring('InstanceId', self.querystring) instance_id = optional_from_querystring('InstanceId', self.querystring)
interface_id = optional_from_querystring('NetworkInterfaceId', self.querystring) interface_id = optional_from_querystring(
pcx_id = optional_from_querystring('VpcPeeringConnectionId', self.querystring) 'NetworkInterfaceId', self.querystring)
pcx_id = optional_from_querystring(
'VpcPeeringConnectionId', self.querystring)
self.ec2_backend.create_route(route_table_id, destination_cidr_block, self.ec2_backend.create_route(route_table_id, destination_cidr_block,
gateway_id=gateway_id, gateway_id=gateway_id,
instance_id=instance_id, instance_id=instance_id,
interface_id=interface_id, interface_id=interface_id,
vpc_peering_connection_id=pcx_id) vpc_peering_connection_id=pcx_id)
template = self.response_template(CREATE_ROUTE_RESPONSE) template = self.response_template(CREATE_ROUTE_RESPONSE)
return template.render() return template.render()
@ -38,7 +42,8 @@ class RouteTables(BaseResponse):
def delete_route(self): def delete_route(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self.querystring.get('RouteTableId')[0]
destination_cidr_block = self.querystring.get('DestinationCidrBlock')[0] destination_cidr_block = self.querystring.get(
'DestinationCidrBlock')[0]
self.ec2_backend.delete_route(route_table_id, destination_cidr_block) self.ec2_backend.delete_route(route_table_id, destination_cidr_block)
template = self.response_template(DELETE_ROUTE_RESPONSE) template = self.response_template(DELETE_ROUTE_RESPONSE)
return template.render() return template.render()
@ -52,7 +57,8 @@ class RouteTables(BaseResponse):
def describe_route_tables(self): def describe_route_tables(self):
route_table_ids = route_table_ids_from_querystring(self.querystring) route_table_ids = route_table_ids_from_querystring(self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
route_tables = self.ec2_backend.get_all_route_tables(route_table_ids, filters) route_tables = self.ec2_backend.get_all_route_tables(
route_table_ids, filters)
template = self.response_template(DESCRIBE_ROUTE_TABLES_RESPONSE) template = self.response_template(DESCRIBE_ROUTE_TABLES_RESPONSE)
return template.render(route_tables=route_tables) return template.render(route_tables=route_tables)
@ -64,18 +70,21 @@ class RouteTables(BaseResponse):
def replace_route(self): def replace_route(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self.querystring.get('RouteTableId')[0]
destination_cidr_block = self.querystring.get('DestinationCidrBlock')[0] destination_cidr_block = self.querystring.get(
'DestinationCidrBlock')[0]
gateway_id = optional_from_querystring('GatewayId', self.querystring) gateway_id = optional_from_querystring('GatewayId', self.querystring)
instance_id = optional_from_querystring('InstanceId', self.querystring) instance_id = optional_from_querystring('InstanceId', self.querystring)
interface_id = optional_from_querystring('NetworkInterfaceId', self.querystring) interface_id = optional_from_querystring(
pcx_id = optional_from_querystring('VpcPeeringConnectionId', self.querystring) 'NetworkInterfaceId', self.querystring)
pcx_id = optional_from_querystring(
'VpcPeeringConnectionId', self.querystring)
self.ec2_backend.replace_route(route_table_id, destination_cidr_block, self.ec2_backend.replace_route(route_table_id, destination_cidr_block,
gateway_id=gateway_id, gateway_id=gateway_id,
instance_id=instance_id, instance_id=instance_id,
interface_id=interface_id, interface_id=interface_id,
vpc_peering_connection_id=pcx_id) vpc_peering_connection_id=pcx_id)
template = self.response_template(REPLACE_ROUTE_RESPONSE) template = self.response_template(REPLACE_ROUTE_RESPONSE)
return template.render() return template.render()
@ -83,8 +92,10 @@ class RouteTables(BaseResponse):
def replace_route_table_association(self): def replace_route_table_association(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self.querystring.get('RouteTableId')[0]
association_id = self.querystring.get('AssociationId')[0] association_id = self.querystring.get('AssociationId')[0]
new_association_id = self.ec2_backend.replace_route_table_association(association_id, route_table_id) new_association_id = self.ec2_backend.replace_route_table_association(
template = self.response_template(REPLACE_ROUTE_TABLE_ASSOCIATION_RESPONSE) association_id, route_table_id)
template = self.response_template(
REPLACE_ROUTE_TABLE_ASSOCIATION_RESPONSE)
return template.render(association_id=new_association_id) return template.render(association_id=new_association_id)

View File

@ -1,7 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import collections
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring from moto.ec2.utils import filters_from_querystring
@ -55,10 +53,11 @@ def process_rules_from_querystring(querystring):
source_groups.append(group_dict['GroupName'][0]) source_groups.append(group_dict['GroupName'][0])
yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges, yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges,
source_groups, source_group_ids) source_groups, source_group_ids)
class SecurityGroups(BaseResponse): class SecurityGroups(BaseResponse):
def authorize_security_group_egress(self): def authorize_security_group_egress(self):
if self.is_not_dryrun('GrantSecurityGroupEgress'): if self.is_not_dryrun('GrantSecurityGroupEgress'):
for args in process_rules_from_querystring(self.querystring): for args in process_rules_from_querystring(self.querystring):
@ -77,12 +76,15 @@ class SecurityGroups(BaseResponse):
vpc_id = self.querystring.get("VpcId", [None])[0] vpc_id = self.querystring.get("VpcId", [None])[0]
if self.is_not_dryrun('CreateSecurityGroup'): if self.is_not_dryrun('CreateSecurityGroup'):
group = self.ec2_backend.create_security_group(name, description, vpc_id=vpc_id) group = self.ec2_backend.create_security_group(
name, description, vpc_id=vpc_id)
template = self.response_template(CREATE_SECURITY_GROUP_RESPONSE) template = self.response_template(CREATE_SECURITY_GROUP_RESPONSE)
return template.render(group=group) return template.render(group=group)
def delete_security_group(self): def delete_security_group(self):
# TODO this should raise an error if there are instances in the group. See http://docs.aws.amazon.com/AWSEC2/latest/APIReference/ApiReference-query-DeleteSecurityGroup.html # TODO this should raise an error if there are instances in the group.
# See
# http://docs.aws.amazon.com/AWSEC2/latest/APIReference/ApiReference-query-DeleteSecurityGroup.html
name = self.querystring.get('GroupName') name = self.querystring.get('GroupName')
sg_id = self.querystring.get('GroupId') sg_id = self.querystring.get('GroupId')

View File

@ -7,21 +7,25 @@ class SpotFleets(BaseResponse):
def cancel_spot_fleet_requests(self): def cancel_spot_fleet_requests(self):
spot_fleet_request_ids = self._get_multi_param("SpotFleetRequestId.") spot_fleet_request_ids = self._get_multi_param("SpotFleetRequestId.")
terminate_instances = self._get_param("TerminateInstances") terminate_instances = self._get_param("TerminateInstances")
spot_fleets = self.ec2_backend.cancel_spot_fleet_requests(spot_fleet_request_ids, terminate_instances) spot_fleets = self.ec2_backend.cancel_spot_fleet_requests(
spot_fleet_request_ids, terminate_instances)
template = self.response_template(CANCEL_SPOT_FLEETS_TEMPLATE) template = self.response_template(CANCEL_SPOT_FLEETS_TEMPLATE)
return template.render(spot_fleets=spot_fleets) return template.render(spot_fleets=spot_fleets)
def describe_spot_fleet_instances(self): def describe_spot_fleet_instances(self):
spot_fleet_request_id = self._get_param("SpotFleetRequestId") spot_fleet_request_id = self._get_param("SpotFleetRequestId")
spot_requests = self.ec2_backend.describe_spot_fleet_instances(spot_fleet_request_id) spot_requests = self.ec2_backend.describe_spot_fleet_instances(
template = self.response_template(DESCRIBE_SPOT_FLEET_INSTANCES_TEMPLATE) spot_fleet_request_id)
template = self.response_template(
DESCRIBE_SPOT_FLEET_INSTANCES_TEMPLATE)
return template.render(spot_request_id=spot_fleet_request_id, spot_requests=spot_requests) return template.render(spot_request_id=spot_fleet_request_id, spot_requests=spot_requests)
def describe_spot_fleet_requests(self): def describe_spot_fleet_requests(self):
spot_fleet_request_ids = self._get_multi_param("SpotFleetRequestId.") spot_fleet_request_ids = self._get_multi_param("SpotFleetRequestId.")
requests = self.ec2_backend.describe_spot_fleet_requests(spot_fleet_request_ids) requests = self.ec2_backend.describe_spot_fleet_requests(
spot_fleet_request_ids)
template = self.response_template(DESCRIBE_SPOT_FLEET_TEMPLATE) template = self.response_template(DESCRIBE_SPOT_FLEET_TEMPLATE)
return template.render(requests=requests) return template.render(requests=requests)
@ -32,7 +36,8 @@ class SpotFleets(BaseResponse):
iam_fleet_role = spot_config['iam_fleet_role'] iam_fleet_role = spot_config['iam_fleet_role']
allocation_strategy = spot_config['allocation_strategy'] allocation_strategy = spot_config['allocation_strategy']
launch_specs = self._get_list_prefix("SpotFleetRequestConfig.LaunchSpecifications") launch_specs = self._get_list_prefix(
"SpotFleetRequestConfig.LaunchSpecifications")
request = self.ec2_backend.request_spot_fleet( request = self.ec2_backend.request_spot_fleet(
spot_price=spot_price, spot_price=spot_price,
@ -45,6 +50,7 @@ class SpotFleets(BaseResponse):
template = self.response_template(REQUEST_SPOT_FLEET_TEMPLATE) template = self.response_template(REQUEST_SPOT_FLEET_TEMPLATE)
return template.render(request=request) return template.render(request=request)
REQUEST_SPOT_FLEET_TEMPLATE = """<RequestSpotFleetResponse xmlns="http://ec2.amazonaws.com/doc/2016-09-15/"> REQUEST_SPOT_FLEET_TEMPLATE = """<RequestSpotFleetResponse xmlns="http://ec2.amazonaws.com/doc/2016-09-15/">
<requestId>60262cc5-2bd4-4c8d-98ed-example</requestId> <requestId>60262cc5-2bd4-4c8d-98ed-example</requestId>
<spotFleetRequestId>{{ request.id }}</spotFleetRequestId> <spotFleetRequestId>{{ request.id }}</spotFleetRequestId>

View File

@ -8,29 +8,35 @@ class SpotInstances(BaseResponse):
def cancel_spot_instance_requests(self): def cancel_spot_instance_requests(self):
request_ids = self._get_multi_param('SpotInstanceRequestId') request_ids = self._get_multi_param('SpotInstanceRequestId')
if self.is_not_dryrun('CancelSpotInstance'): if self.is_not_dryrun('CancelSpotInstance'):
requests = self.ec2_backend.cancel_spot_instance_requests(request_ids) requests = self.ec2_backend.cancel_spot_instance_requests(
request_ids)
template = self.response_template(CANCEL_SPOT_INSTANCES_TEMPLATE) template = self.response_template(CANCEL_SPOT_INSTANCES_TEMPLATE)
return template.render(requests=requests) return template.render(requests=requests)
def create_spot_datafeed_subscription(self): def create_spot_datafeed_subscription(self):
if self.is_not_dryrun('CreateSpotDatafeedSubscription'): if self.is_not_dryrun('CreateSpotDatafeedSubscription'):
raise NotImplementedError('SpotInstances.create_spot_datafeed_subscription is not yet implemented') raise NotImplementedError(
'SpotInstances.create_spot_datafeed_subscription is not yet implemented')
def delete_spot_datafeed_subscription(self): def delete_spot_datafeed_subscription(self):
if self.is_not_dryrun('DeleteSpotDatafeedSubscription'): if self.is_not_dryrun('DeleteSpotDatafeedSubscription'):
raise NotImplementedError('SpotInstances.delete_spot_datafeed_subscription is not yet implemented') raise NotImplementedError(
'SpotInstances.delete_spot_datafeed_subscription is not yet implemented')
def describe_spot_datafeed_subscription(self): def describe_spot_datafeed_subscription(self):
raise NotImplementedError('SpotInstances.describe_spot_datafeed_subscription is not yet implemented') raise NotImplementedError(
'SpotInstances.describe_spot_datafeed_subscription is not yet implemented')
def describe_spot_instance_requests(self): def describe_spot_instance_requests(self):
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
requests = self.ec2_backend.describe_spot_instance_requests(filters=filters) requests = self.ec2_backend.describe_spot_instance_requests(
filters=filters)
template = self.response_template(DESCRIBE_SPOT_INSTANCES_TEMPLATE) template = self.response_template(DESCRIBE_SPOT_INSTANCES_TEMPLATE)
return template.render(requests=requests) return template.render(requests=requests)
def describe_spot_price_history(self): def describe_spot_price_history(self):
raise NotImplementedError('SpotInstances.describe_spot_price_history is not yet implemented') raise NotImplementedError(
'SpotInstances.describe_spot_price_history is not yet implemented')
def request_spot_instances(self): def request_spot_instances(self):
price = self._get_param('SpotPrice') price = self._get_param('SpotPrice')
@ -42,13 +48,17 @@ class SpotInstances(BaseResponse):
launch_group = self._get_param('LaunchGroup') launch_group = self._get_param('LaunchGroup')
availability_zone_group = self._get_param('AvailabilityZoneGroup') availability_zone_group = self._get_param('AvailabilityZoneGroup')
key_name = self._get_param('LaunchSpecification.KeyName') key_name = self._get_param('LaunchSpecification.KeyName')
security_groups = self._get_multi_param('LaunchSpecification.SecurityGroup') security_groups = self._get_multi_param(
'LaunchSpecification.SecurityGroup')
user_data = self._get_param('LaunchSpecification.UserData') user_data = self._get_param('LaunchSpecification.UserData')
instance_type = self._get_param('LaunchSpecification.InstanceType', 'm1.small') instance_type = self._get_param(
placement = self._get_param('LaunchSpecification.Placement.AvailabilityZone') 'LaunchSpecification.InstanceType', 'm1.small')
placement = self._get_param(
'LaunchSpecification.Placement.AvailabilityZone')
kernel_id = self._get_param('LaunchSpecification.KernelId') kernel_id = self._get_param('LaunchSpecification.KernelId')
ramdisk_id = self._get_param('LaunchSpecification.RamdiskId') ramdisk_id = self._get_param('LaunchSpecification.RamdiskId')
monitoring_enabled = self._get_param('LaunchSpecification.Monitoring.Enabled') monitoring_enabled = self._get_param(
'LaunchSpecification.Monitoring.Enabled')
subnet_id = self._get_param('LaunchSpecification.SubnetId') subnet_id = self._get_param('LaunchSpecification.SubnetId')
if self.is_not_dryrun('RequestSpotInstance'): if self.is_not_dryrun('RequestSpotInstance'):

View File

@ -5,13 +5,15 @@ from moto.ec2.utils import filters_from_querystring
class Subnets(BaseResponse): class Subnets(BaseResponse):
def create_subnet(self): def create_subnet(self):
vpc_id = self.querystring.get('VpcId')[0] vpc_id = self.querystring.get('VpcId')[0]
cidr_block = self.querystring.get('CidrBlock')[0] cidr_block = self.querystring.get('CidrBlock')[0]
if 'AvailabilityZone' in self.querystring: if 'AvailabilityZone' in self.querystring:
availability_zone = self.querystring['AvailabilityZone'][0] availability_zone = self.querystring['AvailabilityZone'][0]
else: else:
zone = random.choice(self.ec2_backend.describe_availability_zones()) zone = random.choice(
self.ec2_backend.describe_availability_zones())
availability_zone = zone.name availability_zone = zone.name
subnet = self.ec2_backend.create_subnet( subnet = self.ec2_backend.create_subnet(
vpc_id, vpc_id,

View File

@ -8,7 +8,8 @@ from moto.ec2.utils import sequence_from_querystring, tags_from_query_string, fi
class TagResponse(BaseResponse): class TagResponse(BaseResponse):
def create_tags(self): def create_tags(self):
resource_ids = sequence_from_querystring('ResourceId', self.querystring) resource_ids = sequence_from_querystring(
'ResourceId', self.querystring)
validate_resource_ids(resource_ids) validate_resource_ids(resource_ids)
self.ec2_backend.do_resources_exist(resource_ids) self.ec2_backend.do_resources_exist(resource_ids)
tags = tags_from_query_string(self.querystring) tags = tags_from_query_string(self.querystring)
@ -17,7 +18,8 @@ class TagResponse(BaseResponse):
return CREATE_RESPONSE return CREATE_RESPONSE
def delete_tags(self): def delete_tags(self):
resource_ids = sequence_from_querystring('ResourceId', self.querystring) resource_ids = sequence_from_querystring(
'ResourceId', self.querystring)
validate_resource_ids(resource_ids) validate_resource_ids(resource_ids)
tags = tags_from_query_string(self.querystring) tags = tags_from_query_string(self.querystring)
if self.is_not_dryrun('DeleteTags'): if self.is_not_dryrun('DeleteTags'):

View File

@ -4,6 +4,7 @@ from moto.ec2.utils import filters_from_querystring
class VirtualPrivateGateways(BaseResponse): class VirtualPrivateGateways(BaseResponse):
def attach_vpn_gateway(self): def attach_vpn_gateway(self):
vpn_gateway_id = self.querystring.get('VpnGatewayId')[0] vpn_gateway_id = self.querystring.get('VpnGatewayId')[0]
vpc_id = self.querystring.get('VpcId')[0] vpc_id = self.querystring.get('VpcId')[0]
@ -42,6 +43,7 @@ class VirtualPrivateGateways(BaseResponse):
template = self.response_template(DETACH_VPN_GATEWAY_RESPONSE) template = self.response_template(DETACH_VPN_GATEWAY_RESPONSE)
return template.render(attachment=attachment) return template.render(attachment=attachment)
CREATE_VPN_GATEWAY_RESPONSE = """ CREATE_VPN_GATEWAY_RESPONSE = """
<CreateVpnGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <CreateVpnGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId> <requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>

View File

@ -3,11 +3,15 @@ from moto.core.responses import BaseResponse
class VMExport(BaseResponse): class VMExport(BaseResponse):
def cancel_export_task(self): def cancel_export_task(self):
raise NotImplementedError('VMExport.cancel_export_task is not yet implemented') raise NotImplementedError(
'VMExport.cancel_export_task is not yet implemented')
def create_instance_export_task(self): def create_instance_export_task(self):
raise NotImplementedError('VMExport.create_instance_export_task is not yet implemented') raise NotImplementedError(
'VMExport.create_instance_export_task is not yet implemented')
def describe_export_tasks(self): def describe_export_tasks(self):
raise NotImplementedError('VMExport.describe_export_tasks is not yet implemented') raise NotImplementedError(
'VMExport.describe_export_tasks is not yet implemented')

View File

@ -3,14 +3,19 @@ from moto.core.responses import BaseResponse
class VMImport(BaseResponse): class VMImport(BaseResponse):
def cancel_conversion_task(self): def cancel_conversion_task(self):
raise NotImplementedError('VMImport.cancel_conversion_task is not yet implemented') raise NotImplementedError(
'VMImport.cancel_conversion_task is not yet implemented')
def describe_conversion_tasks(self): def describe_conversion_tasks(self):
raise NotImplementedError('VMImport.describe_conversion_tasks is not yet implemented') raise NotImplementedError(
'VMImport.describe_conversion_tasks is not yet implemented')
def import_instance(self): def import_instance(self):
raise NotImplementedError('VMImport.import_instance is not yet implemented') raise NotImplementedError(
'VMImport.import_instance is not yet implemented')
def import_volume(self): def import_volume(self):
raise NotImplementedError('VMImport.import_volume is not yet implemented') raise NotImplementedError(
'VMImport.import_volume is not yet implemented')

View File

@ -3,34 +3,41 @@ from moto.core.responses import BaseResponse
class VPCPeeringConnections(BaseResponse): class VPCPeeringConnections(BaseResponse):
def create_vpc_peering_connection(self): def create_vpc_peering_connection(self):
vpc = self.ec2_backend.get_vpc(self.querystring.get('VpcId')[0]) vpc = self.ec2_backend.get_vpc(self.querystring.get('VpcId')[0])
peer_vpc = self.ec2_backend.get_vpc(self.querystring.get('PeerVpcId')[0]) peer_vpc = self.ec2_backend.get_vpc(
self.querystring.get('PeerVpcId')[0])
vpc_pcx = self.ec2_backend.create_vpc_peering_connection(vpc, peer_vpc) vpc_pcx = self.ec2_backend.create_vpc_peering_connection(vpc, peer_vpc)
template = self.response_template(CREATE_VPC_PEERING_CONNECTION_RESPONSE) template = self.response_template(
CREATE_VPC_PEERING_CONNECTION_RESPONSE)
return template.render(vpc_pcx=vpc_pcx) return template.render(vpc_pcx=vpc_pcx)
def delete_vpc_peering_connection(self): def delete_vpc_peering_connection(self):
vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0] vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0]
vpc_pcx = self.ec2_backend.delete_vpc_peering_connection(vpc_pcx_id) vpc_pcx = self.ec2_backend.delete_vpc_peering_connection(vpc_pcx_id)
template = self.response_template(DELETE_VPC_PEERING_CONNECTION_RESPONSE) template = self.response_template(
DELETE_VPC_PEERING_CONNECTION_RESPONSE)
return template.render(vpc_pcx=vpc_pcx) return template.render(vpc_pcx=vpc_pcx)
def describe_vpc_peering_connections(self): def describe_vpc_peering_connections(self):
vpc_pcxs = self.ec2_backend.get_all_vpc_peering_connections() vpc_pcxs = self.ec2_backend.get_all_vpc_peering_connections()
template = self.response_template(DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE) template = self.response_template(
DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE)
return template.render(vpc_pcxs=vpc_pcxs) return template.render(vpc_pcxs=vpc_pcxs)
def accept_vpc_peering_connection(self): def accept_vpc_peering_connection(self):
vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0] vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0]
vpc_pcx = self.ec2_backend.accept_vpc_peering_connection(vpc_pcx_id) vpc_pcx = self.ec2_backend.accept_vpc_peering_connection(vpc_pcx_id)
template = self.response_template(ACCEPT_VPC_PEERING_CONNECTION_RESPONSE) template = self.response_template(
ACCEPT_VPC_PEERING_CONNECTION_RESPONSE)
return template.render(vpc_pcx=vpc_pcx) return template.render(vpc_pcx=vpc_pcx)
def reject_vpc_peering_connection(self): def reject_vpc_peering_connection(self):
vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0] vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0]
self.ec2_backend.reject_vpc_peering_connection(vpc_pcx_id) self.ec2_backend.reject_vpc_peering_connection(vpc_pcx_id)
template = self.response_template(REJECT_VPC_PEERING_CONNECTION_RESPONSE) template = self.response_template(
REJECT_VPC_PEERING_CONNECTION_RESPONSE)
return template.render() return template.render()

View File

@ -5,9 +5,11 @@ from moto.ec2.utils import filters_from_querystring, vpc_ids_from_querystring
class VPCs(BaseResponse): class VPCs(BaseResponse):
def create_vpc(self): def create_vpc(self):
cidr_block = self.querystring.get('CidrBlock')[0] cidr_block = self.querystring.get('CidrBlock')[0]
instance_tenancy = self.querystring.get('InstanceTenancy', ['default'])[0] instance_tenancy = self.querystring.get(
'InstanceTenancy', ['default'])[0]
vpc = self.ec2_backend.create_vpc(cidr_block, instance_tenancy) vpc = self.ec2_backend.create_vpc(cidr_block, instance_tenancy)
template = self.response_template(CREATE_VPC_RESPONSE) template = self.response_template(CREATE_VPC_RESPONSE)
return template.render(vpc=vpc) return template.render(vpc=vpc)
@ -40,7 +42,8 @@ class VPCs(BaseResponse):
if self.querystring.get('%s.Value' % attribute): if self.querystring.get('%s.Value' % attribute):
attr_name = camelcase_to_underscores(attribute) attr_name = camelcase_to_underscores(attribute)
attr_value = self.querystring.get('%s.Value' % attribute)[0] attr_value = self.querystring.get('%s.Value' % attribute)[0]
self.ec2_backend.modify_vpc_attribute(vpc_id, attr_name, attr_value) self.ec2_backend.modify_vpc_attribute(
vpc_id, attr_name, attr_value)
return MODIFY_VPC_ATTRIBUTE_RESPONSE return MODIFY_VPC_ATTRIBUTE_RESPONSE

View File

@ -4,23 +4,27 @@ from moto.ec2.utils import filters_from_querystring, sequence_from_querystring
class VPNConnections(BaseResponse): class VPNConnections(BaseResponse):
def create_vpn_connection(self): def create_vpn_connection(self):
type = self.querystring.get("Type", [None])[0] type = self.querystring.get("Type", [None])[0]
cgw_id = self.querystring.get("CustomerGatewayId", [None])[0] cgw_id = self.querystring.get("CustomerGatewayId", [None])[0]
vgw_id = self.querystring.get("VPNGatewayId", [None])[0] vgw_id = self.querystring.get("VPNGatewayId", [None])[0]
static_routes = self.querystring.get("StaticRoutesOnly", [None])[0] static_routes = self.querystring.get("StaticRoutesOnly", [None])[0]
vpn_connection = self.ec2_backend.create_vpn_connection(type, cgw_id, vgw_id, static_routes_only=static_routes) vpn_connection = self.ec2_backend.create_vpn_connection(
type, cgw_id, vgw_id, static_routes_only=static_routes)
template = self.response_template(CREATE_VPN_CONNECTION_RESPONSE) template = self.response_template(CREATE_VPN_CONNECTION_RESPONSE)
return template.render(vpn_connection=vpn_connection) return template.render(vpn_connection=vpn_connection)
def delete_vpn_connection(self): def delete_vpn_connection(self):
vpn_connection_id = self.querystring.get('VpnConnectionId')[0] vpn_connection_id = self.querystring.get('VpnConnectionId')[0]
vpn_connection = self.ec2_backend.delete_vpn_connection(vpn_connection_id) vpn_connection = self.ec2_backend.delete_vpn_connection(
vpn_connection_id)
template = self.response_template(DELETE_VPN_CONNECTION_RESPONSE) template = self.response_template(DELETE_VPN_CONNECTION_RESPONSE)
return template.render(vpn_connection=vpn_connection) return template.render(vpn_connection=vpn_connection)
def describe_vpn_connections(self): def describe_vpn_connections(self):
vpn_connection_ids = sequence_from_querystring('VpnConnectionId', self.querystring) vpn_connection_ids = sequence_from_querystring(
'VpnConnectionId', self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
vpn_connections = self.ec2_backend.get_all_vpn_connections( vpn_connections = self.ec2_backend.get_all_vpn_connections(
vpn_connection_ids=vpn_connection_ids, filters=filters) vpn_connection_ids=vpn_connection_ids, filters=filters)

View File

@ -3,14 +3,19 @@ from moto.core.responses import BaseResponse
class Windows(BaseResponse): class Windows(BaseResponse):
def bundle_instance(self): def bundle_instance(self):
raise NotImplementedError('Windows.bundle_instance is not yet implemented') raise NotImplementedError(
'Windows.bundle_instance is not yet implemented')
def cancel_bundle_task(self): def cancel_bundle_task(self):
raise NotImplementedError('Windows.cancel_bundle_task is not yet implemented') raise NotImplementedError(
'Windows.cancel_bundle_task is not yet implemented')
def describe_bundle_tasks(self): def describe_bundle_tasks(self):
raise NotImplementedError('Windows.describe_bundle_tasks is not yet implemented') raise NotImplementedError(
'Windows.describe_bundle_tasks is not yet implemented')
def get_password_data(self): def get_password_data(self):
raise NotImplementedError('Windows.get_password_data is not yet implemented') raise NotImplementedError(
'Windows.get_password_data is not yet implemented')

View File

@ -32,13 +32,15 @@ EC2_RESOURCE_TO_PREFIX = {
'vpn-gateway': 'vgw'} 'vpn-gateway': 'vgw'}
EC2_PREFIX_TO_RESOURCE = dict((v, k) for (k, v) in EC2_RESOURCE_TO_PREFIX.items()) EC2_PREFIX_TO_RESOURCE = dict((v, k)
for (k, v) in EC2_RESOURCE_TO_PREFIX.items())
def random_id(prefix='', size=8): def random_id(prefix='', size=8):
chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f'] chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f']
resource_id = ''.join(six.text_type(random.choice(chars)) for x in range(size)) resource_id = ''.join(six.text_type(random.choice(chars))
for x in range(size))
return '{0}-{1}'.format(prefix, resource_id) return '{0}-{1}'.format(prefix, resource_id)
@ -228,7 +230,8 @@ def tags_from_query_string(querystring_dict):
tag_key = querystring_dict.get("Tag.{0}.Key".format(tag_index))[0] tag_key = querystring_dict.get("Tag.{0}.Key".format(tag_index))[0]
tag_value_key = "Tag.{0}.Value".format(tag_index) tag_value_key = "Tag.{0}.Value".format(tag_index)
if tag_value_key in querystring_dict: if tag_value_key in querystring_dict:
response_values[tag_key] = querystring_dict.get(tag_value_key)[0] response_values[tag_key] = querystring_dict.get(tag_value_key)[
0]
else: else:
response_values[tag_key] = None response_values[tag_key] = None
return response_values return response_values
@ -262,7 +265,8 @@ def dhcp_configuration_from_querystring(querystring, option=u'DhcpConfiguration'
key_index = key.split(".")[1] key_index = key.split(".")[1]
value_index = 1 value_index = 1
while True: while True:
value_key = u'{0}.{1}.Value.{2}'.format(option, key_index, value_index) value_key = u'{0}.{1}.Value.{2}'.format(
option, key_index, value_index)
if value_key in querystring: if value_key in querystring:
values.extend(querystring[value_key]) values.extend(querystring[value_key])
else: else:
@ -337,16 +341,20 @@ def get_obj_tag(obj, filter_name):
tags = dict((tag['key'], tag['value']) for tag in obj.get_tags()) tags = dict((tag['key'], tag['value']) for tag in obj.get_tags())
return tags.get(tag_name) return tags.get(tag_name)
def get_obj_tag_names(obj): def get_obj_tag_names(obj):
tags = set((tag['key'] for tag in obj.get_tags())) tags = set((tag['key'] for tag in obj.get_tags()))
return tags return tags
def get_obj_tag_values(obj): def get_obj_tag_values(obj):
tags = set((tag['value'] for tag in obj.get_tags())) tags = set((tag['value'] for tag in obj.get_tags()))
return tags return tags
def tag_filter_matches(obj, filter_name, filter_values): def tag_filter_matches(obj, filter_name, filter_values):
regex_filters = [re.compile(simple_aws_filter_to_re(f)) for f in filter_values] regex_filters = [re.compile(simple_aws_filter_to_re(f))
for f in filter_values]
if filter_name == 'tag-key': if filter_name == 'tag-key':
tag_values = get_obj_tag_names(obj) tag_values = get_obj_tag_names(obj)
elif filter_name == 'tag-value': elif filter_name == 'tag-value':
@ -400,7 +408,7 @@ def instance_value_in_filter_values(instance_value, filter_values):
if not set(filter_values).intersection(set(instance_value)): if not set(filter_values).intersection(set(instance_value)):
return False return False
elif instance_value not in filter_values: elif instance_value not in filter_values:
return False return False
return True return True
@ -464,7 +472,8 @@ def is_filter_matching(obj, filter, filter_value):
def generic_filter(filters, objects): def generic_filter(filters, objects):
if filters: if filters:
for (_filter, _filter_value) in filters.items(): for (_filter, _filter_value) in filters.items():
objects = [obj for obj in objects if is_filter_matching(obj, _filter, _filter_value)] objects = [obj for obj in objects if is_filter_matching(
obj, _filter, _filter_value)]
return objects return objects
@ -480,8 +489,10 @@ def simple_aws_filter_to_re(filter_string):
def random_key_pair(): def random_key_pair():
def random_hex(): def random_hex():
return chr(random.choice(list(range(48, 58)) + list(range(97, 102)))) return chr(random.choice(list(range(48, 58)) + list(range(97, 102))))
def random_fingerprint(): def random_fingerprint():
return ':'.join([random_hex()+random_hex() for i in range(20)]) return ':'.join([random_hex() + random_hex() for i in range(20)])
def random_material(): def random_material():
return ''.join([ return ''.join([
chr(random.choice(list(range(65, 91)) + list(range(48, 58)) + chr(random.choice(list(range(65, 91)) + list(range(48, 58)) +
@ -489,7 +500,7 @@ def random_key_pair():
for i in range(1000) for i in range(1000)
]) ])
material = "---- BEGIN RSA PRIVATE KEY ----" + random_material() + \ material = "---- BEGIN RSA PRIVATE KEY ----" + random_material() + \
"-----END RSA PRIVATE KEY-----" "-----END RSA PRIVATE KEY-----"
return { return {
'fingerprint': random_fingerprint(), 'fingerprint': random_fingerprint(),
'material': material 'material': material
@ -500,9 +511,11 @@ def get_prefix(resource_id):
resource_id_prefix, separator, after = resource_id.partition('-') resource_id_prefix, separator, after = resource_id.partition('-')
if resource_id_prefix == EC2_RESOURCE_TO_PREFIX['network-interface']: if resource_id_prefix == EC2_RESOURCE_TO_PREFIX['network-interface']:
if after.startswith('attach'): if after.startswith('attach'):
resource_id_prefix = EC2_RESOURCE_TO_PREFIX['network-interface-attachment'] resource_id_prefix = EC2_RESOURCE_TO_PREFIX[
'network-interface-attachment']
if resource_id_prefix not in EC2_RESOURCE_TO_PREFIX.values(): if resource_id_prefix not in EC2_RESOURCE_TO_PREFIX.values():
uuid4hex = re.compile('[0-9a-f]{12}4[0-9a-f]{3}[89ab][0-9a-f]{15}\Z', re.I) uuid4hex = re.compile(
'[0-9a-f]{12}4[0-9a-f]{3}[89ab][0-9a-f]{15}\Z', re.I)
if uuid4hex.match(resource_id) is not None: if uuid4hex.match(resource_id) is not None:
resource_id_prefix = EC2_RESOURCE_TO_PREFIX['reserved-instance'] resource_id_prefix = EC2_RESOURCE_TO_PREFIX['reserved-instance']
else: else:
@ -539,20 +552,20 @@ def generate_instance_identity_document(instance):
""" """
document = { document = {
'devPayProductCodes': None, 'devPayProductCodes': None,
'availabilityZone': instance.placement['AvailabilityZone'], 'availabilityZone': instance.placement['AvailabilityZone'],
'privateIp': instance.private_ip_address, 'privateIp': instance.private_ip_address,
'version': '2010-8-31', 'version': '2010-8-31',
'region': instance.placement['AvailabilityZone'][:-1], 'region': instance.placement['AvailabilityZone'][:-1],
'instanceId': instance.id, 'instanceId': instance.id,
'billingProducts': None, 'billingProducts': None,
'instanceType': instance.instance_type, 'instanceType': instance.instance_type,
'accountId': '012345678910', 'accountId': '012345678910',
'pendingTime': '2015-11-19T16:32:11Z', 'pendingTime': '2015-11-19T16:32:11Z',
'imageId': instance.image_id, 'imageId': instance.image_id,
'kernelId': instance.kernel_id, 'kernelId': instance.kernel_id,
'ramdiskId': instance.ramdisk_id, 'ramdiskId': instance.ramdisk_id,
'architecture': instance.architecture, 'architecture': instance.architecture,
} }
return document return document

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import ecs_backends from .models import ecs_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
ecs_backend = ecs_backends['us-east-1'] ecs_backend = ecs_backends['us-east-1']
mock_ecs = base_decorator(ecs_backends) mock_ecs = base_decorator(ecs_backends)

View File

@ -8,6 +8,7 @@ from copy import copy
class BaseObject(object): class BaseObject(object):
def camelCase(self, key): def camelCase(self, key):
words = [] words = []
for i, word in enumerate(key.split('_')): for i, word in enumerate(key.split('_')):
@ -31,9 +32,11 @@ class BaseObject(object):
class Cluster(BaseObject): class Cluster(BaseObject):
def __init__(self, cluster_name): def __init__(self, cluster_name):
self.active_services_count = 0 self.active_services_count = 0
self.arn = 'arn:aws:ecs:us-east-1:012345678910:cluster/{0}'.format(cluster_name) self.arn = 'arn:aws:ecs:us-east-1:012345678910:cluster/{0}'.format(
cluster_name)
self.name = cluster_name self.name = cluster_name
self.pending_tasks_count = 0 self.pending_tasks_count = 0
self.registered_container_instances_count = 0 self.registered_container_instances_count = 0
@ -58,9 +61,12 @@ class Cluster(BaseObject):
ecs_backend = ecs_backends[region_name] ecs_backend = ecs_backends[region_name]
return ecs_backend.create_cluster( return ecs_backend.create_cluster(
# ClusterName is optional in CloudFormation, thus create a random name if necessary # ClusterName is optional in CloudFormation, thus create a random
cluster_name=properties.get('ClusterName', 'ecscluster{0}'.format(int(random() * 10 ** 6))), # name if necessary
cluster_name=properties.get(
'ClusterName', 'ecscluster{0}'.format(int(random() * 10 ** 6))),
) )
@classmethod @classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties'] properties = cloudformation_json['Properties']
@ -69,8 +75,10 @@ class Cluster(BaseObject):
ecs_backend = ecs_backends[region_name] ecs_backend = ecs_backends[region_name]
ecs_backend.delete_cluster(original_resource.arn) ecs_backend.delete_cluster(original_resource.arn)
return ecs_backend.create_cluster( return ecs_backend.create_cluster(
# ClusterName is optional in CloudFormation, thus create a random name if necessary # ClusterName is optional in CloudFormation, thus create a
cluster_name=properties.get('ClusterName', 'ecscluster{0}'.format(int(random() * 10 ** 6))), # random name if necessary
cluster_name=properties.get(
'ClusterName', 'ecscluster{0}'.format(int(random() * 10 ** 6))),
) )
else: else:
# no-op when nothing changed between old and new resources # no-op when nothing changed between old and new resources
@ -78,9 +86,11 @@ class Cluster(BaseObject):
class TaskDefinition(BaseObject): class TaskDefinition(BaseObject):
def __init__(self, family, revision, container_definitions, volumes=None): def __init__(self, family, revision, container_definitions, volumes=None):
self.family = family self.family = family
self.arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/{0}:{1}'.format(family, revision) self.arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/{0}:{1}'.format(
family, revision)
self.container_definitions = container_definitions self.container_definitions = container_definitions
if volumes is None: if volumes is None:
self.volumes = [] self.volumes = []
@ -98,7 +108,8 @@ class TaskDefinition(BaseObject):
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties'] properties = cloudformation_json['Properties']
family = properties.get('Family', 'task-definition-{0}'.format(int(random() * 10 ** 6))) family = properties.get(
'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6)))
container_definitions = properties['ContainerDefinitions'] container_definitions = properties['ContainerDefinitions']
volumes = properties['Volumes'] volumes = properties['Volumes']
@ -110,14 +121,16 @@ class TaskDefinition(BaseObject):
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties'] properties = cloudformation_json['Properties']
family = properties.get('Family', 'task-definition-{0}'.format(int(random() * 10 ** 6))) family = properties.get(
'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6)))
container_definitions = properties['ContainerDefinitions'] container_definitions = properties['ContainerDefinitions']
volumes = properties['Volumes'] volumes = properties['Volumes']
if (original_resource.family != family or if (original_resource.family != family or
original_resource.container_definitions != container_definitions or original_resource.container_definitions != container_definitions or
original_resource.volumes != volumes original_resource.volumes != volumes):
# currently TaskRoleArn isn't stored at TaskDefinition instances # currently TaskRoleArn isn't stored at TaskDefinition
): # instances
ecs_backend = ecs_backends[region_name] ecs_backend = ecs_backends[region_name]
ecs_backend.deregister_task_definition(original_resource.arn) ecs_backend.deregister_task_definition(original_resource.arn)
return ecs_backend.register_task_definition( return ecs_backend.register_task_definition(
@ -126,10 +139,13 @@ class TaskDefinition(BaseObject):
# no-op when nothing changed between old and new resources # no-op when nothing changed between old and new resources
return original_resource return original_resource
class Task(BaseObject): class Task(BaseObject):
def __init__(self, cluster, task_definition, container_instance_arn, overrides={}, started_by=''): def __init__(self, cluster, task_definition, container_instance_arn, overrides={}, started_by=''):
self.cluster_arn = cluster.arn self.cluster_arn = cluster.arn
self.task_arn = 'arn:aws:ecs:us-east-1:012345678910:task/{0}'.format(str(uuid.uuid1())) self.task_arn = 'arn:aws:ecs:us-east-1:012345678910:task/{0}'.format(
str(uuid.uuid1()))
self.container_instance_arn = container_instance_arn self.container_instance_arn = container_instance_arn
self.last_status = 'RUNNING' self.last_status = 'RUNNING'
self.desired_status = 'RUNNING' self.desired_status = 'RUNNING'
@ -146,9 +162,11 @@ class Task(BaseObject):
class Service(BaseObject): class Service(BaseObject):
def __init__(self, cluster, service_name, task_definition, desired_count): def __init__(self, cluster, service_name, task_definition, desired_count):
self.cluster_arn = cluster.arn self.cluster_arn = cluster.arn
self.arn = 'arn:aws:ecs:us-east-1:012345678910:service/{0}'.format(service_name) self.arn = 'arn:aws:ecs:us-east-1:012345678910:service/{0}'.format(
service_name)
self.name = service_name self.name = service_name
self.status = 'ACTIVE' self.status = 'ACTIVE'
self.running_count = 0 self.running_count = 0
@ -209,7 +227,8 @@ class Service(BaseObject):
# TODO: LoadBalancers # TODO: LoadBalancers
# TODO: Role # TODO: Role
ecs_backend.delete_service(cluster_name, service_name) ecs_backend.delete_service(cluster_name, service_name)
new_service_name = '{0}Service{1}'.format(cluster_name, int(random() * 10 ** 6)) new_service_name = '{0}Service{1}'.format(
cluster_name, int(random() * 10 ** 6))
return ecs_backend.create_service( return ecs_backend.create_service(
cluster_name, new_service_name, task_definition, desired_count) cluster_name, new_service_name, task_definition, desired_count)
else: else:
@ -217,20 +236,22 @@ class Service(BaseObject):
class ContainerInstance(BaseObject): class ContainerInstance(BaseObject):
def __init__(self, ec2_instance_id): def __init__(self, ec2_instance_id):
self.ec2_instance_id = ec2_instance_id self.ec2_instance_id = ec2_instance_id
self.status = 'ACTIVE' self.status = 'ACTIVE'
self.registeredResources = [] self.registeredResources = []
self.agentConnected = True self.agentConnected = True
self.containerInstanceArn = "arn:aws:ecs:us-east-1:012345678910:container-instance/{0}".format(str(uuid.uuid1())) self.containerInstanceArn = "arn:aws:ecs:us-east-1:012345678910:container-instance/{0}".format(
str(uuid.uuid1()))
self.pendingTaskCount = 0 self.pendingTaskCount = 0
self.remainingResources = [] self.remainingResources = []
self.runningTaskCount = 0 self.runningTaskCount = 0
self.versionInfo = { self.versionInfo = {
'agentVersion': "1.0.0", 'agentVersion': "1.0.0",
'agentHash': '4023248', 'agentHash': '4023248',
'dockerVersion': 'DockerVersion: 1.5.0' 'dockerVersion': 'DockerVersion: 1.5.0'
} }
@property @property
def response_object(self): def response_object(self):
@ -240,9 +261,11 @@ class ContainerInstance(BaseObject):
class ContainerInstanceFailure(BaseObject): class ContainerInstanceFailure(BaseObject):
def __init__(self, reason, container_instance_id): def __init__(self, reason, container_instance_id):
self.reason = reason self.reason = reason
self.arn = "arn:aws:ecs:us-east-1:012345678910:container-instance/{0}".format(container_instance_id) self.arn = "arn:aws:ecs:us-east-1:012345678910:container-instance/{0}".format(
container_instance_id)
@property @property
def response_object(self): def response_object(self):
@ -253,6 +276,7 @@ class ContainerInstanceFailure(BaseObject):
class EC2ContainerServiceBackend(BaseBackend): class EC2ContainerServiceBackend(BaseBackend):
def __init__(self): def __init__(self):
self.clusters = {} self.clusters = {}
self.task_definitions = {} self.task_definitions = {}
@ -261,19 +285,21 @@ class EC2ContainerServiceBackend(BaseBackend):
self.container_instances = {} self.container_instances = {}
def describe_task_definition(self, task_definition_str): def describe_task_definition(self, task_definition_str):
task_definition_components = task_definition_str.split(':') task_definition_name = task_definition_str.split('/')[-1]
if len(task_definition_components) == 2: if ':' in task_definition_name:
family, revision = task_definition_components family, revision = task_definition_name.split(':')
revision = int(revision) revision = int(revision)
else: else:
family = task_definition_components[0] family = task_definition_name
revision = -1 revision = len(self.task_definitions.get(family, []))
if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]): if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]):
return self.task_definitions[family][revision - 1] return self.task_definitions[family][revision - 1]
elif family in self.task_definitions and revision == -1: elif family in self.task_definitions and revision == -1:
return self.task_definitions[family][revision] return self.task_definitions[family][revision]
else: else:
raise Exception("{0} is not a task_definition".format(task_definition_str)) raise Exception(
"{0} is not a task_definition".format(task_definition_name))
def create_cluster(self, cluster_name): def create_cluster(self, cluster_name):
cluster = Cluster(cluster_name) cluster = Cluster(cluster_name)
@ -295,9 +321,11 @@ class EC2ContainerServiceBackend(BaseBackend):
for cluster in list_clusters_name: for cluster in list_clusters_name:
cluster_name = cluster.split('/')[-1] cluster_name = cluster.split('/')[-1]
if cluster_name in self.clusters: if cluster_name in self.clusters:
list_clusters.append(self.clusters[cluster_name].response_object) list_clusters.append(
self.clusters[cluster_name].response_object)
else: else:
raise Exception("{0} is not a cluster".format(cluster_name)) raise Exception(
"{0} is not a cluster".format(cluster_name))
return list_clusters return list_clusters
def delete_cluster(self, cluster_str): def delete_cluster(self, cluster_str):
@ -313,7 +341,8 @@ class EC2ContainerServiceBackend(BaseBackend):
else: else:
self.task_definitions[family] = [] self.task_definitions[family] = []
revision = 1 revision = 1
task_definition = TaskDefinition(family, revision, container_definitions, volumes) task_definition = TaskDefinition(
family, revision, container_definitions, volumes)
self.task_definitions[family].append(task_definition) self.task_definitions[family].append(task_definition)
return task_definition return task_definition
@ -324,23 +353,10 @@ class EC2ContainerServiceBackend(BaseBackend):
""" """
task_arns = [] task_arns = []
for task_definition_list in self.task_definitions.values(): for task_definition_list in self.task_definitions.values():
task_arns.extend([task_definition.arn for task_definition in task_definition_list]) task_arns.extend(
[task_definition.arn for task_definition in task_definition_list])
return task_arns return task_arns
def describe_task_definition(self, task_definition_str):
task_definition_name = task_definition_str.split('/')[-1]
if ':' in task_definition_name:
family, revision = task_definition_name.split(':')
revision = int(revision)
else:
family = task_definition_name
revision = len(self.task_definitions.get(family, []))
if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]):
return self.task_definitions[family][revision-1]
else:
raise Exception("{0} is not a task_definition".format(task_definition_name))
def deregister_task_definition(self, task_definition_str): def deregister_task_definition(self, task_definition_str):
task_definition_name = task_definition_str.split('/')[-1] task_definition_name = task_definition_str.split('/')[-1]
family, revision = task_definition_name.split(':') family, revision = task_definition_name.split(':')
@ -348,7 +364,8 @@ class EC2ContainerServiceBackend(BaseBackend):
if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]): if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]):
return self.task_definitions[family].pop(revision - 1) return self.task_definitions[family].pop(revision - 1)
else: else:
raise Exception("{0} is not a task_definition".format(task_definition_name)) raise Exception(
"{0} is not a task_definition".format(task_definition_name))
def run_task(self, cluster_str, task_definition_str, count, overrides, started_by): def run_task(self, cluster_str, task_definition_str, count, overrides, started_by):
cluster_name = cluster_str.split('/')[-1] cluster_name = cluster_str.split('/')[-1]
@ -360,14 +377,17 @@ class EC2ContainerServiceBackend(BaseBackend):
if cluster_name not in self.tasks: if cluster_name not in self.tasks:
self.tasks[cluster_name] = {} self.tasks[cluster_name] = {}
tasks = [] tasks = []
container_instances = list(self.container_instances.get(cluster_name, {}).keys()) container_instances = list(
self.container_instances.get(cluster_name, {}).keys())
if not container_instances: if not container_instances:
raise Exception("No instances found in cluster {}".format(cluster_name)) raise Exception(
"No instances found in cluster {}".format(cluster_name))
for _ in range(count or 1): for _ in range(count or 1):
container_instance_arn = self.container_instances[cluster_name][ container_instance_arn = self.container_instances[cluster_name][
container_instances[randint(0, len(container_instances) - 1)] container_instances[randint(0, len(container_instances) - 1)]
].containerInstanceArn ].containerInstanceArn
task = Task(cluster, task_definition, container_instance_arn, overrides or {}, started_by or '') task = Task(cluster, task_definition, container_instance_arn,
overrides or {}, started_by or '')
tasks.append(task) tasks.append(task)
self.tasks[cluster_name][task.task_arn] = task self.tasks[cluster_name][task.task_arn] = task
return tasks return tasks
@ -385,13 +405,15 @@ class EC2ContainerServiceBackend(BaseBackend):
if not container_instances: if not container_instances:
raise Exception("No container instance list provided") raise Exception("No container instance list provided")
container_instance_ids = [x.split('/')[-1] for x in container_instances] container_instance_ids = [x.split('/')[-1]
for x in container_instances]
for container_instance_id in container_instance_ids: for container_instance_id in container_instance_ids:
container_instance_arn = self.container_instances[cluster_name][ container_instance_arn = self.container_instances[cluster_name][
container_instance_id container_instance_id
].containerInstanceArn ].containerInstanceArn
task = Task(cluster, task_definition, container_instance_arn, overrides or {}, started_by or '') task = Task(cluster, task_definition, container_instance_arn,
overrides or {}, started_by or '')
tasks.append(task) tasks.append(task)
self.tasks[cluster_name][task.task_arn] = task self.tasks[cluster_name][task.task_arn] = task
return tasks return tasks
@ -418,17 +440,18 @@ class EC2ContainerServiceBackend(BaseBackend):
filtered_tasks.append(task) filtered_tasks.append(task)
if cluster_str: if cluster_str:
cluster_name = cluster_str.split('/')[-1] cluster_name = cluster_str.split('/')[-1]
if cluster_name in self.clusters: if cluster_name not in self.clusters:
cluster = self.clusters[cluster_name]
else:
raise Exception("{0} is not a cluster".format(cluster_name)) raise Exception("{0} is not a cluster".format(cluster_name))
filtered_tasks = list(filter(lambda t: cluster_name in t.cluster_arn, filtered_tasks)) filtered_tasks = list(
filter(lambda t: cluster_name in t.cluster_arn, filtered_tasks))
if container_instance: if container_instance:
filtered_tasks = list(filter(lambda t: container_instance in t.container_instance_arn, filtered_tasks)) filtered_tasks = list(filter(
lambda t: container_instance in t.container_instance_arn, filtered_tasks))
if started_by: if started_by:
filtered_tasks = list(filter(lambda t: started_by == t.started_by, filtered_tasks)) filtered_tasks = list(
filter(lambda t: started_by == t.started_by, filtered_tasks))
return [t.task_arn for t in filtered_tasks] return [t.task_arn for t in filtered_tasks]
def stop_task(self, cluster_str, task_str, reason): def stop_task(self, cluster_str, task_str, reason):
@ -441,14 +464,16 @@ class EC2ContainerServiceBackend(BaseBackend):
task_id = task_str.split('/')[-1] task_id = task_str.split('/')[-1]
tasks = self.tasks.get(cluster_name, None) tasks = self.tasks.get(cluster_name, None)
if not tasks: if not tasks:
raise Exception("Cluster {} has no registered tasks".format(cluster_name)) raise Exception(
"Cluster {} has no registered tasks".format(cluster_name))
for task in tasks.keys(): for task in tasks.keys():
if task.endswith(task_id): if task.endswith(task_id):
tasks[task].last_status = 'STOPPED' tasks[task].last_status = 'STOPPED'
tasks[task].desired_status = 'STOPPED' tasks[task].desired_status = 'STOPPED'
tasks[task].stopped_reason = reason tasks[task].stopped_reason = reason
return tasks[task] return tasks[task]
raise Exception("Could not find task {} on cluster {}".format(task_str, cluster_name)) raise Exception("Could not find task {} on cluster {}".format(
task_str, cluster_name))
def create_service(self, cluster_str, service_name, task_definition_str, desired_count): def create_service(self, cluster_str, service_name, task_definition_str, desired_count):
cluster_name = cluster_str.split('/')[-1] cluster_name = cluster_str.split('/')[-1]
@ -458,7 +483,8 @@ class EC2ContainerServiceBackend(BaseBackend):
raise Exception("{0} is not a cluster".format(cluster_name)) raise Exception("{0} is not a cluster".format(cluster_name))
task_definition = self.describe_task_definition(task_definition_str) task_definition = self.describe_task_definition(task_definition_str)
desired_count = desired_count if desired_count is not None else 0 desired_count = desired_count if desired_count is not None else 0
service = Service(cluster, service_name, task_definition, desired_count) service = Service(cluster, service_name,
task_definition, desired_count)
cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name) cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name)
self.services[cluster_service_pair] = service self.services[cluster_service_pair] = service
return service return service
@ -476,7 +502,8 @@ class EC2ContainerServiceBackend(BaseBackend):
result = [] result = []
for existing_service_name, existing_service_obj in sorted(self.services.items()): for existing_service_name, existing_service_obj in sorted(self.services.items()):
for requested_name_or_arn in service_names_or_arns: for requested_name_or_arn in service_names_or_arns:
cluster_service_pair = '{0}:{1}'.format(cluster_name, requested_name_or_arn) cluster_service_pair = '{0}:{1}'.format(
cluster_name, requested_name_or_arn)
if cluster_service_pair == existing_service_name or existing_service_obj.arn == requested_name_or_arn: if cluster_service_pair == existing_service_name or existing_service_obj.arn == requested_name_or_arn:
result.append(existing_service_obj) result.append(existing_service_obj)
return result return result
@ -486,13 +513,16 @@ class EC2ContainerServiceBackend(BaseBackend):
cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name) cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name)
if cluster_service_pair in self.services: if cluster_service_pair in self.services:
if task_definition_str is not None: if task_definition_str is not None:
task_definition = self.describe_task_definition(task_definition_str) self.describe_task_definition(task_definition_str)
self.services[cluster_service_pair].task_definition = task_definition_str self.services[
cluster_service_pair].task_definition = task_definition_str
if desired_count is not None: if desired_count is not None:
self.services[cluster_service_pair].desired_count = desired_count self.services[
cluster_service_pair].desired_count = desired_count
return self.services[cluster_service_pair] return self.services[cluster_service_pair]
else: else:
raise Exception("cluster {0} or service {1} does not exist".format(cluster_name, service_name)) raise Exception("cluster {0} or service {1} does not exist".format(
cluster_name, service_name))
def delete_service(self, cluster_name, service_name): def delete_service(self, cluster_name, service_name):
cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name) cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name)
@ -503,7 +533,8 @@ class EC2ContainerServiceBackend(BaseBackend):
else: else:
return self.services.pop(cluster_service_pair) return self.services.pop(cluster_service_pair)
else: else:
raise Exception("cluster {0} or service {1} does not exist".format(cluster_name, service_name)) raise Exception("cluster {0} or service {1} does not exist".format(
cluster_name, service_name))
def register_container_instance(self, cluster_str, ec2_instance_id): def register_container_instance(self, cluster_str, ec2_instance_id):
cluster_name = cluster_str.split('/')[-1] cluster_name = cluster_str.split('/')[-1]
@ -512,14 +543,18 @@ class EC2ContainerServiceBackend(BaseBackend):
container_instance = ContainerInstance(ec2_instance_id) container_instance = ContainerInstance(ec2_instance_id)
if not self.container_instances.get(cluster_name): if not self.container_instances.get(cluster_name):
self.container_instances[cluster_name] = {} self.container_instances[cluster_name] = {}
container_instance_id = container_instance.containerInstanceArn.split('/')[-1] container_instance_id = container_instance.containerInstanceArn.split(
self.container_instances[cluster_name][container_instance_id] = container_instance '/')[-1]
self.container_instances[cluster_name][
container_instance_id] = container_instance
return container_instance return container_instance
def list_container_instances(self, cluster_str): def list_container_instances(self, cluster_str):
cluster_name = cluster_str.split('/')[-1] cluster_name = cluster_str.split('/')[-1]
container_instances_values = self.container_instances.get(cluster_name, {}).values() container_instances_values = self.container_instances.get(
container_instances = [ci.containerInstanceArn for ci in container_instances_values] cluster_name, {}).values()
container_instances = [
ci.containerInstanceArn for ci in container_instances_values]
return sorted(container_instances) return sorted(container_instances)
def describe_container_instances(self, cluster_str, list_container_instance_ids): def describe_container_instances(self, cluster_str, list_container_instance_ids):
@ -529,11 +564,13 @@ class EC2ContainerServiceBackend(BaseBackend):
failures = [] failures = []
container_instance_objects = [] container_instance_objects = []
for container_instance_id in list_container_instance_ids: for container_instance_id in list_container_instance_ids:
container_instance = self.container_instances[cluster_name].get(container_instance_id, None) container_instance = self.container_instances[
cluster_name].get(container_instance_id, None)
if container_instance is not None: if container_instance is not None:
container_instance_objects.append(container_instance) container_instance_objects.append(container_instance)
else: else:
failures.append(ContainerInstanceFailure('MISSING', container_instance_id)) failures.append(ContainerInstanceFailure(
'MISSING', container_instance_id))
return container_instance_objects, failures return container_instance_objects, failures

View File

@ -1,12 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import json import json
import uuid
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import ecs_backends from .models import ecs_backends
class EC2ContainerServiceResponse(BaseResponse): class EC2ContainerServiceResponse(BaseResponse):
@property @property
def ecs_backend(self): def ecs_backend(self):
return ecs_backends[self.region] return ecs_backends[self.region]
@ -34,8 +34,7 @@ class EC2ContainerServiceResponse(BaseResponse):
cluster_arns = self.ecs_backend.list_clusters() cluster_arns = self.ecs_backend.list_clusters()
return json.dumps({ return json.dumps({
'clusterArns': cluster_arns 'clusterArns': cluster_arns
#, # 'nextToken': str(uuid.uuid1())
#'nextToken': str(uuid.uuid1())
}) })
def describe_clusters(self): def describe_clusters(self):
@ -57,7 +56,8 @@ class EC2ContainerServiceResponse(BaseResponse):
family = self._get_param('family') family = self._get_param('family')
container_definitions = self._get_param('containerDefinitions') container_definitions = self._get_param('containerDefinitions')
volumes = self._get_param('volumes') volumes = self._get_param('volumes')
task_definition = self.ecs_backend.register_task_definition(family, container_definitions, volumes) task_definition = self.ecs_backend.register_task_definition(
family, container_definitions, volumes)
return json.dumps({ return json.dumps({
'taskDefinition': task_definition.response_object 'taskDefinition': task_definition.response_object
}) })
@ -66,43 +66,7 @@ class EC2ContainerServiceResponse(BaseResponse):
task_definition_arns = self.ecs_backend.list_task_definitions() task_definition_arns = self.ecs_backend.list_task_definitions()
return json.dumps({ return json.dumps({
'taskDefinitionArns': task_definition_arns 'taskDefinitionArns': task_definition_arns
#, # 'nextToken': str(uuid.uuid1())
#'nextToken': str(uuid.uuid1())
})
def describe_task_definition(self):
task_definition_str = self._get_param('taskDefinition')
task_definition = self.ecs_backend.describe_task_definition(task_definition_str)
return json.dumps({
'taskDefinition': task_definition.response_object
})
def deregister_task_definition(self):
task_definition_str = self._get_param('taskDefinition')
task_definition = self.ecs_backend.deregister_task_definition(task_definition_str)
return json.dumps({
'taskDefinition': task_definition.response_object
})
def run_task(self):
cluster_str = self._get_param('cluster')
overrides = self._get_param('overrides')
task_definition_str = self._get_param('taskDefinition')
count = self._get_int_param('count')
started_by = self._get_param('startedBy')
tasks = self.ecs_backend.run_task(cluster_str, task_definition_str, count, overrides, started_by)
return json.dumps({
'tasks': [task.response_object for task in tasks],
'failures': []
})
def describe_tasks(self):
cluster = self._get_param('cluster')
tasks = self._get_param('tasks')
data = self.ecs_backend.describe_tasks(cluster, tasks)
return json.dumps({
'tasks': [task.response_object for task in data],
'failures': []
}) })
def describe_task_definition(self): def describe_task_definition(self):
@ -113,17 +77,48 @@ class EC2ContainerServiceResponse(BaseResponse):
'failures': [] 'failures': []
}) })
def deregister_task_definition(self):
task_definition_str = self._get_param('taskDefinition')
task_definition = self.ecs_backend.deregister_task_definition(
task_definition_str)
return json.dumps({
'taskDefinition': task_definition.response_object
})
def run_task(self):
cluster_str = self._get_param('cluster')
overrides = self._get_param('overrides')
task_definition_str = self._get_param('taskDefinition')
count = self._get_int_param('count')
started_by = self._get_param('startedBy')
tasks = self.ecs_backend.run_task(
cluster_str, task_definition_str, count, overrides, started_by)
return json.dumps({
'tasks': [task.response_object for task in tasks],
'failures': []
})
def describe_tasks(self):
cluster = self._get_param('cluster')
tasks = self._get_param('tasks')
data = self.ecs_backend.describe_tasks(cluster, tasks)
return json.dumps({
'tasks': [task.response_object for task in data],
'failures': []
})
def start_task(self): def start_task(self):
cluster_str = self._get_param('cluster') cluster_str = self._get_param('cluster')
overrides = self._get_param('overrides') overrides = self._get_param('overrides')
task_definition_str = self._get_param('taskDefinition') task_definition_str = self._get_param('taskDefinition')
container_instances = self._get_param('containerInstances') container_instances = self._get_param('containerInstances')
started_by = self._get_param('startedBy') started_by = self._get_param('startedBy')
tasks = self.ecs_backend.start_task(cluster_str, task_definition_str, container_instances, overrides, started_by) tasks = self.ecs_backend.start_task(
cluster_str, task_definition_str, container_instances, overrides, started_by)
return json.dumps({ return json.dumps({
'tasks': [task.response_object for task in tasks], 'tasks': [task.response_object for task in tasks],
'failures': [] 'failures': []
}) })
def list_tasks(self): def list_tasks(self):
cluster_str = self._get_param('cluster') cluster_str = self._get_param('cluster')
@ -132,11 +127,11 @@ class EC2ContainerServiceResponse(BaseResponse):
started_by = self._get_param('startedBy') started_by = self._get_param('startedBy')
service_name = self._get_param('serviceName') service_name = self._get_param('serviceName')
desiredStatus = self._get_param('desiredStatus') desiredStatus = self._get_param('desiredStatus')
task_arns = self.ecs_backend.list_tasks(cluster_str, container_instance, family, started_by, service_name, desiredStatus) task_arns = self.ecs_backend.list_tasks(
cluster_str, container_instance, family, started_by, service_name, desiredStatus)
return json.dumps({ return json.dumps({
'taskArns': task_arns 'taskArns': task_arns
}) })
def stop_task(self): def stop_task(self):
cluster_str = self._get_param('cluster') cluster_str = self._get_param('cluster')
@ -145,15 +140,15 @@ class EC2ContainerServiceResponse(BaseResponse):
task = self.ecs_backend.stop_task(cluster_str, task, reason) task = self.ecs_backend.stop_task(cluster_str, task, reason)
return json.dumps({ return json.dumps({
'task': task.response_object 'task': task.response_object
}) })
def create_service(self): def create_service(self):
cluster_str = self._get_param('cluster') cluster_str = self._get_param('cluster')
service_name = self._get_param('serviceName') service_name = self._get_param('serviceName')
task_definition_str = self._get_param('taskDefinition') task_definition_str = self._get_param('taskDefinition')
desired_count = self._get_int_param('desiredCount') desired_count = self._get_int_param('desiredCount')
service = self.ecs_backend.create_service(cluster_str, service_name, task_definition_str, desired_count) service = self.ecs_backend.create_service(
cluster_str, service_name, task_definition_str, desired_count)
return json.dumps({ return json.dumps({
'service': service.response_object 'service': service.response_object
}) })
@ -170,7 +165,8 @@ class EC2ContainerServiceResponse(BaseResponse):
def describe_services(self): def describe_services(self):
cluster_str = self._get_param('cluster') cluster_str = self._get_param('cluster')
service_names = self._get_param('services') service_names = self._get_param('services')
services = self.ecs_backend.describe_services(cluster_str, service_names) services = self.ecs_backend.describe_services(
cluster_str, service_names)
return json.dumps({ return json.dumps({
'services': [service.response_object for service in services], 'services': [service.response_object for service in services],
'failures': [] 'failures': []
@ -181,7 +177,8 @@ class EC2ContainerServiceResponse(BaseResponse):
service_name = self._get_param('service') service_name = self._get_param('service')
task_definition = self._get_param('taskDefinition') task_definition = self._get_param('taskDefinition')
desired_count = self._get_int_param('desiredCount') desired_count = self._get_int_param('desiredCount')
service = self.ecs_backend.update_service(cluster_str, service_name, task_definition, desired_count) service = self.ecs_backend.update_service(
cluster_str, service_name, task_definition, desired_count)
return json.dumps({ return json.dumps({
'service': service.response_object 'service': service.response_object
}) })
@ -196,17 +193,20 @@ class EC2ContainerServiceResponse(BaseResponse):
def register_container_instance(self): def register_container_instance(self):
cluster_str = self._get_param('cluster') cluster_str = self._get_param('cluster')
instance_identity_document_str = self._get_param('instanceIdentityDocument') instance_identity_document_str = self._get_param(
'instanceIdentityDocument')
instance_identity_document = json.loads(instance_identity_document_str) instance_identity_document = json.loads(instance_identity_document_str)
ec2_instance_id = instance_identity_document["instanceId"] ec2_instance_id = instance_identity_document["instanceId"]
container_instance = self.ecs_backend.register_container_instance(cluster_str, ec2_instance_id) container_instance = self.ecs_backend.register_container_instance(
cluster_str, ec2_instance_id)
return json.dumps({ return json.dumps({
'containerInstance' : container_instance.response_object 'containerInstance': container_instance.response_object
}) })
def list_container_instances(self): def list_container_instances(self):
cluster_str = self._get_param('cluster') cluster_str = self._get_param('cluster')
container_instance_arns = self.ecs_backend.list_container_instances(cluster_str) container_instance_arns = self.ecs_backend.list_container_instances(
cluster_str)
return json.dumps({ return json.dumps({
'containerInstanceArns': container_instance_arns 'containerInstanceArns': container_instance_arns
}) })
@ -214,8 +214,9 @@ class EC2ContainerServiceResponse(BaseResponse):
def describe_container_instances(self): def describe_container_instances(self):
cluster_str = self._get_param('cluster') cluster_str = self._get_param('cluster')
list_container_instance_arns = self._get_param('containerInstances') list_container_instance_arns = self._get_param('containerInstances')
container_instances, failures = self.ecs_backend.describe_container_instances(cluster_str, list_container_instance_arns) container_instances, failures = self.ecs_backend.describe_container_instances(
cluster_str, list_container_instance_arns)
return json.dumps({ return json.dumps({
'failures': [ci.response_object for ci in failures], 'failures': [ci.response_object for ci in failures],
'containerInstances': [ci.response_object for ci in container_instances] 'containerInstances': [ci.response_object for ci in container_instances]
}) })

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import elb_backends from .models import elb_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
elb_backend = elb_backends['us-east-1'] elb_backend = elb_backends['us-east-1']
mock_elb = base_decorator(elb_backends) mock_elb = base_decorator(elb_backends)

View File

@ -7,6 +7,7 @@ class ELBClientError(RESTError):
class DuplicateTagKeysError(ELBClientError): class DuplicateTagKeysError(ELBClientError):
def __init__(self, cidr): def __init__(self, cidr):
super(DuplicateTagKeysError, self).__init__( super(DuplicateTagKeysError, self).__init__(
"DuplicateTagKeys", "DuplicateTagKeys",
@ -15,6 +16,7 @@ class DuplicateTagKeysError(ELBClientError):
class LoadBalancerNotFoundError(ELBClientError): class LoadBalancerNotFoundError(ELBClientError):
def __init__(self, cidr): def __init__(self, cidr):
super(LoadBalancerNotFoundError, self).__init__( super(LoadBalancerNotFoundError, self).__init__(
"LoadBalancerNotFound", "LoadBalancerNotFound",
@ -23,6 +25,7 @@ class LoadBalancerNotFoundError(ELBClientError):
class TooManyTagsError(ELBClientError): class TooManyTagsError(ELBClientError):
def __init__(self): def __init__(self):
super(TooManyTagsError, self).__init__( super(TooManyTagsError, self).__init__(
"LoadBalancerNotFound", "LoadBalancerNotFound",
@ -30,6 +33,7 @@ class TooManyTagsError(ELBClientError):
class BadHealthCheckDefinition(ELBClientError): class BadHealthCheckDefinition(ELBClientError):
def __init__(self): def __init__(self):
super(BadHealthCheckDefinition, self).__init__( super(BadHealthCheckDefinition, self).__init__(
"ValidationError", "ValidationError",
@ -37,9 +41,9 @@ class BadHealthCheckDefinition(ELBClientError):
class DuplicateLoadBalancerName(ELBClientError): class DuplicateLoadBalancerName(ELBClientError):
def __init__(self, name): def __init__(self, name):
super(DuplicateLoadBalancerName, self).__init__( super(DuplicateLoadBalancerName, self).__init__(
"DuplicateLoadBalancerName", "DuplicateLoadBalancerName",
"The specified load balancer name already exists for this account: {0}" "The specified load balancer name already exists for this account: {0}"
.format(name)) .format(name))

View File

@ -1,6 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto.ec2.elb
from boto.ec2.elb.attributes import ( from boto.ec2.elb.attributes import (
LbAttributes, LbAttributes,
ConnectionSettingAttribute, ConnectionSettingAttribute,
@ -22,8 +21,8 @@ from .exceptions import (
) )
class FakeHealthCheck(object): class FakeHealthCheck(object):
def __init__(self, timeout, healthy_threshold, unhealthy_threshold, def __init__(self, timeout, healthy_threshold, unhealthy_threshold,
interval, target): interval, target):
self.timeout = timeout self.timeout = timeout
@ -36,6 +35,7 @@ class FakeHealthCheck(object):
class FakeListener(object): class FakeListener(object):
def __init__(self, load_balancer_port, instance_port, protocol, ssl_certificate_id): def __init__(self, load_balancer_port, instance_port, protocol, ssl_certificate_id):
self.load_balancer_port = load_balancer_port self.load_balancer_port = load_balancer_port
self.instance_port = instance_port self.instance_port = instance_port
@ -48,6 +48,7 @@ class FakeListener(object):
class FakeBackend(object): class FakeBackend(object):
def __init__(self, instance_port): def __init__(self, instance_port):
self.instance_port = instance_port self.instance_port = instance_port
self.policy_names = [] self.policy_names = []
@ -57,6 +58,7 @@ class FakeBackend(object):
class FakeLoadBalancer(object): class FakeLoadBalancer(object):
def __init__(self, name, zones, ports, scheme='internet-facing', vpc_id=None, subnets=None): def __init__(self, name, zones, ports, scheme='internet-facing', vpc_id=None, subnets=None):
self.name = name self.name = name
self.health_check = None self.health_check = None
@ -78,16 +80,20 @@ class FakeLoadBalancer(object):
for port in ports: for port in ports:
listener = FakeListener( listener = FakeListener(
protocol=(port.get('protocol') or port['Protocol']), protocol=(port.get('protocol') or port['Protocol']),
load_balancer_port=(port.get('load_balancer_port') or port['LoadBalancerPort']), load_balancer_port=(
instance_port=(port.get('instance_port') or port['InstancePort']), port.get('load_balancer_port') or port['LoadBalancerPort']),
ssl_certificate_id=port.get('sslcertificate_id', port.get('SSLCertificateId')), instance_port=(
port.get('instance_port') or port['InstancePort']),
ssl_certificate_id=port.get(
'sslcertificate_id', port.get('SSLCertificateId')),
) )
self.listeners.append(listener) self.listeners.append(listener)
# it is unclear per the AWS documentation as to when or how backend # it is unclear per the AWS documentation as to when or how backend
# information gets set, so let's guess and set it here *shrug* # information gets set, so let's guess and set it here *shrug*
backend = FakeBackend( backend = FakeBackend(
instance_port=(port.get('instance_port') or port['InstancePort']), instance_port=(
port.get('instance_port') or port['InstancePort']),
) )
self.backends.append(backend) self.backends.append(backend)
@ -120,7 +126,8 @@ class FakeLoadBalancer(object):
port_policies[port] = policies_for_port port_policies[port] = policies_for_port
for port, policies in port_policies.items(): for port, policies in port_policies.items():
elb_backend.set_load_balancer_policies_of_backend_server(new_elb.name, port, list(policies)) elb_backend.set_load_balancer_policies_of_backend_server(
new_elb.name, port, list(policies))
health_check = properties.get('HealthCheck') health_check = properties.get('HealthCheck')
if health_check: if health_check:
@ -137,7 +144,8 @@ class FakeLoadBalancer(object):
@classmethod @classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
cls.delete_from_cloudformation_json(original_resource.name, cloudformation_json, region_name) cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name)
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name)
@classmethod @classmethod
@ -155,15 +163,19 @@ class FakeLoadBalancer(object):
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == 'CanonicalHostedZoneName': if attribute_name == 'CanonicalHostedZoneName':
raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneName" ]"') raise NotImplementedError(
'"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneName" ]"')
elif attribute_name == 'CanonicalHostedZoneNameID': elif attribute_name == 'CanonicalHostedZoneNameID':
raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneNameID" ]"') raise NotImplementedError(
'"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneNameID" ]"')
elif attribute_name == 'DNSName': elif attribute_name == 'DNSName':
return self.dns_name return self.dns_name
elif attribute_name == 'SourceSecurityGroup.GroupName': elif attribute_name == 'SourceSecurityGroup.GroupName':
raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.GroupName" ]"') raise NotImplementedError(
'"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.GroupName" ]"')
elif attribute_name == 'SourceSecurityGroup.OwnerAlias': elif attribute_name == 'SourceSecurityGroup.OwnerAlias':
raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.OwnerAlias" ]"') raise NotImplementedError(
'"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.OwnerAlias" ]"')
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@classmethod @classmethod
@ -224,7 +236,8 @@ class ELBBackend(BaseBackend):
vpc_id = subnet.vpc_id vpc_id = subnet.vpc_id
if name in self.load_balancers: if name in self.load_balancers:
raise DuplicateLoadBalancerName(name) raise DuplicateLoadBalancerName(name)
new_load_balancer = FakeLoadBalancer(name=name, zones=zones, ports=ports, scheme=scheme, subnets=subnets, vpc_id=vpc_id) new_load_balancer = FakeLoadBalancer(
name=name, zones=zones, ports=ports, scheme=scheme, subnets=subnets, vpc_id=vpc_id)
self.load_balancers[name] = new_load_balancer self.load_balancers[name] = new_load_balancer
return new_load_balancer return new_load_balancer
@ -240,14 +253,16 @@ class ELBBackend(BaseBackend):
if lb_port == listener.load_balancer_port: if lb_port == listener.load_balancer_port:
break break
else: else:
balancer.listeners.append(FakeListener(lb_port, instance_port, protocol, ssl_certificate_id)) balancer.listeners.append(FakeListener(
lb_port, instance_port, protocol, ssl_certificate_id))
return balancer return balancer
def describe_load_balancers(self, names): def describe_load_balancers(self, names):
balancers = self.load_balancers.values() balancers = self.load_balancers.values()
if names: if names:
matched_balancers = [balancer for balancer in balancers if balancer.name in names] matched_balancers = [
balancer for balancer in balancers if balancer.name in names]
if len(names) != len(matched_balancers): if len(names) != len(matched_balancers):
missing_elb = list(set(names) - set(matched_balancers))[0] missing_elb = list(set(names) - set(matched_balancers))[0]
raise LoadBalancerNotFoundError(missing_elb) raise LoadBalancerNotFoundError(missing_elb)
@ -288,7 +303,8 @@ class ELBBackend(BaseBackend):
if balancer: if balancer:
for idx, listener in enumerate(balancer.listeners): for idx, listener in enumerate(balancer.listeners):
if lb_port == listener.load_balancer_port: if lb_port == listener.load_balancer_port:
balancer.listeners[idx].ssl_certificate_id = ssl_certificate_id balancer.listeners[
idx].ssl_certificate_id = ssl_certificate_id
return balancer return balancer
@ -299,7 +315,8 @@ class ELBBackend(BaseBackend):
def deregister_instances(self, load_balancer_name, instance_ids): def deregister_instances(self, load_balancer_name, instance_ids):
load_balancer = self.get_load_balancer(load_balancer_name) load_balancer = self.get_load_balancer(load_balancer_name)
new_instance_ids = [instance_id for instance_id in load_balancer.instance_ids if instance_id not in instance_ids] new_instance_ids = [
instance_id for instance_id in load_balancer.instance_ids if instance_id not in instance_ids]
load_balancer.instance_ids = new_instance_ids load_balancer.instance_ids = new_instance_ids
return load_balancer return load_balancer
@ -342,7 +359,8 @@ class ELBBackend(BaseBackend):
def set_load_balancer_policies_of_backend_server(self, load_balancer_name, instance_port, policies): def set_load_balancer_policies_of_backend_server(self, load_balancer_name, instance_port, policies):
load_balancer = self.get_load_balancer(load_balancer_name) load_balancer = self.get_load_balancer(load_balancer_name)
backend = [b for b in load_balancer.backends if int(b.instance_port) == instance_port][0] backend = [b for b in load_balancer.backends if int(
b.instance_port) == instance_port][0]
backend_idx = load_balancer.backends.index(backend) backend_idx = load_balancer.backends.index(backend)
backend.policy_names = policies backend.policy_names = policies
load_balancer.backends[backend_idx] = backend load_balancer.backends[backend_idx] = backend
@ -350,7 +368,8 @@ class ELBBackend(BaseBackend):
def set_load_balancer_policies_of_listener(self, load_balancer_name, load_balancer_port, policies): def set_load_balancer_policies_of_listener(self, load_balancer_name, load_balancer_port, policies):
load_balancer = self.get_load_balancer(load_balancer_name) load_balancer = self.get_load_balancer(load_balancer_name)
listener = [l for l in load_balancer.listeners if int(l.load_balancer_port) == load_balancer_port][0] listener = [l for l in load_balancer.listeners if int(
l.load_balancer_port) == load_balancer_port][0]
listener_idx = load_balancer.listeners.index(listener) listener_idx = load_balancer.listeners.index(listener)
listener.policy_names = policies listener.policy_names = policies
load_balancer.listeners[listener_idx] = listener load_balancer.listeners[listener_idx] = listener

View File

@ -43,9 +43,11 @@ class ELBResponse(BaseResponse):
load_balancer_name = self._get_param('LoadBalancerName') load_balancer_name = self._get_param('LoadBalancerName')
ports = self._get_list_prefix("Listeners.member") ports = self._get_list_prefix("Listeners.member")
self.elb_backend.create_load_balancer_listeners(name=load_balancer_name, ports=ports) self.elb_backend.create_load_balancer_listeners(
name=load_balancer_name, ports=ports)
template = self.response_template(CREATE_LOAD_BALANCER_LISTENERS_TEMPLATE) template = self.response_template(
CREATE_LOAD_BALANCER_LISTENERS_TEMPLATE)
return template.render() return template.render()
def describe_load_balancers(self): def describe_load_balancers(self):
@ -59,7 +61,8 @@ class ELBResponse(BaseResponse):
ports = self._get_multi_param("LoadBalancerPorts.member") ports = self._get_multi_param("LoadBalancerPorts.member")
ports = [int(port) for port in ports] ports = [int(port) for port in ports]
self.elb_backend.delete_load_balancer_listeners(load_balancer_name, ports) self.elb_backend.delete_load_balancer_listeners(
load_balancer_name, ports)
template = self.response_template(DELETE_LOAD_BALANCER_LISTENERS) template = self.response_template(DELETE_LOAD_BALANCER_LISTENERS)
return template.render() return template.render()
@ -74,7 +77,8 @@ class ELBResponse(BaseResponse):
load_balancer_name=self._get_param('LoadBalancerName'), load_balancer_name=self._get_param('LoadBalancerName'),
timeout=self._get_param('HealthCheck.Timeout'), timeout=self._get_param('HealthCheck.Timeout'),
healthy_threshold=self._get_param('HealthCheck.HealthyThreshold'), healthy_threshold=self._get_param('HealthCheck.HealthyThreshold'),
unhealthy_threshold=self._get_param('HealthCheck.UnhealthyThreshold'), unhealthy_threshold=self._get_param(
'HealthCheck.UnhealthyThreshold'),
interval=self._get_param('HealthCheck.Interval'), interval=self._get_param('HealthCheck.Interval'),
target=self._get_param('HealthCheck.Target'), target=self._get_param('HealthCheck.Target'),
) )
@ -83,9 +87,11 @@ class ELBResponse(BaseResponse):
def register_instances_with_load_balancer(self): def register_instances_with_load_balancer(self):
load_balancer_name = self._get_param('LoadBalancerName') load_balancer_name = self._get_param('LoadBalancerName')
instance_ids = [value[0] for key, value in self.querystring.items() if "Instances.member" in key] instance_ids = [value[0] for key, value in self.querystring.items(
) if "Instances.member" in key]
template = self.response_template(REGISTER_INSTANCES_TEMPLATE) template = self.response_template(REGISTER_INSTANCES_TEMPLATE)
load_balancer = self.elb_backend.register_instances(load_balancer_name, instance_ids) load_balancer = self.elb_backend.register_instances(
load_balancer_name, instance_ids)
return template.render(load_balancer=load_balancer) return template.render(load_balancer=load_balancer)
def set_load_balancer_listener_sslcertificate(self): def set_load_balancer_listener_sslcertificate(self):
@ -93,16 +99,19 @@ class ELBResponse(BaseResponse):
ssl_certificate_id = self.querystring['SSLCertificateId'][0] ssl_certificate_id = self.querystring['SSLCertificateId'][0]
lb_port = self.querystring['LoadBalancerPort'][0] lb_port = self.querystring['LoadBalancerPort'][0]
self.elb_backend.set_load_balancer_listener_sslcertificate(load_balancer_name, lb_port, ssl_certificate_id) self.elb_backend.set_load_balancer_listener_sslcertificate(
load_balancer_name, lb_port, ssl_certificate_id)
template = self.response_template(SET_LOAD_BALANCER_SSL_CERTIFICATE) template = self.response_template(SET_LOAD_BALANCER_SSL_CERTIFICATE)
return template.render() return template.render()
def deregister_instances_from_load_balancer(self): def deregister_instances_from_load_balancer(self):
load_balancer_name = self._get_param('LoadBalancerName') load_balancer_name = self._get_param('LoadBalancerName')
instance_ids = [value[0] for key, value in self.querystring.items() if "Instances.member" in key] instance_ids = [value[0] for key, value in self.querystring.items(
) if "Instances.member" in key]
template = self.response_template(DEREGISTER_INSTANCES_TEMPLATE) template = self.response_template(DEREGISTER_INSTANCES_TEMPLATE)
load_balancer = self.elb_backend.deregister_instances(load_balancer_name, instance_ids) load_balancer = self.elb_backend.deregister_instances(
load_balancer_name, instance_ids)
return template.render(load_balancer=load_balancer) return template.render(load_balancer=load_balancer)
def describe_load_balancer_attributes(self): def describe_load_balancer_attributes(self):
@ -115,11 +124,13 @@ class ELBResponse(BaseResponse):
load_balancer_name = self._get_param('LoadBalancerName') load_balancer_name = self._get_param('LoadBalancerName')
load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) load_balancer = self.elb_backend.get_load_balancer(load_balancer_name)
cross_zone = self._get_dict_param("LoadBalancerAttributes.CrossZoneLoadBalancing.") cross_zone = self._get_dict_param(
"LoadBalancerAttributes.CrossZoneLoadBalancing.")
if cross_zone: if cross_zone:
attribute = CrossZoneLoadBalancingAttribute() attribute = CrossZoneLoadBalancingAttribute()
attribute.enabled = cross_zone["enabled"] == "true" attribute.enabled = cross_zone["enabled"] == "true"
self.elb_backend.set_cross_zone_load_balancing_attribute(load_balancer_name, attribute) self.elb_backend.set_cross_zone_load_balancing_attribute(
load_balancer_name, attribute)
access_log = self._get_dict_param("LoadBalancerAttributes.AccessLog.") access_log = self._get_dict_param("LoadBalancerAttributes.AccessLog.")
if access_log: if access_log:
@ -128,20 +139,25 @@ class ELBResponse(BaseResponse):
attribute.s3_bucket_name = access_log['s3_bucket_name'] attribute.s3_bucket_name = access_log['s3_bucket_name']
attribute.s3_bucket_prefix = access_log['s3_bucket_prefix'] attribute.s3_bucket_prefix = access_log['s3_bucket_prefix']
attribute.emit_interval = access_log["emit_interval"] attribute.emit_interval = access_log["emit_interval"]
self.elb_backend.set_access_log_attribute(load_balancer_name, attribute) self.elb_backend.set_access_log_attribute(
load_balancer_name, attribute)
connection_draining = self._get_dict_param("LoadBalancerAttributes.ConnectionDraining.") connection_draining = self._get_dict_param(
"LoadBalancerAttributes.ConnectionDraining.")
if connection_draining: if connection_draining:
attribute = ConnectionDrainingAttribute() attribute = ConnectionDrainingAttribute()
attribute.enabled = connection_draining["enabled"] == "true" attribute.enabled = connection_draining["enabled"] == "true"
attribute.timeout = connection_draining["timeout"] attribute.timeout = connection_draining["timeout"]
self.elb_backend.set_connection_draining_attribute(load_balancer_name, attribute) self.elb_backend.set_connection_draining_attribute(
load_balancer_name, attribute)
connection_settings = self._get_dict_param("LoadBalancerAttributes.ConnectionSettings.") connection_settings = self._get_dict_param(
"LoadBalancerAttributes.ConnectionSettings.")
if connection_settings: if connection_settings:
attribute = ConnectionSettingAttribute() attribute = ConnectionSettingAttribute()
attribute.idle_timeout = connection_settings["idle_timeout"] attribute.idle_timeout = connection_settings["idle_timeout"]
self.elb_backend.set_connection_settings_attribute(load_balancer_name, attribute) self.elb_backend.set_connection_settings_attribute(
load_balancer_name, attribute)
template = self.response_template(MODIFY_ATTRIBUTES_TEMPLATE) template = self.response_template(MODIFY_ATTRIBUTES_TEMPLATE)
return template.render(attributes=load_balancer.attributes) return template.render(attributes=load_balancer.attributes)
@ -153,7 +169,8 @@ class ELBResponse(BaseResponse):
policy_name = self._get_param("PolicyName") policy_name = self._get_param("PolicyName")
other_policy.policy_name = policy_name other_policy.policy_name = policy_name
self.elb_backend.create_lb_other_policy(load_balancer_name, other_policy) self.elb_backend.create_lb_other_policy(
load_balancer_name, other_policy)
template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE) template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE)
return template.render() return template.render()
@ -165,7 +182,8 @@ class ELBResponse(BaseResponse):
policy.policy_name = self._get_param("PolicyName") policy.policy_name = self._get_param("PolicyName")
policy.cookie_name = self._get_param("CookieName") policy.cookie_name = self._get_param("CookieName")
self.elb_backend.create_app_cookie_stickiness_policy(load_balancer_name, policy) self.elb_backend.create_app_cookie_stickiness_policy(
load_balancer_name, policy)
template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE) template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE)
return template.render() return template.render()
@ -181,7 +199,8 @@ class ELBResponse(BaseResponse):
else: else:
policy.cookie_expiration_period = None policy.cookie_expiration_period = None
self.elb_backend.create_lb_cookie_stickiness_policy(load_balancer_name, policy) self.elb_backend.create_lb_cookie_stickiness_policy(
load_balancer_name, policy)
template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE) template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE)
return template.render() return template.render()
@ -191,13 +210,16 @@ class ELBResponse(BaseResponse):
load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) load_balancer = self.elb_backend.get_load_balancer(load_balancer_name)
load_balancer_port = int(self._get_param('LoadBalancerPort')) load_balancer_port = int(self._get_param('LoadBalancerPort'))
mb_listener = [l for l in load_balancer.listeners if int(l.load_balancer_port) == load_balancer_port] mb_listener = [l for l in load_balancer.listeners if int(
l.load_balancer_port) == load_balancer_port]
if mb_listener: if mb_listener:
policies = self._get_multi_param("PolicyNames.member") policies = self._get_multi_param("PolicyNames.member")
self.elb_backend.set_load_balancer_policies_of_listener(load_balancer_name, load_balancer_port, policies) self.elb_backend.set_load_balancer_policies_of_listener(
load_balancer_name, load_balancer_port, policies)
# else: explode? # else: explode?
template = self.response_template(SET_LOAD_BALANCER_POLICIES_OF_LISTENER_TEMPLATE) template = self.response_template(
SET_LOAD_BALANCER_POLICIES_OF_LISTENER_TEMPLATE)
return template.render() return template.render()
def set_load_balancer_policies_for_backend_server(self): def set_load_balancer_policies_for_backend_server(self):
@ -205,20 +227,25 @@ class ELBResponse(BaseResponse):
load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) load_balancer = self.elb_backend.get_load_balancer(load_balancer_name)
instance_port = int(self.querystring.get('InstancePort')[0]) instance_port = int(self.querystring.get('InstancePort')[0])
mb_backend = [b for b in load_balancer.backends if int(b.instance_port) == instance_port] mb_backend = [b for b in load_balancer.backends if int(
b.instance_port) == instance_port]
if mb_backend: if mb_backend:
policies = self._get_multi_param('PolicyNames.member') policies = self._get_multi_param('PolicyNames.member')
self.elb_backend.set_load_balancer_policies_of_backend_server(load_balancer_name, instance_port, policies) self.elb_backend.set_load_balancer_policies_of_backend_server(
load_balancer_name, instance_port, policies)
# else: explode? # else: explode?
template = self.response_template(SET_LOAD_BALANCER_POLICIES_FOR_BACKEND_SERVER_TEMPLATE) template = self.response_template(
SET_LOAD_BALANCER_POLICIES_FOR_BACKEND_SERVER_TEMPLATE)
return template.render() return template.render()
def describe_instance_health(self): def describe_instance_health(self):
load_balancer_name = self._get_param('LoadBalancerName') load_balancer_name = self._get_param('LoadBalancerName')
instance_ids = [value[0] for key, value in self.querystring.items() if "Instances.member" in key] instance_ids = [value[0] for key, value in self.querystring.items(
) if "Instances.member" in key]
if len(instance_ids) == 0: if len(instance_ids) == 0:
instance_ids = self.elb_backend.get_load_balancer(load_balancer_name).instance_ids instance_ids = self.elb_backend.get_load_balancer(
load_balancer_name).instance_ids
template = self.response_template(DESCRIBE_INSTANCE_HEALTH_TEMPLATE) template = self.response_template(DESCRIBE_INSTANCE_HEALTH_TEMPLATE)
return template.render(instance_ids=instance_ids) return template.render(instance_ids=instance_ids)
@ -226,7 +253,6 @@ class ELBResponse(BaseResponse):
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if "LoadBalancerNames.member" in key: if "LoadBalancerNames.member" in key:
number = key.split('.')[2]
load_balancer_name = value[0] load_balancer_name = value[0]
elb = self.elb_backend.get_load_balancer(load_balancer_name) elb = self.elb_backend.get_load_balancer(load_balancer_name)
if not elb: if not elb:
@ -241,7 +267,8 @@ class ELBResponse(BaseResponse):
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if "LoadBalancerNames.member" in key: if "LoadBalancerNames.member" in key:
number = key.split('.')[2] number = key.split('.')[2]
load_balancer_name = self._get_param('LoadBalancerNames.member.{0}'.format(number)) load_balancer_name = self._get_param(
'LoadBalancerNames.member.{0}'.format(number))
elb = self.elb_backend.get_load_balancer(load_balancer_name) elb = self.elb_backend.get_load_balancer(load_balancer_name)
if not elb: if not elb:
raise LoadBalancerNotFoundError(load_balancer_name) raise LoadBalancerNotFoundError(load_balancer_name)
@ -260,7 +287,8 @@ class ELBResponse(BaseResponse):
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if "LoadBalancerNames.member" in key: if "LoadBalancerNames.member" in key:
number = key.split('.')[2] number = key.split('.')[2]
load_balancer_name = self._get_param('LoadBalancerNames.member.{0}'.format(number)) load_balancer_name = self._get_param(
'LoadBalancerNames.member.{0}'.format(number))
elb = self.elb_backend.get_load_balancer(load_balancer_name) elb = self.elb_backend.get_load_balancer(load_balancer_name)
if not elb: if not elb:
raise LoadBalancerNotFoundError(load_balancer_name) raise LoadBalancerNotFoundError(load_balancer_name)
@ -284,7 +312,7 @@ class ELBResponse(BaseResponse):
for i in tag_keys: for i in tag_keys:
counts[i] = tag_keys.count(i) counts[i] = tag_keys.count(i)
counts = sorted(counts.items(), key=lambda i:i[1], reverse=True) counts = sorted(counts.items(), key=lambda i: i[1], reverse=True)
if counts and counts[0][1] > 1: if counts and counts[0][1] > 1:
# We have dupes... # We have dupes...

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import emr_backends from .models import emr_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
emr_backend = emr_backends['us-east-1'] emr_backend = emr_backends['us-east-1']
mock_emr = base_decorator(emr_backends) mock_emr = base_decorator(emr_backends)

View File

@ -11,6 +11,7 @@ from .utils import random_instance_group_id, random_cluster_id, random_step_id
class FakeApplication(object): class FakeApplication(object):
def __init__(self, name, version, args=None, additional_info=None): def __init__(self, name, version, args=None, additional_info=None):
self.additional_info = additional_info or {} self.additional_info = additional_info or {}
self.args = args or [] self.args = args or []
@ -19,6 +20,7 @@ class FakeApplication(object):
class FakeBootstrapAction(object): class FakeBootstrapAction(object):
def __init__(self, args, name, script_path): def __init__(self, args, name, script_path):
self.args = args or [] self.args = args or []
self.name = name self.name = name
@ -26,6 +28,7 @@ class FakeBootstrapAction(object):
class FakeInstanceGroup(object): class FakeInstanceGroup(object):
def __init__(self, instance_count, instance_role, instance_type, def __init__(self, instance_count, instance_role, instance_type,
market='ON_DEMAND', name=None, id=None, bid_price=None): market='ON_DEMAND', name=None, id=None, bid_price=None):
self.id = id or random_instance_group_id() self.id = id or random_instance_group_id()
@ -55,6 +58,7 @@ class FakeInstanceGroup(object):
class FakeStep(object): class FakeStep(object):
def __init__(self, def __init__(self,
state, state,
name='', name='',
@ -78,6 +82,7 @@ class FakeStep(object):
class FakeCluster(object): class FakeCluster(object):
def __init__(self, def __init__(self,
emr_backend, emr_backend,
name, name,
@ -135,17 +140,24 @@ class FakeCluster(object):
'instance_type': instance_attrs['slave_instance_type'], 'instance_type': instance_attrs['slave_instance_type'],
'market': 'ON_DEMAND', 'market': 'ON_DEMAND',
'name': 'slave'}]) 'name': 'slave'}])
self.additional_master_security_groups = instance_attrs.get('additional_master_security_groups') self.additional_master_security_groups = instance_attrs.get(
self.additional_slave_security_groups = instance_attrs.get('additional_slave_security_groups') 'additional_master_security_groups')
self.additional_slave_security_groups = instance_attrs.get(
'additional_slave_security_groups')
self.availability_zone = instance_attrs.get('availability_zone') self.availability_zone = instance_attrs.get('availability_zone')
self.ec2_key_name = instance_attrs.get('ec2_key_name') self.ec2_key_name = instance_attrs.get('ec2_key_name')
self.ec2_subnet_id = instance_attrs.get('ec2_subnet_id') self.ec2_subnet_id = instance_attrs.get('ec2_subnet_id')
self.hadoop_version = instance_attrs.get('hadoop_version') self.hadoop_version = instance_attrs.get('hadoop_version')
self.keep_job_flow_alive_when_no_steps = instance_attrs.get('keep_job_flow_alive_when_no_steps') self.keep_job_flow_alive_when_no_steps = instance_attrs.get(
self.master_security_group = instance_attrs.get('emr_managed_master_security_group') 'keep_job_flow_alive_when_no_steps')
self.service_access_security_group = instance_attrs.get('service_access_security_group') self.master_security_group = instance_attrs.get(
self.slave_security_group = instance_attrs.get('emr_managed_slave_security_group') 'emr_managed_master_security_group')
self.termination_protected = instance_attrs.get('termination_protected') self.service_access_security_group = instance_attrs.get(
'service_access_security_group')
self.slave_security_group = instance_attrs.get(
'emr_managed_slave_security_group')
self.termination_protected = instance_attrs.get(
'termination_protected')
self.release_label = release_label self.release_label = release_label
self.requested_ami_version = requested_ami_version self.requested_ami_version = requested_ami_version
@ -286,7 +298,8 @@ class ElasticMapReduceBackend(BaseBackend):
clusters = self.clusters.values() clusters = self.clusters.values()
within_two_month = datetime.now(pytz.utc) - timedelta(days=60) within_two_month = datetime.now(pytz.utc) - timedelta(days=60)
clusters = [c for c in clusters if c.creation_datetime >= within_two_month] clusters = [
c for c in clusters if c.creation_datetime >= within_two_month]
if job_flow_ids: if job_flow_ids:
clusters = [c for c in clusters if c.id in job_flow_ids] clusters = [c for c in clusters if c.id in job_flow_ids]
@ -294,10 +307,12 @@ class ElasticMapReduceBackend(BaseBackend):
clusters = [c for c in clusters if c.state in job_flow_states] clusters = [c for c in clusters if c.state in job_flow_states]
if created_after: if created_after:
created_after = dtparse(created_after) created_after = dtparse(created_after)
clusters = [c for c in clusters if c.creation_datetime > created_after] clusters = [
c for c in clusters if c.creation_datetime > created_after]
if created_before: if created_before:
created_before = dtparse(created_before) created_before = dtparse(created_before)
clusters = [c for c in clusters if c.creation_datetime < created_before] clusters = [
c for c in clusters if c.creation_datetime < created_before]
# Amazon EMR can return a maximum of 512 job flow descriptions # Amazon EMR can return a maximum of 512 job flow descriptions
return sorted(clusters, key=lambda x: x.id)[:512] return sorted(clusters, key=lambda x: x.id)[:512]
@ -322,7 +337,8 @@ class ElasticMapReduceBackend(BaseBackend):
max_items = 50 max_items = 50
actions = self.clusters[cluster_id].bootstrap_actions actions = self.clusters[cluster_id].bootstrap_actions
start_idx = 0 if marker is None else int(marker) start_idx = 0 if marker is None else int(marker)
marker = None if len(actions) <= start_idx + max_items else str(start_idx + max_items) marker = None if len(actions) <= start_idx + \
max_items else str(start_idx + max_items)
return actions[start_idx:start_idx + max_items], marker return actions[start_idx:start_idx + max_items], marker
def list_clusters(self, cluster_states=None, created_after=None, def list_clusters(self, cluster_states=None, created_after=None,
@ -333,13 +349,16 @@ class ElasticMapReduceBackend(BaseBackend):
clusters = [c for c in clusters if c.state in cluster_states] clusters = [c for c in clusters if c.state in cluster_states]
if created_after: if created_after:
created_after = dtparse(created_after) created_after = dtparse(created_after)
clusters = [c for c in clusters if c.creation_datetime > created_after] clusters = [
c for c in clusters if c.creation_datetime > created_after]
if created_before: if created_before:
created_before = dtparse(created_before) created_before = dtparse(created_before)
clusters = [c for c in clusters if c.creation_datetime < created_before] clusters = [
c for c in clusters if c.creation_datetime < created_before]
clusters = sorted(clusters, key=lambda x: x.id) clusters = sorted(clusters, key=lambda x: x.id)
start_idx = 0 if marker is None else int(marker) start_idx = 0 if marker is None else int(marker)
marker = None if len(clusters) <= start_idx + max_items else str(start_idx + max_items) marker = None if len(clusters) <= start_idx + \
max_items else str(start_idx + max_items)
return clusters[start_idx:start_idx + max_items], marker return clusters[start_idx:start_idx + max_items], marker
def list_instance_groups(self, cluster_id, marker=None): def list_instance_groups(self, cluster_id, marker=None):
@ -347,7 +366,8 @@ class ElasticMapReduceBackend(BaseBackend):
groups = sorted(self.clusters[cluster_id].instance_groups, groups = sorted(self.clusters[cluster_id].instance_groups,
key=lambda x: x.id) key=lambda x: x.id)
start_idx = 0 if marker is None else int(marker) start_idx = 0 if marker is None else int(marker)
marker = None if len(groups) <= start_idx + max_items else str(start_idx + max_items) marker = None if len(groups) <= start_idx + \
max_items else str(start_idx + max_items)
return groups[start_idx:start_idx + max_items], marker return groups[start_idx:start_idx + max_items], marker
def list_steps(self, cluster_id, marker=None, step_ids=None, step_states=None): def list_steps(self, cluster_id, marker=None, step_ids=None, step_states=None):
@ -358,7 +378,8 @@ class ElasticMapReduceBackend(BaseBackend):
if step_states: if step_states:
steps = [s for s in steps if s.state in step_states] steps = [s for s in steps if s.state in step_states]
start_idx = 0 if marker is None else int(marker) start_idx = 0 if marker is None else int(marker)
marker = None if len(steps) <= start_idx + max_items else str(start_idx + max_items) marker = None if len(steps) <= start_idx + \
max_items else str(start_idx + max_items)
return steps[start_idx:start_idx + max_items], marker return steps[start_idx:start_idx + max_items], marker
def modify_instance_groups(self, instance_groups): def modify_instance_groups(self, instance_groups):

View File

@ -29,7 +29,8 @@ def generate_boto3_response(operation):
{'x-amzn-requestid': '2690d7eb-ed86-11dd-9877-6fad448a8419', {'x-amzn-requestid': '2690d7eb-ed86-11dd-9877-6fad448a8419',
'date': datetime.now(pytz.utc).strftime('%a, %d %b %Y %H:%M:%S %Z'), 'date': datetime.now(pytz.utc).strftime('%a, %d %b %Y %H:%M:%S %Z'),
'content-type': 'application/x-amz-json-1.1'}) 'content-type': 'application/x-amz-json-1.1'})
resp = xml_to_json_response(self.aws_service_spec, operation, rendered) resp = xml_to_json_response(
self.aws_service_spec, operation, rendered)
return '' if resp is None else json.dumps(resp) return '' if resp is None else json.dumps(resp)
return rendered return rendered
return f return f
@ -63,14 +64,16 @@ class ElasticMapReduceResponse(BaseResponse):
instance_groups = self._get_list_prefix('InstanceGroups.member') instance_groups = self._get_list_prefix('InstanceGroups.member')
for item in instance_groups: for item in instance_groups:
item['instance_count'] = int(item['instance_count']) item['instance_count'] = int(item['instance_count'])
instance_groups = self.backend.add_instance_groups(jobflow_id, instance_groups) instance_groups = self.backend.add_instance_groups(
jobflow_id, instance_groups)
template = self.response_template(ADD_INSTANCE_GROUPS_TEMPLATE) template = self.response_template(ADD_INSTANCE_GROUPS_TEMPLATE)
return template.render(instance_groups=instance_groups) return template.render(instance_groups=instance_groups)
@generate_boto3_response('AddJobFlowSteps') @generate_boto3_response('AddJobFlowSteps')
def add_job_flow_steps(self): def add_job_flow_steps(self):
job_flow_id = self._get_param('JobFlowId') job_flow_id = self._get_param('JobFlowId')
steps = self.backend.add_job_flow_steps(job_flow_id, steps_from_query_string(self._get_list_prefix('Steps.member'))) steps = self.backend.add_job_flow_steps(
job_flow_id, steps_from_query_string(self._get_list_prefix('Steps.member')))
template = self.response_template(ADD_JOB_FLOW_STEPS_TEMPLATE) template = self.response_template(ADD_JOB_FLOW_STEPS_TEMPLATE)
return template.render(steps=steps) return template.render(steps=steps)
@ -104,7 +107,8 @@ class ElasticMapReduceResponse(BaseResponse):
created_before = self._get_param('CreatedBefore') created_before = self._get_param('CreatedBefore')
job_flow_ids = self._get_multi_param("JobFlowIds.member") job_flow_ids = self._get_multi_param("JobFlowIds.member")
job_flow_states = self._get_multi_param('JobFlowStates.member') job_flow_states = self._get_multi_param('JobFlowStates.member')
clusters = self.backend.describe_job_flows(job_flow_ids, job_flow_states, created_after, created_before) clusters = self.backend.describe_job_flows(
job_flow_ids, job_flow_states, created_after, created_before)
template = self.response_template(DESCRIBE_JOB_FLOWS_TEMPLATE) template = self.response_template(DESCRIBE_JOB_FLOWS_TEMPLATE)
return template.render(clusters=clusters) return template.render(clusters=clusters)
@ -123,7 +127,8 @@ class ElasticMapReduceResponse(BaseResponse):
def list_bootstrap_actions(self): def list_bootstrap_actions(self):
cluster_id = self._get_param('ClusterId') cluster_id = self._get_param('ClusterId')
marker = self._get_param('Marker') marker = self._get_param('Marker')
bootstrap_actions, marker = self.backend.list_bootstrap_actions(cluster_id, marker) bootstrap_actions, marker = self.backend.list_bootstrap_actions(
cluster_id, marker)
template = self.response_template(LIST_BOOTSTRAP_ACTIONS_TEMPLATE) template = self.response_template(LIST_BOOTSTRAP_ACTIONS_TEMPLATE)
return template.render(bootstrap_actions=bootstrap_actions, marker=marker) return template.render(bootstrap_actions=bootstrap_actions, marker=marker)
@ -133,7 +138,8 @@ class ElasticMapReduceResponse(BaseResponse):
created_after = self._get_param('CreatedAfter') created_after = self._get_param('CreatedAfter')
created_before = self._get_param('CreatedBefore') created_before = self._get_param('CreatedBefore')
marker = self._get_param('Marker') marker = self._get_param('Marker')
clusters, marker = self.backend.list_clusters(cluster_states, created_after, created_before, marker) clusters, marker = self.backend.list_clusters(
cluster_states, created_after, created_before, marker)
template = self.response_template(LIST_CLUSTERS_TEMPLATE) template = self.response_template(LIST_CLUSTERS_TEMPLATE)
return template.render(clusters=clusters, marker=marker) return template.render(clusters=clusters, marker=marker)
@ -141,7 +147,8 @@ class ElasticMapReduceResponse(BaseResponse):
def list_instance_groups(self): def list_instance_groups(self):
cluster_id = self._get_param('ClusterId') cluster_id = self._get_param('ClusterId')
marker = self._get_param('Marker') marker = self._get_param('Marker')
instance_groups, marker = self.backend.list_instance_groups(cluster_id, marker=marker) instance_groups, marker = self.backend.list_instance_groups(
cluster_id, marker=marker)
template = self.response_template(LIST_INSTANCE_GROUPS_TEMPLATE) template = self.response_template(LIST_INSTANCE_GROUPS_TEMPLATE)
return template.render(instance_groups=instance_groups, marker=marker) return template.render(instance_groups=instance_groups, marker=marker)
@ -154,7 +161,8 @@ class ElasticMapReduceResponse(BaseResponse):
marker = self._get_param('Marker') marker = self._get_param('Marker')
step_ids = self._get_multi_param('StepIds.member') step_ids = self._get_multi_param('StepIds.member')
step_states = self._get_multi_param('StepStates.member') step_states = self._get_multi_param('StepStates.member')
steps, marker = self.backend.list_steps(cluster_id, marker=marker, step_ids=step_ids, step_states=step_states) steps, marker = self.backend.list_steps(
cluster_id, marker=marker, step_ids=step_ids, step_states=step_states)
template = self.response_template(LIST_STEPS_TEMPLATE) template = self.response_template(LIST_STEPS_TEMPLATE)
return template.render(steps=steps, marker=marker) return template.render(steps=steps, marker=marker)
@ -178,19 +186,27 @@ class ElasticMapReduceResponse(BaseResponse):
@generate_boto3_response('RunJobFlow') @generate_boto3_response('RunJobFlow')
def run_job_flow(self): def run_job_flow(self):
instance_attrs = dict( instance_attrs = dict(
master_instance_type=self._get_param('Instances.MasterInstanceType'), master_instance_type=self._get_param(
'Instances.MasterInstanceType'),
slave_instance_type=self._get_param('Instances.SlaveInstanceType'), slave_instance_type=self._get_param('Instances.SlaveInstanceType'),
instance_count=self._get_int_param('Instances.InstanceCount', 1), instance_count=self._get_int_param('Instances.InstanceCount', 1),
ec2_key_name=self._get_param('Instances.Ec2KeyName'), ec2_key_name=self._get_param('Instances.Ec2KeyName'),
ec2_subnet_id=self._get_param('Instances.Ec2SubnetId'), ec2_subnet_id=self._get_param('Instances.Ec2SubnetId'),
hadoop_version=self._get_param('Instances.HadoopVersion'), hadoop_version=self._get_param('Instances.HadoopVersion'),
availability_zone=self._get_param('Instances.Placement.AvailabilityZone', self.backend.region_name + 'a'), availability_zone=self._get_param(
keep_job_flow_alive_when_no_steps=self._get_bool_param('Instances.KeepJobFlowAliveWhenNoSteps', False), 'Instances.Placement.AvailabilityZone', self.backend.region_name + 'a'),
termination_protected=self._get_bool_param('Instances.TerminationProtected', False), keep_job_flow_alive_when_no_steps=self._get_bool_param(
emr_managed_master_security_group=self._get_param('Instances.EmrManagedMasterSecurityGroup'), 'Instances.KeepJobFlowAliveWhenNoSteps', False),
emr_managed_slave_security_group=self._get_param('Instances.EmrManagedSlaveSecurityGroup'), termination_protected=self._get_bool_param(
service_access_security_group=self._get_param('Instances.ServiceAccessSecurityGroup'), 'Instances.TerminationProtected', False),
additional_master_security_groups=self._get_multi_param('Instances.AdditionalMasterSecurityGroups.member.'), emr_managed_master_security_group=self._get_param(
'Instances.EmrManagedMasterSecurityGroup'),
emr_managed_slave_security_group=self._get_param(
'Instances.EmrManagedSlaveSecurityGroup'),
service_access_security_group=self._get_param(
'Instances.ServiceAccessSecurityGroup'),
additional_master_security_groups=self._get_multi_param(
'Instances.AdditionalMasterSecurityGroups.member.'),
additional_slave_security_groups=self._get_multi_param('Instances.AdditionalSlaveSecurityGroups.member.')) additional_slave_security_groups=self._get_multi_param('Instances.AdditionalSlaveSecurityGroups.member.'))
kwargs = dict( kwargs = dict(
@ -198,8 +214,10 @@ class ElasticMapReduceResponse(BaseResponse):
log_uri=self._get_param('LogUri'), log_uri=self._get_param('LogUri'),
job_flow_role=self._get_param('JobFlowRole'), job_flow_role=self._get_param('JobFlowRole'),
service_role=self._get_param('ServiceRole'), service_role=self._get_param('ServiceRole'),
steps=steps_from_query_string(self._get_list_prefix('Steps.member')), steps=steps_from_query_string(
visible_to_all_users=self._get_bool_param('VisibleToAllUsers', False), self._get_list_prefix('Steps.member')),
visible_to_all_users=self._get_bool_param(
'VisibleToAllUsers', False),
instance_attrs=instance_attrs, instance_attrs=instance_attrs,
) )
@ -225,7 +243,8 @@ class ElasticMapReduceResponse(BaseResponse):
if key.startswith('properties.'): if key.startswith('properties.'):
config.pop(key) config.pop(key)
config['properties'] = {} config['properties'] = {}
map_items = self._get_map_prefix('Configurations.member.{0}.Properties.entry'.format(idx)) map_items = self._get_map_prefix(
'Configurations.member.{0}.Properties.entry'.format(idx))
config['properties'] = map_items config['properties'] = map_items
kwargs['configurations'] = configurations kwargs['configurations'] = configurations
@ -239,7 +258,8 @@ class ElasticMapReduceResponse(BaseResponse):
'Only one AMI version and release label may be specified. ' 'Only one AMI version and release label may be specified. '
'Provided AMI: {0}, release label: {1}.').format( 'Provided AMI: {0}, release label: {1}.').format(
ami_version, release_label) ami_version, release_label)
raise EmrError(error_type="ValidationException", message=message, template='single_error') raise EmrError(error_type="ValidationException",
message=message, template='single_error')
else: else:
if ami_version: if ami_version:
kwargs['requested_ami_version'] = ami_version kwargs['requested_ami_version'] = ami_version
@ -256,7 +276,8 @@ class ElasticMapReduceResponse(BaseResponse):
self.backend.add_applications( self.backend.add_applications(
cluster.id, [{'Name': 'Hadoop', 'Version': '0.18'}]) cluster.id, [{'Name': 'Hadoop', 'Version': '0.18'}])
instance_groups = self._get_list_prefix('Instances.InstanceGroups.member') instance_groups = self._get_list_prefix(
'Instances.InstanceGroups.member')
if instance_groups: if instance_groups:
for ig in instance_groups: for ig in instance_groups:
ig['instance_count'] = int(ig['instance_count']) ig['instance_count'] = int(ig['instance_count'])
@ -274,7 +295,8 @@ class ElasticMapReduceResponse(BaseResponse):
def set_termination_protection(self): def set_termination_protection(self):
termination_protection = self._get_param('TerminationProtected') termination_protection = self._get_param('TerminationProtected')
job_ids = self._get_multi_param('JobFlowIds.member') job_ids = self._get_multi_param('JobFlowIds.member')
self.backend.set_termination_protection(job_ids, termination_protection) self.backend.set_termination_protection(
job_ids, termination_protection)
template = self.response_template(SET_TERMINATION_PROTECTION_TEMPLATE) template = self.response_template(SET_TERMINATION_PROTECTION_TEMPLATE)
return template.render() return template.render()

View File

@ -32,7 +32,8 @@ def tags_from_query_string(querystring_dict):
tag_key = querystring_dict.get("Tags.{0}.Key".format(tag_index))[0] tag_key = querystring_dict.get("Tags.{0}.Key".format(tag_index))[0]
tag_value_key = "Tags.{0}.Value".format(tag_index) tag_value_key = "Tags.{0}.Value".format(tag_index)
if tag_value_key in querystring_dict: if tag_value_key in querystring_dict:
response_values[tag_key] = querystring_dict.get(tag_value_key)[0] response_values[tag_key] = querystring_dict.get(tag_value_key)[
0]
else: else:
response_values[tag_key] = None response_values[tag_key] = None
return response_values return response_values
@ -42,7 +43,8 @@ def steps_from_query_string(querystring_dict):
steps = [] steps = []
for step in querystring_dict: for step in querystring_dict:
step['jar'] = step.pop('hadoop_jar_step._jar') step['jar'] = step.pop('hadoop_jar_step._jar')
step['properties'] = dict((o['Key'], o['Value']) for o in step.get('properties', [])) step['properties'] = dict((o['Key'], o['Value'])
for o in step.get('properties', []))
step['args'] = [] step['args'] = []
idx = 1 idx = 1
keyfmt = 'hadoop_jar_step._args.member.{0}' keyfmt = 'hadoop_jar_step._args.member.{0}'

View File

@ -53,7 +53,8 @@ class EventsBackend(BaseBackend):
def __init__(self): def __init__(self):
self.rules = {} self.rules = {}
# This array tracks the order in which the rules have been added, since 2.6 doesn't have OrderedDicts. # This array tracks the order in which the rules have been added, since
# 2.6 doesn't have OrderedDicts.
self.rules_order = [] self.rules_order = []
self.next_tokens = {} self.next_tokens = {}
@ -106,7 +107,8 @@ class EventsBackend(BaseBackend):
matching_rules = [] matching_rules = []
return_obj = {} return_obj = {}
start_index, end_index, new_next_token = self._process_token_and_limits(len(self.rules), next_token, limit) start_index, end_index, new_next_token = self._process_token_and_limits(
len(self.rules), next_token, limit)
for i in range(start_index, end_index): for i in range(start_index, end_index):
rule = self._get_rule_by_index(i) rule = self._get_rule_by_index(i)
@ -130,7 +132,8 @@ class EventsBackend(BaseBackend):
matching_rules = [] matching_rules = []
return_obj = {} return_obj = {}
start_index, end_index, new_next_token = self._process_token_and_limits(len(self.rules), next_token, limit) start_index, end_index, new_next_token = self._process_token_and_limits(
len(self.rules), next_token, limit)
for i in range(start_index, end_index): for i in range(start_index, end_index):
rule = self._get_rule_by_index(i) rule = self._get_rule_by_index(i)
@ -144,10 +147,12 @@ class EventsBackend(BaseBackend):
return return_obj return return_obj
def list_targets_by_rule(self, rule, next_token=None, limit=None): def list_targets_by_rule(self, rule, next_token=None, limit=None):
# We'll let a KeyError exception be thrown for response to handle if rule doesn't exist. # We'll let a KeyError exception be thrown for response to handle if
# rule doesn't exist.
rule = self.rules[rule] rule = self.rules[rule]
start_index, end_index, new_next_token = self._process_token_and_limits(len(rule.targets), next_token, limit) start_index, end_index, new_next_token = self._process_token_and_limits(
len(rule.targets), next_token, limit)
returned_targets = [] returned_targets = []
return_obj = {} return_obj = {}
@ -188,4 +193,5 @@ class EventsBackend(BaseBackend):
def test_event_pattern(self): def test_event_pattern(self):
raise NotImplementedError() raise NotImplementedError()
events_backend = EventsBackend() events_backend = EventsBackend()

View File

@ -87,7 +87,8 @@ class EventsHandler(BaseResponse):
if not target_arn: if not target_arn:
return self.error('ValidationException', 'Parameter TargetArn is required.') return self.error('ValidationException', 'Parameter TargetArn is required.')
rule_names = events_backend.list_rule_names_by_target(target_arn, next_token, limit) rule_names = events_backend.list_rule_names_by_target(
target_arn, next_token, limit)
return json.dumps(rule_names), self.response_headers return json.dumps(rule_names), self.response_headers
@ -118,7 +119,8 @@ class EventsHandler(BaseResponse):
return self.error('ValidationException', 'Parameter Rule is required.') return self.error('ValidationException', 'Parameter Rule is required.')
try: try:
targets = events_backend.list_targets_by_rule(rule_name, next_token, limit) targets = events_backend.list_targets_by_rule(
rule_name, next_token, limit)
except KeyError: except KeyError:
return self.error('ResourceNotFoundException', 'Rule ' + rule_name + ' does not exist.') return self.error('ResourceNotFoundException', 'Rule ' + rule_name + ' does not exist.')
@ -140,7 +142,8 @@ class EventsHandler(BaseResponse):
try: try:
json.loads(event_pattern) json.loads(event_pattern)
except ValueError: except ValueError:
# Not quite as informative as the real error, but it'll work for now. # Not quite as informative as the real error, but it'll work
# for now.
return self.error('InvalidEventPatternException', 'Event pattern is not valid.') return self.error('InvalidEventPatternException', 'Event pattern is not valid.')
if sched_exp: if sched_exp:

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import glacier_backends from .models import glacier_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
glacier_backend = glacier_backends['us-east-1'] glacier_backend = glacier_backends['us-east-1']
mock_glacier = base_decorator(glacier_backends) mock_glacier = base_decorator(glacier_backends)

View File

@ -36,6 +36,7 @@ class ArchiveJob(object):
class Vault(object): class Vault(object):
def __init__(self, vault_name, region): def __init__(self, vault_name, region):
self.vault_name = vault_name self.vault_name = vault_name
self.region = region self.region = region

View File

@ -128,7 +128,8 @@ class GlacierResponse(_TemplateEnvironmentMixin):
archive_id = json_body['ArchiveId'] archive_id = json_body['ArchiveId']
job_id = self.backend.initiate_job(vault_name, archive_id) job_id = self.backend.initiate_job(vault_name, archive_id)
headers['x-amz-job-id'] = job_id headers['x-amz-job-id'] = job_id
headers['Location'] = "/{0}/vaults/{1}/jobs/{2}".format(account_id, vault_name, job_id) headers[
'Location'] = "/{0}/vaults/{1}/jobs/{2}".format(account_id, vault_name, job_id)
return 202, headers, "" return 202, headers, ""
@classmethod @classmethod

View File

@ -3,4 +3,4 @@ from .models import iam_backend
iam_backends = {"global": iam_backend} iam_backends = {"global": iam_backend}
mock_iam = iam_backend.decorator mock_iam = iam_backend.decorator
mock_iam_deprecated = iam_backend.deprecated_decorator mock_iam_deprecated = iam_backend.deprecated_decorator

View File

@ -97,6 +97,7 @@ class Role(object):
class InstanceProfile(object): class InstanceProfile(object):
def __init__(self, instance_profile_id, name, path, roles): def __init__(self, instance_profile_id, name, path, roles):
self.id = instance_profile_id self.id = instance_profile_id
self.name = name self.name = name
@ -126,6 +127,7 @@ class InstanceProfile(object):
class Certificate(object): class Certificate(object):
def __init__(self, cert_name, cert_body, private_key, cert_chain=None, path=None): def __init__(self, cert_name, cert_body, private_key, cert_chain=None, path=None):
self.cert_name = cert_name self.cert_name = cert_name
self.cert_body = cert_body self.cert_body = cert_body
@ -139,6 +141,7 @@ class Certificate(object):
class AccessKey(object): class AccessKey(object):
def __init__(self, user_name): def __init__(self, user_name):
self.user_name = user_name self.user_name = user_name
self.access_key_id = random_access_key() self.access_key_id = random_access_key()
@ -157,6 +160,7 @@ class AccessKey(object):
class Group(object): class Group(object):
def __init__(self, name, path='/'): def __init__(self, name, path='/'):
self.name = name self.name = name
self.id = random_resource_id() self.id = random_resource_id()
@ -176,6 +180,7 @@ class Group(object):
class User(object): class User(object):
def __init__(self, name, path=None): def __init__(self, name, path=None):
self.name = name self.name = name
self.id = random_resource_id() self.id = random_resource_id()
@ -184,7 +189,8 @@ class User(object):
datetime.utcnow(), datetime.utcnow(),
"%Y-%m-%d-%H-%M-%S" "%Y-%m-%d-%H-%M-%S"
) )
self.arn = 'arn:aws:iam::123456789012:user{0}{1}'.format(self.path, name) self.arn = 'arn:aws:iam::123456789012:user{0}{1}'.format(
self.path, name)
self.policies = {} self.policies = {}
self.access_keys = [] self.access_keys = []
self.password = None self.password = None
@ -194,7 +200,8 @@ class User(object):
try: try:
policy_json = self.policies[policy_name] policy_json = self.policies[policy_name]
except KeyError: except KeyError:
raise IAMNotFoundException("Policy {0} not found".format(policy_name)) raise IAMNotFoundException(
"Policy {0} not found".format(policy_name))
return { return {
'policy_name': policy_name, 'policy_name': policy_name,
@ -207,7 +214,8 @@ class User(object):
def delete_policy(self, policy_name): def delete_policy(self, policy_name):
if policy_name not in self.policies: if policy_name not in self.policies:
raise IAMNotFoundException("Policy {0} not found".format(policy_name)) raise IAMNotFoundException(
"Policy {0} not found".format(policy_name))
del self.policies[policy_name] del self.policies[policy_name]
@ -225,7 +233,8 @@ class User(object):
self.access_keys.remove(key) self.access_keys.remove(key)
break break
else: else:
raise IAMNotFoundException("Key {0} not found".format(access_key_id)) raise IAMNotFoundException(
"Key {0} not found".format(access_key_id))
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
@ -261,16 +270,18 @@ class User(object):
access_key_2_last_rotated = date_created.strftime(date_format) access_key_2_last_rotated = date_created.strftime(date_format)
return '{0},{1},{2},{3},{4},{5},not_supported,false,{6},{7},{8},{9},false,N/A,false,N/A'.format(self.name, return '{0},{1},{2},{3},{4},{5},not_supported,false,{6},{7},{8},{9},false,N/A,false,N/A'.format(self.name,
self.arn, self.arn,
date_created.strftime(date_format), date_created.strftime(
password_enabled, date_format),
password_last_used, password_enabled,
date_created.strftime(date_format), password_last_used,
access_key_1_active, date_created.strftime(
access_key_1_last_rotated, date_format),
access_key_2_active, access_key_1_active,
access_key_2_last_rotated access_key_1_last_rotated,
) access_key_2_active,
access_key_2_last_rotated
)
# predefine AWS managed policies # predefine AWS managed policies
@ -439,7 +450,8 @@ class IAMBackend(BaseBackend):
if scope == 'AWS': if scope == 'AWS':
policies = [p for p in policies if isinstance(p, AWSManagedPolicy)] policies = [p for p in policies if isinstance(p, AWSManagedPolicy)]
elif scope == 'Local': elif scope == 'Local':
policies = [p for p in policies if not isinstance(p, AWSManagedPolicy)] policies = [p for p in policies if not isinstance(
p, AWSManagedPolicy)]
if path_prefix: if path_prefix:
policies = [p for p in policies if p.path.startswith(path_prefix)] policies = [p for p in policies if p.path.startswith(path_prefix)]
@ -492,7 +504,8 @@ class IAMBackend(BaseBackend):
instance_profile_id = random_resource_id() instance_profile_id = random_resource_id()
roles = [iam_backend.get_role_by_id(role_id) for role_id in role_ids] roles = [iam_backend.get_role_by_id(role_id) for role_id in role_ids]
instance_profile = InstanceProfile(instance_profile_id, name, path, roles) instance_profile = InstanceProfile(
instance_profile_id, name, path, roles)
self.instance_profiles[instance_profile_id] = instance_profile self.instance_profiles[instance_profile_id] = instance_profile
return instance_profile return instance_profile
@ -501,7 +514,8 @@ class IAMBackend(BaseBackend):
if profile.name == profile_name: if profile.name == profile_name:
return profile return profile
raise IAMNotFoundException("Instance profile {0} not found".format(profile_name)) raise IAMNotFoundException(
"Instance profile {0} not found".format(profile_name))
def get_instance_profiles(self): def get_instance_profiles(self):
return self.instance_profiles.values() return self.instance_profiles.values()
@ -546,7 +560,8 @@ class IAMBackend(BaseBackend):
def create_group(self, group_name, path='/'): def create_group(self, group_name, path='/'):
if group_name in self.groups: if group_name in self.groups:
raise IAMConflictException("Group {0} already exists".format(group_name)) raise IAMConflictException(
"Group {0} already exists".format(group_name))
group = Group(group_name, path) group = Group(group_name, path)
self.groups[group_name] = group self.groups[group_name] = group
@ -557,7 +572,8 @@ class IAMBackend(BaseBackend):
try: try:
group = self.groups[group_name] group = self.groups[group_name]
except KeyError: except KeyError:
raise IAMNotFoundException("Group {0} not found".format(group_name)) raise IAMNotFoundException(
"Group {0} not found".format(group_name))
return group return group
@ -575,7 +591,8 @@ class IAMBackend(BaseBackend):
def create_user(self, user_name, path='/'): def create_user(self, user_name, path='/'):
if user_name in self.users: if user_name in self.users:
raise IAMConflictException("EntityAlreadyExists", "User {0} already exists".format(user_name)) raise IAMConflictException(
"EntityAlreadyExists", "User {0} already exists".format(user_name))
user = User(user_name, path) user = User(user_name, path)
self.users[user_name] = user self.users[user_name] = user
@ -595,7 +612,8 @@ class IAMBackend(BaseBackend):
try: try:
users = self.users.values() users = self.users.values()
except KeyError: except KeyError:
raise IAMNotFoundException("Users {0}, {1}, {2} not found".format(path_prefix, marker, max_items)) raise IAMNotFoundException(
"Users {0}, {1}, {2} not found".format(path_prefix, marker, max_items))
return users return users
@ -603,13 +621,15 @@ class IAMBackend(BaseBackend):
# This does not currently deal with PasswordPolicyViolation. # This does not currently deal with PasswordPolicyViolation.
user = self.get_user(user_name) user = self.get_user(user_name)
if user.password: if user.password:
raise IAMConflictException("User {0} already has password".format(user_name)) raise IAMConflictException(
"User {0} already has password".format(user_name))
user.password = password user.password = password
def delete_login_profile(self, user_name): def delete_login_profile(self, user_name):
user = self.get_user(user_name) user = self.get_user(user_name)
if not user.password: if not user.password:
raise IAMNotFoundException("Login profile for {0} not found".format(user_name)) raise IAMNotFoundException(
"Login profile for {0} not found".format(user_name))
user.password = None user.password = None
def add_user_to_group(self, group_name, user_name): def add_user_to_group(self, group_name, user_name):
@ -623,7 +643,8 @@ class IAMBackend(BaseBackend):
try: try:
group.users.remove(user) group.users.remove(user)
except ValueError: except ValueError:
raise IAMNotFoundException("User {0} not in group {1}".format(user_name, group_name)) raise IAMNotFoundException(
"User {0} not in group {1}".format(user_name, group_name))
def get_user_policy(self, user_name, policy_name): def get_user_policy(self, user_name, policy_name):
user = self.get_user(user_name) user = self.get_user(user_name)
@ -672,4 +693,5 @@ class IAMBackend(BaseBackend):
report += self.users[user].to_csv() report += self.users[user].to_csv()
return base64.b64encode(report.encode('ascii')).decode('ascii') return base64.b64encode(report.encode('ascii')).decode('ascii')
iam_backend = IAMBackend() iam_backend = IAMBackend()

View File

@ -18,7 +18,8 @@ class IamResponse(BaseResponse):
path = self._get_param('Path') path = self._get_param('Path')
policy_document = self._get_param('PolicyDocument') policy_document = self._get_param('PolicyDocument')
policy_name = self._get_param('PolicyName') policy_name = self._get_param('PolicyName')
policy = iam_backend.create_policy(description, path, policy_document, policy_name) policy = iam_backend.create_policy(
description, path, policy_document, policy_name)
template = self.response_template(CREATE_POLICY_TEMPLATE) template = self.response_template(CREATE_POLICY_TEMPLATE)
return template.render(policy=policy) return template.render(policy=policy)
@ -27,7 +28,8 @@ class IamResponse(BaseResponse):
max_items = self._get_int_param('MaxItems', 100) max_items = self._get_int_param('MaxItems', 100)
path_prefix = self._get_param('PathPrefix', '/') path_prefix = self._get_param('PathPrefix', '/')
role_name = self._get_param('RoleName') role_name = self._get_param('RoleName')
policies, marker = iam_backend.list_attached_role_policies(role_name, marker=marker, max_items=max_items, path_prefix=path_prefix) policies, marker = iam_backend.list_attached_role_policies(
role_name, marker=marker, max_items=max_items, path_prefix=path_prefix)
template = self.response_template(LIST_ATTACHED_ROLE_POLICIES_TEMPLATE) template = self.response_template(LIST_ATTACHED_ROLE_POLICIES_TEMPLATE)
return template.render(policies=policies, marker=marker) return template.render(policies=policies, marker=marker)
@ -37,16 +39,19 @@ class IamResponse(BaseResponse):
only_attached = self._get_bool_param('OnlyAttached', False) only_attached = self._get_bool_param('OnlyAttached', False)
path_prefix = self._get_param('PathPrefix', '/') path_prefix = self._get_param('PathPrefix', '/')
scope = self._get_param('Scope', 'All') scope = self._get_param('Scope', 'All')
policies, marker = iam_backend.list_policies(marker, max_items, only_attached, path_prefix, scope) policies, marker = iam_backend.list_policies(
marker, max_items, only_attached, path_prefix, scope)
template = self.response_template(LIST_POLICIES_TEMPLATE) template = self.response_template(LIST_POLICIES_TEMPLATE)
return template.render(policies=policies, marker=marker) return template.render(policies=policies, marker=marker)
def create_role(self): def create_role(self):
role_name = self._get_param('RoleName') role_name = self._get_param('RoleName')
path = self._get_param('Path') path = self._get_param('Path')
assume_role_policy_document = self._get_param('AssumeRolePolicyDocument') assume_role_policy_document = self._get_param(
'AssumeRolePolicyDocument')
role = iam_backend.create_role(role_name, assume_role_policy_document, path) role = iam_backend.create_role(
role_name, assume_role_policy_document, path)
template = self.response_template(CREATE_ROLE_TEMPLATE) template = self.response_template(CREATE_ROLE_TEMPLATE)
return template.render(role=role) return template.render(role=role)
@ -74,7 +79,8 @@ class IamResponse(BaseResponse):
def get_role_policy(self): def get_role_policy(self):
role_name = self._get_param('RoleName') role_name = self._get_param('RoleName')
policy_name = self._get_param('PolicyName') policy_name = self._get_param('PolicyName')
policy_name, policy_document = iam_backend.get_role_policy(role_name, policy_name) policy_name, policy_document = iam_backend.get_role_policy(
role_name, policy_name)
template = self.response_template(GET_ROLE_POLICY_TEMPLATE) template = self.response_template(GET_ROLE_POLICY_TEMPLATE)
return template.render(role_name=role_name, return template.render(role_name=role_name,
policy_name=policy_name, policy_name=policy_name,
@ -91,7 +97,8 @@ class IamResponse(BaseResponse):
profile_name = self._get_param('InstanceProfileName') profile_name = self._get_param('InstanceProfileName')
path = self._get_param('Path') path = self._get_param('Path')
profile = iam_backend.create_instance_profile(profile_name, path, role_ids=[]) profile = iam_backend.create_instance_profile(
profile_name, path, role_ids=[])
template = self.response_template(CREATE_INSTANCE_PROFILE_TEMPLATE) template = self.response_template(CREATE_INSTANCE_PROFILE_TEMPLATE)
return template.render(profile=profile) return template.render(profile=profile)
@ -107,7 +114,8 @@ class IamResponse(BaseResponse):
role_name = self._get_param('RoleName') role_name = self._get_param('RoleName')
iam_backend.add_role_to_instance_profile(profile_name, role_name) iam_backend.add_role_to_instance_profile(profile_name, role_name)
template = self.response_template(ADD_ROLE_TO_INSTANCE_PROFILE_TEMPLATE) template = self.response_template(
ADD_ROLE_TO_INSTANCE_PROFILE_TEMPLATE)
return template.render() return template.render()
def remove_role_from_instance_profile(self): def remove_role_from_instance_profile(self):
@ -115,7 +123,8 @@ class IamResponse(BaseResponse):
role_name = self._get_param('RoleName') role_name = self._get_param('RoleName')
iam_backend.remove_role_from_instance_profile(profile_name, role_name) iam_backend.remove_role_from_instance_profile(profile_name, role_name)
template = self.response_template(REMOVE_ROLE_FROM_INSTANCE_PROFILE_TEMPLATE) template = self.response_template(
REMOVE_ROLE_FROM_INSTANCE_PROFILE_TEMPLATE)
return template.render() return template.render()
def list_roles(self): def list_roles(self):
@ -132,9 +141,11 @@ class IamResponse(BaseResponse):
def list_instance_profiles_for_role(self): def list_instance_profiles_for_role(self):
role_name = self._get_param('RoleName') role_name = self._get_param('RoleName')
profiles = iam_backend.get_instance_profiles_for_role(role_name=role_name) profiles = iam_backend.get_instance_profiles_for_role(
role_name=role_name)
template = self.response_template(LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE) template = self.response_template(
LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE)
return template.render(instance_profiles=profiles) return template.render(instance_profiles=profiles)
def upload_server_certificate(self): def upload_server_certificate(self):
@ -144,7 +155,8 @@ class IamResponse(BaseResponse):
private_key = self._get_param('PrivateKey') private_key = self._get_param('PrivateKey')
cert_chain = self._get_param('CertificateName') cert_chain = self._get_param('CertificateName')
cert = iam_backend.upload_server_cert(cert_name, cert_body, private_key, cert_chain=cert_chain, path=path) cert = iam_backend.upload_server_cert(
cert_name, cert_body, private_key, cert_chain=cert_chain, path=path)
template = self.response_template(UPLOAD_CERT_TEMPLATE) template = self.response_template(UPLOAD_CERT_TEMPLATE)
return template.render(certificate=cert) return template.render(certificate=cert)

View File

@ -1,4 +1,4 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import instance_metadata_backend from .models import instance_metadata_backend
instance_metadata_backends = {"global": instance_metadata_backend} instance_metadata_backends = {"global": instance_metadata_backend}

View File

@ -4,4 +4,5 @@ from moto.core.models import BaseBackend
class InstanceMetadataBackend(BaseBackend): class InstanceMetadataBackend(BaseBackend):
pass pass
instance_metadata_backend = InstanceMetadataBackend() instance_metadata_backend = InstanceMetadataBackend()

View File

@ -7,6 +7,7 @@ from moto.core.responses import BaseResponse
class InstanceMetadataResponse(BaseResponse): class InstanceMetadataResponse(BaseResponse):
def metadata_response(self, request, full_url, headers): def metadata_response(self, request, full_url, headers):
""" """
Mock response for localhost metadata Mock response for localhost metadata
@ -43,5 +44,6 @@ class InstanceMetadataResponse(BaseResponse):
elif path == 'iam/security-credentials/default-role': elif path == 'iam/security-credentials/default-role':
result = json.dumps(credentials) result = json.dumps(credentials)
else: else:
raise NotImplementedError("The {0} metadata path has not been implemented".format(path)) raise NotImplementedError(
"The {0} metadata path has not been implemented".format(path))
return 200, headers, result return 200, headers, result

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import kinesis_backends from .models import kinesis_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
kinesis_backend = kinesis_backends['us-east-1'] kinesis_backend = kinesis_backends['us-east-1']
mock_kinesis = base_decorator(kinesis_backends) mock_kinesis = base_decorator(kinesis_backends)

View File

@ -5,6 +5,7 @@ from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest): class ResourceNotFoundError(BadRequest):
def __init__(self, message): def __init__(self, message):
super(ResourceNotFoundError, self).__init__() super(ResourceNotFoundError, self).__init__()
self.description = json.dumps({ self.description = json.dumps({
@ -14,6 +15,7 @@ class ResourceNotFoundError(BadRequest):
class ResourceInUseError(BadRequest): class ResourceInUseError(BadRequest):
def __init__(self, message): def __init__(self, message):
super(ResourceNotFoundError, self).__init__() super(ResourceNotFoundError, self).__init__()
self.description = json.dumps({ self.description = json.dumps({
@ -23,18 +25,21 @@ class ResourceInUseError(BadRequest):
class StreamNotFoundError(ResourceNotFoundError): class StreamNotFoundError(ResourceNotFoundError):
def __init__(self, stream_name): def __init__(self, stream_name):
super(StreamNotFoundError, self).__init__( super(StreamNotFoundError, self).__init__(
'Stream {0} under account 123456789012 not found.'.format(stream_name)) 'Stream {0} under account 123456789012 not found.'.format(stream_name))
class ShardNotFoundError(ResourceNotFoundError): class ShardNotFoundError(ResourceNotFoundError):
def __init__(self, shard_id): def __init__(self, shard_id):
super(ShardNotFoundError, self).__init__( super(ShardNotFoundError, self).__init__(
'Shard {0} under account 123456789012 not found.'.format(shard_id)) 'Shard {0} under account 123456789012 not found.'.format(shard_id))
class InvalidArgumentError(BadRequest): class InvalidArgumentError(BadRequest):
def __init__(self, message): def __init__(self, message):
super(InvalidArgumentError, self).__init__() super(InvalidArgumentError, self).__init__()
self.description = json.dumps({ self.description = json.dumps({

View File

@ -18,6 +18,7 @@ from .utils import compose_shard_iterator, compose_new_shard_iterator, decompose
class Record(object): class Record(object):
def __init__(self, partition_key, data, sequence_number, explicit_hash_key): def __init__(self, partition_key, data, sequence_number, explicit_hash_key):
self.partition_key = partition_key self.partition_key = partition_key
self.data = data self.data = data
@ -33,6 +34,7 @@ class Record(object):
class Shard(object): class Shard(object):
def __init__(self, shard_id, starting_hash, ending_hash): def __init__(self, shard_id, starting_hash, ending_hash):
self._shard_id = shard_id self._shard_id = shard_id
self.starting_hash = starting_hash self.starting_hash = starting_hash
@ -64,7 +66,8 @@ class Shard(object):
else: else:
last_sequence_number = 0 last_sequence_number = 0
sequence_number = last_sequence_number + 1 sequence_number = last_sequence_number + 1
self.records[sequence_number] = Record(partition_key, data, sequence_number, explicit_hash_key) self.records[sequence_number] = Record(
partition_key, data, sequence_number, explicit_hash_key)
return sequence_number return sequence_number
def get_min_sequence_number(self): def get_min_sequence_number(self):
@ -107,8 +110,10 @@ class Stream(object):
izip_longest = itertools.izip_longest izip_longest = itertools.izip_longest
for index, start, end in izip_longest(range(shard_count), for index, start, end in izip_longest(range(shard_count),
range(0,2**128,2**128//shard_count), range(0, 2**128, 2 **
range(2**128//shard_count,2**128,2**128//shard_count), 128 // shard_count),
range(2**128 // shard_count, 2 **
128, 2**128 // shard_count),
fillvalue=2**128): fillvalue=2**128):
shard = Shard(index, start, end) shard = Shard(index, start, end)
self.shards[shard.shard_id] = shard self.shards[shard.shard_id] = shard
@ -152,7 +157,8 @@ class Stream(object):
def put_record(self, partition_key, explicit_hash_key, sequence_number_for_ordering, data): def put_record(self, partition_key, explicit_hash_key, sequence_number_for_ordering, data):
shard = self.get_shard_for_key(partition_key, explicit_hash_key) shard = self.get_shard_for_key(partition_key, explicit_hash_key)
sequence_number = shard.put_record(partition_key, data, explicit_hash_key) sequence_number = shard.put_record(
partition_key, data, explicit_hash_key)
return sequence_number, shard.shard_id return sequence_number, shard.shard_id
def to_json(self): def to_json(self):
@ -168,12 +174,14 @@ class Stream(object):
class FirehoseRecord(object): class FirehoseRecord(object):
def __init__(self, record_data): def __init__(self, record_data):
self.record_id = 12345678 self.record_id = 12345678
self.record_data = record_data self.record_data = record_data
class DeliveryStream(object): class DeliveryStream(object):
def __init__(self, stream_name, **stream_kwargs): def __init__(self, stream_name, **stream_kwargs):
self.name = stream_name self.name = stream_name
self.redshift_username = stream_kwargs.get('redshift_username') self.redshift_username = stream_kwargs.get('redshift_username')
@ -185,14 +193,18 @@ class DeliveryStream(object):
self.s3_role_arn = stream_kwargs.get('s3_role_arn') self.s3_role_arn = stream_kwargs.get('s3_role_arn')
self.s3_bucket_arn = stream_kwargs.get('s3_bucket_arn') self.s3_bucket_arn = stream_kwargs.get('s3_bucket_arn')
self.s3_prefix = stream_kwargs.get('s3_prefix') self.s3_prefix = stream_kwargs.get('s3_prefix')
self.s3_compression_format = stream_kwargs.get('s3_compression_format', 'UNCOMPRESSED') self.s3_compression_format = stream_kwargs.get(
's3_compression_format', 'UNCOMPRESSED')
self.s3_buffering_hings = stream_kwargs.get('s3_buffering_hings') self.s3_buffering_hings = stream_kwargs.get('s3_buffering_hings')
self.redshift_s3_role_arn = stream_kwargs.get('redshift_s3_role_arn') self.redshift_s3_role_arn = stream_kwargs.get('redshift_s3_role_arn')
self.redshift_s3_bucket_arn = stream_kwargs.get('redshift_s3_bucket_arn') self.redshift_s3_bucket_arn = stream_kwargs.get(
'redshift_s3_bucket_arn')
self.redshift_s3_prefix = stream_kwargs.get('redshift_s3_prefix') self.redshift_s3_prefix = stream_kwargs.get('redshift_s3_prefix')
self.redshift_s3_compression_format = stream_kwargs.get('redshift_s3_compression_format', 'UNCOMPRESSED') self.redshift_s3_compression_format = stream_kwargs.get(
self.redshift_s3_buffering_hings = stream_kwargs.get('redshift_s3_buffering_hings') 'redshift_s3_compression_format', 'UNCOMPRESSED')
self.redshift_s3_buffering_hings = stream_kwargs.get(
'redshift_s3_buffering_hings')
self.records = [] self.records = []
self.status = 'ACTIVE' self.status = 'ACTIVE'
@ -231,9 +243,8 @@ class DeliveryStream(object):
}, },
"Username": self.redshift_username, "Username": self.redshift_username,
}, },
} }
] ]
def to_dict(self): def to_dict(self):
return { return {
@ -261,10 +272,9 @@ class KinesisBackend(BaseBackend):
self.streams = {} self.streams = {}
self.delivery_streams = {} self.delivery_streams = {}
def create_stream(self, stream_name, shard_count, region): def create_stream(self, stream_name, shard_count, region):
if stream_name in self.streams: if stream_name in self.streams:
raise ResourceInUseError(stream_name) raise ResourceInUseError(stream_name)
stream = Stream(stream_name, shard_count, region) stream = Stream(stream_name, shard_count, region)
self.streams[stream_name] = stream self.streams[stream_name] = stream
return stream return stream
@ -302,7 +312,8 @@ class KinesisBackend(BaseBackend):
records, last_sequence_id = shard.get_records(last_sequence_id, limit) records, last_sequence_id = shard.get_records(last_sequence_id, limit)
next_shard_iterator = compose_shard_iterator(stream_name, shard, last_sequence_id) next_shard_iterator = compose_shard_iterator(
stream_name, shard, last_sequence_id)
return next_shard_iterator, records return next_shard_iterator, records
@ -320,7 +331,7 @@ class KinesisBackend(BaseBackend):
response = { response = {
"FailedRecordCount": 0, "FailedRecordCount": 0,
"Records" : [] "Records": []
} }
for record in records: for record in records:
@ -342,7 +353,7 @@ class KinesisBackend(BaseBackend):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_name)
if shard_to_split not in stream.shards: if shard_to_split not in stream.shards:
raise ResourceNotFoundError(shard_to_split) raise ResourceNotFoundError(shard_to_split)
if not re.match(r'0|([1-9]\d{0,38})', new_starting_hash_key): if not re.match(r'0|([1-9]\d{0,38})', new_starting_hash_key):
raise InvalidArgumentError(new_starting_hash_key) raise InvalidArgumentError(new_starting_hash_key)
@ -350,10 +361,12 @@ class KinesisBackend(BaseBackend):
shard = stream.shards[shard_to_split] shard = stream.shards[shard_to_split]
last_id = sorted(stream.shards.values(), key=attrgetter('_shard_id'))[-1]._shard_id last_id = sorted(stream.shards.values(),
key=attrgetter('_shard_id'))[-1]._shard_id
if shard.starting_hash < new_starting_hash_key < shard.ending_hash: if shard.starting_hash < new_starting_hash_key < shard.ending_hash:
new_shard = Shard(last_id+1, new_starting_hash_key, shard.ending_hash) new_shard = Shard(
last_id + 1, new_starting_hash_key, shard.ending_hash)
shard.ending_hash = new_starting_hash_key shard.ending_hash = new_starting_hash_key
stream.shards[new_shard.shard_id] = new_shard stream.shards[new_shard.shard_id] = new_shard
else: else:
@ -372,10 +385,10 @@ class KinesisBackend(BaseBackend):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_name)
if shard_to_merge not in stream.shards: if shard_to_merge not in stream.shards:
raise ResourceNotFoundError(shard_to_merge) raise ResourceNotFoundError(shard_to_merge)
if adjacent_shard_to_merge not in stream.shards: if adjacent_shard_to_merge not in stream.shards:
raise ResourceNotFoundError(adjacent_shard_to_merge) raise ResourceNotFoundError(adjacent_shard_to_merge)
shard1 = stream.shards[shard_to_merge] shard1 = stream.shards[shard_to_merge]
shard2 = stream.shards[adjacent_shard_to_merge] shard2 = stream.shards[adjacent_shard_to_merge]
@ -390,9 +403,11 @@ class KinesisBackend(BaseBackend):
del stream.shards[shard2.shard_id] del stream.shards[shard2.shard_id]
for index in shard2.records: for index in shard2.records:
record = shard2.records[index] record = shard2.records[index]
shard1.put_record(record.partition_key, record.data, record.explicit_hash_key) shard1.put_record(record.partition_key,
record.data, record.explicit_hash_key)
''' Firehose ''' ''' Firehose '''
def create_delivery_stream(self, stream_name, **stream_kwargs): def create_delivery_stream(self, stream_name, **stream_kwargs):
stream = DeliveryStream(stream_name, **stream_kwargs) stream = DeliveryStream(stream_name, **stream_kwargs)
self.delivery_streams[stream_name] = stream self.delivery_streams[stream_name] = stream
@ -416,19 +431,19 @@ class KinesisBackend(BaseBackend):
return record return record
def list_tags_for_stream(self, stream_name, exclusive_start_tag_key=None, limit=None): def list_tags_for_stream(self, stream_name, exclusive_start_tag_key=None, limit=None):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_name)
tags = [] tags = []
result = { result = {
'HasMoreTags': False, 'HasMoreTags': False,
'Tags': tags 'Tags': tags
} }
for key, val in sorted(stream.tags.items(), key=lambda x:x[0]): for key, val in sorted(stream.tags.items(), key=lambda x: x[0]):
if limit and len(res) >= limit: if limit and len(tags) >= limit:
result['HasMoreTags'] = True result['HasMoreTags'] = True
break break
if exclusive_start_tag_key and key < exexclusive_start_tag_key: if exclusive_start_tag_key and key < exclusive_start_tag_key:
continue continue
tags.append({ tags.append({
'Key': key, 'Key': key,
@ -438,14 +453,14 @@ class KinesisBackend(BaseBackend):
return result return result
def add_tags_to_stream(self, stream_name, tags): def add_tags_to_stream(self, stream_name, tags):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_name)
stream.tags.update(tags) stream.tags.update(tags)
def remove_tags_from_stream(self, stream_name, tag_keys): def remove_tags_from_stream(self, stream_name, tag_keys):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_name)
for key in tag_keys: for key in tag_keys:
if key in stream.tags: if key in stream.tags:
del stream.tags[key] del stream.tags[key]
kinesis_backends = {} kinesis_backends = {}

View File

@ -4,7 +4,6 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import kinesis_backends from .models import kinesis_backends
from werkzeug.exceptions import BadRequest
class KinesisResponse(BaseResponse): class KinesisResponse(BaseResponse):
@ -25,7 +24,8 @@ class KinesisResponse(BaseResponse):
def create_stream(self): def create_stream(self):
stream_name = self.parameters.get('StreamName') stream_name = self.parameters.get('StreamName')
shard_count = self.parameters.get('ShardCount') shard_count = self.parameters.get('ShardCount')
self.kinesis_backend.create_stream(stream_name, shard_count, self.region) self.kinesis_backend.create_stream(
stream_name, shard_count, self.region)
return "" return ""
def describe_stream(self): def describe_stream(self):
@ -50,7 +50,8 @@ class KinesisResponse(BaseResponse):
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
shard_id = self.parameters.get("ShardId") shard_id = self.parameters.get("ShardId")
shard_iterator_type = self.parameters.get("ShardIteratorType") shard_iterator_type = self.parameters.get("ShardIteratorType")
starting_sequence_number = self.parameters.get("StartingSequenceNumber") starting_sequence_number = self.parameters.get(
"StartingSequenceNumber")
shard_iterator = self.kinesis_backend.get_shard_iterator( shard_iterator = self.kinesis_backend.get_shard_iterator(
stream_name, shard_id, shard_iterator_type, starting_sequence_number, stream_name, shard_id, shard_iterator_type, starting_sequence_number,
@ -64,7 +65,8 @@ class KinesisResponse(BaseResponse):
shard_iterator = self.parameters.get("ShardIterator") shard_iterator = self.parameters.get("ShardIterator")
limit = self.parameters.get("Limit") limit = self.parameters.get("Limit")
next_shard_iterator, records = self.kinesis_backend.get_records(shard_iterator, limit) next_shard_iterator, records = self.kinesis_backend.get_records(
shard_iterator, limit)
return json.dumps({ return json.dumps({
"NextShardIterator": next_shard_iterator, "NextShardIterator": next_shard_iterator,
@ -77,7 +79,8 @@ class KinesisResponse(BaseResponse):
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
partition_key = self.parameters.get("PartitionKey") partition_key = self.parameters.get("PartitionKey")
explicit_hash_key = self.parameters.get("ExplicitHashKey") explicit_hash_key = self.parameters.get("ExplicitHashKey")
sequence_number_for_ordering = self.parameters.get("SequenceNumberForOrdering") sequence_number_for_ordering = self.parameters.get(
"SequenceNumberForOrdering")
data = self.parameters.get("Data") data = self.parameters.get("Data")
sequence_number, shard_id = self.kinesis_backend.put_record( sequence_number, shard_id = self.kinesis_backend.put_record(
@ -105,7 +108,7 @@ class KinesisResponse(BaseResponse):
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
shard_to_split = self.parameters.get("ShardToSplit") shard_to_split = self.parameters.get("ShardToSplit")
new_starting_hash_key = self.parameters.get("NewStartingHashKey") new_starting_hash_key = self.parameters.get("NewStartingHashKey")
response = self.kinesis_backend.split_shard( self.kinesis_backend.split_shard(
stream_name, shard_to_split, new_starting_hash_key stream_name, shard_to_split, new_starting_hash_key
) )
return "" return ""
@ -114,15 +117,17 @@ class KinesisResponse(BaseResponse):
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
shard_to_merge = self.parameters.get("ShardToMerge") shard_to_merge = self.parameters.get("ShardToMerge")
adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge") adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge")
response = self.kinesis_backend.merge_shards( self.kinesis_backend.merge_shards(
stream_name, shard_to_merge, adjacent_shard_to_merge stream_name, shard_to_merge, adjacent_shard_to_merge
) )
return "" return ""
''' Firehose ''' ''' Firehose '''
def create_delivery_stream(self): def create_delivery_stream(self):
stream_name = self.parameters['DeliveryStreamName'] stream_name = self.parameters['DeliveryStreamName']
redshift_config = self.parameters.get('RedshiftDestinationConfiguration') redshift_config = self.parameters.get(
'RedshiftDestinationConfiguration')
if redshift_config: if redshift_config:
redshift_s3_config = redshift_config['S3Configuration'] redshift_s3_config = redshift_config['S3Configuration']
@ -149,7 +154,8 @@ class KinesisResponse(BaseResponse):
's3_compression_format': s3_config.get('CompressionFormat'), 's3_compression_format': s3_config.get('CompressionFormat'),
's3_buffering_hings': s3_config['BufferingHints'], 's3_buffering_hings': s3_config['BufferingHints'],
} }
stream = self.kinesis_backend.create_delivery_stream(stream_name, **stream_kwargs) stream = self.kinesis_backend.create_delivery_stream(
stream_name, **stream_kwargs)
return json.dumps({ return json.dumps({
'DeliveryStreamARN': stream.arn 'DeliveryStreamARN': stream.arn
}) })
@ -177,7 +183,8 @@ class KinesisResponse(BaseResponse):
stream_name = self.parameters['DeliveryStreamName'] stream_name = self.parameters['DeliveryStreamName']
record_data = self.parameters['Record']['Data'] record_data = self.parameters['Record']['Data']
record = self.kinesis_backend.put_firehose_record(stream_name, record_data) record = self.kinesis_backend.put_firehose_record(
stream_name, record_data)
return json.dumps({ return json.dumps({
"RecordId": record.record_id, "RecordId": record.record_id,
}) })
@ -188,7 +195,8 @@ class KinesisResponse(BaseResponse):
request_responses = [] request_responses = []
for record in records: for record in records:
record_response = self.kinesis_backend.put_firehose_record(stream_name, record['Data']) record_response = self.kinesis_backend.put_firehose_record(
stream_name, record['Data'])
request_responses.append({ request_responses.append({
"RecordId": record_response.record_id "RecordId": record_response.record_id
}) })
@ -207,7 +215,8 @@ class KinesisResponse(BaseResponse):
stream_name = self.parameters.get('StreamName') stream_name = self.parameters.get('StreamName')
exclusive_start_tag_key = self.parameters.get('ExclusiveStartTagKey') exclusive_start_tag_key = self.parameters.get('ExclusiveStartTagKey')
limit = self.parameters.get('Limit') limit = self.parameters.get('Limit')
response = self.kinesis_backend.list_tags_for_stream(stream_name, exclusive_start_tag_key, limit) response = self.kinesis_backend.list_tags_for_stream(
stream_name, exclusive_start_tag_key, limit)
return json.dumps(response) return json.dumps(response)
def remove_tags_from_stream(self): def remove_tags_from_stream(self):

View File

@ -13,7 +13,8 @@ def compose_new_shard_iterator(stream_name, shard, shard_iterator_type, starting
elif shard_iterator_type == "LATEST": elif shard_iterator_type == "LATEST":
last_sequence_id = shard.get_max_sequence_number() last_sequence_id = shard.get_max_sequence_number()
else: else:
raise InvalidArgumentError("Invalid ShardIteratorType: {0}".format(shard_iterator_type)) raise InvalidArgumentError(
"Invalid ShardIteratorType: {0}".format(shard_iterator_type))
return compose_shard_iterator(stream_name, shard, last_sequence_id) return compose_shard_iterator(stream_name, shard, last_sequence_id)

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import kms_backends from .models import kms_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
kms_backend = kms_backends['us-east-1'] kms_backend = kms_backends['us-east-1']
mock_kms = base_decorator(kms_backends) mock_kms = base_decorator(kms_backends)

View File

@ -7,6 +7,7 @@ from collections import defaultdict
class Key(object): class Key(object):
def __init__(self, policy, key_usage, description, region): def __init__(self, policy, key_usage, description, region):
self.id = generate_key_id() self.id = generate_key_id()
self.policy = policy self.policy = policy
@ -77,7 +78,8 @@ class KmsBackend(BaseBackend):
return self.keys.pop(key_id) return self.keys.pop(key_id)
def describe_key(self, key_id): def describe_key(self, key_id):
# allow the different methods (alias, ARN :key/, keyId, ARN alias) to describe key not just KeyId # allow the different methods (alias, ARN :key/, keyId, ARN alias) to
# describe key not just KeyId
key_id = self.get_key_id(key_id) key_id = self.get_key_id(key_id)
if r'alias/' in str(key_id).lower(): if r'alias/' in str(key_id).lower():
key_id = self.get_key_id_from_alias(key_id.split('alias/')[1]) key_id = self.get_key_id_from_alias(key_id.split('alias/')[1])
@ -128,6 +130,7 @@ class KmsBackend(BaseBackend):
def get_key_policy(self, key_id): def get_key_policy(self, key_id):
return self.keys[self.get_key_id(key_id)].policy return self.keys[self.get_key_id(key_id)].policy
kms_backends = {} kms_backends = {}
for region in boto.kms.regions(): for region in boto.kms.regions():
kms_backends[region.name] = KmsBackend() kms_backends[region.name] = KmsBackend()

View File

@ -18,6 +18,7 @@ reserved_aliases = [
'alias/aws/rds', 'alias/aws/rds',
] ]
class KmsResponse(BaseResponse): class KmsResponse(BaseResponse):
@property @property
@ -33,13 +34,15 @@ class KmsResponse(BaseResponse):
key_usage = self.parameters.get('KeyUsage') key_usage = self.parameters.get('KeyUsage')
description = self.parameters.get('Description') description = self.parameters.get('Description')
key = self.kms_backend.create_key(policy, key_usage, description, self.region) key = self.kms_backend.create_key(
policy, key_usage, description, self.region)
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())
def describe_key(self): def describe_key(self):
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
try: try:
key = self.kms_backend.describe_key(self.kms_backend.get_key_id(key_id)) key = self.kms_backend.describe_key(
self.kms_backend.get_key_id(key_id))
except KeyError: except KeyError:
headers = dict(self.headers) headers = dict(self.headers)
headers['status'] = 404 headers['status'] = 404
@ -70,7 +73,8 @@ class KmsResponse(BaseResponse):
body={'message': 'Invalid identifier', '__type': 'ValidationException'}) body={'message': 'Invalid identifier', '__type': 'ValidationException'})
if alias_name in reserved_aliases: if alias_name in reserved_aliases:
raise JSONResponseError(400, 'Bad Request', body={'__type': 'NotAuthorizedException'}) raise JSONResponseError(400, 'Bad Request', body={
'__type': 'NotAuthorizedException'})
if ':' in alias_name: if ':' in alias_name:
raise JSONResponseError(400, 'Bad Request', body={ raise JSONResponseError(400, 'Bad Request', body={
@ -81,7 +85,7 @@ class KmsResponse(BaseResponse):
raise JSONResponseError(400, 'Bad Request', body={ raise JSONResponseError(400, 'Bad Request', body={
'message': "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$" 'message': "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$"
.format(**locals()), .format(**locals()),
'__type': 'ValidationException'}) '__type': 'ValidationException'})
if self.kms_backend.alias_exists(target_key_id): if self.kms_backend.alias_exists(target_key_id):
raise JSONResponseError(400, 'Bad Request', body={ raise JSONResponseError(400, 'Bad Request', body={
@ -120,7 +124,7 @@ class KmsResponse(BaseResponse):
response_aliases = [ response_aliases = [
{ {
'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region, 'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region,
reserved_alias=reserved_alias), reserved_alias=reserved_alias),
'AliasName': reserved_alias 'AliasName': reserved_alias
} for reserved_alias in reserved_aliases } for reserved_alias in reserved_aliases
] ]
@ -147,7 +151,7 @@ class KmsResponse(BaseResponse):
self.kms_backend.enable_key_rotation(key_id) self.kms_backend.enable_key_rotation(key_id)
except KeyError: except KeyError:
raise JSONResponseError(404, 'Not Found', body={ raise JSONResponseError(404, 'Not Found', body={
'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region,key_id=key_id), 'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region, key_id=key_id),
'__type': 'NotFoundException'}) '__type': 'NotFoundException'})
return json.dumps(None) return json.dumps(None)
@ -159,7 +163,7 @@ class KmsResponse(BaseResponse):
self.kms_backend.disable_key_rotation(key_id) self.kms_backend.disable_key_rotation(key_id)
except KeyError: except KeyError:
raise JSONResponseError(404, 'Not Found', body={ raise JSONResponseError(404, 'Not Found', body={
'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region,key_id=key_id), 'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region, key_id=key_id),
'__type': 'NotFoundException'}) '__type': 'NotFoundException'})
return json.dumps(None) return json.dumps(None)
@ -170,7 +174,7 @@ class KmsResponse(BaseResponse):
rotation_enabled = self.kms_backend.get_key_rotation_status(key_id) rotation_enabled = self.kms_backend.get_key_rotation_status(key_id)
except KeyError: except KeyError:
raise JSONResponseError(404, 'Not Found', body={ raise JSONResponseError(404, 'Not Found', body={
'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region,key_id=key_id), 'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region, key_id=key_id),
'__type': 'NotFoundException'}) '__type': 'NotFoundException'})
return json.dumps({'KeyRotationEnabled': rotation_enabled}) return json.dumps({'KeyRotationEnabled': rotation_enabled})
@ -185,7 +189,7 @@ class KmsResponse(BaseResponse):
self.kms_backend.put_key_policy(key_id, policy) self.kms_backend.put_key_policy(key_id, policy)
except KeyError: except KeyError:
raise JSONResponseError(404, 'Not Found', body={ raise JSONResponseError(404, 'Not Found', body={
'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region,key_id=key_id), 'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region, key_id=key_id),
'__type': 'NotFoundException'}) '__type': 'NotFoundException'})
return json.dumps(None) return json.dumps(None)
@ -200,7 +204,7 @@ class KmsResponse(BaseResponse):
return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)}) return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)})
except KeyError: except KeyError:
raise JSONResponseError(404, 'Not Found', body={ raise JSONResponseError(404, 'Not Found', body={
'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region,key_id=key_id), 'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region, key_id=key_id),
'__type': 'NotFoundException'}) '__type': 'NotFoundException'})
def list_key_policies(self): def list_key_policies(self):
@ -210,7 +214,7 @@ class KmsResponse(BaseResponse):
self.kms_backend.describe_key(key_id) self.kms_backend.describe_key(key_id)
except KeyError: except KeyError:
raise JSONResponseError(404, 'Not Found', body={ raise JSONResponseError(404, 'Not Found', body={
'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region,key_id=key_id), 'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region, key_id=key_id),
'__type': 'NotFoundException'}) '__type': 'NotFoundException'})
return json.dumps({'Truncated': False, 'PolicyNames': ['default']}) return json.dumps({'Truncated': False, 'PolicyNames': ['default']})
@ -233,7 +237,9 @@ class KmsResponse(BaseResponse):
def _assert_valid_key_id(key_id): def _assert_valid_key_id(key_id):
if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE): if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE):
raise JSONResponseError(404, 'Not Found', body={'message': ' Invalid keyId', '__type': 'NotFoundException'}) raise JSONResponseError(404, 'Not Found', body={
'message': ' Invalid keyId', '__type': 'NotFoundException'})
def _assert_default_policy(policy_name): def _assert_default_policy(policy_name):
if policy_name != 'default': if policy_name != 'default':

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import opsworks_backends from .models import opsworks_backends
from ..core.models import MockAWS, base_decorator, HttprettyMockAWS, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
opsworks_backend = opsworks_backends['us-east-1'] opsworks_backend = opsworks_backends['us-east-1']
mock_opsworks = base_decorator(opsworks_backends) mock_opsworks = base_decorator(opsworks_backends)

View File

@ -5,6 +5,7 @@ from werkzeug.exceptions import BadRequest
class ResourceNotFoundException(BadRequest): class ResourceNotFoundException(BadRequest):
def __init__(self, message): def __init__(self, message):
super(ResourceNotFoundException, self).__init__() super(ResourceNotFoundException, self).__init__()
self.description = json.dumps({ self.description = json.dumps({
@ -14,6 +15,7 @@ class ResourceNotFoundException(BadRequest):
class ValidationException(BadRequest): class ValidationException(BadRequest):
def __init__(self, message): def __init__(self, message):
super(ValidationException, self).__init__() super(ValidationException, self).__init__()
self.description = json.dumps({ self.description = json.dumps({

Some files were not shown because too many files have changed in this diff Show More