diff --git a/Makefile b/Makefile index 0420a9ea3..fb83906d6 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,9 @@ SHELL := /bin/bash init: - python setup.py develop - pip install -r requirements.txt + @python setup.py develop + @pip install -r requirements.txt test: rm -f .coverage - nosetests --with-coverage ./tests/ + @nosetests --with-coverage ./tests/ diff --git a/moto/core/models.py b/moto/core/models.py index e98c1eed3..c451fb11d 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -1,7 +1,7 @@ import functools import re -from moto.packages.httpretty import HTTPretty +from httpretty import HTTPretty from .responses import metadata_response from .utils import convert_regex_to_flask_path diff --git a/moto/core/responses.py b/moto/core/responses.py index d74bcd2e4..a25f5f26a 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -1,50 +1,75 @@ import datetime import json -from urlparse import parse_qs +from urlparse import parse_qs, urlparse from moto.core.utils import headers_to_dict, camelcase_to_underscores, method_names_from_class class BaseResponse(object): - def dispatch(self, uri, method, body, headers): - if body: - querystring = parse_qs(body) + + def dispatch(self, request, full_url, headers): + if hasattr(request, 'body'): + # Boto + self.body = request.body else: + # Flask server + self.body = request.data + + querystring = parse_qs(urlparse(full_url).query) + if not querystring: + querystring = parse_qs(self.body) + if not querystring: querystring = headers_to_dict(headers) - self.path = uri.path + self.uri = full_url + self.path = urlparse(full_url).path self.querystring = querystring + self.method = request.method - action = querystring.get('Action', [""])[0] + self.headers = dict(request.headers) + self.response_headers = headers + return self.call_action() + + def call_action(self): + headers = self.response_headers + action = self.querystring.get('Action', [""])[0] action = camelcase_to_underscores(action) - method_names = method_names_from_class(self.__class__) if action in method_names: method = getattr(self, action) - return method() + response = method() + if isinstance(response, basestring): + return 200, headers, response + else: + body, new_headers = response + status = new_headers.pop('status', 200) + headers.update(new_headers) + return status, headers, body raise NotImplementedError("The {} action has not been implemented".format(action)) -def metadata_response(uri, method, body, headers): +def metadata_response(request, full_url, headers): """ Mock response for localhost metadata http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AESDG-chapter-instancedata.html """ + parsed_url = urlparse(full_url) tomorrow = datetime.datetime.now() + datetime.timedelta(days=1) - path = uri.path.lstrip("/latest/meta-data/") + path = parsed_url.path.lstrip("/latest/meta-data/") if path == '': - return "iam/" + result = "iam/" elif path == 'iam/': - return 'security-credentials/' + result = 'security-credentials/' elif path == 'iam/security-credentials/': - return 'default-role' + result = 'default-role' elif path == 'iam/security-credentials/default-role': - return json.dumps(dict( + result = json.dumps(dict( AccessKeyId="test-key", SecretAccessKey="test-secret-key", Token="test-session-token", Expiration=tomorrow.strftime("%Y-%m-%dT%H:%M:%SZ") )) + return 200, headers, result diff --git a/moto/core/utils.py b/moto/core/utils.py index 8532698cb..13aca14b0 100644 --- a/moto/core/utils.py +++ b/moto/core/utils.py @@ -1,4 +1,3 @@ -from collections import namedtuple import inspect import random import re @@ -91,23 +90,12 @@ class convert_flask_to_httpretty_response(object): return "{}.{}".format(outer, self.callback.__name__) def __call__(self, args=None, **kwargs): - hostname = request.host_url - method = request.method - path = request.path - query = request.query_string - - # Mimic the HTTPretty URIInfo class - URI = namedtuple('URI', 'hostname method path query') - uri = URI(hostname, method, path, query) - - body = request.data or query headers = dict(request.headers) - result = self.callback(uri, method, body, headers) + result = self.callback(request, request.url, headers) if isinstance(result, basestring): # result is just the response return result else: - # result is a responce, headers tuple - response, headers = result - status = headers.pop('status', None) + # result is a status, headers, response tuple + status, headers, response = result return response, status, headers diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index a75443a11..dece06542 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -1,6 +1,7 @@ import json -from moto.core.utils import headers_to_dict, camelcase_to_underscores +from moto.core.responses import BaseResponse +from moto.core.utils import camelcase_to_underscores from .models import dynamodb_backend, dynamo_json_dump @@ -27,17 +28,11 @@ GET_SESSION_TOKEN_RESULT = """ """ -def sts_handler(uri, method, body, headers): +def sts_handler(): return GET_SESSION_TOKEN_RESULT -class DynamoHandler(object): - - def __init__(self, uri, method, body, headers): - self.uri = uri - self.method = method - self.body = body - self.headers = headers +class DynamoHandler(BaseResponse): def get_endpoint_name(self, headers): """Parses request headers and extracts part od the X-Amz-Target @@ -45,22 +40,35 @@ class DynamoHandler(object): ie: X-Amz-Target: DynamoDB_20111205.ListTables -> ListTables """ - match = headers.get('X-Amz-Target') + # Headers are case-insensitive. Probably a better way to do this. + match = headers.get('x-amz-target') or headers.get('X-Amz-Target') if match: return match.split(".")[1] def error(self, type_, status=400): - return dynamo_json_dump({'__type': type_}), dict(status=400) + return status, self.response_headers, dynamo_json_dump({'__type': type_}) - def dispatch(self): + def call_action(self): + if 'GetSessionToken' in self.body: + return 200, self.response_headers, sts_handler() + + self.body = json.loads(self.body or '{}') endpoint = self.get_endpoint_name(self.headers) if endpoint: endpoint = camelcase_to_underscores(endpoint) - return getattr(self, endpoint)(self.uri, self.method, self.body, self.headers) - else: - return "", dict(status=404) + response = getattr(self, endpoint)() + if isinstance(response, basestring): + return 200, self.response_headers, response - def list_tables(self, uri, method, body, headers): + else: + status_code, new_headers, response_content = response + self.response_headers.update(new_headers) + return status_code, self.response_headers, response_content + else: + return 404, self.response_headers, "" + + def list_tables(self): + body = self.body limit = body.get('Limit') if body.get("ExclusiveStartTableName"): last = body.get("ExclusiveStartTableName") @@ -77,7 +85,8 @@ class DynamoHandler(object): response["LastEvaluatedTableName"] = tables[-1] return dynamo_json_dump(response) - def create_table(self, uri, method, body, headers): + def create_table(self): + body = self.body name = body['TableName'] key_schema = body['KeySchema'] @@ -104,8 +113,8 @@ class DynamoHandler(object): ) return dynamo_json_dump(table.describe) - def delete_table(self, uri, method, body, headers): - name = body['TableName'] + def delete_table(self): + name = self.body['TableName'] table = dynamodb_backend.delete_table(name) if table: return dynamo_json_dump(table.describe) @@ -113,16 +122,16 @@ class DynamoHandler(object): er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' return self.error(er) - def update_table(self, uri, method, body, headers): - name = body['TableName'] - throughput = body["ProvisionedThroughput"] + def update_table(self): + name = self.body['TableName'] + throughput = self.body["ProvisionedThroughput"] new_read_units = throughput["ReadCapacityUnits"] new_write_units = throughput["WriteCapacityUnits"] table = dynamodb_backend.update_table_throughput(name, new_read_units, new_write_units) return dynamo_json_dump(table.describe) - def describe_table(self, uri, method, body, headers): - name = body['TableName'] + def describe_table(self): + name = self.body['TableName'] try: table = dynamodb_backend.tables[name] except KeyError: @@ -130,9 +139,9 @@ class DynamoHandler(object): return self.error(er) return dynamo_json_dump(table.describe) - def put_item(self, uri, method, body, headers): - name = body['TableName'] - item = body['Item'] + def put_item(self): + name = self.body['TableName'] + item = self.body['Item'] result = dynamodb_backend.put_item(name, item) if result: item_dict = result.to_json() @@ -142,8 +151,8 @@ class DynamoHandler(object): er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' return self.error(er) - def batch_write_item(self, uri, method, body, headers): - table_batches = body['RequestItems'] + def batch_write_item(self): + table_batches = self.body['RequestItems'] for table_name, table_requests in table_batches.iteritems(): for table_request in table_requests: @@ -173,12 +182,12 @@ class DynamoHandler(object): return dynamo_json_dump(response) - def get_item(self, uri, method, body, headers): - name = body['TableName'] - key = body['Key'] + def get_item(self): + name = self.body['TableName'] + key = self.body['Key'] hash_key = key['HashKeyElement'] range_key = key.get('RangeKeyElement') - attrs_to_get = body.get('AttributesToGet') + attrs_to_get = self.body.get('AttributesToGet') item = dynamodb_backend.get_item(name, hash_key, range_key) if item: item_dict = item.describe_attrs(attrs_to_get) @@ -188,8 +197,8 @@ class DynamoHandler(object): er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' return self.error(er) - def batch_get_item(self, uri, method, body, headers): - table_batches = body['RequestItems'] + def batch_get_item(self): + table_batches = self.body['RequestItems'] results = { "Responses": { @@ -211,10 +220,10 @@ class DynamoHandler(object): results["Responses"][table_name] = {"Items": items, "ConsumedCapacityUnits": 1} return dynamo_json_dump(results) - def query(self, uri, method, body, headers): - name = body['TableName'] - hash_key = body['HashKeyValue'] - range_condition = body.get('RangeKeyCondition') + def query(self): + name = self.body['TableName'] + hash_key = self.body['HashKeyValue'] + range_condition = self.body.get('RangeKeyCondition') if range_condition: range_comparison = range_condition['ComparisonOperator'] range_values = range_condition['AttributeValueList'] @@ -242,11 +251,11 @@ class DynamoHandler(object): # } return dynamo_json_dump(result) - def scan(self, uri, method, body, headers): - name = body['TableName'] + def scan(self): + name = self.body['TableName'] filters = {} - scan_filters = body.get('ScanFilter', {}) + scan_filters = self.body.get('ScanFilter', {}) for attribute_name, scan_filter in scan_filters.iteritems(): # Keys are attribute names. Values are tuples of (comparison, comparison_value) comparison_operator = scan_filter["ComparisonOperator"] @@ -274,12 +283,12 @@ class DynamoHandler(object): # } return dynamo_json_dump(result) - def delete_item(self, uri, method, body, headers): - name = body['TableName'] - key = body['Key'] + def delete_item(self): + name = self.body['TableName'] + key = self.body['Key'] hash_key = key['HashKeyElement'] range_key = key.get('RangeKeyElement') - return_values = body.get('ReturnValues', '') + return_values = self.body.get('ReturnValues', '') item = dynamodb_backend.delete_item(name, hash_key, range_key) if item: if return_values == 'ALL_OLD': @@ -291,10 +300,3 @@ class DynamoHandler(object): else: er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' return self.error(er) - - -def handler(uri, method, body, headers): - if 'GetSessionToken' in body: - return sts_handler(uri, method, body, headers) - body = json.loads(body or '{}') - return DynamoHandler(uri, method, body, headers_to_dict(headers)).dispatch() diff --git a/moto/dynamodb/urls.py b/moto/dynamodb/urls.py index 85634ef2f..6ed5e00d5 100644 --- a/moto/dynamodb/urls.py +++ b/moto/dynamodb/urls.py @@ -1,4 +1,4 @@ -from .responses import handler +from .responses import DynamoHandler url_bases = [ "https?://dynamodb.(.+).amazonaws.com", @@ -6,5 +6,5 @@ url_bases = [ ] url_paths = { - "{0}/": handler, + "{0}/": DynamoHandler().dispatch, } diff --git a/moto/ec2/responses/__init__.py b/moto/ec2/responses/__init__.py index 0a50797ee..690419438 100644 --- a/moto/ec2/responses/__init__.py +++ b/moto/ec2/responses/__init__.py @@ -1,6 +1,4 @@ -from urlparse import parse_qs - -from moto.core.utils import camelcase_to_underscores, method_names_from_class +from moto.core.responses import BaseResponse from .amazon_dev_pay import AmazonDevPay from .amis import AmisResponse @@ -32,53 +30,35 @@ from .vpn_connections import VPNConnections from .windows import Windows -class EC2Response(object): - - sub_responses = [ - AmazonDevPay, - AmisResponse, - AvailabilityZonesAndRegions, - CustomerGateways, - DHCPOptions, - ElasticBlockStore, - ElasticIPAddresses, - ElasticNetworkInterfaces, - General, - InstanceResponse, - InternetGateways, - IPAddresses, - KeyPairs, - Monitoring, - NetworkACLs, - PlacementGroups, - ReservedInstances, - RouteTables, - SecurityGroups, - SpotInstances, - Subnets, - TagResponse, - VirtualPrivateGateways, - VMExport, - VMImport, - VPCs, - VPNConnections, - Windows, - ] - - def dispatch(self, uri, method, body, headers): - if body: - querystring = parse_qs(body) - else: - querystring = parse_qs(headers) - - action = querystring.get('Action', [None])[0] - if action: - action = camelcase_to_underscores(action) - - for sub_response in self.sub_responses: - method_names = method_names_from_class(sub_response) - if action in method_names: - response = sub_response(querystring) - method = getattr(response, action) - return method() - raise NotImplementedError("The {} action has not been implemented".format(action)) +class EC2Response( + BaseResponse, + AmazonDevPay, + AmisResponse, + AvailabilityZonesAndRegions, + CustomerGateways, + DHCPOptions, + ElasticBlockStore, + ElasticIPAddresses, + ElasticNetworkInterfaces, + General, + InstanceResponse, + InternetGateways, + IPAddresses, + KeyPairs, + Monitoring, + NetworkACLs, + PlacementGroups, + ReservedInstances, + RouteTables, + SecurityGroups, + SpotInstances, + Subnets, + TagResponse, + VirtualPrivateGateways, + VMExport, + VMImport, + VPCs, + VPNConnections, + Windows, +): + pass diff --git a/moto/ec2/responses/amis.py b/moto/ec2/responses/amis.py index afce0bbb5..feddc89f1 100644 --- a/moto/ec2/responses/amis.py +++ b/moto/ec2/responses/amis.py @@ -5,14 +5,11 @@ from moto.ec2.utils import instance_ids_from_querystring class AmisResponse(object): - def __init__(self, querystring): - self.querystring = querystring - self.instance_ids = instance_ids_from_querystring(querystring) - def create_image(self): name = self.querystring.get('Name')[0] description = self.querystring.get('Description')[0] - instance_id = self.instance_ids[0] + instance_ids = instance_ids_from_querystring(self.querystring) + instance_id = instance_ids[0] image = ec2_backend.create_image(instance_id, name, description) if not image: return "There is not instance with id {}".format(instance_id), dict(status=404) diff --git a/moto/ec2/responses/availability_zones_and_regions.py b/moto/ec2/responses/availability_zones_and_regions.py index 4faeda764..f216a644f 100644 --- a/moto/ec2/responses/availability_zones_and_regions.py +++ b/moto/ec2/responses/availability_zones_and_regions.py @@ -4,9 +4,6 @@ from moto.ec2.models import ec2_backend class AvailabilityZonesAndRegions(object): - def __init__(self, querystring): - self.querystring = querystring - def describe_availability_zones(self): zones = ec2_backend.describe_availability_zones() template = Template(DESCRIBE_ZONES_RESPONSE) diff --git a/moto/ec2/responses/elastic_block_store.py b/moto/ec2/responses/elastic_block_store.py index bdea18188..d81c61c9d 100644 --- a/moto/ec2/responses/elastic_block_store.py +++ b/moto/ec2/responses/elastic_block_store.py @@ -4,9 +4,6 @@ from moto.ec2.models import ec2_backend class ElasticBlockStore(object): - def __init__(self, querystring): - self.querystring = querystring - def attach_volume(self): volume_id = self.querystring.get('VolumeId')[0] instance_id = self.querystring.get('InstanceId')[0] diff --git a/moto/ec2/responses/general.py b/moto/ec2/responses/general.py index ad133a30c..5353bb99a 100644 --- a/moto/ec2/responses/general.py +++ b/moto/ec2/responses/general.py @@ -5,11 +5,8 @@ from moto.ec2.utils import instance_ids_from_querystring class General(object): - def __init__(self, querystring): - self.querystring = querystring - self.instance_ids = instance_ids_from_querystring(querystring) - def get_console_output(self): + self.instance_ids = instance_ids_from_querystring(self.querystring) instance_id = self.instance_ids[0] instance = ec2_backend.get_instance(instance_id) if instance: diff --git a/moto/ec2/responses/instances.py b/moto/ec2/responses/instances.py index 7c7c9d725..7170a0928 100644 --- a/moto/ec2/responses/instances.py +++ b/moto/ec2/responses/instances.py @@ -6,10 +6,6 @@ from moto.ec2.utils import instance_ids_from_querystring class InstanceResponse(object): - def __init__(self, querystring): - self.querystring = querystring - self.instance_ids = instance_ids_from_querystring(querystring) - def describe_instances(self): template = Template(EC2_DESCRIBE_INSTANCES) return template.render(reservations=ec2_backend.all_reservations()) @@ -22,22 +18,26 @@ class InstanceResponse(object): return template.render(reservation=new_reservation) def terminate_instances(self): - instances = ec2_backend.terminate_instances(self.instance_ids) + instance_ids = instance_ids_from_querystring(self.querystring) + instances = ec2_backend.terminate_instances(instance_ids) template = Template(EC2_TERMINATE_INSTANCES) return template.render(instances=instances) def reboot_instances(self): - instances = ec2_backend.reboot_instances(self.instance_ids) + instance_ids = instance_ids_from_querystring(self.querystring) + instances = ec2_backend.reboot_instances(instance_ids) template = Template(EC2_REBOOT_INSTANCES) return template.render(instances=instances) def stop_instances(self): - instances = ec2_backend.stop_instances(self.instance_ids) + instance_ids = instance_ids_from_querystring(self.querystring) + instances = ec2_backend.stop_instances(instance_ids) template = Template(EC2_STOP_INSTANCES) return template.render(instances=instances) def start_instances(self): - instances = ec2_backend.start_instances(self.instance_ids) + instance_ids = instance_ids_from_querystring(self.querystring) + instances = ec2_backend.start_instances(instance_ids) template = Template(EC2_START_INSTANCES) return template.render(instances=instances) @@ -45,7 +45,8 @@ class InstanceResponse(object): # TODO this and modify below should raise IncorrectInstanceState if instance not in stopped state attribute = self.querystring.get("Attribute")[0] key = camelcase_to_underscores(attribute) - instance_id = self.instance_ids[0] + instance_ids = instance_ids_from_querystring(self.querystring) + instance_id = instance_ids[0] instance, value = ec2_backend.describe_instance_attribute(instance_id, key) template = Template(EC2_DESCRIBE_INSTANCE_ATTRIBUTE) return template.render(instance=instance, attribute=attribute, value=value) @@ -57,7 +58,8 @@ class InstanceResponse(object): value = self.querystring.get(key)[0] normalized_attribute = camelcase_to_underscores(key.split(".")[0]) - instance_id = self.instance_ids[0] + instance_ids = instance_ids_from_querystring(self.querystring) + instance_id = instance_ids[0] ec2_backend.modify_instance_attribute(instance_id, normalized_attribute, value) return EC2_MODIFY_INSTANCE_ATTRIBUTE diff --git a/moto/ec2/responses/security_groups.py b/moto/ec2/responses/security_groups.py index 2768494a8..1b40e182f 100644 --- a/moto/ec2/responses/security_groups.py +++ b/moto/ec2/responses/security_groups.py @@ -1,7 +1,6 @@ from jinja2 import Template from moto.ec2.models import ec2_backend -from moto.ec2.utils import resource_ids_from_querystring def process_rules_from_querystring(querystring): @@ -22,9 +21,6 @@ def process_rules_from_querystring(querystring): class SecurityGroups(object): - def __init__(self, querystring): - self.querystring = querystring - def authorize_security_group_egress(self): raise NotImplementedError('SecurityGroups.authorize_security_group_egress is not yet implemented') diff --git a/moto/ec2/responses/subnets.py b/moto/ec2/responses/subnets.py index 97a5da287..761f492e5 100644 --- a/moto/ec2/responses/subnets.py +++ b/moto/ec2/responses/subnets.py @@ -4,9 +4,6 @@ from moto.ec2.models import ec2_backend class Subnets(object): - def __init__(self, querystring): - self.querystring = querystring - def create_subnet(self): vpc_id = self.querystring.get('VpcId')[0] cidr_block = self.querystring.get('CidrBlock')[0] diff --git a/moto/ec2/responses/tags.py b/moto/ec2/responses/tags.py index 18478e9a5..dd8dce8e8 100644 --- a/moto/ec2/responses/tags.py +++ b/moto/ec2/responses/tags.py @@ -5,17 +5,16 @@ from moto.ec2.utils import resource_ids_from_querystring class TagResponse(object): - def __init__(self, querystring): - self.querystring = querystring - self.resource_ids = resource_ids_from_querystring(querystring) def create_tags(self): - for resource_id, tag in self.resource_ids.iteritems(): + resource_ids = resource_ids_from_querystring(self.querystring) + for resource_id, tag in resource_ids.iteritems(): ec2_backend.create_tag(resource_id, tag[0], tag[1]) return CREATE_RESPONSE def delete_tags(self): - for resource_id, tag in self.resource_ids.iteritems(): + resource_ids = resource_ids_from_querystring(self.querystring) + for resource_id, tag in resource_ids.iteritems(): ec2_backend.delete_tag(resource_id, tag[0]) template = Template(DELETE_RESPONSE) return template.render(reservations=ec2_backend.all_reservations()) diff --git a/moto/ec2/responses/vpcs.py b/moto/ec2/responses/vpcs.py index 857b9b2bb..c2b16f9cd 100644 --- a/moto/ec2/responses/vpcs.py +++ b/moto/ec2/responses/vpcs.py @@ -4,9 +4,6 @@ from moto.ec2.models import ec2_backend class VPCs(object): - def __init__(self, querystring): - self.querystring = querystring - def create_vpc(self): cidr_block = self.querystring.get('CidrBlock')[0] vpc = ec2_backend.create_vpc(cidr_block) diff --git a/moto/packages/__init__.py b/moto/packages/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moto/packages/httpretty.py b/moto/packages/httpretty.py deleted file mode 100644 index ebd69e4ed..000000000 --- a/moto/packages/httpretty.py +++ /dev/null @@ -1,944 +0,0 @@ -# #!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (C) <2011-2013> Gabriel Falcão -# -# Permission is hereby granted, free of charge, to any person -# obtaining a copy of this software and associated documentation -# files (the "Software"), to deal in the Software without -# restriction, including without limitation the rights to use, -# copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following -# conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR -# OTHER DEALINGS IN THE SOFTWARE. -from __future__ import unicode_literals - -version = '0.5.12' - -import re -import inspect -import socket -import functools -import itertools -import warnings -import logging -import sys -import traceback -import types - -PY3 = sys.version_info[0] == 3 -if PY3: - text_type = str - byte_type = bytes - basestring = (str, bytes) - - import io - StringIO = io.BytesIO - - class Py3kObject(object): - def __repr__(self): - return self.__str__() -else: - text_type = unicode - byte_type = str - import StringIO - StringIO = StringIO.StringIO - - -class Py3kObject(object): - def __repr__(self): - ret = self.__str__() - if PY3: - return ret - else: - ret.encode('utf-8') - -from datetime import datetime -from datetime import timedelta -try: - from urllib.parse import urlsplit, urlunsplit, parse_qs, quote, quote_plus -except ImportError: - from urlparse import urlsplit, urlunsplit, parse_qs - from urllib import quote, quote_plus - -try: - from http.server import BaseHTTPRequestHandler -except ImportError: - from BaseHTTPServer import BaseHTTPRequestHandler - -old_socket = socket.socket -old_create_connection = socket.create_connection -old_gethostbyname = socket.gethostbyname -old_gethostname = socket.gethostname -old_getaddrinfo = socket.getaddrinfo -old_socksocket = None -old_ssl_wrap_socket = None -old_sslwrap_simple = None -old_sslsocket = None - -try: - import socks - old_socksocket = socks.socksocket -except ImportError: - socks = None - -try: - import ssl - old_ssl_wrap_socket = ssl.wrap_socket - if not PY3: - old_sslwrap_simple = ssl.sslwrap_simple - old_sslsocket = ssl.SSLSocket -except ImportError: - ssl = None - - -ClassTypes = (type,) -if not PY3: - ClassTypes = (type, types.ClassType) - - -POTENTIAL_HTTP_PORTS = [80, 443] - - -class HTTPrettyError(Exception): - pass - - -def utf8(s): - if isinstance(s, text_type): - s = s.encode('utf-8') - - return byte_type(s) - - -def decode_utf8(s): - if isinstance(s, byte_type): - s = s.decode("utf-8") - - return text_type(s) - - -def parse_requestline(s): - """ - http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5 - - >>> parse_requestline('GET / HTTP/1.0') - ('GET', '/', '1.0') - >>> parse_requestline('post /testurl htTP/1.1') - ('POST', '/testurl', '1.1') - >>> parse_requestline('Im not a RequestLine') - Traceback (most recent call last): - ... - ValueError: Not a Request-Line - """ - methods = b'|'.join(HTTPretty.METHODS) - m = re.match(br'(' + methods + b')\s+(.*)\s+HTTP/(1.[0|1])', s, re.I) - if m: - return m.group(1).upper(), m.group(2), m.group(3) - else: - raise ValueError('Not a Request-Line') - - -class HTTPrettyRequest(BaseHTTPRequestHandler, Py3kObject): - def __init__(self, headers, body=''): - self.body = utf8(body) - self.raw_headers = utf8(headers) - self.client_address = ['10.0.0.1'] - self.rfile = StringIO(b'\r\n\r\n'.join([headers.strip(), body])) - self.wfile = StringIO() - self.raw_requestline = self.rfile.readline() - self.error_code = self.error_message = None - self.parse_request() - self.method = self.command - self.querystring = parse_qs(self.path.split("?", 1)[-1]) - - def __str__(self): - return 'HTTPrettyRequest(headers={0}, body="{1}")'.format( - self.headers, - self.body, - ) - - -class EmptyRequestHeaders(dict): - pass - - -class HTTPrettyRequestEmpty(object): - body = '' - headers = EmptyRequestHeaders() - - -class FakeSockFile(StringIO): - pass - - -class FakeSSLSocket(object): - def __init__(self, sock, *args, **kw): - self._httpretty_sock = sock - - def __getattr__(self, attr): - if attr == '_httpretty_sock': - return super(FakeSSLSocket, self).__getattribute__(attr) - - return getattr(self._httpretty_sock, attr) - - -class fakesock(object): - class socket(object): - _entry = None - debuglevel = 0 - _sent_data = [] - - def __init__(self, family, type, protocol=6): - self.setsockopt(family, type, protocol) - self.truesock = old_socket(family, type, protocol) - self._closed = True - self.fd = FakeSockFile() - self.timeout = socket._GLOBAL_DEFAULT_TIMEOUT - self._sock = self - self.is_http = False - - def getpeercert(self, *a, **kw): - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - 'notAfter': shift.strftime('%b %d %H:%M:%S GMT'), - 'subjectAltName': ( - ('DNS', '*%s' % self._host), - ('DNS', self._host), - ('DNS', '*'), - ), - 'subject': ( - ( - ('organizationName', u'*.%s' % self._host), - ), - ( - ('organizationalUnitName', - u'Domain Control Validated'), - ), - ( - ('commonName', u'*.%s' % self._host), - ), - ), - } - - def ssl(self, sock, *args, **kw): - return sock - - def setsockopt(self, family, type, protocol): - self.family = family - self.protocol = protocol - self.type = type - - def connect(self, address): - self._address = (self._host, self._port) = address - self._closed = False - self.is_http = self._port in POTENTIAL_HTTP_PORTS - if not self.is_http: - self.truesock.connect(self._address) - - def close(self): - if not self._closed: - self.truesock.close() - self._closed = True - - def makefile(self, mode='r', bufsize=-1): - self._mode = mode - self._bufsize = bufsize - - if self._entry: - self._entry.fill_filekind(self.fd, self._request) - - return self.fd - - def _true_sendall(self, data, *args, **kw): - if self.is_http: - self.truesock.connect(self._address) - - self.truesock.sendall(data, *args, **kw) - - _d = True - while _d: - try: - _d = self.truesock.recv(16) - self.truesock.settimeout(0.0) - self.fd.write(_d) - - except socket.error: - break - - self.fd.seek(0) - - def sendall(self, data, *args, **kw): - - self._sent_data.append(data) - hostnames = [getattr(i.info, 'hostname', None) for i in HTTPretty._entries.keys()] - self.fd.seek(0) - try: - requestline, _ = data.split(b'\r\n', 1) - method, path, version = parse_requestline(requestline) - is_parsing_headers = True - except ValueError: - is_parsing_headers = False - - if not is_parsing_headers: - if len(self._sent_data) > 1: - headers, body = map(utf8, self._sent_data[-2:]) - - method, path, version = parse_requestline(headers) - split_url = urlsplit(path) - - info = URIInfo(hostname=self._host, port=self._port, - path=split_url.path, - query=split_url.query) - - # If we are sending more data to a dynamic response entry, - # we need to call the method again. - if self._entry and self._entry.dynamic_response: - self._entry.body(info, method, body, headers) - - try: - return HTTPretty.historify_request(headers, body, False) - - except Exception as e: - logging.error(traceback.format_exc(e)) - return self._true_sendall(data, *args, **kw) - - # path might come with - s = urlsplit(path) - POTENTIAL_HTTP_PORTS.append(int(s.port or 80)) - headers, body = map(utf8, data.split(b'\r\n\r\n', 1)) - - request = HTTPretty.historify_request(headers, body) - - info = URIInfo(hostname=self._host, port=self._port, - path=s.path, - query=s.query, - last_request=request) - - entries = [] - - for matcher, value in HTTPretty._entries.items(): - if matcher.matches(info): - entries = value - break - - if not entries: - self._true_sendall(data) - return - - self._entry = matcher.get_next_entry(method) - self._request = (info, body, headers) - - def debug(*a, **kw): - frame = inspect.stack()[0][0] - lines = map(utf8, traceback.format_stack(frame)) - - message = [ - "HTTPretty intercepted and unexpected socket method call.", - ("Please open an issue at " - "'https://github.com/gabrielfalcao/HTTPretty/issues'"), - "And paste the following traceback:\n", - "".join(decode_utf8(lines)), - ] - raise RuntimeError("\n".join(message)) - - def settimeout(self, new_timeout): - self.timeout = new_timeout - - sendto = send = recvfrom_into = recv_into = recvfrom = recv = debug - - -def fake_wrap_socket(s, *args, **kw): - return s - - -def create_fake_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None): - s = fakesock.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) - if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: - s.settimeout(timeout) - if source_address: - s.bind(source_address) - s.connect(address) - return s - - -def fake_gethostbyname(host): - return host - - -def fake_gethostname(): - return 'localhost' - - -def fake_getaddrinfo( - host, port, family=None, socktype=None, proto=None, flags=None): - return [(2, 1, 6, '', (host, port))] - - -STATUSES = { - 100: "Continue", - 101: "Switching Protocols", - 102: "Processing", - 200: "OK", - 201: "Created", - 202: "Accepted", - 203: "Non-Authoritative Information", - 204: "No Content", - 205: "Reset Content", - 206: "Partial Content", - 207: "Multi-Status", - 208: "Already Reported", - 226: "IM Used", - 300: "Multiple Choices", - 301: "Moved Permanently", - 302: "Found", - 303: "See Other", - 304: "Not Modified", - 305: "Use Proxy", - 306: "Switch Proxy", - 307: "Temporary Redirect", - 308: "Permanent Redirect", - 400: "Bad Request", - 401: "Unauthorized", - 402: "Payment Required", - 403: "Forbidden", - 404: "Not Found", - 405: "Method Not Allowed", - 406: "Not Acceptable", - 407: "Proxy Authentication Required", - 408: "Request a Timeout", - 409: "Conflict", - 410: "Gone", - 411: "Length Required", - 412: "Precondition Failed", - 413: "Request Entity Too Large", - 414: "Request-URI Too Long", - 415: "Unsupported Media Type", - 416: "Requested Range Not Satisfiable", - 417: "Expectation Failed", - 418: "I'm a teapot", - 420: "Enhance Your Calm", - 422: "Unprocessable Entity", - 423: "Locked", - 424: "Failed Dependency", - 424: "Method Failure", - 425: "Unordered Collection", - 426: "Upgrade Required", - 428: "Precondition Required", - 429: "Too Many Requests", - 431: "Request Header Fields Too Large", - 444: "No Response", - 449: "Retry With", - 450: "Blocked by Windows Parental Controls", - 451: "Unavailable For Legal Reasons", - 451: "Redirect", - 494: "Request Header Too Large", - 495: "Cert Error", - 496: "No Cert", - 497: "HTTP to HTTPS", - 499: "Client Closed Request", - 500: "Internal Server Error", - 501: "Not Implemented", - 502: "Bad Gateway", - 503: "Service Unavailable", - 504: "Gateway Timeout", - 505: "HTTP Version Not Supported", - 506: "Variant Also Negotiates", - 507: "Insufficient Storage", - 508: "Loop Detected", - 509: "Bandwidth Limit Exceeded", - 510: "Not Extended", - 511: "Network Authentication Required", - 598: "Network read timeout error", - 599: "Network connect timeout error", -} - - -class Entry(Py3kObject): - def __init__(self, method, uri, body, - adding_headers=None, - forcing_headers=None, - status=200, - streaming=False, - **headers): - - self.method = method - self.uri = uri - - if callable(body): - self.dynamic_response = True - else: - self.dynamic_response = False - - self.body = body - self.streaming = streaming - - if self.dynamic_response or self.streaming: - self.body_length = 0 - else: - self.body_length = len(self.body or '') - - self.adding_headers = adding_headers or {} - self.forcing_headers = forcing_headers or {} - self.status = int(status) - - for k, v in headers.items(): - name = "-".join(k.split("_")).capitalize() - self.adding_headers[name] = v - - self.validate() - - def validate(self): - content_length_keys = 'Content-Length', 'content-length' - for key in content_length_keys: - got = self.adding_headers.get( - key, self.forcing_headers.get(key, None)) - - if got is None: - continue - - try: - igot = int(got) - except ValueError: - warnings.warn( - 'HTTPretty got to register the Content-Length header ' \ - 'with "%r" which is not a number' % got, - ) - - if igot > self.body_length: - raise HTTPrettyError( - 'HTTPretty got inconsistent parameters. The header ' \ - 'Content-Length you registered expects size "%d" but ' \ - 'the body you registered for that has actually length ' \ - '"%d".' % ( - igot, self.body_length, - ) - ) - - def __str__(self): - return r'' % ( - self.method, self.uri, self.status) - - def normalize_headers(self, headers): - new = {} - for k in headers: - new_k = '-'.join([s.lower() for s in k.split('-')]) - new[new_k] = headers[k] - - return new - - def fill_filekind(self, fk, request): - now = datetime.utcnow() - - headers = { - 'status': self.status, - 'date': now.strftime('%a, %d %b %Y %H:%M:%S GMT'), - 'server': 'Python/HTTPretty', - 'connection': 'close', - } - - if self.forcing_headers: - headers = self.forcing_headers - - if self.dynamic_response: - req_info, req_body, req_headers = request - response = self.body(req_info, self.method, req_body, req_headers) - if isinstance(response, basestring): - body = response - else: - body, new_headers = response - headers.update(new_headers) - else: - body = self.body - - if self.adding_headers: - headers.update(self.normalize_headers(self.adding_headers)) - - headers = self.normalize_headers(headers) - - status = headers.get('status', self.status) - string_list = [ - 'HTTP/1.1 %d %s' % (status, STATUSES[status]), - ] - - if 'date' in headers: - string_list.append('date: %s' % headers.pop('date')) - - if not self.forcing_headers: - content_type = headers.pop('content-type', - 'text/plain; charset=utf-8') - - body_length = self.body_length - if self.dynamic_response: - body_length = len(body) - content_length = headers.pop('content-length', body_length) - - string_list.append('content-type: %s' % content_type) - if not self.streaming: - string_list.append('content-length: %s' % content_length) - - string_list.append('server: %s' % headers.pop('server')) - - for k, v in headers.items(): - string_list.append( - '{0}: {1}'.format(k, v), - ) - - for item in string_list: - fk.write(utf8(item) + b'\n') - - fk.write(b'\r\n') - - if self.streaming: - self.body, body = itertools.tee(body) - for chunk in body: - fk.write(utf8(chunk)) - else: - fk.write(utf8(body)) - - fk.seek(0) - - -def url_fix(s, charset='utf-8'): - scheme, netloc, path, querystring, fragment = urlsplit(s) - path = quote(path, b'/%') - querystring = quote_plus(querystring, b':&=') - return urlunsplit((scheme, netloc, path, querystring, fragment)) - - -class URIInfo(Py3kObject): - def __init__(self, - username='', - password='', - hostname='', - port=80, - path='/', - query='', - fragment='', - scheme='', - last_request=None): - - self.username = username or '' - self.password = password or '' - self.hostname = hostname or '' - - if port: - port = int(port) - - elif scheme == 'https': - port = 443 - - self.port = port or 80 - self.path = path or '' - self.query = query or '' - self.scheme = scheme or (self.port is 80 and "http" or "https") - self.fragment = fragment or '' - self.last_request = last_request - - def __str__(self): - attrs = ( - 'username', - 'password', - 'hostname', - 'port', - 'path', - ) - fmt = ", ".join(['%s="%s"' % (k, getattr(self, k, '')) for k in attrs]) - return r'' % fmt - - def __hash__(self): - return hash(text_type(self)) - - def __eq__(self, other): - self_tuple = ( - self.port, - decode_utf8(self.hostname), - url_fix(decode_utf8(self.path)), - ) - other_tuple = ( - other.port, - decode_utf8(other.hostname), - url_fix(decode_utf8(other.path)), - ) - return self_tuple == other_tuple - - def full_url(self): - credentials = "" - if self.password: - credentials = "{0}:{1}@".format( - self.username, self.password) - - result = "{scheme}://{credentials}{host}{path}".format( - scheme=self.scheme, - credentials=credentials, - host=decode_utf8(self.hostname), - path=decode_utf8(self.path) - ) - return result - - @classmethod - def from_uri(cls, uri, entry): - result = urlsplit(uri) - POTENTIAL_HTTP_PORTS.append(int(result.port or 80)) - return cls(result.username, - result.password, - result.hostname, - result.port, - result.path, - result.query, - result.fragment, - result.scheme, - entry) - - -class URIMatcher(object): - regex = None - info = None - - def __init__(self, uri, entries): - if type(uri).__name__ == 'SRE_Pattern': - self.regex = uri - else: - self.info = URIInfo.from_uri(uri, entries) - - self.entries = entries - - #hash of current_entry pointers, per method. - self.current_entries = {} - - def matches(self, info): - if self.info: - return self.info == info - else: - return self.regex.search(info.full_url()) - - def __str__(self): - wrap = 'URLMatcher({0})' - if self.info: - return wrap.format(text_type(self.info)) - else: - return wrap.format(self.regex.pattern) - - def get_next_entry(self, method='GET'): - """Cycle through available responses, but only once. - Any subsequent requests will receive the last response""" - - if method not in self.current_entries: - self.current_entries[method] = 0 - - #restrict selection to entries that match the requested method - entries_for_method = [e for e in self.entries if e.method == method] - - if self.current_entries[method] >= len(entries_for_method): - self.current_entries[method] = -1 - - if not self.entries or not entries_for_method: - raise ValueError('I have no entries for method %s: %s' - % (method, self)) - - entry = entries_for_method[self.current_entries[method]] - if self.current_entries[method] != -1: - self.current_entries[method] += 1 - return entry - - def __hash__(self): - return hash(text_type(self)) - - def __eq__(self, other): - return text_type(self) == text_type(other) - - -class HTTPretty(Py3kObject): - u"""The URI registration class""" - _entries = {} - latest_requests = [] - GET = b'GET' - PUT = b'PUT' - POST = b'POST' - DELETE = b'DELETE' - HEAD = b'HEAD' - PATCH = b'PATCH' - METHODS = (GET, PUT, POST, DELETE, HEAD, PATCH) - last_request = HTTPrettyRequestEmpty() - _is_enabled = False - - @classmethod - def reset(cls): - cls._entries.clear() - cls.latest_requests = [] - cls.last_request = HTTPrettyRequestEmpty() - - @classmethod - def historify_request(cls, headers, body='', append=True): - request = HTTPrettyRequest(headers, body) - cls.last_request = request - if append: - cls.latest_requests.append(request) - else: - cls.latest_requests[-1] = request - return request - - @classmethod - def register_uri(cls, method, uri, body='HTTPretty :)', - adding_headers=None, - forcing_headers=None, - status=200, - responses=None, **headers): - - if isinstance(responses, list) and len(responses) > 0: - for response in responses: - response.uri = uri - response.method = method - entries_for_this_uri = responses - else: - headers['body'] = body - headers['adding_headers'] = adding_headers - headers['forcing_headers'] = forcing_headers - headers['status'] = status - - entries_for_this_uri = [ - cls.Response(method=method, uri=uri, **headers), - ] - - matcher = URIMatcher(uri, entries_for_this_uri) - if matcher in cls._entries: - matcher.entries.extend(cls._entries[matcher]) - del cls._entries[matcher] - - cls._entries[matcher] = entries_for_this_uri - - def __str__(self): - return u'' % len(self._entries) - - @classmethod - def Response(cls, body, method=None, uri=None, adding_headers=None, forcing_headers=None, - status=200, streaming=False, **headers): - - headers['body'] = body - headers['adding_headers'] = adding_headers - headers['forcing_headers'] = forcing_headers - headers['status'] = int(status) - headers['streaming'] = streaming - return Entry(method, uri, **headers) - - @classmethod - def disable(cls): - cls._is_enabled = False - socket.socket = old_socket - socket.SocketType = old_socket - socket._socketobject = old_socket - - socket.create_connection = old_create_connection - socket.gethostname = old_gethostname - socket.gethostbyname = old_gethostbyname - socket.getaddrinfo = old_getaddrinfo - socket.inet_aton = old_gethostbyname - - socket.__dict__['socket'] = old_socket - socket.__dict__['_socketobject'] = old_socket - socket.__dict__['SocketType'] = old_socket - - socket.__dict__['create_connection'] = old_create_connection - socket.__dict__['gethostname'] = old_gethostname - socket.__dict__['gethostbyname'] = old_gethostbyname - socket.__dict__['getaddrinfo'] = old_getaddrinfo - socket.__dict__['inet_aton'] = old_gethostbyname - - if socks: - socks.socksocket = old_socksocket - socks.__dict__['socksocket'] = old_socksocket - - if ssl: - ssl.wrap_socket = old_ssl_wrap_socket - ssl.SSLSocket = old_sslsocket - ssl.__dict__['wrap_socket'] = old_ssl_wrap_socket - ssl.__dict__['SSLSocket'] = old_sslsocket - - if not PY3: - ssl.sslwrap_simple = old_sslwrap_simple - ssl.__dict__['sslwrap_simple'] = old_sslwrap_simple - - @classmethod - def is_enabled(cls): - return cls._is_enabled - - @classmethod - def enable(cls): - cls._is_enabled = True - socket.socket = fakesock.socket - socket._socketobject = fakesock.socket - socket.SocketType = fakesock.socket - - socket.create_connection = create_fake_connection - socket.gethostname = fake_gethostname - socket.gethostbyname = fake_gethostbyname - socket.getaddrinfo = fake_getaddrinfo - socket.inet_aton = fake_gethostbyname - - socket.__dict__['socket'] = fakesock.socket - socket.__dict__['_socketobject'] = fakesock.socket - socket.__dict__['SocketType'] = fakesock.socket - - socket.__dict__['create_connection'] = create_fake_connection - socket.__dict__['gethostname'] = fake_gethostname - socket.__dict__['gethostbyname'] = fake_gethostbyname - socket.__dict__['inet_aton'] = fake_gethostbyname - socket.__dict__['getaddrinfo'] = fake_getaddrinfo - - if socks: - socks.socksocket = fakesock.socket - socks.__dict__['socksocket'] = fakesock.socket - - if ssl: - ssl.wrap_socket = fake_wrap_socket - ssl.SSLSocket = FakeSSLSocket - - ssl.__dict__['wrap_socket'] = fake_wrap_socket - ssl.__dict__['SSLSocket'] = FakeSSLSocket - - if not PY3: - ssl.sslwrap_simple = fake_wrap_socket - ssl.__dict__['sslwrap_simple'] = fake_wrap_socket - - -def httprettified(test): - "A decorator tests that use HTTPretty" - def decorate_class(klass): - for attr in dir(klass): - if not attr.startswith('test_'): - continue - - attr_value = getattr(klass, attr) - if not hasattr(attr_value, "__call__"): - continue - - setattr(klass, attr, decorate_callable(attr_value)) - return klass - - def decorate_callable(test): - @functools.wraps(test) - def wrapper(*args, **kw): - HTTPretty.reset() - HTTPretty.enable() - try: - return test(*args, **kw) - finally: - HTTPretty.disable() - return wrapper - - if isinstance(test, ClassTypes): - return decorate_class(test) - return decorate_callable(test) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index e5a2bed65..974b2dc49 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -1,10 +1,10 @@ -from urlparse import parse_qs +from urlparse import parse_qs, urlparse from jinja2 import Template from .models import s3_backend from moto.core.utils import headers_to_dict -from .utils import bucket_name_from_hostname +from .utils import bucket_name_from_url def all_buckets(): @@ -14,11 +14,23 @@ def all_buckets(): return template.render(buckets=all_buckets) -def bucket_response(uri, method, body, headers): - hostname = uri.hostname - querystring = parse_qs(uri.query) +def bucket_response(request, full_url, headers): + headers = headers_to_dict(headers) + response = _bucket_response(request, full_url, headers) + if isinstance(response, basestring): + return 200, headers, response - bucket_name = bucket_name_from_hostname(hostname) + else: + status_code, headers, response_content = response + return status_code, headers, response_content + + +def _bucket_response(request, full_url, headers): + parsed_url = urlparse(full_url) + querystring = parse_qs(parsed_url.query) + method = request.method + + bucket_name = bucket_name_from_url(full_url) if not bucket_name: # If no bucket specified, list all buckets return all_buckets() @@ -38,7 +50,7 @@ def bucket_response(uri, method, body, headers): result_folders=result_folders ) else: - return "", dict(status=404) + return 404, headers, "" elif method == 'PUT': new_bucket = s3_backend.create_bucket(bucket_name) template = Template(S3_BUCKET_CREATE_RESPONSE) @@ -48,37 +60,53 @@ def bucket_response(uri, method, body, headers): if removed_bucket is None: # Non-existant bucket template = Template(S3_DELETE_NON_EXISTING_BUCKET) - return template.render(bucket_name=bucket_name), dict(status=404) + return 404, headers, template.render(bucket_name=bucket_name) elif removed_bucket: # Bucket exists template = Template(S3_DELETE_BUCKET_SUCCESS) - return template.render(bucket=removed_bucket), dict(status=204) + return 204, headers, template.render(bucket=removed_bucket) else: # Tried to delete a bucket that still has keys template = Template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR) - return template.render(bucket=removed_bucket), dict(status=409) + return 409, headers, template.render(bucket=removed_bucket) else: raise NotImplementedError("Method {} has not been impelemented in the S3 backend yet".format(method)) -def key_response(uri_info, method, body, headers): - - key_name = uri_info.path.lstrip('/') - hostname = uri_info.hostname +def key_response(request, full_url, headers): headers = headers_to_dict(headers) - bucket_name = bucket_name_from_hostname(hostname) + response = _key_response(request, full_url, headers) + if isinstance(response, basestring): + return 200, headers, response + else: + status_code, headers, response_content = response + return status_code, headers, response_content + + +def _key_response(request, full_url, headers): + parsed_url = urlparse(full_url) + method = request.method + + key_name = parsed_url.path.lstrip('/') + bucket_name = bucket_name_from_url(full_url) + if hasattr(request, 'body'): + # Boto + body = request.body + else: + # Flask server + body = request.data if method == 'GET': key = s3_backend.get_key(bucket_name, key_name) if key: return key.value else: - return "", dict(status=404) + return 404, headers, "" if method == 'PUT': - if 'x-amz-copy-source' in headers: + if 'x-amz-copy-source' in request.headers: # Copy key - src_bucket, src_key = headers.get("x-amz-copy-source").split("/") + src_bucket, src_key = request.headers.get("x-amz-copy-source").split("/") s3_backend.copy_key(src_bucket, src_key, bucket_name, key_name) template = Template(S3_OBJECT_COPY_RESPONSE) return template.render(key=src_key) @@ -92,20 +120,23 @@ def key_response(uri_info, method, body, headers): # empty string as part of closing the connection. new_key = s3_backend.set_key(bucket_name, key_name, body) template = Template(S3_OBJECT_RESPONSE) - return template.render(key=new_key), new_key.response_dict + headers.update(new_key.response_dict) + return 200, headers, template.render(key=new_key) key = s3_backend.get_key(bucket_name, key_name) if key: - return "", key.response_dict + headers.update(key.response_dict) + return 200, headers, "" elif method == 'HEAD': key = s3_backend.get_key(bucket_name, key_name) if key: - return S3_OBJECT_RESPONSE, key.response_dict + headers.update(key.response_dict) + return 200, headers, S3_OBJECT_RESPONSE else: - return "", dict(status=404) + return 404, headers, "" elif method == 'DELETE': removed_key = s3_backend.delete_key(bucket_name, key_name) template = Template(S3_DELETE_OBJECT_SUCCESS) - return template.render(bucket=removed_key), dict(status=204) + return 204, headers, template.render(bucket=removed_key) else: raise NotImplementedError("Method {} has not been impelemented in the S3 backend yet".format(method)) diff --git a/moto/s3/utils.py b/moto/s3/utils.py index d9e5671e9..765303743 100644 --- a/moto/s3/utils.py +++ b/moto/s3/utils.py @@ -5,20 +5,19 @@ import urlparse bucket_name_regex = re.compile("(.+).s3.amazonaws.com") -def bucket_name_from_hostname(hostname): - if 'amazonaws.com' in hostname: - bucket_result = bucket_name_regex.search(hostname) +def bucket_name_from_url(url): + domain = urlparse.urlparse(url).netloc + + # If 'www' prefixed, strip it. + domain = domain.lstrip("www.") + + if 'amazonaws.com' in domain: + bucket_result = bucket_name_regex.search(domain) if bucket_result: return bucket_result.groups()[0] else: - # In server mode. Use left-most part of subdomain for bucket name - split_url = urlparse.urlparse(hostname) - - # If 'www' prefixed, strip it. - clean_hostname = split_url.netloc.lstrip("www.") - - if '.' in clean_hostname: - return clean_hostname.split(".")[0] + if '.' in domain: + return domain.split(".")[0] else: # No subdomain found. return None diff --git a/requirements.txt b/requirements.txt index b3731770e..62f6f0a27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ coverage freezegun -#httpretty mock nose https://github.com/spulec/python-coveralls/tarball/796d9dba34b759664e42ba39e6414209a0f319ad diff --git a/setup.py b/setup.py index fab0a59d1..71244de60 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,11 @@ setup( packages=find_packages(), install_requires=[ "boto", - "Jinja2", "flask", + "httpretty", + "Jinja2", + ], + dependency_links=[ + "https://github.com/gabrielfalcao/HTTPretty/tarball/2347df40a3a3cd00e73f0353f5ea2670ad3405c1", ], ) diff --git a/tests/test_s3/test_s3_utils.py b/tests/test_s3/test_s3_utils.py new file mode 100644 index 000000000..5b03d61fd --- /dev/null +++ b/tests/test_s3/test_s3_utils.py @@ -0,0 +1,14 @@ +from sure import expect +from moto.s3.utils import bucket_name_from_url + + +def test_base_url(): + expect(bucket_name_from_url('https://s3.amazonaws.com/')).should.equal(None) + + +def test_localhost_bucket(): + expect(bucket_name_from_url('https://foobar.localhost:5000/abc')).should.equal("foobar") + + +def test_localhost_without_bucket(): + expect(bucket_name_from_url('https://www.localhost:5000/def')).should.equal(None)