TechDebt - enable pylint rule redefined-outer-scope (#5518)

This commit is contained in:
Bert Blommers 2022-10-04 16:28:30 +00:00 committed by GitHub
parent 696b809b5a
commit 4f84e2f154
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
46 changed files with 236 additions and 396 deletions

View File

@ -157,11 +157,11 @@ class _DockerDataVolumeContext:
raise # multiple processes trying to use same volume? raise # multiple processes trying to use same volume?
def _zipfile_content(zipfile): def _zipfile_content(zipfile_content):
try: try:
to_unzip_code = base64.b64decode(bytes(zipfile, "utf-8")) to_unzip_code = base64.b64decode(bytes(zipfile_content, "utf-8"))
except Exception: except Exception:
to_unzip_code = base64.b64decode(zipfile) to_unzip_code = base64.b64decode(zipfile_content)
sha_code = hashlib.sha256(to_unzip_code) sha_code = hashlib.sha256(to_unzip_code)
base64ed_sha = base64.b64encode(sha_code.digest()).decode("utf-8") base64ed_sha = base64.b64encode(sha_code.digest()).decode("utf-8")

View File

@ -292,23 +292,23 @@ class CloudFrontBackend(BaseBackend):
return dist return dist
return False return False
def update_distribution(self, DistributionConfig, Id, IfMatch): def update_distribution(self, dist_config, _id, if_match):
""" """
The IfMatch-value is ignored - any value is considered valid. The IfMatch-value is ignored - any value is considered valid.
Calling this function without a value is invalid, per AWS' behaviour Calling this function without a value is invalid, per AWS' behaviour
""" """
if Id not in self.distributions or Id is None: if _id not in self.distributions or _id is None:
raise NoSuchDistribution raise NoSuchDistribution
if not IfMatch: if not if_match:
raise InvalidIfMatchVersion raise InvalidIfMatchVersion
if not DistributionConfig: if not dist_config:
raise NoSuchDistribution raise NoSuchDistribution
dist = self.distributions[Id] dist = self.distributions[_id]
aliases = DistributionConfig["Aliases"]["Items"]["CNAME"] aliases = dist_config["Aliases"]["Items"]["CNAME"]
dist.distribution_config.config = DistributionConfig dist.distribution_config.config = dist_config
dist.distribution_config.aliases = aliases dist.distribution_config.aliases = aliases
self.distributions[Id] = dist self.distributions[_id] = dist
dist.advance() dist.advance()
return dist, dist.location, dist.etag return dist, dist.location, dist.etag

View File

@ -83,9 +83,9 @@ class CloudFrontResponse(BaseResponse):
if_match = headers["If-Match"] if_match = headers["If-Match"]
dist, location, e_tag = self.backend.update_distribution( dist, location, e_tag = self.backend.update_distribution(
DistributionConfig=distribution_config, dist_config=distribution_config,
Id=dist_id, _id=dist_id,
IfMatch=if_match, if_match=if_match,
) )
template = self.response_template(UPDATE_DISTRIBUTION_TEMPLATE) template = self.response_template(UPDATE_DISTRIBUTION_TEMPLATE)
response = template.render(distribution=dist, xmlns=XMLNS) response = template.render(distribution=dist, xmlns=XMLNS)

View File

@ -146,28 +146,28 @@ class convert_flask_to_responses_response(object):
return status, headers, response return status, headers, response
def iso_8601_datetime_with_milliseconds(datetime): def iso_8601_datetime_with_milliseconds(value):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" return value.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
# Even Python does not support nanoseconds, other languages like Go do (needed for Terraform) # Even Python does not support nanoseconds, other languages like Go do (needed for Terraform)
def iso_8601_datetime_with_nanoseconds(datetime): def iso_8601_datetime_with_nanoseconds(value):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f000Z") return value.strftime("%Y-%m-%dT%H:%M:%S.%f000Z")
def iso_8601_datetime_without_milliseconds(datetime): def iso_8601_datetime_without_milliseconds(value):
return None if datetime is None else datetime.strftime("%Y-%m-%dT%H:%M:%SZ") return None if value is None else value.strftime("%Y-%m-%dT%H:%M:%SZ")
def iso_8601_datetime_without_milliseconds_s3(datetime): def iso_8601_datetime_without_milliseconds_s3(value):
return None if datetime is None else datetime.strftime("%Y-%m-%dT%H:%M:%S.000Z") return None if value is None else value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT" RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
def rfc_1123_datetime(datetime): def rfc_1123_datetime(src):
return datetime.strftime(RFC1123) return src.strftime(RFC1123)
def str_to_rfc_1123_datetime(value): def str_to_rfc_1123_datetime(value):

View File

@ -12,14 +12,14 @@ INSTANCE_FAMILIES = list(set([i.split(".")[0] for i in INSTANCE_TYPES.keys()]))
root = pathlib.Path(__file__).parent root = pathlib.Path(__file__).parent
offerings_path = "../resources/instance_type_offerings" offerings_path = "../resources/instance_type_offerings"
INSTANCE_TYPE_OFFERINGS = {} INSTANCE_TYPE_OFFERINGS = {}
for location_type in listdir(root / offerings_path): for _location_type in listdir(root / offerings_path):
INSTANCE_TYPE_OFFERINGS[location_type] = {} INSTANCE_TYPE_OFFERINGS[_location_type] = {}
for _region in listdir(root / offerings_path / location_type): for _region in listdir(root / offerings_path / _location_type):
full_path = offerings_path + "/" + location_type + "/" + _region full_path = offerings_path + "/" + _location_type + "/" + _region
res = load_resource(__name__, full_path) res = load_resource(__name__, full_path)
for instance in res: for instance in res:
instance["LocationType"] = location_type instance["LocationType"] = _location_type
INSTANCE_TYPE_OFFERINGS[location_type][_region.replace(".json", "")] = res INSTANCE_TYPE_OFFERINGS[_location_type][_region.replace(".json", "")] = res
class InstanceType(dict): class InstanceType(dict):

View File

@ -1021,19 +1021,12 @@ class SecurityGroupBackend:
if cidr_item.get("CidrIp6") == item.get("CidrIp6"): if cidr_item.get("CidrIp6") == item.get("CidrIp6"):
cidr_item["Description"] = item.get("Description") cidr_item["Description"] = item.get("Description")
for item in security_rule.source_groups: for group in security_rule.source_groups:
for source_group in rule.source_groups: for source_group in rule.source_groups:
if source_group.get("GroupId") == item.get( if source_group.get("GroupId") == group.get(
"GroupId" "GroupId"
) or source_group.get("GroupName") == item.get("GroupName"): ) or source_group.get("GroupName") == group.get("GroupName"):
source_group["Description"] = item.get("Description") source_group["Description"] = group.get("Description")
for item in security_rule.source_groups:
for source_group in rule.source_groups:
if source_group.get("GroupId") == item.get(
"GroupId"
) or source_group.get("GroupName") == item.get("GroupName"):
source_group["Description"] = item.get("Description")
def _remove_items_from_rule(self, ip_ranges, _source_groups, prefix_list_ids, rule): def _remove_items_from_rule(self, ip_ranges, _source_groups, prefix_list_ids, rule):
for item in ip_ranges: for item in ip_ranges:

View File

@ -297,42 +297,42 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
self.name = name self.name = name
self.modification_time = datetime.now(timezone.utc).isoformat() self.modification_time = datetime.now(timezone.utc).isoformat()
def associate_ip_address(self, ip_address): def associate_ip_address(self, value):
self.ip_addresses.append(ip_address) self.ip_addresses.append(value)
self.ip_address_count = len(self.ip_addresses) self.ip_address_count = len(self.ip_addresses)
eni_id = f"rni-{mock_random.get_random_hex(17)}" eni_id = f"rni-{mock_random.get_random_hex(17)}"
self.subnets[ip_address["SubnetId"]][ip_address["Ip"]] = eni_id self.subnets[value["SubnetId"]][value["Ip"]] = eni_id
eni_info = self.ec2_backend.create_network_interface( eni_info = self.ec2_backend.create_network_interface(
description=f"Route 53 Resolver: {self.id}:{eni_id}", description=f"Route 53 Resolver: {self.id}:{eni_id}",
group_ids=self.security_group_ids, group_ids=self.security_group_ids,
interface_type="interface", interface_type="interface",
private_ip_address=ip_address.get("Ip"), private_ip_address=value.get("Ip"),
private_ip_addresses=[ private_ip_addresses=[
{"Primary": True, "PrivateIpAddress": ip_address.get("Ip")} {"Primary": True, "PrivateIpAddress": value.get("Ip")}
], ],
subnet=ip_address.get("SubnetId"), subnet=value.get("SubnetId"),
) )
self.eni_ids.append(eni_info.id) self.eni_ids.append(eni_info.id)
def disassociate_ip_address(self, ip_address): def disassociate_ip_address(self, value):
if not ip_address.get("Ip") and ip_address.get("IpId"): if not value.get("Ip") and value.get("IpId"):
for ip_addr, eni_id in self.subnets[ip_address.get("SubnetId")].items(): for ip_addr, eni_id in self.subnets[value.get("SubnetId")].items():
if ip_address.get("IpId") == eni_id: if value.get("IpId") == eni_id:
ip_address["Ip"] = ip_addr value["Ip"] = ip_addr
if ip_address.get("Ip"): if value.get("Ip"):
self.ip_addresses = list( self.ip_addresses = list(
filter(lambda i: i["Ip"] != ip_address.get("Ip"), self.ip_addresses) filter(lambda i: i["Ip"] != value.get("Ip"), self.ip_addresses)
) )
if len(self.subnets[ip_address["SubnetId"]]) == 1: if len(self.subnets[value["SubnetId"]]) == 1:
self.subnets.pop(ip_address["SubnetId"]) self.subnets.pop(value["SubnetId"])
else: else:
self.subnets[ip_address["SubnetId"]].pop(ip_address["Ip"]) self.subnets[value["SubnetId"]].pop(value["Ip"])
for eni_id in self.eni_ids: for eni_id in self.eni_ids:
eni_info = self.ec2_backend.get_network_interface(eni_id) eni_info = self.ec2_backend.get_network_interface(eni_id)
if eni_info.private_ip_address == ip_address.get("Ip"): if eni_info.private_ip_address == value.get("Ip"):
self.ec2_backend.delete_network_interface(eni_id) self.ec2_backend.delete_network_interface(eni_id)
self.eni_ids.remove(eni_id) self.eni_ids.remove(eni_id)
self.ip_address_count = len(self.ip_addresses) self.ip_address_count = len(self.ip_addresses)
@ -873,32 +873,30 @@ class Route53ResolverBackend(BaseBackend):
resolver_endpoint.update_name(name) resolver_endpoint.update_name(name)
return resolver_endpoint return resolver_endpoint
def associate_resolver_endpoint_ip_address(self, resolver_endpoint_id, ip_address): def associate_resolver_endpoint_ip_address(self, resolver_endpoint_id, value):
self._validate_resolver_endpoint_id(resolver_endpoint_id) self._validate_resolver_endpoint_id(resolver_endpoint_id)
resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
if not ip_address.get("Ip"): if not value.get("Ip"):
subnet_info = self.ec2_backend.get_all_subnets( subnet_info = self.ec2_backend.get_all_subnets(
subnet_ids=[ip_address.get("SubnetId")] subnet_ids=[value.get("SubnetId")]
)[0] )[0]
ip_address["Ip"] = subnet_info.get_available_subnet_ip(self) value["Ip"] = subnet_info.get_available_subnet_ip(self)
self._verify_subnet_ips([ip_address], False) self._verify_subnet_ips([value], False)
resolver_endpoint.associate_ip_address(ip_address) resolver_endpoint.associate_ip_address(value)
return resolver_endpoint return resolver_endpoint
def disassociate_resolver_endpoint_ip_address( def disassociate_resolver_endpoint_ip_address(self, resolver_endpoint_id, value):
self, resolver_endpoint_id, ip_address
):
self._validate_resolver_endpoint_id(resolver_endpoint_id) self._validate_resolver_endpoint_id(resolver_endpoint_id)
resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
if not (ip_address.get("Ip") or ip_address.get("IpId")): if not (value.get("Ip") or value.get("IpId")):
raise InvalidRequestException( raise InvalidRequestException(
"[RSLVR-00503] Need to specify either the IP ID or both subnet and IP address in order to remove IP address." "[RSLVR-00503] Need to specify either the IP ID or both subnet and IP address in order to remove IP address."
) )
resolver_endpoint.disassociate_ip_address(ip_address) resolver_endpoint.disassociate_ip_address(value)
return resolver_endpoint return resolver_endpoint

View File

@ -266,7 +266,7 @@ class Route53ResolverResponse(BaseResponse):
resolver_endpoint = ( resolver_endpoint = (
self.route53resolver_backend.associate_resolver_endpoint_ip_address( self.route53resolver_backend.associate_resolver_endpoint_ip_address(
resolver_endpoint_id=resolver_endpoint_id, resolver_endpoint_id=resolver_endpoint_id,
ip_address=ip_address, value=ip_address,
) )
) )
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()}) return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})
@ -278,7 +278,7 @@ class Route53ResolverResponse(BaseResponse):
resolver_endpoint = ( resolver_endpoint = (
self.route53resolver_backend.disassociate_resolver_endpoint_ip_address( self.route53resolver_backend.disassociate_resolver_endpoint_ip_address(
resolver_endpoint_id=resolver_endpoint_id, resolver_endpoint_id=resolver_endpoint_id,
ip_address=ip_address, value=ip_address,
) )
) )
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()}) return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})

View File

@ -1,8 +1,8 @@
def name(secret, names): def name_filter(secret, names):
return _matcher(names, [secret.name]) return _matcher(names, [secret.name])
def description(secret, descriptions): def description_filter(secret, descriptions):
return _matcher(descriptions, [secret.description]) return _matcher(descriptions, [secret.description])

View File

@ -18,13 +18,19 @@ from .exceptions import (
ClientError, ClientError,
) )
from .utils import random_password, secret_arn, get_secret_name_from_arn from .utils import random_password, secret_arn, get_secret_name_from_arn
from .list_secrets.filters import filter_all, tag_key, tag_value, description, name from .list_secrets.filters import (
filter_all,
tag_key,
tag_value,
description_filter,
name_filter,
)
_filter_functions = { _filter_functions = {
"all": filter_all, "all": filter_all,
"name": name, "name": name_filter,
"description": description, "description": description_filter,
"tag-key": tag_key, "tag-key": tag_key,
"tag-value": tag_value, "tag-value": tag_value,
} }

View File

@ -470,24 +470,24 @@ class SESBackend(BaseBackend):
text_part = str.replace(str(text_part), "{{%s}}" % key, value) text_part = str.replace(str(text_part), "{{%s}}" % key, value)
html_part = str.replace(str(html_part), "{{%s}}" % key, value) html_part = str.replace(str(html_part), "{{%s}}" % key, value)
email = MIMEMultipart("alternative") email_obj = MIMEMultipart("alternative")
mime_text = MIMEBase("text", "plain;charset=UTF-8") mime_text = MIMEBase("text", "plain;charset=UTF-8")
mime_text.set_payload(text_part.encode("utf-8")) mime_text.set_payload(text_part.encode("utf-8"))
encode_7or8bit(mime_text) encode_7or8bit(mime_text)
email.attach(mime_text) email_obj.attach(mime_text)
mime_html = MIMEBase("text", "html;charset=UTF-8") mime_html = MIMEBase("text", "html;charset=UTF-8")
mime_html.set_payload(html_part.encode("utf-8")) mime_html.set_payload(html_part.encode("utf-8"))
encode_7or8bit(mime_html) encode_7or8bit(mime_html)
email.attach(mime_html) email_obj.attach(mime_html)
now = datetime.datetime.now().isoformat() now = datetime.datetime.now().isoformat()
rendered_template = "Date: %s\r\nSubject: %s\r\n%s" % ( rendered_template = "Date: %s\r\nSubject: %s\r\n%s" % (
now, now,
subject_part, subject_part,
email.as_string(), email_obj.as_string(),
) )
return rendered_template return rendered_template

View File

@ -141,10 +141,10 @@ class Message(BaseModel):
) )
@staticmethod @staticmethod
def utf8(string): def utf8(value):
if isinstance(string, str): if isinstance(value, str):
return string.encode("utf-8") return value.encode("utf-8")
return string return value
@property @property
def body(self): def body(self):

View File

@ -16,11 +16,11 @@ class TelemetryRecords(BaseModel):
self.records = records self.records = records
@classmethod @classmethod
def from_json(cls, json): def from_json(cls, src):
instance_id = json.get("EC2InstanceId", None) instance_id = src.get("EC2InstanceId", None)
hostname = json.get("Hostname") hostname = src.get("Hostname")
resource_arn = json.get("ResourceARN") resource_arn = src.get("ResourceARN")
telemetry_records = json["TelemetryRecords"] telemetry_records = src["TelemetryRecords"]
return cls(instance_id, hostname, resource_arn, telemetry_records) return cls(instance_id, hostname, resource_arn, telemetry_records)
@ -242,8 +242,8 @@ class XRayBackend(BaseBackend):
service_region, zones, "xray" service_region, zones, "xray"
) )
def add_telemetry_records(self, json): def add_telemetry_records(self, src):
self._telemetry_records.append(TelemetryRecords.from_json(json)) self._telemetry_records.append(TelemetryRecords.from_json(src))
def process_segment(self, doc): def process_segment(self, doc):
try: try:

View File

@ -14,5 +14,5 @@ ignore-paths=moto/packages
[pylint.'MESSAGES CONTROL'] [pylint.'MESSAGES CONTROL']
disable = W,C,R,E disable = W,C,R,E
# future sensible checks = super-init-not-called, redefined-outer-name, unspecified-encoding, undefined-loop-variable # future sensible checks = super-init-not-called, unspecified-encoding, undefined-loop-variable
enable = arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import

View File

@ -139,7 +139,7 @@ def test_event_source_mapping_create_from_cloudformation_json():
"FunctionArn" "FunctionArn"
] ]
template = event_source_mapping_template.substitute( esm_template = event_source_mapping_template.substitute(
{ {
"resource_name": "Foo", "resource_name": "Foo",
"batch_size": 1, "batch_size": 1,
@ -149,7 +149,7 @@ def test_event_source_mapping_create_from_cloudformation_json():
} }
) )
cf.create_stack(StackName=random_stack_name(), TemplateBody=template) cf.create_stack(StackName=random_stack_name(), TemplateBody=esm_template)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name) event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
event_sources["EventSourceMappings"].should.have.length_of(1) event_sources["EventSourceMappings"].should.have.length_of(1)
@ -174,7 +174,7 @@ def test_event_source_mapping_delete_stack():
_, lambda_stack = create_stack(cf, s3) _, lambda_stack = create_stack(cf, s3)
created_fn_name = get_created_function_name(cf, lambda_stack) created_fn_name = get_created_function_name(cf, lambda_stack)
template = event_source_mapping_template.substitute( esm_template = event_source_mapping_template.substitute(
{ {
"resource_name": "Foo", "resource_name": "Foo",
"batch_size": 1, "batch_size": 1,
@ -184,7 +184,9 @@ def test_event_source_mapping_delete_stack():
} }
) )
esm_stack = cf.create_stack(StackName=random_stack_name(), TemplateBody=template) esm_stack = cf.create_stack(
StackName=random_stack_name(), TemplateBody=esm_template
)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name) event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
event_sources["EventSourceMappings"].should.have.length_of(1) event_sources["EventSourceMappings"].should.have.length_of(1)

View File

@ -18,8 +18,8 @@ def test_basic_decorator():
client.describe_addresses()["Addresses"].should.equal([]) client.describe_addresses()["Addresses"].should.equal([])
@pytest.fixture @pytest.fixture(name="aws_credentials")
def aws_credentials(monkeypatch): def fixture_aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto.""" """Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing") monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")

View File

@ -6,8 +6,8 @@ from moto import settings
from unittest import SkipTest from unittest import SkipTest
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="aws_credentials")
def aws_credentials(monkeypatch): def fixture_aws_credentials(monkeypatch):
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
raise SkipTest("No point in testing this in ServerMode.") raise SkipTest("No point in testing this in ServerMode.")
"""Mocked AWS Credentials for moto.""" """Mocked AWS Credentials for moto."""

View File

@ -495,13 +495,13 @@ def test_creating_table_with_0_global_indexes():
@mock_dynamodb @mock_dynamodb
def test_multiple_transactions_on_same_item(): def test_multiple_transactions_on_same_item():
table_schema = { schema = {
"KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}], "KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}],
"AttributeDefinitions": [{"AttributeName": "id", "AttributeType": "S"}], "AttributeDefinitions": [{"AttributeName": "id", "AttributeType": "S"}],
} }
dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb = boto3.client("dynamodb", region_name="us-east-1")
dynamodb.create_table( dynamodb.create_table(
TableName="test-table", BillingMode="PAY_PER_REQUEST", **table_schema TableName="test-table", BillingMode="PAY_PER_REQUEST", **schema
) )
# Insert an item # Insert an item
dynamodb.put_item(TableName="test-table", Item={"id": {"S": "foo"}}) dynamodb.put_item(TableName="test-table", Item={"id": {"S": "foo"}})
@ -533,13 +533,13 @@ def test_multiple_transactions_on_same_item():
@mock_dynamodb @mock_dynamodb
def test_transact_write_items__too_many_transactions(): def test_transact_write_items__too_many_transactions():
table_schema = { schema = {
"KeySchema": [{"AttributeName": "pk", "KeyType": "HASH"}], "KeySchema": [{"AttributeName": "pk", "KeyType": "HASH"}],
"AttributeDefinitions": [{"AttributeName": "pk", "AttributeType": "S"}], "AttributeDefinitions": [{"AttributeName": "pk", "AttributeType": "S"}],
} }
dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb = boto3.client("dynamodb", region_name="us-east-1")
dynamodb.create_table( dynamodb.create_table(
TableName="test-table", BillingMode="PAY_PER_REQUEST", **table_schema TableName="test-table", BillingMode="PAY_PER_REQUEST", **schema
) )
def update_email_transact(email): def update_email_transact(email):

View File

@ -15,8 +15,8 @@ TABLE_NAME = "my_table_name"
TABLE_WITH_RANGE_NAME = "my_table_with_range_name" TABLE_WITH_RANGE_NAME = "my_table_with_range_name"
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True, name="test_client")
def test_client(): def fixture_test_client():
backend = server.create_backend_app("dynamodb_v20111205") backend = server.create_backend_app("dynamodb_v20111205")
test_client = backend.test_client() test_client = backend.test_client()

View File

@ -0,0 +1,16 @@
import boto3
import pytest
from moto import mock_ec2, mock_efs
@pytest.fixture(scope="function", name="ec2")
def fixture_ec2():
with mock_ec2():
yield boto3.client("ec2", region_name="us-east-1")
@pytest.fixture(scope="function", name="efs")
def fixture_efs():
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")

View File

@ -1,26 +1,9 @@
import boto3
import pytest import pytest
from . import fixture_efs # noqa # pylint: disable=unused-import
from moto import mock_efs
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="file_system")
def aws_credentials(monkeypatch): def fixture_file_system(efs):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz") create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata") create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp yield create_fs_resp

View File

@ -1,28 +1,12 @@
import boto3
import pytest import pytest
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_efs
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from . import fixture_efs # noqa # pylint: disable=unused-import
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="file_system")
def aws_credentials(monkeypatch): def fixture_file_system(efs):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz") create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata") create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp yield create_fs_resp

View File

@ -1,11 +1,10 @@
import re import re
import boto3
import pytest import pytest
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_efs
from tests.test_efs.junk_drawer import has_status_code from tests.test_efs.junk_drawer import has_status_code
from . import fixture_efs # noqa # pylint: disable=unused-import
ARN_PATT = r"^arn:(?P<Partition>[^:\n]*):(?P<Service>[^:\n]*):(?P<Region>[^:\n]*):(?P<AccountID>[^:\n]*):(?P<Ignore>(?P<ResourceType>[^:\/\n]*)[:\/])?(?P<Resource>.*)$" ARN_PATT = r"^arn:(?P<Partition>[^:\n]*):(?P<Service>[^:\n]*):(?P<Region>[^:\n]*):(?P<AccountID>[^:\n]*):(?P<Ignore>(?P<ResourceType>[^:\/\n]*)[:\/])?(?P<Resource>.*)$"
STRICT_ARN_PATT = r"^arn:aws:[a-z]+:[a-z]{2}-[a-z]+-[0-9]:[0-9]+:[a-z-]+\/[a-z0-9-]+$" STRICT_ARN_PATT = r"^arn:aws:[a-z]+:[a-z]{2}-[a-z]+-[0-9]:[0-9]+:[a-z-]+\/[a-z0-9-]+$"
@ -30,21 +29,6 @@ SAMPLE_2_PARAMS = {
} }
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
# Testing Create # Testing Create
# ============== # ==============

View File

@ -1,22 +1,4 @@
import boto3 from . import fixture_efs # noqa # pylint: disable=unused-import
import pytest
from moto import mock_efs
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
def test_list_tags_for_resource__without_tags(efs): def test_list_tags_for_resource__without_tags(efs):

View File

@ -1,23 +1,7 @@
import boto3
import pytest import pytest
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_efs from . import fixture_efs # noqa # pylint: disable=unused-import
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
def test_describe_filesystem_config__unknown(efs): def test_describe_filesystem_config__unknown(efs):

View File

@ -2,13 +2,12 @@ import re
import sys import sys
from ipaddress import IPv4Network from ipaddress import IPv4Network
import boto3
import pytest import pytest
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_ec2, mock_efs
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from tests.test_efs.junk_drawer import has_status_code from tests.test_efs.junk_drawer import has_status_code
from . import fixture_ec2, fixture_efs # noqa # pylint: disable=unused-import
# Handle the fact that `subnet_of` is not a feature before 3.7. # Handle the fact that `subnet_of` is not a feature before 3.7.
@ -36,36 +35,15 @@ else:
) )
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="file_system")
def aws_credentials(monkeypatch): def fixture_file_system(efs):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def ec2(aws_credentials): # pylint: disable=unused-argument
with mock_ec2():
yield boto3.client("ec2", region_name="us-east-1")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz") create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata") create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp yield create_fs_resp
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="subnet")
def subnet(ec2): def fixture_subnet(ec2):
desc_sn_resp = ec2.describe_subnets() desc_sn_resp = ec2.describe_subnets()
subnet = desc_sn_resp["Subnets"][0] subnet = desc_sn_resp["Subnets"][0]
yield subnet yield subnet

View File

@ -1,40 +1,18 @@
import boto3
import pytest import pytest
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_ec2, mock_efs from . import fixture_ec2, fixture_efs # noqa # pylint: disable=unused-import
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="file_system")
def aws_credentials(monkeypatch): def fixture_file_system(efs):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def ec2(aws_credentials): # pylint: disable=unused-argument
with mock_ec2():
yield boto3.client("ec2", region_name="us-east-1")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz") create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata") create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp yield create_fs_resp
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="subnet")
def subnet(ec2): def fixture_subnet(ec2):
desc_sn_resp = ec2.describe_subnets() desc_sn_resp = ec2.describe_subnets()
subnet = desc_sn_resp["Subnets"][0] subnet = desc_sn_resp["Subnets"][0]
yield subnet yield subnet

View File

@ -9,8 +9,8 @@ FILE_SYSTEMS = "/2015-02-01/file-systems"
MOUNT_TARGETS = "/2015-02-01/mount-targets" MOUNT_TARGETS = "/2015-02-01/mount-targets"
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="aws_credentials")
def aws_credentials(monkeypatch): def fixture_aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto.""" """Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing") monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
@ -18,14 +18,14 @@ def aws_credentials(monkeypatch):
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing") monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="efs_client")
def efs_client(aws_credentials): # pylint: disable=unused-argument def fixture_efs_client(aws_credentials): # pylint: disable=unused-argument
with mock_efs(): with mock_efs():
yield server.create_backend_app("efs").test_client() yield server.create_backend_app("efs").test_client()
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="subnet_id")
def subnet_id(aws_credentials): # pylint: disable=unused-argument def fixture_subnet_id(aws_credentials): # pylint: disable=unused-argument
with mock_ec2(): with mock_ec2():
ec2_client = server.create_backend_app("ec2").test_client() ec2_client = server.create_backend_app("ec2").test_client()
resp = ec2_client.get("/?Action=DescribeSubnets") resp = ec2_client.get("/?Action=DescribeSubnets")
@ -33,8 +33,8 @@ def subnet_id(aws_credentials): # pylint: disable=unused-argument
yield subnet_ids[0] yield subnet_ids[0]
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="file_system_id")
def file_system_id(efs_client): def fixture_file_system_id(efs_client):
resp = efs_client.post( resp = efs_client.post(
FILE_SYSTEMS, json={"CreationToken": "foobarbaz", "Backup": True} FILE_SYSTEMS, json={"CreationToken": "foobarbaz", "Backup": True}
) )

View File

@ -72,8 +72,8 @@ from .test_eks_utils import (
) )
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="ClusterBuilder")
def ClusterBuilder(): def fixture_ClusterBuilder():
class ClusterTestDataFactory: class ClusterTestDataFactory:
def __init__(self, client, count, minimal): def __init__(self, client, count, minimal):
# Generate 'count' number of random Cluster objects. # Generate 'count' number of random Cluster objects.
@ -104,8 +104,8 @@ def ClusterBuilder():
yield _execute yield _execute
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="FargateProfileBuilder")
def FargateProfileBuilder(ClusterBuilder): def fixture_FargateProfileBuilder(ClusterBuilder):
class FargateProfileTestDataFactory: class FargateProfileTestDataFactory:
def __init__(self, client, cluster, count, minimal): def __init__(self, client, cluster, count, minimal):
self.cluster_name = cluster.existing_cluster_name self.cluster_name = cluster.existing_cluster_name
@ -142,8 +142,8 @@ def FargateProfileBuilder(ClusterBuilder):
return _execute return _execute
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="NodegroupBuilder")
def NodegroupBuilder(ClusterBuilder): def fixture_NodegroupBuilder(ClusterBuilder):
class NodegroupTestDataFactory: class NodegroupTestDataFactory:
def __init__(self, client, cluster, count, minimal): def __init__(self, client, cluster, count, minimal):
self.cluster_name = cluster.existing_cluster_name self.cluster_name = cluster.existing_cluster_name

View File

@ -79,14 +79,14 @@ class TestNodegroup:
] ]
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True, name="test_client")
def test_client(): def fixture_test_client():
backend = server.create_backend_app(service=SERVICE) backend = server.create_backend_app(service=SERVICE)
yield backend.test_client() yield backend.test_client()
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="create_cluster")
def create_cluster(test_client): def fixtue_create_cluster(test_client):
def create_and_verify_cluster(client, name): def create_and_verify_cluster(client, name):
"""Creates one valid cluster and verifies return status code 200.""" """Creates one valid cluster and verifies return status code 200."""
data = deepcopy(TestCluster.data) data = deepcopy(TestCluster.data)
@ -106,8 +106,8 @@ def create_cluster(test_client):
yield _execute yield _execute
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True, name="create_nodegroup")
def create_nodegroup(test_client): def fixture_create_nodegroup(test_client):
def create_and_verify_nodegroup(client, name): def create_and_verify_nodegroup(client, name):
"""Creates one valid nodegroup and verifies return status code 200.""" """Creates one valid nodegroup and verifies return status code 200."""
data = deepcopy(TestNodegroup.data) data = deepcopy(TestNodegroup.data)

View File

@ -610,10 +610,8 @@ def test_put_remove_auto_scaling_policy():
("AutoScalingPolicy" not in core_instance_group).should.equal(True) ("AutoScalingPolicy" not in core_instance_group).should.equal(True)
def _patch_cluster_id_placeholder_in_autoscaling_policy( def _patch_cluster_id_placeholder_in_autoscaling_policy(policy, cluster_id):
auto_scaling_policy, cluster_id policy_copy = deepcopy(policy)
):
policy_copy = deepcopy(auto_scaling_policy)
for rule in policy_copy["Rules"]: for rule in policy_copy["Rules"]:
for dimension in rule["Trigger"]["CloudWatchAlarmDefinition"]["Dimensions"]: for dimension in rule["Trigger"]["CloudWatchAlarmDefinition"]["Dimensions"]:
dimension["Value"] = cluster_id dimension["Value"] = cluster_id

View File

@ -15,14 +15,14 @@ from unittest.mock import patch
from moto.emrcontainers import REGION as DEFAULT_REGION from moto.emrcontainers import REGION as DEFAULT_REGION
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="client")
def client(): def fixture_client():
with mock_emrcontainers(): with mock_emrcontainers():
yield boto3.client("emr-containers", region_name=DEFAULT_REGION) yield boto3.client("emr-containers", region_name=DEFAULT_REGION)
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="virtual_cluster_factory")
def virtual_cluster_factory(client): def fixture_virtual_cluster_factory(client):
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
raise SkipTest("Cant manipulate time in server mode") raise SkipTest("Cant manipulate time in server mode")
@ -47,8 +47,8 @@ def virtual_cluster_factory(client):
yield cluster_list yield cluster_list
@pytest.fixture() @pytest.fixture(name="job_factory")
def job_factory(client, virtual_cluster_factory): def fixture_job_factory(client, virtual_cluster_factory):
virtual_cluster_id = virtual_cluster_factory[0] virtual_cluster_id = virtual_cluster_factory[0]
default_job_driver = { default_job_driver = {
"sparkSubmitJobDriver": { "sparkSubmitJobDriver": {

View File

@ -19,14 +19,14 @@ def does_not_raise():
yield yield
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="client")
def client(): def fixture_client():
with mock_emrserverless(): with mock_emrserverless():
yield boto3.client("emr-serverless", region_name=DEFAULT_REGION) yield boto3.client("emr-serverless", region_name=DEFAULT_REGION)
@pytest.fixture(scope="function") @pytest.fixture(scope="function", name="application_factory")
def application_factory(client): def fixture_application_factory(client):
application_list = [] application_list = []
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:

View File

@ -42,8 +42,8 @@ from .fixtures.schema_registry import (
) )
@pytest.fixture @pytest.fixture(name="client")
def client(): def fixture_client():
with mock_glue(): with mock_glue():
yield boto3.client("glue", region_name="us-east-1") yield boto3.client("glue", region_name="us-east-1")

View File

@ -6,19 +6,19 @@ from botocore.exceptions import ClientError
from moto import mock_iot, mock_cognitoidentity from moto import mock_iot, mock_cognitoidentity
@pytest.fixture @pytest.fixture(name="region_name")
def region_name(): def fixture_region_name():
return "ap-northeast-1" return "ap-northeast-1"
@pytest.fixture @pytest.fixture(name="iot_client")
def iot_client(region_name): def fixture_iot_client(region_name):
with mock_iot(): with mock_iot():
yield boto3.client("iot", region_name=region_name) yield boto3.client("iot", region_name=region_name)
@pytest.fixture @pytest.fixture(name="policy")
def policy(iot_client): def fixture_policy(iot_client):
return iot_client.create_policy(policyName="my-policy", policyDocument="{}") return iot_client.create_policy(policyName="my-policy", policyDocument="{}")

View File

@ -6,13 +6,13 @@ PLAINTEXT = b"text"
REGION = "us-east-1" REGION = "us-east-1"
@pytest.fixture @pytest.fixture(name="backend")
def backend(): def fixture_backend():
return KmsBackend(REGION) return KmsBackend(REGION)
@pytest.fixture @pytest.fixture(name="key")
def key(backend): def fixture_key(backend):
return backend.create_key( return backend.create_key(
None, "ENCRYPT_DECRYPT", "SYMMETRIC_DEFAULT", "Test key", None None, "ENCRYPT_DECRYPT", "SYMMETRIC_DEFAULT", "Test key", None
) )

View File

@ -16,13 +16,11 @@ from moto.logs.models import MAX_RESOURCE_POLICIES_PER_REGION
TEST_REGION = "us-east-1" if settings.TEST_SERVER_MODE else "us-west-2" TEST_REGION = "us-east-1" if settings.TEST_SERVER_MODE else "us-west-2"
@pytest.fixture
def json_policy_doc():
"""Returns a policy document in JSON format. """Returns a policy document in JSON format.
The ARN is bogus, but that shouldn't matter for the test. The ARN is bogus, but that shouldn't matter for the test.
""" """
return json.dumps( json_policy_doc = json.dumps(
{ {
"Version": "2012-10-17", "Version": "2012-10-17",
"Statement": [ "Statement": [
@ -589,7 +587,7 @@ def test_put_resource_policy():
@mock_logs @mock_logs
def test_put_resource_policy_too_many(json_policy_doc): def test_put_resource_policy_too_many():
client = boto3.client("logs", TEST_REGION) client = boto3.client("logs", TEST_REGION)
# Create the maximum number of resource policies. # Create the maximum number of resource policies.
@ -617,7 +615,7 @@ def test_put_resource_policy_too_many(json_policy_doc):
@mock_logs @mock_logs
def test_delete_resource_policy(json_policy_doc): def test_delete_resource_policy():
client = boto3.client("logs", TEST_REGION) client = boto3.client("logs", TEST_REGION)
# Create a bunch of resource policies so we can give delete a workout. # Create a bunch of resource policies so we can give delete a workout.
@ -649,7 +647,7 @@ def test_delete_resource_policy(json_policy_doc):
@mock_logs @mock_logs
def test_describe_resource_policies(json_policy_doc): def test_describe_resource_policies():
client = boto3.client("logs", TEST_REGION) client = boto3.client("logs", TEST_REGION)
# Create the maximum number of resource policies so there's something # Create the maximum number of resource policies so there's something

View File

@ -13,8 +13,8 @@ INVALID_ID_ERROR_MESSAGE = (
RESOURCE_NOT_FOUND_ERROR_MESSAGE = "Query does not exist." RESOURCE_NOT_FOUND_ERROR_MESSAGE = "Query does not exist."
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True, name="client")
def client(): def fixture_client():
yield boto3.client("redshift-data", region_name=REGION) yield boto3.client("redshift-data", region_name=REGION)

View File

@ -18,8 +18,8 @@ def headers(action):
} }
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True, name="client")
def client(): def fixture_client():
backend = server.create_backend_app("redshift-data") backend = server.create_backend_app("redshift-data")
yield backend.test_client() yield backend.test_client()

View File

@ -42,9 +42,10 @@ TEST_SERVERLESS_PRODUCTION_VARIANTS = [
] ]
@pytest.fixture @pytest.fixture(name="sagemaker_client")
def sagemaker_client(): def fixture_sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME) with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
def create_endpoint_config_helper(sagemaker_client, production_variants): def create_endpoint_config_helper(sagemaker_client, production_variants):
@ -72,7 +73,6 @@ def create_endpoint_config_helper(sagemaker_client, production_variants):
resp["ProductionVariants"].should.equal(production_variants) resp["ProductionVariants"].should.equal(production_variants)
@mock_sagemaker
def test_create_endpoint_config(sagemaker_client): def test_create_endpoint_config(sagemaker_client):
with pytest.raises(ClientError) as e: with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config( sagemaker_client.create_endpoint_config(
@ -85,7 +85,6 @@ def test_create_endpoint_config(sagemaker_client):
create_endpoint_config_helper(sagemaker_client, TEST_PRODUCTION_VARIANTS) create_endpoint_config_helper(sagemaker_client, TEST_PRODUCTION_VARIANTS)
@mock_sagemaker
def test_create_endpoint_config_serverless(sagemaker_client): def test_create_endpoint_config_serverless(sagemaker_client):
with pytest.raises(ClientError) as e: with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config( sagemaker_client.create_endpoint_config(
@ -98,7 +97,6 @@ def test_create_endpoint_config_serverless(sagemaker_client):
create_endpoint_config_helper(sagemaker_client, TEST_SERVERLESS_PRODUCTION_VARIANTS) create_endpoint_config_helper(sagemaker_client, TEST_SERVERLESS_PRODUCTION_VARIANTS)
@mock_sagemaker
def test_delete_endpoint_config(sagemaker_client): def test_delete_endpoint_config(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME) _create_model(sagemaker_client, TEST_MODEL_NAME)
resp = sagemaker_client.create_endpoint_config( resp = sagemaker_client.create_endpoint_config(
@ -140,7 +138,6 @@ def test_delete_endpoint_config(sagemaker_client):
) )
@mock_sagemaker
def test_create_endpoint_invalid_instance_type(sagemaker_client): def test_create_endpoint_invalid_instance_type(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME) _create_model(sagemaker_client, TEST_MODEL_NAME)
@ -160,7 +157,6 @@ def test_create_endpoint_invalid_instance_type(sagemaker_client):
assert expected_message in e.value.response["Error"]["Message"] assert expected_message in e.value.response["Error"]["Message"]
@mock_sagemaker
def test_create_endpoint_invalid_memory_size(sagemaker_client): def test_create_endpoint_invalid_memory_size(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME) _create_model(sagemaker_client, TEST_MODEL_NAME)
@ -180,7 +176,6 @@ def test_create_endpoint_invalid_memory_size(sagemaker_client):
assert expected_message in e.value.response["Error"]["Message"] assert expected_message in e.value.response["Error"]["Message"]
@mock_sagemaker
def test_create_endpoint(sagemaker_client): def test_create_endpoint(sagemaker_client):
with pytest.raises(ClientError) as e: with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint( sagemaker_client.create_endpoint(
@ -221,7 +216,6 @@ def test_create_endpoint(sagemaker_client):
assert resp["Tags"] == GENERIC_TAGS_PARAM assert resp["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_delete_endpoint(sagemaker_client): def test_delete_endpoint(sagemaker_client):
_set_up_sagemaker_resources( _set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -237,7 +231,6 @@ def test_delete_endpoint(sagemaker_client):
assert e.value.response["Error"]["Message"].startswith("Could not find endpoint") assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
@mock_sagemaker
def test_add_tags_endpoint(sagemaker_client): def test_add_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources( _set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -253,7 +246,6 @@ def test_add_tags_endpoint(sagemaker_client):
assert response["Tags"] == GENERIC_TAGS_PARAM assert response["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_delete_tags_endpoint(sagemaker_client): def test_delete_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources( _set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -273,7 +265,6 @@ def test_delete_tags_endpoint(sagemaker_client):
assert response["Tags"] == [] assert response["Tags"] == []
@mock_sagemaker
def test_list_tags_endpoint(sagemaker_client): def test_list_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources( _set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -298,7 +289,6 @@ def test_list_tags_endpoint(sagemaker_client):
assert response["Tags"] == tags[50:] assert response["Tags"] == tags[50:]
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client): def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client):
_set_up_sagemaker_resources( _set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -342,7 +332,6 @@ def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client):
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight) resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client): def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
production_variants = [ production_variants = [
{ {
@ -422,7 +411,6 @@ def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
resp["ProductionVariants"][1]["CurrentWeight"].should.equal(new_desired_weight) resp["ProductionVariants"][1]["CurrentWeight"].should.equal(new_desired_weight)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant( def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant(
sagemaker_client, sagemaker_client,
): ):
@ -459,7 +447,6 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_vari
resp.should.equal(old_resp) resp.should.equal(old_resp)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint( def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint(
sagemaker_client, sagemaker_client,
): ):
@ -497,7 +484,6 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endp
resp.should.equal(old_resp) resp.should.equal(old_resp)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant( def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant(
sagemaker_client, sagemaker_client,
): ):

View File

@ -8,12 +8,12 @@ TEST_REGION_NAME = "us-east-1"
TEST_EXPERIMENT_NAME = "MyExperimentName" TEST_EXPERIMENT_NAME = "MyExperimentName"
@pytest.fixture @pytest.fixture(name="sagemaker_client")
def sagemaker_client(): def fixture_sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME) with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@mock_sagemaker
def test_create_experiment(sagemaker_client): def test_create_experiment(sagemaker_client):
resp = sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) resp = sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
@ -29,7 +29,6 @@ def test_create_experiment(sagemaker_client):
) )
@mock_sagemaker
def test_list_experiments(sagemaker_client): def test_list_experiments(sagemaker_client):
experiment_names = [f"some-experiment-name-{i}" for i in range(10)] experiment_names = [f"some-experiment-name-{i}" for i in range(10)]
@ -57,7 +56,6 @@ def test_list_experiments(sagemaker_client):
assert resp.get("NextToken") is None assert resp.get("NextToken") is None
@mock_sagemaker
def test_delete_experiment(sagemaker_client): def test_delete_experiment(sagemaker_client):
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
@ -70,7 +68,6 @@ def test_delete_experiment(sagemaker_client):
assert len(resp["ExperimentSummaries"]) == 0 assert len(resp["ExperimentSummaries"]) == 0
@mock_sagemaker
def test_add_tags_to_experiment(sagemaker_client): def test_add_tags_to_experiment(sagemaker_client):
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
@ -89,7 +86,6 @@ def test_add_tags_to_experiment(sagemaker_client):
assert resp["Tags"] == tags assert resp["Tags"] == tags
@mock_sagemaker
def test_delete_tags_to_experiment(sagemaker_client): def test_delete_tags_to_experiment(sagemaker_client):
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)

View File

@ -12,9 +12,10 @@ TEST_ARN = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
TEST_MODEL_NAME = "MyModelName" TEST_MODEL_NAME = "MyModelName"
@pytest.fixture @pytest.fixture(name="sagemaker_client")
def sagemaker_client(): def fixture_sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME) with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MySageMakerModel(object): class MySageMakerModel(object):
@ -36,7 +37,6 @@ class MySageMakerModel(object):
return resp return resp
@mock_sagemaker
def test_describe_model(sagemaker_client): def test_describe_model(sagemaker_client):
test_model = MySageMakerModel() test_model = MySageMakerModel()
test_model.save(sagemaker_client) test_model.save(sagemaker_client)
@ -44,14 +44,12 @@ def test_describe_model(sagemaker_client):
assert model.get("ModelName").should.equal(TEST_MODEL_NAME) assert model.get("ModelName").should.equal(TEST_MODEL_NAME)
@mock_sagemaker
def test_describe_model_not_found(sagemaker_client): def test_describe_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err: with pytest.raises(ClientError) as err:
sagemaker_client.describe_model(ModelName="unknown") sagemaker_client.describe_model(ModelName="unknown")
assert err.value.response["Error"]["Message"].should.contain("Could not find model") assert err.value.response["Error"]["Message"].should.contain("Could not find model")
@mock_sagemaker
def test_create_model(sagemaker_client): def test_create_model(sagemaker_client):
vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"]) vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"])
model = sagemaker_client.create_model( model = sagemaker_client.create_model(
@ -64,7 +62,6 @@ def test_create_model(sagemaker_client):
) )
@mock_sagemaker
def test_delete_model(sagemaker_client): def test_delete_model(sagemaker_client):
test_model = MySageMakerModel() test_model = MySageMakerModel()
test_model.save(sagemaker_client) test_model.save(sagemaker_client)
@ -74,14 +71,12 @@ def test_delete_model(sagemaker_client):
assert len(sagemaker_client.list_models()["Models"]).should.equal(0) assert len(sagemaker_client.list_models()["Models"]).should.equal(0)
@mock_sagemaker
def test_delete_model_not_found(sagemaker_client): def test_delete_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err: with pytest.raises(ClientError) as err:
sagemaker_client.delete_model(ModelName="blah") sagemaker_client.delete_model(ModelName="blah")
assert err.value.response["Error"]["Code"].should.equal("404") assert err.value.response["Error"]["Code"].should.equal("404")
@mock_sagemaker
def test_list_models(sagemaker_client): def test_list_models(sagemaker_client):
test_model = MySageMakerModel() test_model = MySageMakerModel()
test_model.save(sagemaker_client) test_model.save(sagemaker_client)
@ -93,7 +88,6 @@ def test_list_models(sagemaker_client):
) )
@mock_sagemaker
def test_list_models_multiple(sagemaker_client): def test_list_models_multiple(sagemaker_client):
name_model_1 = "blah" name_model_1 = "blah"
arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar" arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
@ -108,13 +102,11 @@ def test_list_models_multiple(sagemaker_client):
assert len(models["Models"]).should.equal(2) assert len(models["Models"]).should.equal(2)
@mock_sagemaker
def test_list_models_none(sagemaker_client): def test_list_models_none(sagemaker_client):
models = sagemaker_client.list_models() models = sagemaker_client.list_models()
assert len(models["Models"]).should.equal(0) assert len(models["Models"]).should.equal(0)
@mock_sagemaker
def test_add_tags_to_model(sagemaker_client): def test_add_tags_to_model(sagemaker_client):
model = MySageMakerModel().save(sagemaker_client) model = MySageMakerModel().save(sagemaker_client)
resource_arn = model["ModelArn"] resource_arn = model["ModelArn"]
@ -129,7 +121,6 @@ def test_add_tags_to_model(sagemaker_client):
assert response["Tags"] == tags assert response["Tags"] == tags
@mock_sagemaker
def test_delete_tags_from_model(sagemaker_client): def test_delete_tags_from_model(sagemaker_client):
model = MySageMakerModel().save(sagemaker_client) model = MySageMakerModel().save(sagemaker_client)
resource_arn = model["ModelArn"] resource_arn = model["ModelArn"]

View File

@ -26,9 +26,10 @@ FAKE_NAME_PARAM = "MyNotebookInstance"
FAKE_INSTANCE_TYPE_PARAM = "ml.t2.medium" FAKE_INSTANCE_TYPE_PARAM = "ml.t2.medium"
@pytest.fixture @pytest.fixture(name="sagemaker_client")
def sagemaker_client(): def fixture_sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME) with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
def _get_notebook_instance_arn(notebook_name): def _get_notebook_instance_arn(notebook_name):
@ -39,7 +40,6 @@ def _get_notebook_instance_lifecycle_arn(lifecycle_name):
return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance-lifecycle-configuration/{lifecycle_name}" return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance-lifecycle-configuration/{lifecycle_name}"
@mock_sagemaker
def test_create_notebook_instance_minimal_params(sagemaker_client): def test_create_notebook_instance_minimal_params(sagemaker_client):
args = { args = {
"NotebookInstanceName": FAKE_NAME_PARAM, "NotebookInstanceName": FAKE_NAME_PARAM,
@ -68,7 +68,6 @@ def test_create_notebook_instance_minimal_params(sagemaker_client):
# assert resp["RootAccess"] == True # ToDo: Not sure if this defaults... # assert resp["RootAccess"] == True # ToDo: Not sure if this defaults...
@mock_sagemaker
def test_create_notebook_instance_params(sagemaker_client): def test_create_notebook_instance_params(sagemaker_client):
fake_direct_internet_access_param = "Enabled" fake_direct_internet_access_param = "Enabled"
volume_size_in_gb_param = 7 volume_size_in_gb_param = 7
@ -121,7 +120,6 @@ def test_create_notebook_instance_params(sagemaker_client):
assert resp["Tags"] == GENERIC_TAGS_PARAM assert resp["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_create_notebook_instance_invalid_instance_type(sagemaker_client): def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
instance_type = "undefined_instance_type" instance_type = "undefined_instance_type"
args = { args = {
@ -139,7 +137,6 @@ def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
assert expected_message in ex.value.response["Error"]["Message"] assert expected_message in ex.value.response["Error"]["Message"]
@mock_sagemaker
def test_notebook_instance_lifecycle(sagemaker_client): def test_notebook_instance_lifecycle(sagemaker_client):
args = { args = {
"NotebookInstanceName": FAKE_NAME_PARAM, "NotebookInstanceName": FAKE_NAME_PARAM,
@ -193,14 +190,12 @@ def test_notebook_instance_lifecycle(sagemaker_client):
assert ex.value.response["Error"]["Message"] == "RecordNotFound" assert ex.value.response["Error"]["Message"] == "RecordNotFound"
@mock_sagemaker
def test_describe_nonexistent_model(sagemaker_client): def test_describe_nonexistent_model(sagemaker_client):
with pytest.raises(ClientError) as e: with pytest.raises(ClientError) as e:
sagemaker_client.describe_model(ModelName="Nonexistent") sagemaker_client.describe_model(ModelName="Nonexistent")
assert e.value.response["Error"]["Message"].startswith("Could not find model") assert e.value.response["Error"]["Message"].startswith("Could not find model")
@mock_sagemaker
def test_notebook_instance_lifecycle_config(sagemaker_client): def test_notebook_instance_lifecycle_config(sagemaker_client):
name = "MyLifeCycleConfig" name = "MyLifeCycleConfig"
on_create = [{"Content": "Create Script Line 1"}] on_create = [{"Content": "Create Script Line 1"}]
@ -252,7 +247,6 @@ def test_notebook_instance_lifecycle_config(sagemaker_client):
) )
@mock_sagemaker
def test_add_tags_to_notebook(sagemaker_client): def test_add_tags_to_notebook(sagemaker_client):
args = { args = {
"NotebookInstanceName": FAKE_NAME_PARAM, "NotebookInstanceName": FAKE_NAME_PARAM,
@ -272,7 +266,6 @@ def test_add_tags_to_notebook(sagemaker_client):
assert response["Tags"] == tags assert response["Tags"] == tags
@mock_sagemaker
def test_delete_tags_from_notebook(sagemaker_client): def test_delete_tags_from_notebook(sagemaker_client):
args = { args = {
"NotebookInstanceName": FAKE_NAME_PARAM, "NotebookInstanceName": FAKE_NAME_PARAM,

View File

@ -12,9 +12,10 @@ FAKE_CONTAINER = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
TEST_REGION_NAME = "us-east-1" TEST_REGION_NAME = "us-east-1"
@pytest.fixture @pytest.fixture(name="sagemaker_client")
def sagemaker_client(): def fixture_sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME) with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MyProcessingJobModel(object): class MyProcessingJobModel(object):
@ -103,7 +104,6 @@ class MyProcessingJobModel(object):
return sagemaker_client.create_processing_job(**params) return sagemaker_client.create_processing_job(**params)
@mock_sagemaker
def test_create_processing_job(sagemaker_client): def test_create_processing_job(sagemaker_client):
bucket = "my-bucket" bucket = "my-bucket"
prefix = "my-prefix" prefix = "my-prefix"
@ -150,7 +150,6 @@ def test_create_processing_job(sagemaker_client):
assert isinstance(resp["LastModifiedTime"], datetime.datetime) assert isinstance(resp["LastModifiedTime"], datetime.datetime)
@mock_sagemaker
def test_list_processing_jobs(sagemaker_client): def test_list_processing_jobs(sagemaker_client):
test_processing_job = MyProcessingJobModel( test_processing_job = MyProcessingJobModel(
processing_job_name=FAKE_PROCESSING_JOB_NAME, role_arn=FAKE_ROLE_ARN processing_job_name=FAKE_PROCESSING_JOB_NAME, role_arn=FAKE_ROLE_ARN
@ -170,7 +169,6 @@ def test_list_processing_jobs(sagemaker_client):
assert processing_jobs.get("NextToken") is None assert processing_jobs.get("NextToken") is None
@mock_sagemaker
def test_list_processing_jobs_multiple(sagemaker_client): def test_list_processing_jobs_multiple(sagemaker_client):
name_job_1 = "blah" name_job_1 = "blah"
arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
@ -193,13 +191,11 @@ def test_list_processing_jobs_multiple(sagemaker_client):
assert processing_jobs.get("NextToken").should.be.none assert processing_jobs.get("NextToken").should.be.none
@mock_sagemaker
def test_list_processing_jobs_none(sagemaker_client): def test_list_processing_jobs_none(sagemaker_client):
processing_jobs = sagemaker_client.list_processing_jobs() processing_jobs = sagemaker_client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0) assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0)
@mock_sagemaker
def test_list_processing_jobs_should_validate_input(sagemaker_client): def test_list_processing_jobs_should_validate_input(sagemaker_client):
junk_status_equals = "blah" junk_status_equals = "blah"
with pytest.raises(ClientError) as ex: with pytest.raises(ClientError) as ex:
@ -218,7 +214,6 @@ def test_list_processing_jobs_should_validate_input(sagemaker_client):
) )
@mock_sagemaker
def test_list_processing_jobs_with_name_filters(sagemaker_client): def test_list_processing_jobs_with_name_filters(sagemaker_client):
for i in range(5): for i in range(5):
name = "xgboost-{}".format(i) name = "xgboost-{}".format(i)
@ -243,7 +238,6 @@ def test_list_processing_jobs_with_name_filters(sagemaker_client):
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2) assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
@mock_sagemaker
def test_list_processing_jobs_paginated(sagemaker_client): def test_list_processing_jobs_paginated(sagemaker_client):
for i in range(5): for i in range(5):
name = "xgboost-{}".format(i) name = "xgboost-{}".format(i)
@ -273,7 +267,6 @@ def test_list_processing_jobs_paginated(sagemaker_client):
assert xgboost_processing_job_next.get("NextToken").should_not.be.none assert xgboost_processing_job_next.get("NextToken").should_not.be.none
@mock_sagemaker
def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client): def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
for i in range(5): for i in range(5):
name = "xgboost-{}".format(i) name = "xgboost-{}".format(i)
@ -316,7 +309,6 @@ def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
assert vgg_processing_job_10.get("NextToken").should.be.none assert vgg_processing_job_10.get("NextToken").should.be.none
@mock_sagemaker
def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client): def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client):
for i in range(5): for i in range(5):
name = "xgboost-{}".format(i) name = "xgboost-{}".format(i)
@ -357,7 +349,6 @@ def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client
assert processing_jobs_with_2_next_next.get("NextToken").should.be.none assert processing_jobs_with_2_next_next.get("NextToken").should.be.none
@mock_sagemaker
def test_add_and_delete_tags_in_training_job(sagemaker_client): def test_add_and_delete_tags_in_training_job(sagemaker_client):
processing_job_name = "MyProcessingJob" processing_job_name = "MyProcessingJob"
role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)

View File

@ -8,12 +8,12 @@ from moto import mock_sagemaker
TEST_REGION_NAME = "us-east-1" TEST_REGION_NAME = "us-east-1"
@pytest.fixture @pytest.fixture(name="sagemaker_client")
def sagemaker_client(): def fixture_sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME) with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@mock_sagemaker
def test_search(sagemaker_client): def test_search(sagemaker_client):
experiment_name = "experiment_name" experiment_name = "experiment_name"
trial_component_name = "trial_component_name" trial_component_name = "trial_component_name"
@ -60,7 +60,6 @@ def test_search(sagemaker_client):
assert resp["Results"][0]["Trial"]["TrialName"] == trial_name assert resp["Results"][0]["Trial"]["TrialName"] == trial_name
@mock_sagemaker
def test_search_trial_component_with_experiment_name(sagemaker_client): def test_search_trial_component_with_experiment_name(sagemaker_client):
experiment_name = "experiment_name" experiment_name = "experiment_name"
trial_component_name = "trial_component_name" trial_component_name = "trial_component_name"

View File

@ -888,10 +888,10 @@ def test_state_machine_get_execution_history_contains_expected_success_events_wh
execution_history["events"].should.equal(expected_events) execution_history["events"].should.equal(expected_events)
@pytest.mark.parametrize("region", ["us-west-2", "cn-northwest-1"]) @pytest.mark.parametrize("test_region", ["us-west-2", "cn-northwest-1"])
@mock_stepfunctions @mock_stepfunctions
def test_stepfunction_regions(region): def test_stepfunction_regions(test_region):
client = boto3.client("stepfunctions", region_name=region) client = boto3.client("stepfunctions", region_name=test_region)
resp = client.list_state_machines() resp = client.list_state_machines()
resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)