Cleanup headers and encoding.

This commit is contained in:
Steve Pulec 2017-02-16 22:51:04 -05:00
parent 468a1b970c
commit cad185c74d
19 changed files with 138 additions and 101 deletions

View File

@ -331,7 +331,7 @@ class RestAPI(object):
def update_integration_mocks(self, stage_name): def update_integration_mocks(self, stage_name):
stage_url = STAGE_URL.format(api_id=self.id, region_name=self.region_name, stage_name=stage_name) stage_url = STAGE_URL.format(api_id=self.id, region_name=self.region_name, stage_name=stage_name)
responses.add_callback(responses.GET, stage_url, callback=self.resource_callback) responses.add_callback(responses.GET, stage_url.lower(), callback=self.resource_callback)
def create_stage(self, name, deployment_id,variables=None,description='',cacheClusterEnabled=None,cacheClusterSize=None): def create_stage(self, name, deployment_id,variables=None,description='',cacheClusterEnabled=None,cacheClusterSize=None):
if variables is None: if variables is None:

View File

@ -10,11 +10,11 @@ from .exceptions import StageNotFoundException
class APIGatewayResponse(BaseResponse): class APIGatewayResponse(BaseResponse):
def _get_param(self, key): def _get_param(self, key):
return json.loads(self.body.decode("ascii")).get(key) return json.loads(self.body).get(key)
def _get_param_with_default_value(self, key, default): def _get_param_with_default_value(self, key, default):
jsonbody = json.loads(self.body.decode("ascii")) jsonbody = json.loads(self.body)
if key in jsonbody: if key in jsonbody:
return jsonbody.get(key) return jsonbody.get(key)
@ -30,14 +30,14 @@ class APIGatewayResponse(BaseResponse):
if self.method == 'GET': if self.method == 'GET':
apis = self.backend.list_apis() apis = self.backend.list_apis()
return 200, headers, json.dumps({"item": [ return 200, {}, json.dumps({"item": [
api.to_dict() for api in apis api.to_dict() for api in apis
]}) ]})
elif self.method == 'POST': elif self.method == 'POST':
name = self._get_param('name') name = self._get_param('name')
description = self._get_param('description') description = self._get_param('description')
rest_api = self.backend.create_rest_api(name, description) rest_api = self.backend.create_rest_api(name, description)
return 200, headers, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
def restapis_individual(self, request, full_url, headers): def restapis_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -45,10 +45,10 @@ class APIGatewayResponse(BaseResponse):
if self.method == 'GET': if self.method == 'GET':
rest_api = self.backend.get_rest_api(function_id) rest_api = self.backend.get_rest_api(function_id)
return 200, headers, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
elif self.method == 'DELETE': elif self.method == 'DELETE':
rest_api = self.backend.delete_rest_api(function_id) rest_api = self.backend.delete_rest_api(function_id)
return 200, headers, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
def resources(self, request, full_url, headers): def resources(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -56,7 +56,7 @@ class APIGatewayResponse(BaseResponse):
if self.method == 'GET': if self.method == 'GET':
resources = self.backend.list_resources(function_id) resources = self.backend.list_resources(function_id)
return 200, headers, json.dumps({"item": [ return 200, {}, json.dumps({"item": [
resource.to_dict() for resource in resources resource.to_dict() for resource in resources
]}) ]})
@ -72,7 +72,7 @@ class APIGatewayResponse(BaseResponse):
resource = self.backend.create_resource(function_id, resource_id, path_part) resource = self.backend.create_resource(function_id, resource_id, path_part)
elif self.method == 'DELETE': elif self.method == 'DELETE':
resource = self.backend.delete_resource(function_id, resource_id) resource = self.backend.delete_resource(function_id, resource_id)
return 200, headers, json.dumps(resource.to_dict()) return 200, {}, json.dumps(resource.to_dict())
def resource_methods(self, request, full_url, headers): def resource_methods(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -83,11 +83,11 @@ class APIGatewayResponse(BaseResponse):
if self.method == 'GET': if self.method == 'GET':
method = self.backend.get_method(function_id, resource_id, method_type) method = self.backend.get_method(function_id, resource_id, method_type)
return 200, headers, json.dumps(method) return 200, {}, json.dumps(method)
elif self.method == 'PUT': elif self.method == 'PUT':
authorization_type = self._get_param("authorizationType") authorization_type = self._get_param("authorizationType")
method = self.backend.create_method(function_id, resource_id, method_type, authorization_type) method = self.backend.create_method(function_id, resource_id, method_type, authorization_type)
return 200, headers, json.dumps(method) return 200, {}, json.dumps(method)
def resource_method_responses(self, request, full_url, headers): def resource_method_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -103,7 +103,7 @@ class APIGatewayResponse(BaseResponse):
method_response = self.backend.create_method_response(function_id, resource_id, method_type, response_code) method_response = self.backend.create_method_response(function_id, resource_id, method_type, response_code)
elif self.method == 'DELETE': elif self.method == 'DELETE':
method_response = self.backend.delete_method_response(function_id, resource_id, method_type, response_code) method_response = self.backend.delete_method_response(function_id, resource_id, method_type, response_code)
return 200, headers, json.dumps(method_response) return 200, {}, json.dumps(method_response)
def restapis_stages(self, request, full_url, headers): def restapis_stages(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -123,9 +123,9 @@ class APIGatewayResponse(BaseResponse):
cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize) cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize)
elif self.method == 'GET': elif self.method == 'GET':
stages = self.backend.get_stages(function_id) stages = self.backend.get_stages(function_id)
return 200, headers, json.dumps({"item": stages}) return 200, {}, json.dumps({"item": stages})
return 200, headers, json.dumps(stage_response) return 200, {}, json.dumps(stage_response)
def stages(self, request, full_url, headers): def stages(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -137,11 +137,11 @@ class APIGatewayResponse(BaseResponse):
try: try:
stage_response = self.backend.get_stage(function_id, stage_name) stage_response = self.backend.get_stage(function_id, stage_name)
except StageNotFoundException as error: except StageNotFoundException as error:
return error.code, headers,'{{"message":"{0}","code":"{1}"}}'.format(error.message,error.error_type) return error.code, {},'{{"message":"{0}","code":"{1}"}}'.format(error.message,error.error_type)
elif self.method == 'PATCH': elif self.method == 'PATCH':
patch_operations = self._get_param('patchOperations') patch_operations = self._get_param('patchOperations')
stage_response = self.backend.update_stage(function_id, stage_name, patch_operations) stage_response = self.backend.update_stage(function_id, stage_name, patch_operations)
return 200, headers, json.dumps(stage_response) return 200, {}, json.dumps(stage_response)
def integrations(self, request, full_url, headers): def integrations(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -159,7 +159,7 @@ class APIGatewayResponse(BaseResponse):
integration_response = self.backend.create_integration(function_id, resource_id, method_type, integration_type, uri, request_templates=request_templates) integration_response = self.backend.create_integration(function_id, resource_id, method_type, integration_type, uri, request_templates=request_templates)
elif self.method == 'DELETE': elif self.method == 'DELETE':
integration_response = self.backend.delete_integration(function_id, resource_id, method_type) integration_response = self.backend.delete_integration(function_id, resource_id, method_type)
return 200, headers, json.dumps(integration_response) return 200, {}, json.dumps(integration_response)
def integration_responses(self, request, full_url, headers): def integration_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -182,7 +182,7 @@ class APIGatewayResponse(BaseResponse):
integration_response = self.backend.delete_integration_response( integration_response = self.backend.delete_integration_response(
function_id, resource_id, method_type, status_code function_id, resource_id, method_type, status_code
) )
return 200, headers, json.dumps(integration_response) return 200, {}, json.dumps(integration_response)
def deployments(self, request, full_url, headers): def deployments(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -190,13 +190,13 @@ class APIGatewayResponse(BaseResponse):
if self.method == 'GET': if self.method == 'GET':
deployments = self.backend.get_deployments(function_id) deployments = self.backend.get_deployments(function_id)
return 200, headers, json.dumps({"item": deployments}) return 200, {}, json.dumps({"item": deployments})
elif self.method == 'POST': elif self.method == 'POST':
name = self._get_param("stageName") name = self._get_param("stageName")
description = self._get_param_with_default_value("description","") description = self._get_param_with_default_value("description","")
stage_variables = self._get_param_with_default_value('variables',{}) stage_variables = self._get_param_with_default_value('variables',{})
deployment = self.backend.create_deployment(function_id, name, description,stage_variables) deployment = self.backend.create_deployment(function_id, name, description,stage_variables)
return 200, headers, json.dumps(deployment) return 200, {}, json.dumps(deployment)
def individual_deployment(self, request, full_url, headers): def individual_deployment(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -208,4 +208,4 @@ class APIGatewayResponse(BaseResponse):
deployment = self.backend.get_deployment(function_id, deployment_id) deployment = self.backend.get_deployment(function_id, deployment_id)
elif self.method == 'DELETE': elif self.method == 'DELETE':
deployment = self.backend.delete_deployment(function_id, deployment_id) deployment = self.backend.delete_deployment(function_id, deployment_id)
return 200, headers, json.dumps(deployment) return 200, {}, json.dumps(deployment)

View File

@ -196,7 +196,7 @@ class LambdaBackend(BaseBackend):
def __init__(self): def __init__(self):
self._functions = {} self._functions = {}
def has_function(self, function_name): def has_function(self, function_name):
return function_name in self._functions return function_name in self._functions

View File

@ -36,6 +36,7 @@ class LambdaResponse(BaseResponse):
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def _invoke(self, request, full_url, headers): def _invoke(self, request, full_url, headers):
response_headers = {}
lambda_backend = self.get_lambda_backend(full_url) lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url path = request.path if hasattr(request, 'path') else request.path_url
@ -43,15 +44,15 @@ class LambdaResponse(BaseResponse):
if lambda_backend.has_function(function_name): if lambda_backend.has_function(function_name):
fn = lambda_backend.get_function(function_name) fn = lambda_backend.get_function(function_name)
payload = fn.invoke(request, headers) payload = fn.invoke(request, response_headers)
headers['Content-Length'] = str(len(payload)) response_headers['Content-Length'] = str(len(payload))
return 202, headers, payload return 202, response_headers, payload
else: else:
return 404, headers, "{}" return 404, response_headers, "{}"
def _list_functions(self, request, full_url, headers): def _list_functions(self, request, full_url, headers):
lambda_backend = self.get_lambda_backend(full_url) lambda_backend = self.get_lambda_backend(full_url)
return 200, headers, json.dumps({ return 200, {}, json.dumps({
"Functions": [fn.get_configuration() for fn in lambda_backend.list_functions()], "Functions": [fn.get_configuration() for fn in lambda_backend.list_functions()],
# "NextMarker": str(uuid.uuid4()), # "NextMarker": str(uuid.uuid4()),
}) })
@ -62,10 +63,10 @@ class LambdaResponse(BaseResponse):
try: try:
fn = lambda_backend.create_function(spec) fn = lambda_backend.create_function(spec)
except ValueError as e: except ValueError as e:
return 400, headers, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}) return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}})
else: else:
config = fn.get_configuration() config = fn.get_configuration()
return 201, headers, json.dumps(config) return 201, {}, json.dumps(config)
def _delete_function(self, request, full_url, headers): def _delete_function(self, request, full_url, headers):
lambda_backend = self.get_lambda_backend(full_url) lambda_backend = self.get_lambda_backend(full_url)
@ -75,9 +76,9 @@ class LambdaResponse(BaseResponse):
if lambda_backend.has_function(function_name): if lambda_backend.has_function(function_name):
lambda_backend.delete_function(function_name) lambda_backend.delete_function(function_name)
return 204, headers, "" return 204, {}, ""
else: else:
return 404, headers, "{}" return 404, {}, "{}"
def _get_function(self, request, full_url, headers): def _get_function(self, request, full_url, headers):
lambda_backend = self.get_lambda_backend(full_url) lambda_backend = self.get_lambda_backend(full_url)
@ -88,9 +89,9 @@ class LambdaResponse(BaseResponse):
if lambda_backend.has_function(function_name): if lambda_backend.has_function(function_name):
fn = lambda_backend.get_function(function_name) fn = lambda_backend.get_function(function_name)
code = fn.get_code() code = fn.get_code()
return 200, headers, json.dumps(code) return 200, {}, json.dumps(code)
else: else:
return 404, headers, "{}" return 404, {}, "{}"
def get_lambda_backend(self, full_url): def get_lambda_backend(self, full_url):
from moto.awslambda.models import lambda_backends from moto.awslambda.models import lambda_backends

View File

@ -8,7 +8,11 @@ import re
from moto.packages.responses import responses from moto.packages.responses import responses
from moto.packages.httpretty import HTTPretty from moto.packages.httpretty import HTTPretty
from .responses import metadata_response from .responses import metadata_response
from .utils import convert_regex_to_flask_path, convert_flask_to_responses_response from .utils import (
convert_httpretty_response,
convert_regex_to_flask_path,
convert_flask_to_responses_response,
)
class BaseMockAWS(object): class BaseMockAWS(object):
nested_count = 0 nested_count = 0
@ -93,14 +97,14 @@ class HttprettyMockAWS(BaseMockAWS):
HTTPretty.register_uri( HTTPretty.register_uri(
method=method, method=method,
uri=re.compile(key), uri=re.compile(key),
body=value, body=convert_httpretty_response(value),
) )
# Mock out localhost instance metadata # Mock out localhost instance metadata
HTTPretty.register_uri( HTTPretty.register_uri(
method=method, method=method,
uri=re.compile('http://169.254.169.254/latest/meta-data/.*'), uri=re.compile('http://169.254.169.254/latest/meta-data/.*'),
body=metadata_response body=convert_httpretty_response(metadata_response),
) )
def disable_patching(self): def disable_patching(self):

View File

@ -123,14 +123,14 @@ class BaseResponse(_TemplateEnvironmentMixin):
for key, value in request.form.items(): for key, value in request.form.items():
querystring[key] = [value, ] querystring[key] = [value, ]
if isinstance(self.body, six.binary_type):
self.body = self.body.decode('utf-8')
if not querystring: if not querystring:
querystring.update(parse_qs(urlparse(full_url).query, keep_blank_values=True)) querystring.update(parse_qs(urlparse(full_url).query, keep_blank_values=True))
if not querystring: if not querystring:
if 'json' in request.headers.get('content-type', []) and self.aws_service_spec: if 'json' in request.headers.get('content-type', []) and self.aws_service_spec:
if isinstance(self.body, six.binary_type): decoded = json.loads(self.body)
decoded = json.loads(self.body.decode('utf-8'))
else:
decoded = json.loads(self.body)
target = request.headers.get('x-amz-target') or request.headers.get('X-Amz-Target') target = request.headers.get('x-amz-target') or request.headers.get('X-Amz-Target')
service, method = target.split('.') service, method = target.split('.')
@ -154,7 +154,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
self.headers = request.headers self.headers = request.headers
if 'host' not in self.headers: if 'host' not in self.headers:
self.headers['host'] = urlparse(full_url).netloc self.headers['host'] = urlparse(full_url).netloc
self.response_headers = headers self.response_headers = {"server": "amazon.com"}
def get_region_from_url(self, full_url): def get_region_from_url(self, full_url):
match = re.search(self.region_regex, full_url) match = re.search(self.region_regex, full_url)

View File

@ -79,6 +79,29 @@ def convert_regex_to_flask_path(url_path):
return url_path return url_path
class convert_httpretty_response(object):
def __init__(self, callback):
self.callback = callback
@property
def __name__(self):
# For instance methods, use class and method names. Otherwise
# use module and method name
if inspect.ismethod(self.callback):
outer = self.callback.__self__.__class__.__name__
else:
outer = self.callback.__module__
return "{0}.{1}".format(outer, self.callback.__name__)
def __call__(self, request, url, headers, **kwargs):
result = self.callback(request, url, headers)
status, headers, response = result
if 'server' not in headers:
headers["server"] = "amazon.com"
return status, headers, response
class convert_flask_to_httpretty_response(object): class convert_flask_to_httpretty_response(object):
def __init__(self, callback): def __init__(self, callback):
@ -119,8 +142,11 @@ class convert_flask_to_responses_response(object):
return "{0}.{1}".format(outer, self.callback.__name__) return "{0}.{1}".format(outer, self.callback.__name__)
def __call__(self, request, *args, **kwargs): def __call__(self, request, *args, **kwargs):
for key, val in request.headers.items():
if isinstance(val, six.binary_type):
request.headers[key] = val.decode("utf-8")
result = self.callback(request, request.url, request.headers) result = self.callback(request, request.url, request.headers)
# result is a status, headers, response tuple
status, headers, response = result status, headers, response = result
return status, headers, response return status, headers, response

View File

@ -12,7 +12,7 @@ class DataPipelineResponse(BaseResponse):
def parameters(self): def parameters(self):
# TODO this should really be moved to core/responses.py # TODO this should really be moved to core/responses.py
if self.body: if self.body:
return json.loads(self.body.decode("utf-8")) return json.loads(self.body)
else: else:
return self.querystring return self.querystring

View File

@ -51,7 +51,7 @@ class DynamoHandler(BaseResponse):
return status, self.response_headers, dynamo_json_dump({'__type': type_}) return status, self.response_headers, dynamo_json_dump({'__type': type_})
def call_action(self): def call_action(self):
body = self.body.decode('utf-8') body = self.body
if 'GetSessionToken' in body: if 'GetSessionToken' in body:
return 200, self.response_headers, sts_handler() return 200, self.response_headers, sts_handler()

View File

@ -52,7 +52,7 @@ class DynamoHandler(BaseResponse):
return status, self.response_headers, dynamo_json_dump({'__type': type_}) return status, self.response_headers, dynamo_json_dump({'__type': type_})
def call_action(self): def call_action(self):
body = self.body.decode('utf-8') body = self.body
if 'GetSessionToken' in body: if 'GetSessionToken' in body:
return 200, self.response_headers, sts_handler() return 200, self.response_headers, sts_handler()

View File

@ -14,7 +14,7 @@ class EC2ContainerServiceResponse(BaseResponse):
@property @property
def request_params(self): def request_params(self):
try: try:
return json.loads(self.body.decode()) return json.loads(self.body)
except ValueError: except ValueError:
return {} return {}

View File

@ -19,7 +19,7 @@ class EventsHandler(BaseResponse):
} }
def load_body(self): def load_body(self):
decoded_body = self.body.decode('utf-8') decoded_body = self.body
return json.loads(decoded_body or '{}') return json.loads(decoded_body or '{}')
def error(self, type_, message='', status=400): def error(self, type_, message='', status=400):

View File

@ -11,7 +11,7 @@ class KinesisResponse(BaseResponse):
@property @property
def parameters(self): def parameters(self):
return json.loads(self.body.decode("utf-8")) return json.loads(self.body)
@property @property
def kinesis_backend(self): def kinesis_backend(self):

View File

@ -22,7 +22,7 @@ class KmsResponse(BaseResponse):
@property @property
def parameters(self): def parameters(self):
return json.loads(self.body.decode("utf-8")) return json.loads(self.body)
@property @property
def kms_backend(self): def kms_backend(self):

View File

@ -10,7 +10,7 @@ class OpsWorksResponse(BaseResponse):
@property @property
def parameters(self): def parameters(self):
return json.loads(self.body.decode("utf-8")) return json.loads(self.body)
@property @property
def opsworks_backend(self): def opsworks_backend(self):

View File

@ -104,10 +104,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
try: try:
response = self._bucket_response(request, full_url, headers) response = self._bucket_response(request, full_url, headers)
except S3ClientError as s3error: except S3ClientError as s3error:
response = s3error.code, headers, s3error.description response = s3error.code, {}, s3error.description
if isinstance(response, six.string_types): if isinstance(response, six.string_types):
return 200, headers, response.encode("utf-8") return 200, {}, response.encode("utf-8")
else: else:
status_code, headers, response_content = response status_code, headers, response_content = response
return status_code, headers, response_content.encode("utf-8") return status_code, headers, response_content.encode("utf-8")
@ -133,8 +133,9 @@ class ResponseObject(_TemplateEnvironmentMixin):
# Flask server # Flask server
body = request.data body = request.data
if body is None: if body is None:
body = '' body = b''
body = body.decode('utf-8') if isinstance(body, six.binary_type):
body = body.decode('utf-8')
if method == 'HEAD': if method == 'HEAD':
return self._bucket_response_head(bucket_name, headers) return self._bucket_response_head(bucket_name, headers)
@ -151,7 +152,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
def _bucket_response_head(self, bucket_name, headers): def _bucket_response_head(self, bucket_name, headers):
self.backend.get_bucket(bucket_name) self.backend.get_bucket(bucket_name)
return 200, headers, "" return 200, {}, ""
def _bucket_response_get(self, bucket_name, querystring, headers): def _bucket_response_get(self, bucket_name, querystring, headers):
if 'uploads' in querystring: if 'uploads' in querystring:
@ -173,7 +174,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
elif 'lifecycle' in querystring: elif 'lifecycle' in querystring:
bucket = self.backend.get_bucket(bucket_name) bucket = self.backend.get_bucket(bucket_name)
if not bucket.rules: if not bucket.rules:
return 404, headers, "NoSuchLifecycleConfiguration" return 404, {}, "NoSuchLifecycleConfiguration"
template = self.response_template(S3_BUCKET_LIFECYCLE_CONFIGURATION) template = self.response_template(S3_BUCKET_LIFECYCLE_CONFIGURATION)
return template.render(rules=bucket.rules) return template.render(rules=bucket.rules)
elif 'versioning' in querystring: elif 'versioning' in querystring:
@ -184,8 +185,8 @@ class ResponseObject(_TemplateEnvironmentMixin):
policy = self.backend.get_bucket_policy(bucket_name) policy = self.backend.get_bucket_policy(bucket_name)
if not policy: if not policy:
template = self.response_template(S3_NO_POLICY) template = self.response_template(S3_NO_POLICY)
return 404, headers, template.render(bucket_name=bucket_name) return 404, {}, template.render(bucket_name=bucket_name)
return 200, headers, policy return 200, {}, policy
elif 'website' in querystring: elif 'website' in querystring:
website_configuration = self.backend.get_bucket_website_configuration(bucket_name) website_configuration = self.backend.get_bucket_website_configuration(bucket_name)
return website_configuration return website_configuration
@ -211,7 +212,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
version_id_marker=version_id_marker version_id_marker=version_id_marker
) )
template = self.response_template(S3_BUCKET_GET_VERSIONS) template = self.response_template(S3_BUCKET_GET_VERSIONS)
return 200, headers, template.render( return 200, {}, template.render(
key_list=versions, key_list=versions,
bucket=bucket, bucket=bucket,
prefix='', prefix='',
@ -220,14 +221,14 @@ class ResponseObject(_TemplateEnvironmentMixin):
is_truncated='false', is_truncated='false',
) )
elif querystring.get('list-type', [None])[0] == '2': elif querystring.get('list-type', [None])[0] == '2':
return 200, headers, self._handle_list_objects_v2(bucket_name, querystring) return 200, {}, self._handle_list_objects_v2(bucket_name, querystring)
bucket = self.backend.get_bucket(bucket_name) bucket = self.backend.get_bucket(bucket_name)
prefix = querystring.get('prefix', [None])[0] prefix = querystring.get('prefix', [None])[0]
delimiter = querystring.get('delimiter', [None])[0] delimiter = querystring.get('delimiter', [None])[0]
result_keys, result_folders = self.backend.prefix_query(bucket, prefix, delimiter) result_keys, result_folders = self.backend.prefix_query(bucket, prefix, delimiter)
template = self.response_template(S3_BUCKET_GET_RESPONSE) template = self.response_template(S3_BUCKET_GET_RESPONSE)
return 200, headers, template.render( return 200, {}, template.render(
bucket=bucket, bucket=bucket,
prefix=prefix, prefix=prefix,
delimiter=delimiter, delimiter=delimiter,
@ -286,7 +287,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
template = self.response_template(S3_BUCKET_VERSIONING) template = self.response_template(S3_BUCKET_VERSIONING)
return template.render(bucket_versioning_status=ver.group(1)) return template.render(bucket_versioning_status=ver.group(1))
else: else:
return 404, headers, "" return 404, {}, ""
elif 'lifecycle' in querystring: elif 'lifecycle' in querystring:
rules = xmltodict.parse(body)['LifecycleConfiguration']['Rule'] rules = xmltodict.parse(body)['LifecycleConfiguration']['Rule']
if not isinstance(rules, list): if not isinstance(rules, list):
@ -315,27 +316,27 @@ class ResponseObject(_TemplateEnvironmentMixin):
else: else:
raise raise
template = self.response_template(S3_BUCKET_CREATE_RESPONSE) template = self.response_template(S3_BUCKET_CREATE_RESPONSE)
return 200, headers, template.render(bucket=new_bucket) return 200, {}, template.render(bucket=new_bucket)
def _bucket_response_delete(self, body, bucket_name, querystring, headers): def _bucket_response_delete(self, body, bucket_name, querystring, headers):
if 'policy' in querystring: if 'policy' in querystring:
self.backend.delete_bucket_policy(bucket_name, body) self.backend.delete_bucket_policy(bucket_name, body)
return 204, headers, "" return 204, {}, ""
elif 'lifecycle' in querystring: elif 'lifecycle' in querystring:
bucket = self.backend.get_bucket(bucket_name) bucket = self.backend.get_bucket(bucket_name)
bucket.delete_lifecycle() bucket.delete_lifecycle()
return 204, headers, "" return 204, {}, ""
removed_bucket = self.backend.delete_bucket(bucket_name) removed_bucket = self.backend.delete_bucket(bucket_name)
if removed_bucket: if removed_bucket:
# Bucket exists # Bucket exists
template = self.response_template(S3_DELETE_BUCKET_SUCCESS) template = self.response_template(S3_DELETE_BUCKET_SUCCESS)
return 204, headers, template.render(bucket=removed_bucket) return 204, {}, template.render(bucket=removed_bucket)
else: else:
# Tried to delete a bucket that still has keys # Tried to delete a bucket that still has keys
template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR) template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR)
return 409, headers, template.render(bucket=removed_bucket) return 409, {}, template.render(bucket=removed_bucket)
def _bucket_response_post(self, request, body, bucket_name, headers): def _bucket_response_post(self, request, body, bucket_name, headers):
path = request.path if hasattr(request, 'path') else request.path_url path = request.path if hasattr(request, 'path') else request.path_url
@ -349,7 +350,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
else: else:
# HTTPretty, build new form object # HTTPretty, build new form object
form = {} form = {}
for kv in body.decode('utf-8').split('&'): for kv in body.split('&'):
k, v = kv.split('=') k, v = kv.split('=')
form[k] = v form[k] = v
@ -365,7 +366,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
metadata = metadata_from_headers(form) metadata = metadata_from_headers(form)
new_key.set_metadata(metadata) new_key.set_metadata(metadata)
return 200, headers, "" return 200, {}, ""
def _bucket_response_delete_keys(self, request, body, bucket_name, headers): def _bucket_response_delete_keys(self, request, body, bucket_name, headers):
template = self.response_template(S3_DELETE_KEYS_RESPONSE) template = self.response_template(S3_DELETE_KEYS_RESPONSE)
@ -382,9 +383,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
else: else:
error_names.append(key_name) error_names.append(key_name)
return 200, headers, template.render(deleted=deleted_names, delete_errors=error_names) return 200, {}, template.render(deleted=deleted_names, delete_errors=error_names)
def _handle_range_header(self, request, headers, response_content): def _handle_range_header(self, request, headers, response_content):
response_headers = {}
length = len(response_content) length = len(response_content)
last = length - 1 last = length - 1
_, rspec = request.headers.get('range').split('=') _, rspec = request.headers.get('range').split('=')
@ -399,28 +401,29 @@ class ResponseObject(_TemplateEnvironmentMixin):
begin = length - min(end, length) begin = length - min(end, length)
end = last end = last
else: else:
return 400, headers, "" return 400, response_headers, ""
if begin < 0 or end > last or begin > min(end, last): if begin < 0 or end > last or begin > min(end, last):
return 416, headers, "" return 416, response_headers, ""
headers['content-range'] = "bytes {0}-{1}/{2}".format( response_headers['content-range'] = "bytes {0}-{1}/{2}".format(
begin, end, length) begin, end, length)
return 206, headers, response_content[begin:end + 1] return 206, response_headers, response_content[begin:end + 1]
def key_response(self, request, full_url, headers): def key_response(self, request, full_url, headers):
response_headers = {}
try: try:
response = self._key_response(request, full_url, headers) response = self._key_response(request, full_url, headers)
except S3ClientError as s3error: except S3ClientError as s3error:
response = s3error.code, headers, s3error.description response = s3error.code, {}, s3error.description
if isinstance(response, six.string_types): if isinstance(response, six.string_types):
status_code = 200 status_code = 200
response_content = response response_content = response
else: else:
status_code, headers, response_content = response status_code, response_headers, response_content = response
if status_code == 200 and 'range' in request.headers: if status_code == 200 and 'range' in request.headers:
return self._handle_range_header(request, headers, response_content) return self._handle_range_header(request, response_headers, response_content)
return status_code, headers, response_content return status_code, response_headers, response_content
def _key_response(self, request, full_url, headers): def _key_response(self, request, full_url, headers):
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
@ -455,11 +458,12 @@ class ResponseObject(_TemplateEnvironmentMixin):
raise NotImplementedError("Method {0} has not been impelemented in the S3 backend yet".format(method)) raise NotImplementedError("Method {0} has not been impelemented in the S3 backend yet".format(method))
def _key_response_get(self, bucket_name, query, key_name, headers): def _key_response_get(self, bucket_name, query, key_name, headers):
response_headers = {}
if query.get('uploadId'): if query.get('uploadId'):
upload_id = query['uploadId'][0] upload_id = query['uploadId'][0]
parts = self.backend.list_multipart(bucket_name, upload_id) parts = self.backend.list_multipart(bucket_name, upload_id)
template = self.response_template(S3_MULTIPART_LIST_RESPONSE) template = self.response_template(S3_MULTIPART_LIST_RESPONSE)
return 200, headers, template.render( return 200, response_headers, template.render(
bucket_name=bucket_name, bucket_name=bucket_name,
key_name=key_name, key_name=key_name,
upload_id=upload_id, upload_id=upload_id,
@ -471,13 +475,14 @@ class ResponseObject(_TemplateEnvironmentMixin):
bucket_name, key_name, version_id=version_id) bucket_name, key_name, version_id=version_id)
if 'acl' in query: if 'acl' in query:
template = self.response_template(S3_OBJECT_ACL_RESPONSE) template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return 200, headers, template.render(obj=key) return 200, response_headers, template.render(obj=key)
headers.update(key.metadata) response_headers.update(key.metadata)
headers.update(key.response_dict) response_headers.update(key.response_dict)
return 200, headers, key.value return 200, response_headers, key.value
def _key_response_put(self, request, body, bucket_name, query, key_name, headers): def _key_response_put(self, request, body, bucket_name, query, key_name, headers):
response_headers = {}
if query.get('uploadId') and query.get('partNumber'): if query.get('uploadId') and query.get('partNumber'):
upload_id = query['uploadId'][0] upload_id = query['uploadId'][0]
part_number = int(query['partNumber'][0]) part_number = int(query['partNumber'][0])
@ -501,8 +506,8 @@ class ResponseObject(_TemplateEnvironmentMixin):
key = self.backend.set_part( key = self.backend.set_part(
bucket_name, upload_id, part_number, body) bucket_name, upload_id, part_number, body)
response = "" response = ""
headers.update(key.response_dict) response_headers.update(key.response_dict)
return 200, headers, response return 200, response_headers, response
storage_class = request.headers.get('x-amz-storage-class', 'STANDARD') storage_class = request.headers.get('x-amz-storage-class', 'STANDARD')
acl = self._acl_from_headers(request.headers) acl = self._acl_from_headers(request.headers)
@ -511,7 +516,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
key = self.backend.get_key(bucket_name, key_name) key = self.backend.get_key(bucket_name, key_name)
# TODO: Support the XML-based ACL format # TODO: Support the XML-based ACL format
key.set_acl(acl) key.set_acl(acl)
return 200, headers, "" return 200, response_headers, ""
if 'x-amz-copy-source' in request.headers: if 'x-amz-copy-source' in request.headers:
# Copy key # Copy key
@ -526,8 +531,8 @@ class ResponseObject(_TemplateEnvironmentMixin):
metadata = metadata_from_headers(request.headers) metadata = metadata_from_headers(request.headers)
new_key.set_metadata(metadata, replace=True) new_key.set_metadata(metadata, replace=True)
template = self.response_template(S3_OBJECT_COPY_RESPONSE) template = self.response_template(S3_OBJECT_COPY_RESPONSE)
headers.update(new_key.response_dict) response_headers.update(new_key.response_dict)
return 200, headers, template.render(key=new_key) return 200, response_headers, template.render(key=new_key)
streaming_request = hasattr(request, 'streaming') and request.streaming streaming_request = hasattr(request, 'streaming') and request.streaming
closing_connection = headers.get('connection') == 'close' closing_connection = headers.get('connection') == 'close'
if closing_connection and streaming_request: if closing_connection and streaming_request:
@ -546,18 +551,19 @@ class ResponseObject(_TemplateEnvironmentMixin):
new_key.set_acl(acl) new_key.set_acl(acl)
template = self.response_template(S3_OBJECT_RESPONSE) template = self.response_template(S3_OBJECT_RESPONSE)
headers.update(new_key.response_dict) response_headers.update(new_key.response_dict)
return 200, headers, template.render(key=new_key) return 200, response_headers, template.render(key=new_key)
def _key_response_head(self, bucket_name, query, key_name, headers): def _key_response_head(self, bucket_name, query, key_name, headers):
response_headers = {}
version_id = query.get('versionId', [None])[0] version_id = query.get('versionId', [None])[0]
key = self.backend.get_key(bucket_name, key_name, version_id=version_id) key = self.backend.get_key(bucket_name, key_name, version_id=version_id)
if key: if key:
headers.update(key.metadata) response_headers.update(key.metadata)
headers.update(key.response_dict) response_headers.update(key.response_dict)
return 200, headers, "" return 200, response_headers, ""
else: else:
return 404, headers, "" return 404, response_headers, ""
def _acl_from_headers(self, headers): def _acl_from_headers(self, headers):
canned_acl = headers.get('x-amz-acl', '') canned_acl = headers.get('x-amz-acl', '')
@ -595,10 +601,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
if query.get('uploadId'): if query.get('uploadId'):
upload_id = query['uploadId'][0] upload_id = query['uploadId'][0]
self.backend.cancel_multipart(bucket_name, upload_id) self.backend.cancel_multipart(bucket_name, upload_id)
return 204, headers, "" return 204, {}, ""
self.backend.delete_key(bucket_name, key_name) self.backend.delete_key(bucket_name, key_name)
template = self.response_template(S3_DELETE_OBJECT_SUCCESS) template = self.response_template(S3_DELETE_OBJECT_SUCCESS)
return 204, headers, template.render() return 204, {}, template.render()
def _complete_multipart_body(self, body): def _complete_multipart_body(self, body):
ps = minidom.parseString(body).getElementsByTagName('Part') ps = minidom.parseString(body).getElementsByTagName('Part')
@ -620,7 +626,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
key_name=key_name, key_name=key_name,
upload_id=multipart.id, upload_id=multipart.id,
) )
return 200, headers, response return 200, {}, response
if query.get('uploadId'): if query.get('uploadId'):
body = self._complete_multipart_body(body) body = self._complete_multipart_body(body)
@ -640,7 +646,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
if key.expiry_date is not None: if key.expiry_date is not None:
r = 200 r = 200
key.restore(int(days)) key.restore(int(days))
return r, headers, "" return r, {}, ""
else: else:
raise NotImplementedError("Method POST had only been implemented for multipart uploads and restore operations, so far") raise NotImplementedError("Method POST had only been implemented for multipart uploads and restore operations, so far")

View File

@ -16,7 +16,7 @@ class SWFResponse(BaseResponse):
# SWF parameters are passed through a JSON body, so let's ease retrieval # SWF parameters are passed through a JSON body, so let's ease retrieval
@property @property
def _params(self): def _params(self):
return json.loads(self.body.decode("utf-8")) return json.loads(self.body)
def _check_int(self, parameter): def _check_int(self, parameter):
if not isinstance(parameter, int): if not isinstance(parameter, int):

View File

@ -70,7 +70,7 @@ def test_publish_to_http():
last_request = responses.calls[-1].request last_request = responses.calls[-1].request
last_request.method.should.equal("POST") last_request.method.should.equal("POST")
parse_qs(last_request.body.decode('utf-8')).should.equal({ parse_qs(last_request.body).should.equal({
"Type": ["Notification"], "Type": ["Notification"],
"MessageId": [message_id], "MessageId": [message_id],
"TopicArn": ["arn:aws:sns:{0}:123456789012:some-topic".format(conn.region.name)], "TopicArn": ["arn:aws:sns:{0}:123456789012:some-topic".format(conn.region.name)],

View File

@ -75,7 +75,7 @@ def test_publish_to_http():
last_request = responses.calls[-2].request last_request = responses.calls[-2].request
last_request.method.should.equal("POST") last_request.method.should.equal("POST")
parse_qs(last_request.body.decode('utf-8')).should.equal({ parse_qs(last_request.body).should.equal({
"Type": ["Notification"], "Type": ["Notification"],
"MessageId": [message_id], "MessageId": [message_id],
"TopicArn": ["arn:aws:sns:{0}:123456789012:some-topic".format(conn._client_config.region_name)], "TopicArn": ["arn:aws:sns:{0}:123456789012:some-topic".format(conn._client_config.region_name)],