TechDebt - enable pylint rule redefined-outer-scope (#5518)

This commit is contained in:
Bert Blommers 2022-10-04 16:28:30 +00:00 committed by GitHub
parent 696b809b5a
commit 4f84e2f154
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
46 changed files with 236 additions and 396 deletions

View File

@ -157,11 +157,11 @@ class _DockerDataVolumeContext:
raise # multiple processes trying to use same volume?
def _zipfile_content(zipfile):
def _zipfile_content(zipfile_content):
try:
to_unzip_code = base64.b64decode(bytes(zipfile, "utf-8"))
to_unzip_code = base64.b64decode(bytes(zipfile_content, "utf-8"))
except Exception:
to_unzip_code = base64.b64decode(zipfile)
to_unzip_code = base64.b64decode(zipfile_content)
sha_code = hashlib.sha256(to_unzip_code)
base64ed_sha = base64.b64encode(sha_code.digest()).decode("utf-8")

View File

@ -292,23 +292,23 @@ class CloudFrontBackend(BaseBackend):
return dist
return False
def update_distribution(self, DistributionConfig, Id, IfMatch):
def update_distribution(self, dist_config, _id, if_match):
"""
The IfMatch-value is ignored - any value is considered valid.
Calling this function without a value is invalid, per AWS' behaviour
"""
if Id not in self.distributions or Id is None:
if _id not in self.distributions or _id is None:
raise NoSuchDistribution
if not IfMatch:
if not if_match:
raise InvalidIfMatchVersion
if not DistributionConfig:
if not dist_config:
raise NoSuchDistribution
dist = self.distributions[Id]
dist = self.distributions[_id]
aliases = DistributionConfig["Aliases"]["Items"]["CNAME"]
dist.distribution_config.config = DistributionConfig
aliases = dist_config["Aliases"]["Items"]["CNAME"]
dist.distribution_config.config = dist_config
dist.distribution_config.aliases = aliases
self.distributions[Id] = dist
self.distributions[_id] = dist
dist.advance()
return dist, dist.location, dist.etag

View File

@ -83,9 +83,9 @@ class CloudFrontResponse(BaseResponse):
if_match = headers["If-Match"]
dist, location, e_tag = self.backend.update_distribution(
DistributionConfig=distribution_config,
Id=dist_id,
IfMatch=if_match,
dist_config=distribution_config,
_id=dist_id,
if_match=if_match,
)
template = self.response_template(UPDATE_DISTRIBUTION_TEMPLATE)
response = template.render(distribution=dist, xmlns=XMLNS)

View File

@ -146,28 +146,28 @@ class convert_flask_to_responses_response(object):
return status, headers, response
def iso_8601_datetime_with_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
def iso_8601_datetime_with_milliseconds(value):
return value.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
# Even Python does not support nanoseconds, other languages like Go do (needed for Terraform)
def iso_8601_datetime_with_nanoseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f000Z")
def iso_8601_datetime_with_nanoseconds(value):
return value.strftime("%Y-%m-%dT%H:%M:%S.%f000Z")
def iso_8601_datetime_without_milliseconds(datetime):
return None if datetime is None else datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
def iso_8601_datetime_without_milliseconds(value):
return None if value is None else value.strftime("%Y-%m-%dT%H:%M:%SZ")
def iso_8601_datetime_without_milliseconds_s3(datetime):
return None if datetime is None else datetime.strftime("%Y-%m-%dT%H:%M:%S.000Z")
def iso_8601_datetime_without_milliseconds_s3(value):
return None if value is None else value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
def rfc_1123_datetime(datetime):
return datetime.strftime(RFC1123)
def rfc_1123_datetime(src):
return src.strftime(RFC1123)
def str_to_rfc_1123_datetime(value):

View File

@ -12,14 +12,14 @@ INSTANCE_FAMILIES = list(set([i.split(".")[0] for i in INSTANCE_TYPES.keys()]))
root = pathlib.Path(__file__).parent
offerings_path = "../resources/instance_type_offerings"
INSTANCE_TYPE_OFFERINGS = {}
for location_type in listdir(root / offerings_path):
INSTANCE_TYPE_OFFERINGS[location_type] = {}
for _region in listdir(root / offerings_path / location_type):
full_path = offerings_path + "/" + location_type + "/" + _region
for _location_type in listdir(root / offerings_path):
INSTANCE_TYPE_OFFERINGS[_location_type] = {}
for _region in listdir(root / offerings_path / _location_type):
full_path = offerings_path + "/" + _location_type + "/" + _region
res = load_resource(__name__, full_path)
for instance in res:
instance["LocationType"] = location_type
INSTANCE_TYPE_OFFERINGS[location_type][_region.replace(".json", "")] = res
instance["LocationType"] = _location_type
INSTANCE_TYPE_OFFERINGS[_location_type][_region.replace(".json", "")] = res
class InstanceType(dict):

View File

@ -1021,19 +1021,12 @@ class SecurityGroupBackend:
if cidr_item.get("CidrIp6") == item.get("CidrIp6"):
cidr_item["Description"] = item.get("Description")
for item in security_rule.source_groups:
for group in security_rule.source_groups:
for source_group in rule.source_groups:
if source_group.get("GroupId") == item.get(
if source_group.get("GroupId") == group.get(
"GroupId"
) or source_group.get("GroupName") == item.get("GroupName"):
source_group["Description"] = item.get("Description")
for item in security_rule.source_groups:
for source_group in rule.source_groups:
if source_group.get("GroupId") == item.get(
"GroupId"
) or source_group.get("GroupName") == item.get("GroupName"):
source_group["Description"] = item.get("Description")
) or source_group.get("GroupName") == group.get("GroupName"):
source_group["Description"] = group.get("Description")
def _remove_items_from_rule(self, ip_ranges, _source_groups, prefix_list_ids, rule):
for item in ip_ranges:

View File

@ -297,42 +297,42 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
self.name = name
self.modification_time = datetime.now(timezone.utc).isoformat()
def associate_ip_address(self, ip_address):
self.ip_addresses.append(ip_address)
def associate_ip_address(self, value):
self.ip_addresses.append(value)
self.ip_address_count = len(self.ip_addresses)
eni_id = f"rni-{mock_random.get_random_hex(17)}"
self.subnets[ip_address["SubnetId"]][ip_address["Ip"]] = eni_id
self.subnets[value["SubnetId"]][value["Ip"]] = eni_id
eni_info = self.ec2_backend.create_network_interface(
description=f"Route 53 Resolver: {self.id}:{eni_id}",
group_ids=self.security_group_ids,
interface_type="interface",
private_ip_address=ip_address.get("Ip"),
private_ip_address=value.get("Ip"),
private_ip_addresses=[
{"Primary": True, "PrivateIpAddress": ip_address.get("Ip")}
{"Primary": True, "PrivateIpAddress": value.get("Ip")}
],
subnet=ip_address.get("SubnetId"),
subnet=value.get("SubnetId"),
)
self.eni_ids.append(eni_info.id)
def disassociate_ip_address(self, ip_address):
if not ip_address.get("Ip") and ip_address.get("IpId"):
for ip_addr, eni_id in self.subnets[ip_address.get("SubnetId")].items():
if ip_address.get("IpId") == eni_id:
ip_address["Ip"] = ip_addr
if ip_address.get("Ip"):
def disassociate_ip_address(self, value):
if not value.get("Ip") and value.get("IpId"):
for ip_addr, eni_id in self.subnets[value.get("SubnetId")].items():
if value.get("IpId") == eni_id:
value["Ip"] = ip_addr
if value.get("Ip"):
self.ip_addresses = list(
filter(lambda i: i["Ip"] != ip_address.get("Ip"), self.ip_addresses)
filter(lambda i: i["Ip"] != value.get("Ip"), self.ip_addresses)
)
if len(self.subnets[ip_address["SubnetId"]]) == 1:
self.subnets.pop(ip_address["SubnetId"])
if len(self.subnets[value["SubnetId"]]) == 1:
self.subnets.pop(value["SubnetId"])
else:
self.subnets[ip_address["SubnetId"]].pop(ip_address["Ip"])
self.subnets[value["SubnetId"]].pop(value["Ip"])
for eni_id in self.eni_ids:
eni_info = self.ec2_backend.get_network_interface(eni_id)
if eni_info.private_ip_address == ip_address.get("Ip"):
if eni_info.private_ip_address == value.get("Ip"):
self.ec2_backend.delete_network_interface(eni_id)
self.eni_ids.remove(eni_id)
self.ip_address_count = len(self.ip_addresses)
@ -873,32 +873,30 @@ class Route53ResolverBackend(BaseBackend):
resolver_endpoint.update_name(name)
return resolver_endpoint
def associate_resolver_endpoint_ip_address(self, resolver_endpoint_id, ip_address):
def associate_resolver_endpoint_ip_address(self, resolver_endpoint_id, value):
self._validate_resolver_endpoint_id(resolver_endpoint_id)
resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
if not ip_address.get("Ip"):
if not value.get("Ip"):
subnet_info = self.ec2_backend.get_all_subnets(
subnet_ids=[ip_address.get("SubnetId")]
subnet_ids=[value.get("SubnetId")]
)[0]
ip_address["Ip"] = subnet_info.get_available_subnet_ip(self)
self._verify_subnet_ips([ip_address], False)
value["Ip"] = subnet_info.get_available_subnet_ip(self)
self._verify_subnet_ips([value], False)
resolver_endpoint.associate_ip_address(ip_address)
resolver_endpoint.associate_ip_address(value)
return resolver_endpoint
def disassociate_resolver_endpoint_ip_address(
self, resolver_endpoint_id, ip_address
):
def disassociate_resolver_endpoint_ip_address(self, resolver_endpoint_id, value):
self._validate_resolver_endpoint_id(resolver_endpoint_id)
resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
if not (ip_address.get("Ip") or ip_address.get("IpId")):
if not (value.get("Ip") or value.get("IpId")):
raise InvalidRequestException(
"[RSLVR-00503] Need to specify either the IP ID or both subnet and IP address in order to remove IP address."
)
resolver_endpoint.disassociate_ip_address(ip_address)
resolver_endpoint.disassociate_ip_address(value)
return resolver_endpoint

View File

@ -266,7 +266,7 @@ class Route53ResolverResponse(BaseResponse):
resolver_endpoint = (
self.route53resolver_backend.associate_resolver_endpoint_ip_address(
resolver_endpoint_id=resolver_endpoint_id,
ip_address=ip_address,
value=ip_address,
)
)
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})
@ -278,7 +278,7 @@ class Route53ResolverResponse(BaseResponse):
resolver_endpoint = (
self.route53resolver_backend.disassociate_resolver_endpoint_ip_address(
resolver_endpoint_id=resolver_endpoint_id,
ip_address=ip_address,
value=ip_address,
)
)
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})

View File

@ -1,8 +1,8 @@
def name(secret, names):
def name_filter(secret, names):
return _matcher(names, [secret.name])
def description(secret, descriptions):
def description_filter(secret, descriptions):
return _matcher(descriptions, [secret.description])

View File

@ -18,13 +18,19 @@ from .exceptions import (
ClientError,
)
from .utils import random_password, secret_arn, get_secret_name_from_arn
from .list_secrets.filters import filter_all, tag_key, tag_value, description, name
from .list_secrets.filters import (
filter_all,
tag_key,
tag_value,
description_filter,
name_filter,
)
_filter_functions = {
"all": filter_all,
"name": name,
"description": description,
"name": name_filter,
"description": description_filter,
"tag-key": tag_key,
"tag-value": tag_value,
}

View File

@ -470,24 +470,24 @@ class SESBackend(BaseBackend):
text_part = str.replace(str(text_part), "{{%s}}" % key, value)
html_part = str.replace(str(html_part), "{{%s}}" % key, value)
email = MIMEMultipart("alternative")
email_obj = MIMEMultipart("alternative")
mime_text = MIMEBase("text", "plain;charset=UTF-8")
mime_text.set_payload(text_part.encode("utf-8"))
encode_7or8bit(mime_text)
email.attach(mime_text)
email_obj.attach(mime_text)
mime_html = MIMEBase("text", "html;charset=UTF-8")
mime_html.set_payload(html_part.encode("utf-8"))
encode_7or8bit(mime_html)
email.attach(mime_html)
email_obj.attach(mime_html)
now = datetime.datetime.now().isoformat()
rendered_template = "Date: %s\r\nSubject: %s\r\n%s" % (
now,
subject_part,
email.as_string(),
email_obj.as_string(),
)
return rendered_template

View File

@ -141,10 +141,10 @@ class Message(BaseModel):
)
@staticmethod
def utf8(string):
if isinstance(string, str):
return string.encode("utf-8")
return string
def utf8(value):
if isinstance(value, str):
return value.encode("utf-8")
return value
@property
def body(self):

View File

@ -16,11 +16,11 @@ class TelemetryRecords(BaseModel):
self.records = records
@classmethod
def from_json(cls, json):
instance_id = json.get("EC2InstanceId", None)
hostname = json.get("Hostname")
resource_arn = json.get("ResourceARN")
telemetry_records = json["TelemetryRecords"]
def from_json(cls, src):
instance_id = src.get("EC2InstanceId", None)
hostname = src.get("Hostname")
resource_arn = src.get("ResourceARN")
telemetry_records = src["TelemetryRecords"]
return cls(instance_id, hostname, resource_arn, telemetry_records)
@ -242,8 +242,8 @@ class XRayBackend(BaseBackend):
service_region, zones, "xray"
)
def add_telemetry_records(self, json):
self._telemetry_records.append(TelemetryRecords.from_json(json))
def add_telemetry_records(self, src):
self._telemetry_records.append(TelemetryRecords.from_json(src))
def process_segment(self, doc):
try:

View File

@ -14,5 +14,5 @@ ignore-paths=moto/packages
[pylint.'MESSAGES CONTROL']
disable = W,C,R,E
# future sensible checks = super-init-not-called, redefined-outer-name, unspecified-encoding, undefined-loop-variable
enable = arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
# future sensible checks = super-init-not-called, unspecified-encoding, undefined-loop-variable
enable = arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import

View File

@ -139,7 +139,7 @@ def test_event_source_mapping_create_from_cloudformation_json():
"FunctionArn"
]
template = event_source_mapping_template.substitute(
esm_template = event_source_mapping_template.substitute(
{
"resource_name": "Foo",
"batch_size": 1,
@ -149,7 +149,7 @@ def test_event_source_mapping_create_from_cloudformation_json():
}
)
cf.create_stack(StackName=random_stack_name(), TemplateBody=template)
cf.create_stack(StackName=random_stack_name(), TemplateBody=esm_template)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
event_sources["EventSourceMappings"].should.have.length_of(1)
@ -174,7 +174,7 @@ def test_event_source_mapping_delete_stack():
_, lambda_stack = create_stack(cf, s3)
created_fn_name = get_created_function_name(cf, lambda_stack)
template = event_source_mapping_template.substitute(
esm_template = event_source_mapping_template.substitute(
{
"resource_name": "Foo",
"batch_size": 1,
@ -184,7 +184,9 @@ def test_event_source_mapping_delete_stack():
}
)
esm_stack = cf.create_stack(StackName=random_stack_name(), TemplateBody=template)
esm_stack = cf.create_stack(
StackName=random_stack_name(), TemplateBody=esm_template
)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
event_sources["EventSourceMappings"].should.have.length_of(1)

View File

@ -18,8 +18,8 @@ def test_basic_decorator():
client.describe_addresses()["Addresses"].should.equal([])
@pytest.fixture
def aws_credentials(monkeypatch):
@pytest.fixture(name="aws_credentials")
def fixture_aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")

View File

@ -6,8 +6,8 @@ from moto import settings
from unittest import SkipTest
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
@pytest.fixture(scope="function", name="aws_credentials")
def fixture_aws_credentials(monkeypatch):
if settings.TEST_SERVER_MODE:
raise SkipTest("No point in testing this in ServerMode.")
"""Mocked AWS Credentials for moto."""

View File

@ -495,13 +495,13 @@ def test_creating_table_with_0_global_indexes():
@mock_dynamodb
def test_multiple_transactions_on_same_item():
table_schema = {
schema = {
"KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}],
"AttributeDefinitions": [{"AttributeName": "id", "AttributeType": "S"}],
}
dynamodb = boto3.client("dynamodb", region_name="us-east-1")
dynamodb.create_table(
TableName="test-table", BillingMode="PAY_PER_REQUEST", **table_schema
TableName="test-table", BillingMode="PAY_PER_REQUEST", **schema
)
# Insert an item
dynamodb.put_item(TableName="test-table", Item={"id": {"S": "foo"}})
@ -533,13 +533,13 @@ def test_multiple_transactions_on_same_item():
@mock_dynamodb
def test_transact_write_items__too_many_transactions():
table_schema = {
schema = {
"KeySchema": [{"AttributeName": "pk", "KeyType": "HASH"}],
"AttributeDefinitions": [{"AttributeName": "pk", "AttributeType": "S"}],
}
dynamodb = boto3.client("dynamodb", region_name="us-east-1")
dynamodb.create_table(
TableName="test-table", BillingMode="PAY_PER_REQUEST", **table_schema
TableName="test-table", BillingMode="PAY_PER_REQUEST", **schema
)
def update_email_transact(email):

View File

@ -15,8 +15,8 @@ TABLE_NAME = "my_table_name"
TABLE_WITH_RANGE_NAME = "my_table_with_range_name"
@pytest.fixture(autouse=True)
def test_client():
@pytest.fixture(autouse=True, name="test_client")
def fixture_test_client():
backend = server.create_backend_app("dynamodb_v20111205")
test_client = backend.test_client()

View File

@ -0,0 +1,16 @@
import boto3
import pytest
from moto import mock_ec2, mock_efs
@pytest.fixture(scope="function", name="ec2")
def fixture_ec2():
with mock_ec2():
yield boto3.client("ec2", region_name="us-east-1")
@pytest.fixture(scope="function", name="efs")
def fixture_efs():
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")

View File

@ -1,26 +1,9 @@
import boto3
import pytest
from moto import mock_efs
from . import fixture_efs # noqa # pylint: disable=unused-import
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
@pytest.fixture(scope="function", name="file_system")
def fixture_file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp

View File

@ -1,28 +1,12 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_efs
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from . import fixture_efs # noqa # pylint: disable=unused-import
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
@pytest.fixture(scope="function", name="file_system")
def fixture_file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp

View File

@ -1,11 +1,10 @@
import re
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_efs
from tests.test_efs.junk_drawer import has_status_code
from . import fixture_efs # noqa # pylint: disable=unused-import
ARN_PATT = r"^arn:(?P<Partition>[^:\n]*):(?P<Service>[^:\n]*):(?P<Region>[^:\n]*):(?P<AccountID>[^:\n]*):(?P<Ignore>(?P<ResourceType>[^:\/\n]*)[:\/])?(?P<Resource>.*)$"
STRICT_ARN_PATT = r"^arn:aws:[a-z]+:[a-z]{2}-[a-z]+-[0-9]:[0-9]+:[a-z-]+\/[a-z0-9-]+$"
@ -30,21 +29,6 @@ SAMPLE_2_PARAMS = {
}
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
# Testing Create
# ==============

View File

@ -1,22 +1,4 @@
import boto3
import pytest
from moto import mock_efs
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
from . import fixture_efs # noqa # pylint: disable=unused-import
def test_list_tags_for_resource__without_tags(efs):

View File

@ -1,23 +1,7 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_efs
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
from . import fixture_efs # noqa # pylint: disable=unused-import
def test_describe_filesystem_config__unknown(efs):

View File

@ -2,13 +2,12 @@ import re
import sys
from ipaddress import IPv4Network
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_ec2, mock_efs
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from tests.test_efs.junk_drawer import has_status_code
from . import fixture_ec2, fixture_efs # noqa # pylint: disable=unused-import
# Handle the fact that `subnet_of` is not a feature before 3.7.
@ -36,36 +35,15 @@ else:
)
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def ec2(aws_credentials): # pylint: disable=unused-argument
with mock_ec2():
yield boto3.client("ec2", region_name="us-east-1")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
@pytest.fixture(scope="function", name="file_system")
def fixture_file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp
@pytest.fixture(scope="function")
def subnet(ec2):
@pytest.fixture(scope="function", name="subnet")
def fixture_subnet(ec2):
desc_sn_resp = ec2.describe_subnets()
subnet = desc_sn_resp["Subnets"][0]
yield subnet

View File

@ -1,40 +1,18 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_ec2, mock_efs
from . import fixture_ec2, fixture_efs # noqa # pylint: disable=unused-import
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def ec2(aws_credentials): # pylint: disable=unused-argument
with mock_ec2():
yield boto3.client("ec2", region_name="us-east-1")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
@pytest.fixture(scope="function", name="file_system")
def fixture_file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp
@pytest.fixture(scope="function")
def subnet(ec2):
@pytest.fixture(scope="function", name="subnet")
def fixture_subnet(ec2):
desc_sn_resp = ec2.describe_subnets()
subnet = desc_sn_resp["Subnets"][0]
yield subnet

View File

@ -9,8 +9,8 @@ FILE_SYSTEMS = "/2015-02-01/file-systems"
MOUNT_TARGETS = "/2015-02-01/mount-targets"
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
@pytest.fixture(scope="function", name="aws_credentials")
def fixture_aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
@ -18,14 +18,14 @@ def aws_credentials(monkeypatch):
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs_client(aws_credentials): # pylint: disable=unused-argument
@pytest.fixture(scope="function", name="efs_client")
def fixture_efs_client(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield server.create_backend_app("efs").test_client()
@pytest.fixture(scope="function")
def subnet_id(aws_credentials): # pylint: disable=unused-argument
@pytest.fixture(scope="function", name="subnet_id")
def fixture_subnet_id(aws_credentials): # pylint: disable=unused-argument
with mock_ec2():
ec2_client = server.create_backend_app("ec2").test_client()
resp = ec2_client.get("/?Action=DescribeSubnets")
@ -33,8 +33,8 @@ def subnet_id(aws_credentials): # pylint: disable=unused-argument
yield subnet_ids[0]
@pytest.fixture(scope="function")
def file_system_id(efs_client):
@pytest.fixture(scope="function", name="file_system_id")
def fixture_file_system_id(efs_client):
resp = efs_client.post(
FILE_SYSTEMS, json={"CreationToken": "foobarbaz", "Backup": True}
)

View File

@ -72,8 +72,8 @@ from .test_eks_utils import (
)
@pytest.fixture(scope="function")
def ClusterBuilder():
@pytest.fixture(scope="function", name="ClusterBuilder")
def fixture_ClusterBuilder():
class ClusterTestDataFactory:
def __init__(self, client, count, minimal):
# Generate 'count' number of random Cluster objects.
@ -104,8 +104,8 @@ def ClusterBuilder():
yield _execute
@pytest.fixture(scope="function")
def FargateProfileBuilder(ClusterBuilder):
@pytest.fixture(scope="function", name="FargateProfileBuilder")
def fixture_FargateProfileBuilder(ClusterBuilder):
class FargateProfileTestDataFactory:
def __init__(self, client, cluster, count, minimal):
self.cluster_name = cluster.existing_cluster_name
@ -142,8 +142,8 @@ def FargateProfileBuilder(ClusterBuilder):
return _execute
@pytest.fixture(scope="function")
def NodegroupBuilder(ClusterBuilder):
@pytest.fixture(scope="function", name="NodegroupBuilder")
def fixture_NodegroupBuilder(ClusterBuilder):
class NodegroupTestDataFactory:
def __init__(self, client, cluster, count, minimal):
self.cluster_name = cluster.existing_cluster_name

View File

@ -79,14 +79,14 @@ class TestNodegroup:
]
@pytest.fixture(autouse=True)
def test_client():
@pytest.fixture(autouse=True, name="test_client")
def fixture_test_client():
backend = server.create_backend_app(service=SERVICE)
yield backend.test_client()
@pytest.fixture(scope="function")
def create_cluster(test_client):
@pytest.fixture(scope="function", name="create_cluster")
def fixtue_create_cluster(test_client):
def create_and_verify_cluster(client, name):
"""Creates one valid cluster and verifies return status code 200."""
data = deepcopy(TestCluster.data)
@ -106,8 +106,8 @@ def create_cluster(test_client):
yield _execute
@pytest.fixture(scope="function", autouse=True)
def create_nodegroup(test_client):
@pytest.fixture(scope="function", autouse=True, name="create_nodegroup")
def fixture_create_nodegroup(test_client):
def create_and_verify_nodegroup(client, name):
"""Creates one valid nodegroup and verifies return status code 200."""
data = deepcopy(TestNodegroup.data)

View File

@ -610,10 +610,8 @@ def test_put_remove_auto_scaling_policy():
("AutoScalingPolicy" not in core_instance_group).should.equal(True)
def _patch_cluster_id_placeholder_in_autoscaling_policy(
auto_scaling_policy, cluster_id
):
policy_copy = deepcopy(auto_scaling_policy)
def _patch_cluster_id_placeholder_in_autoscaling_policy(policy, cluster_id):
policy_copy = deepcopy(policy)
for rule in policy_copy["Rules"]:
for dimension in rule["Trigger"]["CloudWatchAlarmDefinition"]["Dimensions"]:
dimension["Value"] = cluster_id

View File

@ -15,14 +15,14 @@ from unittest.mock import patch
from moto.emrcontainers import REGION as DEFAULT_REGION
@pytest.fixture(scope="function")
def client():
@pytest.fixture(scope="function", name="client")
def fixture_client():
with mock_emrcontainers():
yield boto3.client("emr-containers", region_name=DEFAULT_REGION)
@pytest.fixture(scope="function")
def virtual_cluster_factory(client):
@pytest.fixture(scope="function", name="virtual_cluster_factory")
def fixture_virtual_cluster_factory(client):
if settings.TEST_SERVER_MODE:
raise SkipTest("Cant manipulate time in server mode")
@ -47,8 +47,8 @@ def virtual_cluster_factory(client):
yield cluster_list
@pytest.fixture()
def job_factory(client, virtual_cluster_factory):
@pytest.fixture(name="job_factory")
def fixture_job_factory(client, virtual_cluster_factory):
virtual_cluster_id = virtual_cluster_factory[0]
default_job_driver = {
"sparkSubmitJobDriver": {

View File

@ -19,14 +19,14 @@ def does_not_raise():
yield
@pytest.fixture(scope="function")
def client():
@pytest.fixture(scope="function", name="client")
def fixture_client():
with mock_emrserverless():
yield boto3.client("emr-serverless", region_name=DEFAULT_REGION)
@pytest.fixture(scope="function")
def application_factory(client):
@pytest.fixture(scope="function", name="application_factory")
def fixture_application_factory(client):
application_list = []
if settings.TEST_SERVER_MODE:

View File

@ -42,8 +42,8 @@ from .fixtures.schema_registry import (
)
@pytest.fixture
def client():
@pytest.fixture(name="client")
def fixture_client():
with mock_glue():
yield boto3.client("glue", region_name="us-east-1")

View File

@ -6,19 +6,19 @@ from botocore.exceptions import ClientError
from moto import mock_iot, mock_cognitoidentity
@pytest.fixture
def region_name():
@pytest.fixture(name="region_name")
def fixture_region_name():
return "ap-northeast-1"
@pytest.fixture
def iot_client(region_name):
@pytest.fixture(name="iot_client")
def fixture_iot_client(region_name):
with mock_iot():
yield boto3.client("iot", region_name=region_name)
@pytest.fixture
def policy(iot_client):
@pytest.fixture(name="policy")
def fixture_policy(iot_client):
return iot_client.create_policy(policyName="my-policy", policyDocument="{}")

View File

@ -6,13 +6,13 @@ PLAINTEXT = b"text"
REGION = "us-east-1"
@pytest.fixture
def backend():
@pytest.fixture(name="backend")
def fixture_backend():
return KmsBackend(REGION)
@pytest.fixture
def key(backend):
@pytest.fixture(name="key")
def fixture_key(backend):
return backend.create_key(
None, "ENCRYPT_DECRYPT", "SYMMETRIC_DEFAULT", "Test key", None
)

View File

@ -16,26 +16,24 @@ from moto.logs.models import MAX_RESOURCE_POLICIES_PER_REGION
TEST_REGION = "us-east-1" if settings.TEST_SERVER_MODE else "us-west-2"
@pytest.fixture
def json_policy_doc():
"""Returns a policy document in JSON format.
"""Returns a policy document in JSON format.
The ARN is bogus, but that shouldn't matter for the test.
"""
return json.dumps(
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "Route53LogsToCloudWatchLogs",
"Effect": "Allow",
"Principal": {"Service": ["route53.amazonaws.com"]},
"Action": "logs:PutLogEvents",
"Resource": "log_arn",
}
],
}
)
The ARN is bogus, but that shouldn't matter for the test.
"""
json_policy_doc = json.dumps(
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "Route53LogsToCloudWatchLogs",
"Effect": "Allow",
"Principal": {"Service": ["route53.amazonaws.com"]},
"Action": "logs:PutLogEvents",
"Resource": "log_arn",
}
],
}
)
@mock_logs
@ -589,7 +587,7 @@ def test_put_resource_policy():
@mock_logs
def test_put_resource_policy_too_many(json_policy_doc):
def test_put_resource_policy_too_many():
client = boto3.client("logs", TEST_REGION)
# Create the maximum number of resource policies.
@ -617,7 +615,7 @@ def test_put_resource_policy_too_many(json_policy_doc):
@mock_logs
def test_delete_resource_policy(json_policy_doc):
def test_delete_resource_policy():
client = boto3.client("logs", TEST_REGION)
# Create a bunch of resource policies so we can give delete a workout.
@ -649,7 +647,7 @@ def test_delete_resource_policy(json_policy_doc):
@mock_logs
def test_describe_resource_policies(json_policy_doc):
def test_describe_resource_policies():
client = boto3.client("logs", TEST_REGION)
# Create the maximum number of resource policies so there's something

View File

@ -13,8 +13,8 @@ INVALID_ID_ERROR_MESSAGE = (
RESOURCE_NOT_FOUND_ERROR_MESSAGE = "Query does not exist."
@pytest.fixture(autouse=True)
def client():
@pytest.fixture(autouse=True, name="client")
def fixture_client():
yield boto3.client("redshift-data", region_name=REGION)

View File

@ -18,8 +18,8 @@ def headers(action):
}
@pytest.fixture(autouse=True)
def client():
@pytest.fixture(autouse=True, name="client")
def fixture_client():
backend = server.create_backend_app("redshift-data")
yield backend.test_client()

View File

@ -42,9 +42,10 @@ TEST_SERVERLESS_PRODUCTION_VARIANTS = [
]
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
def create_endpoint_config_helper(sagemaker_client, production_variants):
@ -72,7 +73,6 @@ def create_endpoint_config_helper(sagemaker_client, production_variants):
resp["ProductionVariants"].should.equal(production_variants)
@mock_sagemaker
def test_create_endpoint_config(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config(
@ -85,7 +85,6 @@ def test_create_endpoint_config(sagemaker_client):
create_endpoint_config_helper(sagemaker_client, TEST_PRODUCTION_VARIANTS)
@mock_sagemaker
def test_create_endpoint_config_serverless(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config(
@ -98,7 +97,6 @@ def test_create_endpoint_config_serverless(sagemaker_client):
create_endpoint_config_helper(sagemaker_client, TEST_SERVERLESS_PRODUCTION_VARIANTS)
@mock_sagemaker
def test_delete_endpoint_config(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME)
resp = sagemaker_client.create_endpoint_config(
@ -140,7 +138,6 @@ def test_delete_endpoint_config(sagemaker_client):
)
@mock_sagemaker
def test_create_endpoint_invalid_instance_type(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME)
@ -160,7 +157,6 @@ def test_create_endpoint_invalid_instance_type(sagemaker_client):
assert expected_message in e.value.response["Error"]["Message"]
@mock_sagemaker
def test_create_endpoint_invalid_memory_size(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME)
@ -180,7 +176,6 @@ def test_create_endpoint_invalid_memory_size(sagemaker_client):
assert expected_message in e.value.response["Error"]["Message"]
@mock_sagemaker
def test_create_endpoint(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint(
@ -221,7 +216,6 @@ def test_create_endpoint(sagemaker_client):
assert resp["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_delete_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -237,7 +231,6 @@ def test_delete_endpoint(sagemaker_client):
assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
@mock_sagemaker
def test_add_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -253,7 +246,6 @@ def test_add_tags_endpoint(sagemaker_client):
assert response["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_delete_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -273,7 +265,6 @@ def test_delete_tags_endpoint(sagemaker_client):
assert response["Tags"] == []
@mock_sagemaker
def test_list_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -298,7 +289,6 @@ def test_list_tags_endpoint(sagemaker_client):
assert response["Tags"] == tags[50:]
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -342,7 +332,6 @@ def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client):
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
production_variants = [
{
@ -422,7 +411,6 @@ def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
resp["ProductionVariants"][1]["CurrentWeight"].should.equal(new_desired_weight)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant(
sagemaker_client,
):
@ -459,7 +447,6 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_vari
resp.should.equal(old_resp)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint(
sagemaker_client,
):
@ -497,7 +484,6 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endp
resp.should.equal(old_resp)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant(
sagemaker_client,
):

View File

@ -8,12 +8,12 @@ TEST_REGION_NAME = "us-east-1"
TEST_EXPERIMENT_NAME = "MyExperimentName"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@mock_sagemaker
def test_create_experiment(sagemaker_client):
resp = sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
@ -29,7 +29,6 @@ def test_create_experiment(sagemaker_client):
)
@mock_sagemaker
def test_list_experiments(sagemaker_client):
experiment_names = [f"some-experiment-name-{i}" for i in range(10)]
@ -57,7 +56,6 @@ def test_list_experiments(sagemaker_client):
assert resp.get("NextToken") is None
@mock_sagemaker
def test_delete_experiment(sagemaker_client):
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
@ -70,7 +68,6 @@ def test_delete_experiment(sagemaker_client):
assert len(resp["ExperimentSummaries"]) == 0
@mock_sagemaker
def test_add_tags_to_experiment(sagemaker_client):
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
@ -89,7 +86,6 @@ def test_add_tags_to_experiment(sagemaker_client):
assert resp["Tags"] == tags
@mock_sagemaker
def test_delete_tags_to_experiment(sagemaker_client):
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)

View File

@ -12,9 +12,10 @@ TEST_ARN = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
TEST_MODEL_NAME = "MyModelName"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MySageMakerModel(object):
@ -36,7 +37,6 @@ class MySageMakerModel(object):
return resp
@mock_sagemaker
def test_describe_model(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
@ -44,14 +44,12 @@ def test_describe_model(sagemaker_client):
assert model.get("ModelName").should.equal(TEST_MODEL_NAME)
@mock_sagemaker
def test_describe_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err:
sagemaker_client.describe_model(ModelName="unknown")
assert err.value.response["Error"]["Message"].should.contain("Could not find model")
@mock_sagemaker
def test_create_model(sagemaker_client):
vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"])
model = sagemaker_client.create_model(
@ -64,7 +62,6 @@ def test_create_model(sagemaker_client):
)
@mock_sagemaker
def test_delete_model(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
@ -74,14 +71,12 @@ def test_delete_model(sagemaker_client):
assert len(sagemaker_client.list_models()["Models"]).should.equal(0)
@mock_sagemaker
def test_delete_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err:
sagemaker_client.delete_model(ModelName="blah")
assert err.value.response["Error"]["Code"].should.equal("404")
@mock_sagemaker
def test_list_models(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
@ -93,7 +88,6 @@ def test_list_models(sagemaker_client):
)
@mock_sagemaker
def test_list_models_multiple(sagemaker_client):
name_model_1 = "blah"
arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
@ -108,13 +102,11 @@ def test_list_models_multiple(sagemaker_client):
assert len(models["Models"]).should.equal(2)
@mock_sagemaker
def test_list_models_none(sagemaker_client):
models = sagemaker_client.list_models()
assert len(models["Models"]).should.equal(0)
@mock_sagemaker
def test_add_tags_to_model(sagemaker_client):
model = MySageMakerModel().save(sagemaker_client)
resource_arn = model["ModelArn"]
@ -129,7 +121,6 @@ def test_add_tags_to_model(sagemaker_client):
assert response["Tags"] == tags
@mock_sagemaker
def test_delete_tags_from_model(sagemaker_client):
model = MySageMakerModel().save(sagemaker_client)
resource_arn = model["ModelArn"]

View File

@ -26,9 +26,10 @@ FAKE_NAME_PARAM = "MyNotebookInstance"
FAKE_INSTANCE_TYPE_PARAM = "ml.t2.medium"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
def _get_notebook_instance_arn(notebook_name):
@ -39,7 +40,6 @@ def _get_notebook_instance_lifecycle_arn(lifecycle_name):
return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance-lifecycle-configuration/{lifecycle_name}"
@mock_sagemaker
def test_create_notebook_instance_minimal_params(sagemaker_client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
@ -68,7 +68,6 @@ def test_create_notebook_instance_minimal_params(sagemaker_client):
# assert resp["RootAccess"] == True # ToDo: Not sure if this defaults...
@mock_sagemaker
def test_create_notebook_instance_params(sagemaker_client):
fake_direct_internet_access_param = "Enabled"
volume_size_in_gb_param = 7
@ -121,7 +120,6 @@ def test_create_notebook_instance_params(sagemaker_client):
assert resp["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
instance_type = "undefined_instance_type"
args = {
@ -139,7 +137,6 @@ def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
assert expected_message in ex.value.response["Error"]["Message"]
@mock_sagemaker
def test_notebook_instance_lifecycle(sagemaker_client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
@ -193,14 +190,12 @@ def test_notebook_instance_lifecycle(sagemaker_client):
assert ex.value.response["Error"]["Message"] == "RecordNotFound"
@mock_sagemaker
def test_describe_nonexistent_model(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.describe_model(ModelName="Nonexistent")
assert e.value.response["Error"]["Message"].startswith("Could not find model")
@mock_sagemaker
def test_notebook_instance_lifecycle_config(sagemaker_client):
name = "MyLifeCycleConfig"
on_create = [{"Content": "Create Script Line 1"}]
@ -252,7 +247,6 @@ def test_notebook_instance_lifecycle_config(sagemaker_client):
)
@mock_sagemaker
def test_add_tags_to_notebook(sagemaker_client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
@ -272,7 +266,6 @@ def test_add_tags_to_notebook(sagemaker_client):
assert response["Tags"] == tags
@mock_sagemaker
def test_delete_tags_from_notebook(sagemaker_client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,

View File

@ -12,9 +12,10 @@ FAKE_CONTAINER = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
TEST_REGION_NAME = "us-east-1"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MyProcessingJobModel(object):
@ -103,7 +104,6 @@ class MyProcessingJobModel(object):
return sagemaker_client.create_processing_job(**params)
@mock_sagemaker
def test_create_processing_job(sagemaker_client):
bucket = "my-bucket"
prefix = "my-prefix"
@ -150,7 +150,6 @@ def test_create_processing_job(sagemaker_client):
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
@mock_sagemaker
def test_list_processing_jobs(sagemaker_client):
test_processing_job = MyProcessingJobModel(
processing_job_name=FAKE_PROCESSING_JOB_NAME, role_arn=FAKE_ROLE_ARN
@ -170,7 +169,6 @@ def test_list_processing_jobs(sagemaker_client):
assert processing_jobs.get("NextToken") is None
@mock_sagemaker
def test_list_processing_jobs_multiple(sagemaker_client):
name_job_1 = "blah"
arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
@ -193,13 +191,11 @@ def test_list_processing_jobs_multiple(sagemaker_client):
assert processing_jobs.get("NextToken").should.be.none
@mock_sagemaker
def test_list_processing_jobs_none(sagemaker_client):
processing_jobs = sagemaker_client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0)
@mock_sagemaker
def test_list_processing_jobs_should_validate_input(sagemaker_client):
junk_status_equals = "blah"
with pytest.raises(ClientError) as ex:
@ -218,7 +214,6 @@ def test_list_processing_jobs_should_validate_input(sagemaker_client):
)
@mock_sagemaker
def test_list_processing_jobs_with_name_filters(sagemaker_client):
for i in range(5):
name = "xgboost-{}".format(i)
@ -243,7 +238,6 @@ def test_list_processing_jobs_with_name_filters(sagemaker_client):
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
@mock_sagemaker
def test_list_processing_jobs_paginated(sagemaker_client):
for i in range(5):
name = "xgboost-{}".format(i)
@ -273,7 +267,6 @@ def test_list_processing_jobs_paginated(sagemaker_client):
assert xgboost_processing_job_next.get("NextToken").should_not.be.none
@mock_sagemaker
def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
for i in range(5):
name = "xgboost-{}".format(i)
@ -316,7 +309,6 @@ def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
assert vgg_processing_job_10.get("NextToken").should.be.none
@mock_sagemaker
def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client):
for i in range(5):
name = "xgboost-{}".format(i)
@ -357,7 +349,6 @@ def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client
assert processing_jobs_with_2_next_next.get("NextToken").should.be.none
@mock_sagemaker
def test_add_and_delete_tags_in_training_job(sagemaker_client):
processing_job_name = "MyProcessingJob"
role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)

View File

@ -8,12 +8,12 @@ from moto import mock_sagemaker
TEST_REGION_NAME = "us-east-1"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@mock_sagemaker
def test_search(sagemaker_client):
experiment_name = "experiment_name"
trial_component_name = "trial_component_name"
@ -60,7 +60,6 @@ def test_search(sagemaker_client):
assert resp["Results"][0]["Trial"]["TrialName"] == trial_name
@mock_sagemaker
def test_search_trial_component_with_experiment_name(sagemaker_client):
experiment_name = "experiment_name"
trial_component_name = "trial_component_name"

View File

@ -888,10 +888,10 @@ def test_state_machine_get_execution_history_contains_expected_success_events_wh
execution_history["events"].should.equal(expected_events)
@pytest.mark.parametrize("region", ["us-west-2", "cn-northwest-1"])
@pytest.mark.parametrize("test_region", ["us-west-2", "cn-northwest-1"])
@mock_stepfunctions
def test_stepfunction_regions(region):
client = boto3.client("stepfunctions", region_name=region)
def test_stepfunction_regions(test_region):
client = boto3.client("stepfunctions", region_name=test_region)
resp = client.list_state_machines()
resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)