Cleanup Server host parsing.

This commit is contained in:
Steve Pulec 2017-02-20 14:31:19 -05:00
parent d0fe1a0956
commit 51df02e7cf
3 changed files with 25 additions and 16 deletions

View File

@ -3,6 +3,7 @@ from __future__ import absolute_import
import functools
import inspect
import os
import re
from moto.packages.responses import responses
@ -48,7 +49,9 @@ class BaseMockAWS(object):
if self.__class__.nested_count < 0:
raise RuntimeError('Called stop() before start().')
self.disable_patching()
if self.__class__.nested_count == 0:
self.disable_patching()
def decorate_callable(self, func, reset):
def wrapper(*args, **kwargs):
@ -108,9 +111,8 @@ class HttprettyMockAWS(BaseMockAWS):
)
def disable_patching(self):
if self.__class__.nested_count == 0:
HTTPretty.disable()
HTTPretty.reset()
HTTPretty.disable()
HTTPretty.reset()
RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD,
@ -142,14 +144,15 @@ class ResponsesMockAWS(BaseMockAWS):
pattern['stream'] = True
def disable_patching(self):
if self.__class__.nested_count == 0:
try:
responses.stop()
except AttributeError:
pass
responses.reset()
try:
responses.stop()
except AttributeError:
pass
responses.reset()
MockAWS = ResponsesMockAWS
class Model(type):
def __new__(self, clsname, bases, namespace):
cls = super(Model, self).__new__(self, clsname, bases, namespace)

View File

@ -42,8 +42,14 @@ class DomainDispatcherApplication(object):
raise RuntimeError('Invalid host: "%s"' % host)
def get_application(self, host):
host = host.split(':')[0]
def get_application(self, environ):
host = environ['HTTP_HOST'].split(':')[0]
if host == "localhost":
# Fall back to parsing auth header to find service
# ['Credential=sdffdsa', '20170220', 'us-east-1', 'sns', 'aws4_request']
_, _, region, service, _ = environ['HTTP_AUTHORIZATION'].split(",")[0].split()[1].split("/")
host = "{service}.{region}.amazonaws.com".format(service=service, region=region)
with self.lock:
backend = self.get_backend_for_host(host)
app = self.app_instances.get(backend, None)
@ -53,7 +59,7 @@ class DomainDispatcherApplication(object):
return app
def __call__(self, environ, start_response):
backend_app = self.get_application(environ['HTTP_HOST'])
backend_app = self.get_application(environ)
return backend_app(environ, start_response)

View File

@ -32,19 +32,19 @@ def test_port_argument(run_simple):
def test_domain_dispatched():
dispatcher = DomainDispatcherApplication(create_backend_app)
backend_app = dispatcher.get_application("email.us-east1.amazonaws.com")
backend_app = dispatcher.get_application({"HTTP_HOST": "email.us-east1.amazonaws.com"})
keys = list(backend_app.view_functions.keys())
keys[0].should.equal('EmailResponse.dispatch')
def test_domain_without_matches():
dispatcher = DomainDispatcherApplication(create_backend_app)
dispatcher.get_application.when.called_with("not-matching-anything.com").should.throw(RuntimeError)
dispatcher.get_application.when.called_with({"HTTP_HOST": "not-matching-anything.com"}).should.throw(RuntimeError)
def test_domain_dispatched_with_service():
# If we pass a particular service, always return that.
dispatcher = DomainDispatcherApplication(create_backend_app, service="s3")
backend_app = dispatcher.get_application("s3.us-east1.amazonaws.com")
backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"})
keys = set(backend_app.view_functions.keys())
keys.should.contain('ResponseObject.key_response')