Refactor BotocoreStubber/BaseBackend into separate files (#5122)
This commit is contained in:
parent
e49e67aba5
commit
31737bc81e
@ -1,4 +1,5 @@
|
||||
from .models import BaseBackend, get_account_id, ACCOUNT_ID # noqa
|
||||
from .models import get_account_id, ACCOUNT_ID # noqa
|
||||
from .base_backend import BaseBackend # noqa
|
||||
from .common_models import BaseModel # noqa
|
||||
from .common_models import CloudFormationModel, CloudWatchMetricProvider # noqa
|
||||
from .models import patch_client, patch_resource # noqa
|
||||
|
160
moto/core/base_backend.py
Normal file
160
moto/core/base_backend.py
Normal file
@ -0,0 +1,160 @@
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
from collections import defaultdict
|
||||
from .utils import convert_regex_to_flask_path
|
||||
|
||||
|
||||
model_data = defaultdict(dict)
|
||||
|
||||
|
||||
class InstanceTrackerMeta(type):
|
||||
def __new__(meta, name, bases, dct):
|
||||
cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct)
|
||||
if name == "BaseModel":
|
||||
return cls
|
||||
|
||||
service = cls.__module__.split(".")[1]
|
||||
if name not in model_data[service]:
|
||||
model_data[service][name] = cls
|
||||
cls.instances = []
|
||||
return cls
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
def _reset_model_refs(self):
|
||||
# Remove all references to the models stored
|
||||
for models in model_data.values():
|
||||
for model in models.values():
|
||||
model.instances = []
|
||||
|
||||
def reset(self):
|
||||
self._reset_model_refs()
|
||||
self.__dict__ = {}
|
||||
self.__init__()
|
||||
|
||||
@property
|
||||
def _url_module(self):
|
||||
backend_module = self.__class__.__module__
|
||||
backend_urls_module_name = backend_module.replace("models", "urls")
|
||||
backend_urls_module = __import__(
|
||||
backend_urls_module_name, fromlist=["url_bases", "url_paths"]
|
||||
)
|
||||
return backend_urls_module
|
||||
|
||||
@property
|
||||
def urls(self):
|
||||
"""
|
||||
A dictionary of the urls to be mocked with this service and the handlers
|
||||
that should be called in their place
|
||||
"""
|
||||
url_bases = self.url_bases
|
||||
unformatted_paths = self._url_module.url_paths
|
||||
|
||||
urls = {}
|
||||
for url_base in url_bases:
|
||||
# The default URL_base will look like: http://service.[..].amazonaws.com/...
|
||||
# This extension ensures support for the China regions
|
||||
cn_url_base = re.sub(r"amazonaws\\?.com$", "amazonaws.com.cn", url_base)
|
||||
for url_path, handler in unformatted_paths.items():
|
||||
url = url_path.format(url_base)
|
||||
urls[url] = handler
|
||||
cn_url = url_path.format(cn_url_base)
|
||||
urls[cn_url] = handler
|
||||
|
||||
return urls
|
||||
|
||||
@property
|
||||
def url_paths(self):
|
||||
"""
|
||||
A dictionary of the paths of the urls to be mocked with this service and
|
||||
the handlers that should be called in their place
|
||||
"""
|
||||
unformatted_paths = self._url_module.url_paths
|
||||
|
||||
paths = {}
|
||||
for unformatted_path, handler in unformatted_paths.items():
|
||||
path = unformatted_path.format("")
|
||||
paths[path] = handler
|
||||
|
||||
return paths
|
||||
|
||||
@property
|
||||
def url_bases(self):
|
||||
"""
|
||||
A list containing the url_bases extracted from urls.py
|
||||
"""
|
||||
return self._url_module.url_bases
|
||||
|
||||
@property
|
||||
def flask_paths(self):
|
||||
"""
|
||||
The url paths that will be used for the flask server
|
||||
"""
|
||||
paths = {}
|
||||
for url_path, handler in self.url_paths.items():
|
||||
url_path = convert_regex_to_flask_path(url_path)
|
||||
paths[url_path] = handler
|
||||
|
||||
return paths
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service(
|
||||
service_region, zones
|
||||
): # pylint: disable=unused-argument
|
||||
"""Invoke the factory method for any VPC endpoint(s) services."""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def vpce_random_number():
|
||||
"""Return random number for a VPC endpoint service ID."""
|
||||
return "".join([random.choice(string.hexdigits.lower()) for i in range(17)])
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service_factory(
|
||||
service_region,
|
||||
zones,
|
||||
service="",
|
||||
service_type="Interface",
|
||||
private_dns_names=True,
|
||||
special_service_name="",
|
||||
policy_supported=True,
|
||||
base_endpoint_dns_names=None,
|
||||
): # pylint: disable=too-many-arguments
|
||||
"""List of dicts representing default VPC endpoints for this service."""
|
||||
if special_service_name:
|
||||
service_name = f"com.amazonaws.{service_region}.{special_service_name}"
|
||||
else:
|
||||
service_name = f"com.amazonaws.{service_region}.{service}"
|
||||
|
||||
if not base_endpoint_dns_names:
|
||||
base_endpoint_dns_names = [f"{service}.{service_region}.vpce.amazonaws.com"]
|
||||
|
||||
endpoint_service = {
|
||||
"AcceptanceRequired": False,
|
||||
"AvailabilityZones": zones,
|
||||
"BaseEndpointDnsNames": base_endpoint_dns_names,
|
||||
"ManagesVpcEndpoints": False,
|
||||
"Owner": "amazon",
|
||||
"ServiceId": f"vpce-svc-{BaseBackend.vpce_random_number()}",
|
||||
"ServiceName": service_name,
|
||||
"ServiceType": [{"ServiceType": service_type}],
|
||||
"Tags": [],
|
||||
"VpcEndpointPolicySupported": policy_supported,
|
||||
}
|
||||
|
||||
# Don't know how private DNS names are different, so for now just
|
||||
# one will be added.
|
||||
if private_dns_names:
|
||||
endpoint_service[
|
||||
"PrivateDnsName"
|
||||
] = f"{service}.{service_region}.amazonaws.com"
|
||||
endpoint_service["PrivateDnsNameVerificationState"] = "verified"
|
||||
endpoint_service["PrivateDnsNames"] = [
|
||||
{"PrivateDnsName": f"{service}.{service_region}.amazonaws.com"}
|
||||
]
|
||||
return [endpoint_service]
|
||||
|
||||
# def list_config_service_resources(self, resource_ids, resource_name, limit, next_token):
|
||||
# """For AWS Config. This will list all of the resources of the given type and optional resource name and region"""
|
||||
# raise NotImplementedError()
|
65
moto/core/botocore_stubber.py
Normal file
65
moto/core/botocore_stubber.py
Normal file
@ -0,0 +1,65 @@
|
||||
from collections import defaultdict
|
||||
from io import BytesIO
|
||||
from botocore.awsrequest import AWSResponse
|
||||
from moto.core.exceptions import HTTPException
|
||||
|
||||
|
||||
class MockRawResponse(BytesIO):
|
||||
def __init__(self, response_input):
|
||||
if isinstance(response_input, str):
|
||||
response_input = response_input.encode("utf-8")
|
||||
super().__init__(response_input)
|
||||
|
||||
def stream(self, **kwargs): # pylint: disable=unused-argument
|
||||
contents = self.read()
|
||||
while contents:
|
||||
yield contents
|
||||
contents = self.read()
|
||||
|
||||
|
||||
class BotocoreStubber:
|
||||
def __init__(self):
|
||||
self.enabled = False
|
||||
self.methods = defaultdict(list)
|
||||
|
||||
def reset(self):
|
||||
self.methods.clear()
|
||||
|
||||
def register_response(self, method, pattern, response):
|
||||
matchers = self.methods[method]
|
||||
matchers.append((pattern, response))
|
||||
|
||||
def __call__(self, event_name, request, **kwargs):
|
||||
if not self.enabled:
|
||||
return None
|
||||
response = None
|
||||
response_callback = None
|
||||
found_index = None
|
||||
matchers = self.methods.get(request.method)
|
||||
|
||||
base_url = request.url.split("?", 1)[0]
|
||||
for i, (pattern, callback) in enumerate(matchers):
|
||||
if pattern.match(base_url):
|
||||
if found_index is None:
|
||||
found_index = i
|
||||
response_callback = callback
|
||||
else:
|
||||
matchers.pop(found_index)
|
||||
break
|
||||
|
||||
if response_callback is not None:
|
||||
for header, value in request.headers.items():
|
||||
if isinstance(value, bytes):
|
||||
request.headers[header] = value.decode("utf-8")
|
||||
try:
|
||||
status, headers, body = response_callback(
|
||||
request, request.url, request.headers
|
||||
)
|
||||
except HTTPException as e:
|
||||
status = e.code
|
||||
headers = e.get_headers()
|
||||
body = e.get_body()
|
||||
body = MockRawResponse(body)
|
||||
response = AWSResponse(request.url, status, headers, body)
|
||||
|
||||
return response
|
@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from .models import InstanceTrackerMeta
|
||||
from .base_backend import InstanceTrackerMeta
|
||||
|
||||
|
||||
class BaseModel(metaclass=InstanceTrackerMeta):
|
||||
|
@ -2,31 +2,26 @@ import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from io import BytesIO
|
||||
from types import FunctionType
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
import responses
|
||||
from botocore.awsrequest import AWSResponse
|
||||
from botocore.config import Config
|
||||
from botocore.handlers import BUILTIN_HANDLERS
|
||||
|
||||
from moto import settings
|
||||
from moto.core.exceptions import HTTPException
|
||||
from .botocore_stubber import BotocoreStubber
|
||||
from .custom_responses_mock import (
|
||||
get_response_mock,
|
||||
CallbackResponse,
|
||||
not_implemented_callback,
|
||||
reset_responses_mock,
|
||||
)
|
||||
from .utils import convert_regex_to_flask_path, convert_flask_to_responses_response
|
||||
from .utils import convert_flask_to_responses_response
|
||||
|
||||
ACCOUNT_ID = os.environ.get("MOTO_ACCOUNT_ID", "123456789012")
|
||||
|
||||
@ -244,67 +239,6 @@ responses_mock = get_response_mock()
|
||||
BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"]
|
||||
|
||||
|
||||
class MockRawResponse(BytesIO):
|
||||
def __init__(self, response_input):
|
||||
if isinstance(response_input, str):
|
||||
response_input = response_input.encode("utf-8")
|
||||
super().__init__(response_input)
|
||||
|
||||
def stream(self, **kwargs): # pylint: disable=unused-argument
|
||||
contents = self.read()
|
||||
while contents:
|
||||
yield contents
|
||||
contents = self.read()
|
||||
|
||||
|
||||
class BotocoreStubber:
|
||||
def __init__(self):
|
||||
self.enabled = False
|
||||
self.methods = defaultdict(list)
|
||||
|
||||
def reset(self):
|
||||
self.methods.clear()
|
||||
|
||||
def register_response(self, method, pattern, response):
|
||||
matchers = self.methods[method]
|
||||
matchers.append((pattern, response))
|
||||
|
||||
def __call__(self, event_name, request, **kwargs):
|
||||
if not self.enabled:
|
||||
return None
|
||||
response = None
|
||||
response_callback = None
|
||||
found_index = None
|
||||
matchers = self.methods.get(request.method)
|
||||
|
||||
base_url = request.url.split("?", 1)[0]
|
||||
for i, (pattern, callback) in enumerate(matchers):
|
||||
if pattern.match(base_url):
|
||||
if found_index is None:
|
||||
found_index = i
|
||||
response_callback = callback
|
||||
else:
|
||||
matchers.pop(found_index)
|
||||
break
|
||||
|
||||
if response_callback is not None:
|
||||
for header, value in request.headers.items():
|
||||
if isinstance(value, bytes):
|
||||
request.headers[header] = value.decode("utf-8")
|
||||
try:
|
||||
status, headers, body = response_callback(
|
||||
request, request.url, request.headers
|
||||
)
|
||||
except HTTPException as e:
|
||||
status = e.code
|
||||
headers = e.get_headers()
|
||||
body = e.get_body()
|
||||
body = MockRawResponse(body)
|
||||
response = AWSResponse(request.url, status, headers, body)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
botocore_stubber = BotocoreStubber()
|
||||
BUILTIN_HANDLERS.append(("before-send", botocore_stubber))
|
||||
|
||||
@ -455,184 +389,6 @@ class ServerModeMockAWS(BaseMockAWS):
|
||||
self._resource_patcher.stop()
|
||||
|
||||
|
||||
class Model(type):
|
||||
def __new__(self, clsname, bases, namespace):
|
||||
cls = super().__new__(self, clsname, bases, namespace)
|
||||
cls.__models__ = {}
|
||||
for name, value in namespace.items():
|
||||
model = getattr(value, "__returns_model__", False)
|
||||
if model is not False:
|
||||
cls.__models__[model] = name
|
||||
for base in bases:
|
||||
cls.__models__.update(getattr(base, "__models__", {}))
|
||||
return cls
|
||||
|
||||
@staticmethod
|
||||
def prop(model_name):
|
||||
"""decorator to mark a class method as returning model values"""
|
||||
|
||||
def dec(f):
|
||||
f.__returns_model__ = model_name
|
||||
return f
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
model_data = defaultdict(dict)
|
||||
|
||||
|
||||
class InstanceTrackerMeta(type):
|
||||
def __new__(meta, name, bases, dct):
|
||||
cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct)
|
||||
if name == "BaseModel":
|
||||
return cls
|
||||
|
||||
service = cls.__module__.split(".")[1]
|
||||
if name not in model_data[service]:
|
||||
model_data[service][name] = cls
|
||||
cls.instances = []
|
||||
return cls
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
def _reset_model_refs(self):
|
||||
# Remove all references to the models stored
|
||||
for models in model_data.values():
|
||||
for model in models.values():
|
||||
model.instances = []
|
||||
|
||||
def reset(self):
|
||||
self._reset_model_refs()
|
||||
self.__dict__ = {}
|
||||
self.__init__()
|
||||
|
||||
@property
|
||||
def _url_module(self):
|
||||
backend_module = self.__class__.__module__
|
||||
backend_urls_module_name = backend_module.replace("models", "urls")
|
||||
backend_urls_module = __import__(
|
||||
backend_urls_module_name, fromlist=["url_bases", "url_paths"]
|
||||
)
|
||||
return backend_urls_module
|
||||
|
||||
@property
|
||||
def urls(self):
|
||||
"""
|
||||
A dictionary of the urls to be mocked with this service and the handlers
|
||||
that should be called in their place
|
||||
"""
|
||||
url_bases = self.url_bases
|
||||
unformatted_paths = self._url_module.url_paths
|
||||
|
||||
urls = {}
|
||||
for url_base in url_bases:
|
||||
# The default URL_base will look like: http://service.[..].amazonaws.com/...
|
||||
# This extension ensures support for the China regions
|
||||
cn_url_base = re.sub(r"amazonaws\\?.com$", "amazonaws.com.cn", url_base)
|
||||
for url_path, handler in unformatted_paths.items():
|
||||
url = url_path.format(url_base)
|
||||
urls[url] = handler
|
||||
cn_url = url_path.format(cn_url_base)
|
||||
urls[cn_url] = handler
|
||||
|
||||
return urls
|
||||
|
||||
@property
|
||||
def url_paths(self):
|
||||
"""
|
||||
A dictionary of the paths of the urls to be mocked with this service and
|
||||
the handlers that should be called in their place
|
||||
"""
|
||||
unformatted_paths = self._url_module.url_paths
|
||||
|
||||
paths = {}
|
||||
for unformatted_path, handler in unformatted_paths.items():
|
||||
path = unformatted_path.format("")
|
||||
paths[path] = handler
|
||||
|
||||
return paths
|
||||
|
||||
@property
|
||||
def url_bases(self):
|
||||
"""
|
||||
A list containing the url_bases extracted from urls.py
|
||||
"""
|
||||
return self._url_module.url_bases
|
||||
|
||||
@property
|
||||
def flask_paths(self):
|
||||
"""
|
||||
The url paths that will be used for the flask server
|
||||
"""
|
||||
paths = {}
|
||||
for url_path, handler in self.url_paths.items():
|
||||
url_path = convert_regex_to_flask_path(url_path)
|
||||
paths[url_path] = handler
|
||||
|
||||
return paths
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service(
|
||||
service_region, zones
|
||||
): # pylint: disable=unused-argument
|
||||
"""Invoke the factory method for any VPC endpoint(s) services."""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def vpce_random_number():
|
||||
"""Return random number for a VPC endpoint service ID."""
|
||||
return "".join([random.choice(string.hexdigits.lower()) for i in range(17)])
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service_factory(
|
||||
service_region,
|
||||
zones,
|
||||
service="",
|
||||
service_type="Interface",
|
||||
private_dns_names=True,
|
||||
special_service_name="",
|
||||
policy_supported=True,
|
||||
base_endpoint_dns_names=None,
|
||||
): # pylint: disable=too-many-arguments
|
||||
"""List of dicts representing default VPC endpoints for this service."""
|
||||
if special_service_name:
|
||||
service_name = f"com.amazonaws.{service_region}.{special_service_name}"
|
||||
else:
|
||||
service_name = f"com.amazonaws.{service_region}.{service}"
|
||||
|
||||
if not base_endpoint_dns_names:
|
||||
base_endpoint_dns_names = [f"{service}.{service_region}.vpce.amazonaws.com"]
|
||||
|
||||
endpoint_service = {
|
||||
"AcceptanceRequired": False,
|
||||
"AvailabilityZones": zones,
|
||||
"BaseEndpointDnsNames": base_endpoint_dns_names,
|
||||
"ManagesVpcEndpoints": False,
|
||||
"Owner": "amazon",
|
||||
"ServiceId": f"vpce-svc-{BaseBackend.vpce_random_number()}",
|
||||
"ServiceName": service_name,
|
||||
"ServiceType": [{"ServiceType": service_type}],
|
||||
"Tags": [],
|
||||
"VpcEndpointPolicySupported": policy_supported,
|
||||
}
|
||||
|
||||
# Don't know how private DNS names are different, so for now just
|
||||
# one will be added.
|
||||
if private_dns_names:
|
||||
endpoint_service[
|
||||
"PrivateDnsName"
|
||||
] = f"{service}.{service_region}.amazonaws.com"
|
||||
endpoint_service["PrivateDnsNameVerificationState"] = "verified"
|
||||
endpoint_service["PrivateDnsNames"] = [
|
||||
{"PrivateDnsName": f"{service}.{service_region}.amazonaws.com"}
|
||||
]
|
||||
return [endpoint_service]
|
||||
|
||||
# def list_config_service_resources(self, resource_ids, resource_name, limit, next_token):
|
||||
# """For AWS Config. This will list all of the resources of the given type and optional resource name and region"""
|
||||
# raise NotImplementedError()
|
||||
|
||||
|
||||
class base_decorator:
|
||||
mock_backend = MockAWS
|
||||
|
||||
|
@ -1,3 +0,0 @@
|
||||
url_bases = []
|
||||
|
||||
url_paths = {}
|
@ -1,7 +1,6 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from moto.core.common_models import CloudFormationModel
|
||||
from moto.core.models import Model
|
||||
from moto.packages.boto.ec2.launchspecification import LaunchSpecification
|
||||
from moto.packages.boto.ec2.spotinstancerequest import (
|
||||
SpotInstanceRequest as BotoSpotRequest,
|
||||
@ -117,7 +116,7 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource):
|
||||
return instance
|
||||
|
||||
|
||||
class SpotRequestBackend(object, metaclass=Model):
|
||||
class SpotRequestBackend(object):
|
||||
def __init__(self):
|
||||
self.spot_instance_requests = {}
|
||||
super().__init__()
|
||||
@ -176,7 +175,6 @@ class SpotRequestBackend(object, metaclass=Model):
|
||||
requests.append(request)
|
||||
return requests
|
||||
|
||||
@Model.prop("SpotInstanceRequest")
|
||||
def describe_spot_instance_requests(self, filters=None, spot_instance_ids=None):
|
||||
requests = self.spot_instance_requests.copy().values()
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from moto.core.models import BaseBackend
|
||||
from moto.core import BaseBackend
|
||||
|
||||
|
||||
class InstanceMetadataBackend(BaseBackend):
|
||||
|
@ -1,4 +1,4 @@
|
||||
from moto.core.models import BaseBackend
|
||||
from moto.core import BaseBackend
|
||||
|
||||
|
||||
class MotoAPIBackend(BaseBackend):
|
||||
|
@ -39,7 +39,7 @@ class MotoAPIResponse(BaseResponse):
|
||||
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"})
|
||||
|
||||
def model_data(self, request, full_url, headers): # pylint: disable=unused-argument
|
||||
from moto.core.models import model_data
|
||||
from moto.core.base_backend import model_data
|
||||
|
||||
results = {}
|
||||
for service in sorted(model_data):
|
||||
|
Loading…
x
Reference in New Issue
Block a user