Refactor BotocoreStubber/BaseBackend into separate files (#5122)

This commit is contained in:
Bert Blommers 2022-05-12 09:02:27 +00:00 committed by GitHub
parent e49e67aba5
commit 31737bc81e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 234 additions and 257 deletions

View File

@ -1,4 +1,5 @@
from .models import BaseBackend, get_account_id, ACCOUNT_ID # noqa from .models import get_account_id, ACCOUNT_ID # noqa
from .base_backend import BaseBackend # noqa
from .common_models import BaseModel # noqa from .common_models import BaseModel # noqa
from .common_models import CloudFormationModel, CloudWatchMetricProvider # noqa from .common_models import CloudFormationModel, CloudWatchMetricProvider # noqa
from .models import patch_client, patch_resource # noqa from .models import patch_client, patch_resource # noqa

160
moto/core/base_backend.py Normal file
View File

@ -0,0 +1,160 @@
import random
import re
import string
from collections import defaultdict
from .utils import convert_regex_to_flask_path
model_data = defaultdict(dict)
class InstanceTrackerMeta(type):
def __new__(meta, name, bases, dct):
cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct)
if name == "BaseModel":
return cls
service = cls.__module__.split(".")[1]
if name not in model_data[service]:
model_data[service][name] = cls
cls.instances = []
return cls
class BaseBackend:
def _reset_model_refs(self):
# Remove all references to the models stored
for models in model_data.values():
for model in models.values():
model.instances = []
def reset(self):
self._reset_model_refs()
self.__dict__ = {}
self.__init__()
@property
def _url_module(self):
backend_module = self.__class__.__module__
backend_urls_module_name = backend_module.replace("models", "urls")
backend_urls_module = __import__(
backend_urls_module_name, fromlist=["url_bases", "url_paths"]
)
return backend_urls_module
@property
def urls(self):
"""
A dictionary of the urls to be mocked with this service and the handlers
that should be called in their place
"""
url_bases = self.url_bases
unformatted_paths = self._url_module.url_paths
urls = {}
for url_base in url_bases:
# The default URL_base will look like: http://service.[..].amazonaws.com/...
# This extension ensures support for the China regions
cn_url_base = re.sub(r"amazonaws\\?.com$", "amazonaws.com.cn", url_base)
for url_path, handler in unformatted_paths.items():
url = url_path.format(url_base)
urls[url] = handler
cn_url = url_path.format(cn_url_base)
urls[cn_url] = handler
return urls
@property
def url_paths(self):
"""
A dictionary of the paths of the urls to be mocked with this service and
the handlers that should be called in their place
"""
unformatted_paths = self._url_module.url_paths
paths = {}
for unformatted_path, handler in unformatted_paths.items():
path = unformatted_path.format("")
paths[path] = handler
return paths
@property
def url_bases(self):
"""
A list containing the url_bases extracted from urls.py
"""
return self._url_module.url_bases
@property
def flask_paths(self):
"""
The url paths that will be used for the flask server
"""
paths = {}
for url_path, handler in self.url_paths.items():
url_path = convert_regex_to_flask_path(url_path)
paths[url_path] = handler
return paths
@staticmethod
def default_vpc_endpoint_service(
service_region, zones
): # pylint: disable=unused-argument
"""Invoke the factory method for any VPC endpoint(s) services."""
return None
@staticmethod
def vpce_random_number():
"""Return random number for a VPC endpoint service ID."""
return "".join([random.choice(string.hexdigits.lower()) for i in range(17)])
@staticmethod
def default_vpc_endpoint_service_factory(
service_region,
zones,
service="",
service_type="Interface",
private_dns_names=True,
special_service_name="",
policy_supported=True,
base_endpoint_dns_names=None,
): # pylint: disable=too-many-arguments
"""List of dicts representing default VPC endpoints for this service."""
if special_service_name:
service_name = f"com.amazonaws.{service_region}.{special_service_name}"
else:
service_name = f"com.amazonaws.{service_region}.{service}"
if not base_endpoint_dns_names:
base_endpoint_dns_names = [f"{service}.{service_region}.vpce.amazonaws.com"]
endpoint_service = {
"AcceptanceRequired": False,
"AvailabilityZones": zones,
"BaseEndpointDnsNames": base_endpoint_dns_names,
"ManagesVpcEndpoints": False,
"Owner": "amazon",
"ServiceId": f"vpce-svc-{BaseBackend.vpce_random_number()}",
"ServiceName": service_name,
"ServiceType": [{"ServiceType": service_type}],
"Tags": [],
"VpcEndpointPolicySupported": policy_supported,
}
# Don't know how private DNS names are different, so for now just
# one will be added.
if private_dns_names:
endpoint_service[
"PrivateDnsName"
] = f"{service}.{service_region}.amazonaws.com"
endpoint_service["PrivateDnsNameVerificationState"] = "verified"
endpoint_service["PrivateDnsNames"] = [
{"PrivateDnsName": f"{service}.{service_region}.amazonaws.com"}
]
return [endpoint_service]
# def list_config_service_resources(self, resource_ids, resource_name, limit, next_token):
# """For AWS Config. This will list all of the resources of the given type and optional resource name and region"""
# raise NotImplementedError()

View File

@ -0,0 +1,65 @@
from collections import defaultdict
from io import BytesIO
from botocore.awsrequest import AWSResponse
from moto.core.exceptions import HTTPException
class MockRawResponse(BytesIO):
def __init__(self, response_input):
if isinstance(response_input, str):
response_input = response_input.encode("utf-8")
super().__init__(response_input)
def stream(self, **kwargs): # pylint: disable=unused-argument
contents = self.read()
while contents:
yield contents
contents = self.read()
class BotocoreStubber:
def __init__(self):
self.enabled = False
self.methods = defaultdict(list)
def reset(self):
self.methods.clear()
def register_response(self, method, pattern, response):
matchers = self.methods[method]
matchers.append((pattern, response))
def __call__(self, event_name, request, **kwargs):
if not self.enabled:
return None
response = None
response_callback = None
found_index = None
matchers = self.methods.get(request.method)
base_url = request.url.split("?", 1)[0]
for i, (pattern, callback) in enumerate(matchers):
if pattern.match(base_url):
if found_index is None:
found_index = i
response_callback = callback
else:
matchers.pop(found_index)
break
if response_callback is not None:
for header, value in request.headers.items():
if isinstance(value, bytes):
request.headers[header] = value.decode("utf-8")
try:
status, headers, body = response_callback(
request, request.url, request.headers
)
except HTTPException as e:
status = e.code
headers = e.get_headers()
body = e.get_body()
body = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body)
return response

View File

@ -1,5 +1,5 @@
from abc import abstractmethod from abc import abstractmethod
from .models import InstanceTrackerMeta from .base_backend import InstanceTrackerMeta
class BaseModel(metaclass=InstanceTrackerMeta): class BaseModel(metaclass=InstanceTrackerMeta):

View File

@ -2,31 +2,26 @@ import functools
import inspect import inspect
import itertools import itertools
import os import os
import random
import re import re
import string
import unittest import unittest
from collections import defaultdict
from io import BytesIO
from types import FunctionType from types import FunctionType
from unittest.mock import patch from unittest.mock import patch
import boto3 import boto3
import botocore import botocore
import responses import responses
from botocore.awsrequest import AWSResponse
from botocore.config import Config from botocore.config import Config
from botocore.handlers import BUILTIN_HANDLERS from botocore.handlers import BUILTIN_HANDLERS
from moto import settings from moto import settings
from moto.core.exceptions import HTTPException from .botocore_stubber import BotocoreStubber
from .custom_responses_mock import ( from .custom_responses_mock import (
get_response_mock, get_response_mock,
CallbackResponse, CallbackResponse,
not_implemented_callback, not_implemented_callback,
reset_responses_mock, reset_responses_mock,
) )
from .utils import convert_regex_to_flask_path, convert_flask_to_responses_response from .utils import convert_flask_to_responses_response
ACCOUNT_ID = os.environ.get("MOTO_ACCOUNT_ID", "123456789012") ACCOUNT_ID = os.environ.get("MOTO_ACCOUNT_ID", "123456789012")
@ -244,67 +239,6 @@ responses_mock = get_response_mock()
BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"] BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"]
class MockRawResponse(BytesIO):
def __init__(self, response_input):
if isinstance(response_input, str):
response_input = response_input.encode("utf-8")
super().__init__(response_input)
def stream(self, **kwargs): # pylint: disable=unused-argument
contents = self.read()
while contents:
yield contents
contents = self.read()
class BotocoreStubber:
def __init__(self):
self.enabled = False
self.methods = defaultdict(list)
def reset(self):
self.methods.clear()
def register_response(self, method, pattern, response):
matchers = self.methods[method]
matchers.append((pattern, response))
def __call__(self, event_name, request, **kwargs):
if not self.enabled:
return None
response = None
response_callback = None
found_index = None
matchers = self.methods.get(request.method)
base_url = request.url.split("?", 1)[0]
for i, (pattern, callback) in enumerate(matchers):
if pattern.match(base_url):
if found_index is None:
found_index = i
response_callback = callback
else:
matchers.pop(found_index)
break
if response_callback is not None:
for header, value in request.headers.items():
if isinstance(value, bytes):
request.headers[header] = value.decode("utf-8")
try:
status, headers, body = response_callback(
request, request.url, request.headers
)
except HTTPException as e:
status = e.code
headers = e.get_headers()
body = e.get_body()
body = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body)
return response
botocore_stubber = BotocoreStubber() botocore_stubber = BotocoreStubber()
BUILTIN_HANDLERS.append(("before-send", botocore_stubber)) BUILTIN_HANDLERS.append(("before-send", botocore_stubber))
@ -455,184 +389,6 @@ class ServerModeMockAWS(BaseMockAWS):
self._resource_patcher.stop() self._resource_patcher.stop()
class Model(type):
def __new__(self, clsname, bases, namespace):
cls = super().__new__(self, clsname, bases, namespace)
cls.__models__ = {}
for name, value in namespace.items():
model = getattr(value, "__returns_model__", False)
if model is not False:
cls.__models__[model] = name
for base in bases:
cls.__models__.update(getattr(base, "__models__", {}))
return cls
@staticmethod
def prop(model_name):
"""decorator to mark a class method as returning model values"""
def dec(f):
f.__returns_model__ = model_name
return f
return dec
model_data = defaultdict(dict)
class InstanceTrackerMeta(type):
def __new__(meta, name, bases, dct):
cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct)
if name == "BaseModel":
return cls
service = cls.__module__.split(".")[1]
if name not in model_data[service]:
model_data[service][name] = cls
cls.instances = []
return cls
class BaseBackend:
def _reset_model_refs(self):
# Remove all references to the models stored
for models in model_data.values():
for model in models.values():
model.instances = []
def reset(self):
self._reset_model_refs()
self.__dict__ = {}
self.__init__()
@property
def _url_module(self):
backend_module = self.__class__.__module__
backend_urls_module_name = backend_module.replace("models", "urls")
backend_urls_module = __import__(
backend_urls_module_name, fromlist=["url_bases", "url_paths"]
)
return backend_urls_module
@property
def urls(self):
"""
A dictionary of the urls to be mocked with this service and the handlers
that should be called in their place
"""
url_bases = self.url_bases
unformatted_paths = self._url_module.url_paths
urls = {}
for url_base in url_bases:
# The default URL_base will look like: http://service.[..].amazonaws.com/...
# This extension ensures support for the China regions
cn_url_base = re.sub(r"amazonaws\\?.com$", "amazonaws.com.cn", url_base)
for url_path, handler in unformatted_paths.items():
url = url_path.format(url_base)
urls[url] = handler
cn_url = url_path.format(cn_url_base)
urls[cn_url] = handler
return urls
@property
def url_paths(self):
"""
A dictionary of the paths of the urls to be mocked with this service and
the handlers that should be called in their place
"""
unformatted_paths = self._url_module.url_paths
paths = {}
for unformatted_path, handler in unformatted_paths.items():
path = unformatted_path.format("")
paths[path] = handler
return paths
@property
def url_bases(self):
"""
A list containing the url_bases extracted from urls.py
"""
return self._url_module.url_bases
@property
def flask_paths(self):
"""
The url paths that will be used for the flask server
"""
paths = {}
for url_path, handler in self.url_paths.items():
url_path = convert_regex_to_flask_path(url_path)
paths[url_path] = handler
return paths
@staticmethod
def default_vpc_endpoint_service(
service_region, zones
): # pylint: disable=unused-argument
"""Invoke the factory method for any VPC endpoint(s) services."""
return None
@staticmethod
def vpce_random_number():
"""Return random number for a VPC endpoint service ID."""
return "".join([random.choice(string.hexdigits.lower()) for i in range(17)])
@staticmethod
def default_vpc_endpoint_service_factory(
service_region,
zones,
service="",
service_type="Interface",
private_dns_names=True,
special_service_name="",
policy_supported=True,
base_endpoint_dns_names=None,
): # pylint: disable=too-many-arguments
"""List of dicts representing default VPC endpoints for this service."""
if special_service_name:
service_name = f"com.amazonaws.{service_region}.{special_service_name}"
else:
service_name = f"com.amazonaws.{service_region}.{service}"
if not base_endpoint_dns_names:
base_endpoint_dns_names = [f"{service}.{service_region}.vpce.amazonaws.com"]
endpoint_service = {
"AcceptanceRequired": False,
"AvailabilityZones": zones,
"BaseEndpointDnsNames": base_endpoint_dns_names,
"ManagesVpcEndpoints": False,
"Owner": "amazon",
"ServiceId": f"vpce-svc-{BaseBackend.vpce_random_number()}",
"ServiceName": service_name,
"ServiceType": [{"ServiceType": service_type}],
"Tags": [],
"VpcEndpointPolicySupported": policy_supported,
}
# Don't know how private DNS names are different, so for now just
# one will be added.
if private_dns_names:
endpoint_service[
"PrivateDnsName"
] = f"{service}.{service_region}.amazonaws.com"
endpoint_service["PrivateDnsNameVerificationState"] = "verified"
endpoint_service["PrivateDnsNames"] = [
{"PrivateDnsName": f"{service}.{service_region}.amazonaws.com"}
]
return [endpoint_service]
# def list_config_service_resources(self, resource_ids, resource_name, limit, next_token):
# """For AWS Config. This will list all of the resources of the given type and optional resource name and region"""
# raise NotImplementedError()
class base_decorator: class base_decorator:
mock_backend = MockAWS mock_backend = MockAWS

View File

@ -1,3 +0,0 @@
url_bases = []
url_paths = {}

View File

@ -1,7 +1,6 @@
from collections import defaultdict from collections import defaultdict
from moto.core.common_models import CloudFormationModel from moto.core.common_models import CloudFormationModel
from moto.core.models import Model
from moto.packages.boto.ec2.launchspecification import LaunchSpecification from moto.packages.boto.ec2.launchspecification import LaunchSpecification
from moto.packages.boto.ec2.spotinstancerequest import ( from moto.packages.boto.ec2.spotinstancerequest import (
SpotInstanceRequest as BotoSpotRequest, SpotInstanceRequest as BotoSpotRequest,
@ -117,7 +116,7 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource):
return instance return instance
class SpotRequestBackend(object, metaclass=Model): class SpotRequestBackend(object):
def __init__(self): def __init__(self):
self.spot_instance_requests = {} self.spot_instance_requests = {}
super().__init__() super().__init__()
@ -176,7 +175,6 @@ class SpotRequestBackend(object, metaclass=Model):
requests.append(request) requests.append(request)
return requests return requests
@Model.prop("SpotInstanceRequest")
def describe_spot_instance_requests(self, filters=None, spot_instance_ids=None): def describe_spot_instance_requests(self, filters=None, spot_instance_ids=None):
requests = self.spot_instance_requests.copy().values() requests = self.spot_instance_requests.copy().values()

View File

@ -1,4 +1,4 @@
from moto.core.models import BaseBackend from moto.core import BaseBackend
class InstanceMetadataBackend(BaseBackend): class InstanceMetadataBackend(BaseBackend):

View File

@ -1,4 +1,4 @@
from moto.core.models import BaseBackend from moto.core import BaseBackend
class MotoAPIBackend(BaseBackend): class MotoAPIBackend(BaseBackend):

View File

@ -39,7 +39,7 @@ class MotoAPIResponse(BaseResponse):
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"}) return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"})
def model_data(self, request, full_url, headers): # pylint: disable=unused-argument def model_data(self, request, full_url, headers): # pylint: disable=unused-argument
from moto.core.models import model_data from moto.core.base_backend import model_data
results = {} results = {}
for service in sorted(model_data): for service in sorted(model_data):