TechDebt - enable pylint rule redefined-outer-scope (#5518)

This commit is contained in:
Bert Blommers 2022-10-04 16:28:30 +00:00 committed by GitHub
parent 696b809b5a
commit 4f84e2f154
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
46 changed files with 236 additions and 396 deletions

View File

@ -157,11 +157,11 @@ class _DockerDataVolumeContext:
raise # multiple processes trying to use same volume?
def _zipfile_content(zipfile):
def _zipfile_content(zipfile_content):
try:
to_unzip_code = base64.b64decode(bytes(zipfile, "utf-8"))
to_unzip_code = base64.b64decode(bytes(zipfile_content, "utf-8"))
except Exception:
to_unzip_code = base64.b64decode(zipfile)
to_unzip_code = base64.b64decode(zipfile_content)
sha_code = hashlib.sha256(to_unzip_code)
base64ed_sha = base64.b64encode(sha_code.digest()).decode("utf-8")

View File

@ -292,23 +292,23 @@ class CloudFrontBackend(BaseBackend):
return dist
return False
def update_distribution(self, DistributionConfig, Id, IfMatch):
def update_distribution(self, dist_config, _id, if_match):
"""
The IfMatch-value is ignored - any value is considered valid.
Calling this function without a value is invalid, per AWS' behaviour
"""
if Id not in self.distributions or Id is None:
if _id not in self.distributions or _id is None:
raise NoSuchDistribution
if not IfMatch:
if not if_match:
raise InvalidIfMatchVersion
if not DistributionConfig:
if not dist_config:
raise NoSuchDistribution
dist = self.distributions[Id]
dist = self.distributions[_id]
aliases = DistributionConfig["Aliases"]["Items"]["CNAME"]
dist.distribution_config.config = DistributionConfig
aliases = dist_config["Aliases"]["Items"]["CNAME"]
dist.distribution_config.config = dist_config
dist.distribution_config.aliases = aliases
self.distributions[Id] = dist
self.distributions[_id] = dist
dist.advance()
return dist, dist.location, dist.etag

View File

@ -83,9 +83,9 @@ class CloudFrontResponse(BaseResponse):
if_match = headers["If-Match"]
dist, location, e_tag = self.backend.update_distribution(
DistributionConfig=distribution_config,
Id=dist_id,
IfMatch=if_match,
dist_config=distribution_config,
_id=dist_id,
if_match=if_match,
)
template = self.response_template(UPDATE_DISTRIBUTION_TEMPLATE)
response = template.render(distribution=dist, xmlns=XMLNS)

View File

@ -146,28 +146,28 @@ class convert_flask_to_responses_response(object):
return status, headers, response
def iso_8601_datetime_with_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
def iso_8601_datetime_with_milliseconds(value):
return value.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
# Even Python does not support nanoseconds, other languages like Go do (needed for Terraform)
def iso_8601_datetime_with_nanoseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f000Z")
def iso_8601_datetime_with_nanoseconds(value):
return value.strftime("%Y-%m-%dT%H:%M:%S.%f000Z")
def iso_8601_datetime_without_milliseconds(datetime):
return None if datetime is None else datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
def iso_8601_datetime_without_milliseconds(value):
return None if value is None else value.strftime("%Y-%m-%dT%H:%M:%SZ")
def iso_8601_datetime_without_milliseconds_s3(datetime):
return None if datetime is None else datetime.strftime("%Y-%m-%dT%H:%M:%S.000Z")
def iso_8601_datetime_without_milliseconds_s3(value):
return None if value is None else value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
def rfc_1123_datetime(datetime):
return datetime.strftime(RFC1123)
def rfc_1123_datetime(src):
return src.strftime(RFC1123)
def str_to_rfc_1123_datetime(value):

View File

@ -12,14 +12,14 @@ INSTANCE_FAMILIES = list(set([i.split(".")[0] for i in INSTANCE_TYPES.keys()]))
root = pathlib.Path(__file__).parent
offerings_path = "../resources/instance_type_offerings"
INSTANCE_TYPE_OFFERINGS = {}
for location_type in listdir(root / offerings_path):
INSTANCE_TYPE_OFFERINGS[location_type] = {}
for _region in listdir(root / offerings_path / location_type):
full_path = offerings_path + "/" + location_type + "/" + _region
for _location_type in listdir(root / offerings_path):
INSTANCE_TYPE_OFFERINGS[_location_type] = {}
for _region in listdir(root / offerings_path / _location_type):
full_path = offerings_path + "/" + _location_type + "/" + _region
res = load_resource(__name__, full_path)
for instance in res:
instance["LocationType"] = location_type
INSTANCE_TYPE_OFFERINGS[location_type][_region.replace(".json", "")] = res
instance["LocationType"] = _location_type
INSTANCE_TYPE_OFFERINGS[_location_type][_region.replace(".json", "")] = res
class InstanceType(dict):

View File

@ -1021,19 +1021,12 @@ class SecurityGroupBackend:
if cidr_item.get("CidrIp6") == item.get("CidrIp6"):
cidr_item["Description"] = item.get("Description")
for item in security_rule.source_groups:
for group in security_rule.source_groups:
for source_group in rule.source_groups:
if source_group.get("GroupId") == item.get(
if source_group.get("GroupId") == group.get(
"GroupId"
) or source_group.get("GroupName") == item.get("GroupName"):
source_group["Description"] = item.get("Description")
for item in security_rule.source_groups:
for source_group in rule.source_groups:
if source_group.get("GroupId") == item.get(
"GroupId"
) or source_group.get("GroupName") == item.get("GroupName"):
source_group["Description"] = item.get("Description")
) or source_group.get("GroupName") == group.get("GroupName"):
source_group["Description"] = group.get("Description")
def _remove_items_from_rule(self, ip_ranges, _source_groups, prefix_list_ids, rule):
for item in ip_ranges:

View File

@ -297,42 +297,42 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
self.name = name
self.modification_time = datetime.now(timezone.utc).isoformat()
def associate_ip_address(self, ip_address):
self.ip_addresses.append(ip_address)
def associate_ip_address(self, value):
self.ip_addresses.append(value)
self.ip_address_count = len(self.ip_addresses)
eni_id = f"rni-{mock_random.get_random_hex(17)}"
self.subnets[ip_address["SubnetId"]][ip_address["Ip"]] = eni_id
self.subnets[value["SubnetId"]][value["Ip"]] = eni_id
eni_info = self.ec2_backend.create_network_interface(
description=f"Route 53 Resolver: {self.id}:{eni_id}",
group_ids=self.security_group_ids,
interface_type="interface",
private_ip_address=ip_address.get("Ip"),
private_ip_address=value.get("Ip"),
private_ip_addresses=[
{"Primary": True, "PrivateIpAddress": ip_address.get("Ip")}
{"Primary": True, "PrivateIpAddress": value.get("Ip")}
],
subnet=ip_address.get("SubnetId"),
subnet=value.get("SubnetId"),
)
self.eni_ids.append(eni_info.id)
def disassociate_ip_address(self, ip_address):
if not ip_address.get("Ip") and ip_address.get("IpId"):
for ip_addr, eni_id in self.subnets[ip_address.get("SubnetId")].items():
if ip_address.get("IpId") == eni_id:
ip_address["Ip"] = ip_addr
if ip_address.get("Ip"):
def disassociate_ip_address(self, value):
if not value.get("Ip") and value.get("IpId"):
for ip_addr, eni_id in self.subnets[value.get("SubnetId")].items():
if value.get("IpId") == eni_id:
value["Ip"] = ip_addr
if value.get("Ip"):
self.ip_addresses = list(
filter(lambda i: i["Ip"] != ip_address.get("Ip"), self.ip_addresses)
filter(lambda i: i["Ip"] != value.get("Ip"), self.ip_addresses)
)
if len(self.subnets[ip_address["SubnetId"]]) == 1:
self.subnets.pop(ip_address["SubnetId"])
if len(self.subnets[value["SubnetId"]]) == 1:
self.subnets.pop(value["SubnetId"])
else:
self.subnets[ip_address["SubnetId"]].pop(ip_address["Ip"])
self.subnets[value["SubnetId"]].pop(value["Ip"])
for eni_id in self.eni_ids:
eni_info = self.ec2_backend.get_network_interface(eni_id)
if eni_info.private_ip_address == ip_address.get("Ip"):
if eni_info.private_ip_address == value.get("Ip"):
self.ec2_backend.delete_network_interface(eni_id)
self.eni_ids.remove(eni_id)
self.ip_address_count = len(self.ip_addresses)
@ -873,32 +873,30 @@ class Route53ResolverBackend(BaseBackend):
resolver_endpoint.update_name(name)
return resolver_endpoint
def associate_resolver_endpoint_ip_address(self, resolver_endpoint_id, ip_address):
def associate_resolver_endpoint_ip_address(self, resolver_endpoint_id, value):
self._validate_resolver_endpoint_id(resolver_endpoint_id)
resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
if not ip_address.get("Ip"):
if not value.get("Ip"):
subnet_info = self.ec2_backend.get_all_subnets(
subnet_ids=[ip_address.get("SubnetId")]
subnet_ids=[value.get("SubnetId")]
)[0]
ip_address["Ip"] = subnet_info.get_available_subnet_ip(self)
self._verify_subnet_ips([ip_address], False)
value["Ip"] = subnet_info.get_available_subnet_ip(self)
self._verify_subnet_ips([value], False)
resolver_endpoint.associate_ip_address(ip_address)
resolver_endpoint.associate_ip_address(value)
return resolver_endpoint
def disassociate_resolver_endpoint_ip_address(
self, resolver_endpoint_id, ip_address
):
def disassociate_resolver_endpoint_ip_address(self, resolver_endpoint_id, value):
self._validate_resolver_endpoint_id(resolver_endpoint_id)
resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
if not (ip_address.get("Ip") or ip_address.get("IpId")):
if not (value.get("Ip") or value.get("IpId")):
raise InvalidRequestException(
"[RSLVR-00503] Need to specify either the IP ID or both subnet and IP address in order to remove IP address."
)
resolver_endpoint.disassociate_ip_address(ip_address)
resolver_endpoint.disassociate_ip_address(value)
return resolver_endpoint

View File

@ -266,7 +266,7 @@ class Route53ResolverResponse(BaseResponse):
resolver_endpoint = (
self.route53resolver_backend.associate_resolver_endpoint_ip_address(
resolver_endpoint_id=resolver_endpoint_id,
ip_address=ip_address,
value=ip_address,
)
)
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})
@ -278,7 +278,7 @@ class Route53ResolverResponse(BaseResponse):
resolver_endpoint = (
self.route53resolver_backend.disassociate_resolver_endpoint_ip_address(
resolver_endpoint_id=resolver_endpoint_id,
ip_address=ip_address,
value=ip_address,
)
)
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})

View File

@ -1,8 +1,8 @@
def name(secret, names):
def name_filter(secret, names):
return _matcher(names, [secret.name])
def description(secret, descriptions):
def description_filter(secret, descriptions):
return _matcher(descriptions, [secret.description])

View File

@ -18,13 +18,19 @@ from .exceptions import (
ClientError,
)
from .utils import random_password, secret_arn, get_secret_name_from_arn
from .list_secrets.filters import filter_all, tag_key, tag_value, description, name
from .list_secrets.filters import (
filter_all,
tag_key,
tag_value,
description_filter,
name_filter,
)
_filter_functions = {
"all": filter_all,
"name": name,
"description": description,
"name": name_filter,
"description": description_filter,
"tag-key": tag_key,
"tag-value": tag_value,
}

View File

@ -470,24 +470,24 @@ class SESBackend(BaseBackend):
text_part = str.replace(str(text_part), "{{%s}}" % key, value)
html_part = str.replace(str(html_part), "{{%s}}" % key, value)
email = MIMEMultipart("alternative")
email_obj = MIMEMultipart("alternative")
mime_text = MIMEBase("text", "plain;charset=UTF-8")
mime_text.set_payload(text_part.encode("utf-8"))
encode_7or8bit(mime_text)
email.attach(mime_text)
email_obj.attach(mime_text)
mime_html = MIMEBase("text", "html;charset=UTF-8")
mime_html.set_payload(html_part.encode("utf-8"))
encode_7or8bit(mime_html)
email.attach(mime_html)
email_obj.attach(mime_html)
now = datetime.datetime.now().isoformat()
rendered_template = "Date: %s\r\nSubject: %s\r\n%s" % (
now,
subject_part,
email.as_string(),
email_obj.as_string(),
)
return rendered_template

View File

@ -141,10 +141,10 @@ class Message(BaseModel):
)
@staticmethod
def utf8(string):
if isinstance(string, str):
return string.encode("utf-8")
return string
def utf8(value):
if isinstance(value, str):
return value.encode("utf-8")
return value
@property
def body(self):

View File

@ -16,11 +16,11 @@ class TelemetryRecords(BaseModel):
self.records = records
@classmethod
def from_json(cls, json):
instance_id = json.get("EC2InstanceId", None)
hostname = json.get("Hostname")
resource_arn = json.get("ResourceARN")
telemetry_records = json["TelemetryRecords"]
def from_json(cls, src):
instance_id = src.get("EC2InstanceId", None)
hostname = src.get("Hostname")
resource_arn = src.get("ResourceARN")
telemetry_records = src["TelemetryRecords"]
return cls(instance_id, hostname, resource_arn, telemetry_records)
@ -242,8 +242,8 @@ class XRayBackend(BaseBackend):
service_region, zones, "xray"
)
def add_telemetry_records(self, json):
self._telemetry_records.append(TelemetryRecords.from_json(json))
def add_telemetry_records(self, src):
self._telemetry_records.append(TelemetryRecords.from_json(src))
def process_segment(self, doc):
try:

View File

@ -14,5 +14,5 @@ ignore-paths=moto/packages
[pylint.'MESSAGES CONTROL']
disable = W,C,R,E
# future sensible checks = super-init-not-called, redefined-outer-name, unspecified-encoding, undefined-loop-variable
enable = arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
# future sensible checks = super-init-not-called, unspecified-encoding, undefined-loop-variable
enable = arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import

View File

@ -139,7 +139,7 @@ def test_event_source_mapping_create_from_cloudformation_json():
"FunctionArn"
]
template = event_source_mapping_template.substitute(
esm_template = event_source_mapping_template.substitute(
{
"resource_name": "Foo",
"batch_size": 1,
@ -149,7 +149,7 @@ def test_event_source_mapping_create_from_cloudformation_json():
}
)
cf.create_stack(StackName=random_stack_name(), TemplateBody=template)
cf.create_stack(StackName=random_stack_name(), TemplateBody=esm_template)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
event_sources["EventSourceMappings"].should.have.length_of(1)
@ -174,7 +174,7 @@ def test_event_source_mapping_delete_stack():
_, lambda_stack = create_stack(cf, s3)
created_fn_name = get_created_function_name(cf, lambda_stack)
template = event_source_mapping_template.substitute(
esm_template = event_source_mapping_template.substitute(
{
"resource_name": "Foo",
"batch_size": 1,
@ -184,7 +184,9 @@ def test_event_source_mapping_delete_stack():
}
)
esm_stack = cf.create_stack(StackName=random_stack_name(), TemplateBody=template)
esm_stack = cf.create_stack(
StackName=random_stack_name(), TemplateBody=esm_template
)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
event_sources["EventSourceMappings"].should.have.length_of(1)

View File

@ -18,8 +18,8 @@ def test_basic_decorator():
client.describe_addresses()["Addresses"].should.equal([])
@pytest.fixture
def aws_credentials(monkeypatch):
@pytest.fixture(name="aws_credentials")
def fixture_aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")

View File

@ -6,8 +6,8 @@ from moto import settings
from unittest import SkipTest
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
@pytest.fixture(scope="function", name="aws_credentials")
def fixture_aws_credentials(monkeypatch):
if settings.TEST_SERVER_MODE:
raise SkipTest("No point in testing this in ServerMode.")
"""Mocked AWS Credentials for moto."""

View File

@ -495,13 +495,13 @@ def test_creating_table_with_0_global_indexes():
@mock_dynamodb
def test_multiple_transactions_on_same_item():
table_schema = {
schema = {
"KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}],
"AttributeDefinitions": [{"AttributeName": "id", "AttributeType": "S"}],
}
dynamodb = boto3.client("dynamodb", region_name="us-east-1")
dynamodb.create_table(
TableName="test-table", BillingMode="PAY_PER_REQUEST", **table_schema
TableName="test-table", BillingMode="PAY_PER_REQUEST", **schema
)
# Insert an item
dynamodb.put_item(TableName="test-table", Item={"id": {"S": "foo"}})
@ -533,13 +533,13 @@ def test_multiple_transactions_on_same_item():
@mock_dynamodb
def test_transact_write_items__too_many_transactions():
table_schema = {
schema = {
"KeySchema": [{"AttributeName": "pk", "KeyType": "HASH"}],
"AttributeDefinitions": [{"AttributeName": "pk", "AttributeType": "S"}],
}
dynamodb = boto3.client("dynamodb", region_name="us-east-1")
dynamodb.create_table(
TableName="test-table", BillingMode="PAY_PER_REQUEST", **table_schema
TableName="test-table", BillingMode="PAY_PER_REQUEST", **schema
)
def update_email_transact(email):

View File

@ -15,8 +15,8 @@ TABLE_NAME = "my_table_name"
TABLE_WITH_RANGE_NAME = "my_table_with_range_name"
@pytest.fixture(autouse=True)
def test_client():
@pytest.fixture(autouse=True, name="test_client")
def fixture_test_client():
backend = server.create_backend_app("dynamodb_v20111205")
test_client = backend.test_client()

View File

@ -0,0 +1,16 @@
import boto3
import pytest
from moto import mock_ec2, mock_efs
@pytest.fixture(scope="function", name="ec2")
def fixture_ec2():
with mock_ec2():
yield boto3.client("ec2", region_name="us-east-1")
@pytest.fixture(scope="function", name="efs")
def fixture_efs():
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")

View File

@ -1,26 +1,9 @@
import boto3
import pytest
from moto import mock_efs
from . import fixture_efs # noqa # pylint: disable=unused-import
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
@pytest.fixture(scope="function", name="file_system")
def fixture_file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp

View File

@ -1,28 +1,12 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_efs
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from . import fixture_efs # noqa # pylint: disable=unused-import
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
@pytest.fixture(scope="function", name="file_system")
def fixture_file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp

View File

@ -1,11 +1,10 @@
import re
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_efs
from tests.test_efs.junk_drawer import has_status_code
from . import fixture_efs # noqa # pylint: disable=unused-import
ARN_PATT = r"^arn:(?P<Partition>[^:\n]*):(?P<Service>[^:\n]*):(?P<Region>[^:\n]*):(?P<AccountID>[^:\n]*):(?P<Ignore>(?P<ResourceType>[^:\/\n]*)[:\/])?(?P<Resource>.*)$"
STRICT_ARN_PATT = r"^arn:aws:[a-z]+:[a-z]{2}-[a-z]+-[0-9]:[0-9]+:[a-z-]+\/[a-z0-9-]+$"
@ -30,21 +29,6 @@ SAMPLE_2_PARAMS = {
}
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
# Testing Create
# ==============

View File

@ -1,22 +1,4 @@
import boto3
import pytest
from moto import mock_efs
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
from . import fixture_efs # noqa # pylint: disable=unused-import
def test_list_tags_for_resource__without_tags(efs):

View File

@ -1,23 +1,7 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_efs
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
from . import fixture_efs # noqa # pylint: disable=unused-import
def test_describe_filesystem_config__unknown(efs):

View File

@ -2,13 +2,12 @@ import re
import sys
from ipaddress import IPv4Network
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_ec2, mock_efs
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from tests.test_efs.junk_drawer import has_status_code
from . import fixture_ec2, fixture_efs # noqa # pylint: disable=unused-import
# Handle the fact that `subnet_of` is not a feature before 3.7.
@ -36,36 +35,15 @@ else:
)
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def ec2(aws_credentials): # pylint: disable=unused-argument
with mock_ec2():
yield boto3.client("ec2", region_name="us-east-1")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
@pytest.fixture(scope="function", name="file_system")
def fixture_file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp
@pytest.fixture(scope="function")
def subnet(ec2):
@pytest.fixture(scope="function", name="subnet")
def fixture_subnet(ec2):
desc_sn_resp = ec2.describe_subnets()
subnet = desc_sn_resp["Subnets"][0]
yield subnet

View File

@ -1,40 +1,18 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_ec2, mock_efs
from . import fixture_ec2, fixture_efs # noqa # pylint: disable=unused-import
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def ec2(aws_credentials): # pylint: disable=unused-argument
with mock_ec2():
yield boto3.client("ec2", region_name="us-east-1")
@pytest.fixture(scope="function")
def efs(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield boto3.client("efs", region_name="us-east-1")
@pytest.fixture(scope="function")
def file_system(efs):
@pytest.fixture(scope="function", name="file_system")
def fixture_file_system(efs):
create_fs_resp = efs.create_file_system(CreationToken="foobarbaz")
create_fs_resp.pop("ResponseMetadata")
yield create_fs_resp
@pytest.fixture(scope="function")
def subnet(ec2):
@pytest.fixture(scope="function", name="subnet")
def fixture_subnet(ec2):
desc_sn_resp = ec2.describe_subnets()
subnet = desc_sn_resp["Subnets"][0]
yield subnet

View File

@ -9,8 +9,8 @@ FILE_SYSTEMS = "/2015-02-01/file-systems"
MOUNT_TARGETS = "/2015-02-01/mount-targets"
@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
@pytest.fixture(scope="function", name="aws_credentials")
def fixture_aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
@ -18,14 +18,14 @@ def aws_credentials(monkeypatch):
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
@pytest.fixture(scope="function")
def efs_client(aws_credentials): # pylint: disable=unused-argument
@pytest.fixture(scope="function", name="efs_client")
def fixture_efs_client(aws_credentials): # pylint: disable=unused-argument
with mock_efs():
yield server.create_backend_app("efs").test_client()
@pytest.fixture(scope="function")
def subnet_id(aws_credentials): # pylint: disable=unused-argument
@pytest.fixture(scope="function", name="subnet_id")
def fixture_subnet_id(aws_credentials): # pylint: disable=unused-argument
with mock_ec2():
ec2_client = server.create_backend_app("ec2").test_client()
resp = ec2_client.get("/?Action=DescribeSubnets")
@ -33,8 +33,8 @@ def subnet_id(aws_credentials): # pylint: disable=unused-argument
yield subnet_ids[0]
@pytest.fixture(scope="function")
def file_system_id(efs_client):
@pytest.fixture(scope="function", name="file_system_id")
def fixture_file_system_id(efs_client):
resp = efs_client.post(
FILE_SYSTEMS, json={"CreationToken": "foobarbaz", "Backup": True}
)

View File

@ -72,8 +72,8 @@ from .test_eks_utils import (
)
@pytest.fixture(scope="function")
def ClusterBuilder():
@pytest.fixture(scope="function", name="ClusterBuilder")
def fixture_ClusterBuilder():
class ClusterTestDataFactory:
def __init__(self, client, count, minimal):
# Generate 'count' number of random Cluster objects.
@ -104,8 +104,8 @@ def ClusterBuilder():
yield _execute
@pytest.fixture(scope="function")
def FargateProfileBuilder(ClusterBuilder):
@pytest.fixture(scope="function", name="FargateProfileBuilder")
def fixture_FargateProfileBuilder(ClusterBuilder):
class FargateProfileTestDataFactory:
def __init__(self, client, cluster, count, minimal):
self.cluster_name = cluster.existing_cluster_name
@ -142,8 +142,8 @@ def FargateProfileBuilder(ClusterBuilder):
return _execute
@pytest.fixture(scope="function")
def NodegroupBuilder(ClusterBuilder):
@pytest.fixture(scope="function", name="NodegroupBuilder")
def fixture_NodegroupBuilder(ClusterBuilder):
class NodegroupTestDataFactory:
def __init__(self, client, cluster, count, minimal):
self.cluster_name = cluster.existing_cluster_name

View File

@ -79,14 +79,14 @@ class TestNodegroup:
]
@pytest.fixture(autouse=True)
def test_client():
@pytest.fixture(autouse=True, name="test_client")
def fixture_test_client():
backend = server.create_backend_app(service=SERVICE)
yield backend.test_client()
@pytest.fixture(scope="function")
def create_cluster(test_client):
@pytest.fixture(scope="function", name="create_cluster")
def fixtue_create_cluster(test_client):
def create_and_verify_cluster(client, name):
"""Creates one valid cluster and verifies return status code 200."""
data = deepcopy(TestCluster.data)
@ -106,8 +106,8 @@ def create_cluster(test_client):
yield _execute
@pytest.fixture(scope="function", autouse=True)
def create_nodegroup(test_client):
@pytest.fixture(scope="function", autouse=True, name="create_nodegroup")
def fixture_create_nodegroup(test_client):
def create_and_verify_nodegroup(client, name):
"""Creates one valid nodegroup and verifies return status code 200."""
data = deepcopy(TestNodegroup.data)

View File

@ -610,10 +610,8 @@ def test_put_remove_auto_scaling_policy():
("AutoScalingPolicy" not in core_instance_group).should.equal(True)
def _patch_cluster_id_placeholder_in_autoscaling_policy(
auto_scaling_policy, cluster_id
):
policy_copy = deepcopy(auto_scaling_policy)
def _patch_cluster_id_placeholder_in_autoscaling_policy(policy, cluster_id):
policy_copy = deepcopy(policy)
for rule in policy_copy["Rules"]:
for dimension in rule["Trigger"]["CloudWatchAlarmDefinition"]["Dimensions"]:
dimension["Value"] = cluster_id

View File

@ -15,14 +15,14 @@ from unittest.mock import patch
from moto.emrcontainers import REGION as DEFAULT_REGION
@pytest.fixture(scope="function")
def client():
@pytest.fixture(scope="function", name="client")
def fixture_client():
with mock_emrcontainers():
yield boto3.client("emr-containers", region_name=DEFAULT_REGION)
@pytest.fixture(scope="function")
def virtual_cluster_factory(client):
@pytest.fixture(scope="function", name="virtual_cluster_factory")
def fixture_virtual_cluster_factory(client):
if settings.TEST_SERVER_MODE:
raise SkipTest("Cant manipulate time in server mode")
@ -47,8 +47,8 @@ def virtual_cluster_factory(client):
yield cluster_list
@pytest.fixture()
def job_factory(client, virtual_cluster_factory):
@pytest.fixture(name="job_factory")
def fixture_job_factory(client, virtual_cluster_factory):
virtual_cluster_id = virtual_cluster_factory[0]
default_job_driver = {
"sparkSubmitJobDriver": {

View File

@ -19,14 +19,14 @@ def does_not_raise():
yield
@pytest.fixture(scope="function")
def client():
@pytest.fixture(scope="function", name="client")
def fixture_client():
with mock_emrserverless():
yield boto3.client("emr-serverless", region_name=DEFAULT_REGION)
@pytest.fixture(scope="function")
def application_factory(client):
@pytest.fixture(scope="function", name="application_factory")
def fixture_application_factory(client):
application_list = []
if settings.TEST_SERVER_MODE:

View File

@ -42,8 +42,8 @@ from .fixtures.schema_registry import (
)
@pytest.fixture
def client():
@pytest.fixture(name="client")
def fixture_client():
with mock_glue():
yield boto3.client("glue", region_name="us-east-1")

View File

@ -6,19 +6,19 @@ from botocore.exceptions import ClientError
from moto import mock_iot, mock_cognitoidentity
@pytest.fixture
def region_name():
@pytest.fixture(name="region_name")
def fixture_region_name():
return "ap-northeast-1"
@pytest.fixture
def iot_client(region_name):
@pytest.fixture(name="iot_client")
def fixture_iot_client(region_name):
with mock_iot():
yield boto3.client("iot", region_name=region_name)
@pytest.fixture
def policy(iot_client):
@pytest.fixture(name="policy")
def fixture_policy(iot_client):
return iot_client.create_policy(policyName="my-policy", policyDocument="{}")

View File

@ -6,13 +6,13 @@ PLAINTEXT = b"text"
REGION = "us-east-1"
@pytest.fixture
def backend():
@pytest.fixture(name="backend")
def fixture_backend():
return KmsBackend(REGION)
@pytest.fixture
def key(backend):
@pytest.fixture(name="key")
def fixture_key(backend):
return backend.create_key(
None, "ENCRYPT_DECRYPT", "SYMMETRIC_DEFAULT", "Test key", None
)

View File

@ -16,13 +16,11 @@ from moto.logs.models import MAX_RESOURCE_POLICIES_PER_REGION
TEST_REGION = "us-east-1" if settings.TEST_SERVER_MODE else "us-west-2"
@pytest.fixture
def json_policy_doc():
"""Returns a policy document in JSON format.
"""Returns a policy document in JSON format.
The ARN is bogus, but that shouldn't matter for the test.
"""
return json.dumps(
The ARN is bogus, but that shouldn't matter for the test.
"""
json_policy_doc = json.dumps(
{
"Version": "2012-10-17",
"Statement": [
@ -35,7 +33,7 @@ def json_policy_doc():
}
],
}
)
)
@mock_logs
@ -589,7 +587,7 @@ def test_put_resource_policy():
@mock_logs
def test_put_resource_policy_too_many(json_policy_doc):
def test_put_resource_policy_too_many():
client = boto3.client("logs", TEST_REGION)
# Create the maximum number of resource policies.
@ -617,7 +615,7 @@ def test_put_resource_policy_too_many(json_policy_doc):
@mock_logs
def test_delete_resource_policy(json_policy_doc):
def test_delete_resource_policy():
client = boto3.client("logs", TEST_REGION)
# Create a bunch of resource policies so we can give delete a workout.
@ -649,7 +647,7 @@ def test_delete_resource_policy(json_policy_doc):
@mock_logs
def test_describe_resource_policies(json_policy_doc):
def test_describe_resource_policies():
client = boto3.client("logs", TEST_REGION)
# Create the maximum number of resource policies so there's something

View File

@ -13,8 +13,8 @@ INVALID_ID_ERROR_MESSAGE = (
RESOURCE_NOT_FOUND_ERROR_MESSAGE = "Query does not exist."
@pytest.fixture(autouse=True)
def client():
@pytest.fixture(autouse=True, name="client")
def fixture_client():
yield boto3.client("redshift-data", region_name=REGION)

View File

@ -18,8 +18,8 @@ def headers(action):
}
@pytest.fixture(autouse=True)
def client():
@pytest.fixture(autouse=True, name="client")
def fixture_client():
backend = server.create_backend_app("redshift-data")
yield backend.test_client()

View File

@ -42,9 +42,10 @@ TEST_SERVERLESS_PRODUCTION_VARIANTS = [
]
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
def create_endpoint_config_helper(sagemaker_client, production_variants):
@ -72,7 +73,6 @@ def create_endpoint_config_helper(sagemaker_client, production_variants):
resp["ProductionVariants"].should.equal(production_variants)
@mock_sagemaker
def test_create_endpoint_config(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config(
@ -85,7 +85,6 @@ def test_create_endpoint_config(sagemaker_client):
create_endpoint_config_helper(sagemaker_client, TEST_PRODUCTION_VARIANTS)
@mock_sagemaker
def test_create_endpoint_config_serverless(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config(
@ -98,7 +97,6 @@ def test_create_endpoint_config_serverless(sagemaker_client):
create_endpoint_config_helper(sagemaker_client, TEST_SERVERLESS_PRODUCTION_VARIANTS)
@mock_sagemaker
def test_delete_endpoint_config(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME)
resp = sagemaker_client.create_endpoint_config(
@ -140,7 +138,6 @@ def test_delete_endpoint_config(sagemaker_client):
)
@mock_sagemaker
def test_create_endpoint_invalid_instance_type(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME)
@ -160,7 +157,6 @@ def test_create_endpoint_invalid_instance_type(sagemaker_client):
assert expected_message in e.value.response["Error"]["Message"]
@mock_sagemaker
def test_create_endpoint_invalid_memory_size(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME)
@ -180,7 +176,6 @@ def test_create_endpoint_invalid_memory_size(sagemaker_client):
assert expected_message in e.value.response["Error"]["Message"]
@mock_sagemaker
def test_create_endpoint(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint(
@ -221,7 +216,6 @@ def test_create_endpoint(sagemaker_client):
assert resp["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_delete_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -237,7 +231,6 @@ def test_delete_endpoint(sagemaker_client):
assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
@mock_sagemaker
def test_add_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -253,7 +246,6 @@ def test_add_tags_endpoint(sagemaker_client):
assert response["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_delete_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -273,7 +265,6 @@ def test_delete_tags_endpoint(sagemaker_client):
assert response["Tags"] == []
@mock_sagemaker
def test_list_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -298,7 +289,6 @@ def test_list_tags_endpoint(sagemaker_client):
assert response["Tags"] == tags[50:]
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
@ -342,7 +332,6 @@ def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client):
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
production_variants = [
{
@ -422,7 +411,6 @@ def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
resp["ProductionVariants"][1]["CurrentWeight"].should.equal(new_desired_weight)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant(
sagemaker_client,
):
@ -459,7 +447,6 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_vari
resp.should.equal(old_resp)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint(
sagemaker_client,
):
@ -497,7 +484,6 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endp
resp.should.equal(old_resp)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant(
sagemaker_client,
):

View File

@ -8,12 +8,12 @@ TEST_REGION_NAME = "us-east-1"
TEST_EXPERIMENT_NAME = "MyExperimentName"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@mock_sagemaker
def test_create_experiment(sagemaker_client):
resp = sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
@ -29,7 +29,6 @@ def test_create_experiment(sagemaker_client):
)
@mock_sagemaker
def test_list_experiments(sagemaker_client):
experiment_names = [f"some-experiment-name-{i}" for i in range(10)]
@ -57,7 +56,6 @@ def test_list_experiments(sagemaker_client):
assert resp.get("NextToken") is None
@mock_sagemaker
def test_delete_experiment(sagemaker_client):
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
@ -70,7 +68,6 @@ def test_delete_experiment(sagemaker_client):
assert len(resp["ExperimentSummaries"]) == 0
@mock_sagemaker
def test_add_tags_to_experiment(sagemaker_client):
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
@ -89,7 +86,6 @@ def test_add_tags_to_experiment(sagemaker_client):
assert resp["Tags"] == tags
@mock_sagemaker
def test_delete_tags_to_experiment(sagemaker_client):
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)

View File

@ -12,9 +12,10 @@ TEST_ARN = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
TEST_MODEL_NAME = "MyModelName"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MySageMakerModel(object):
@ -36,7 +37,6 @@ class MySageMakerModel(object):
return resp
@mock_sagemaker
def test_describe_model(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
@ -44,14 +44,12 @@ def test_describe_model(sagemaker_client):
assert model.get("ModelName").should.equal(TEST_MODEL_NAME)
@mock_sagemaker
def test_describe_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err:
sagemaker_client.describe_model(ModelName="unknown")
assert err.value.response["Error"]["Message"].should.contain("Could not find model")
@mock_sagemaker
def test_create_model(sagemaker_client):
vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"])
model = sagemaker_client.create_model(
@ -64,7 +62,6 @@ def test_create_model(sagemaker_client):
)
@mock_sagemaker
def test_delete_model(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
@ -74,14 +71,12 @@ def test_delete_model(sagemaker_client):
assert len(sagemaker_client.list_models()["Models"]).should.equal(0)
@mock_sagemaker
def test_delete_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err:
sagemaker_client.delete_model(ModelName="blah")
assert err.value.response["Error"]["Code"].should.equal("404")
@mock_sagemaker
def test_list_models(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
@ -93,7 +88,6 @@ def test_list_models(sagemaker_client):
)
@mock_sagemaker
def test_list_models_multiple(sagemaker_client):
name_model_1 = "blah"
arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
@ -108,13 +102,11 @@ def test_list_models_multiple(sagemaker_client):
assert len(models["Models"]).should.equal(2)
@mock_sagemaker
def test_list_models_none(sagemaker_client):
models = sagemaker_client.list_models()
assert len(models["Models"]).should.equal(0)
@mock_sagemaker
def test_add_tags_to_model(sagemaker_client):
model = MySageMakerModel().save(sagemaker_client)
resource_arn = model["ModelArn"]
@ -129,7 +121,6 @@ def test_add_tags_to_model(sagemaker_client):
assert response["Tags"] == tags
@mock_sagemaker
def test_delete_tags_from_model(sagemaker_client):
model = MySageMakerModel().save(sagemaker_client)
resource_arn = model["ModelArn"]

View File

@ -26,9 +26,10 @@ FAKE_NAME_PARAM = "MyNotebookInstance"
FAKE_INSTANCE_TYPE_PARAM = "ml.t2.medium"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
def _get_notebook_instance_arn(notebook_name):
@ -39,7 +40,6 @@ def _get_notebook_instance_lifecycle_arn(lifecycle_name):
return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance-lifecycle-configuration/{lifecycle_name}"
@mock_sagemaker
def test_create_notebook_instance_minimal_params(sagemaker_client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
@ -68,7 +68,6 @@ def test_create_notebook_instance_minimal_params(sagemaker_client):
# assert resp["RootAccess"] == True # ToDo: Not sure if this defaults...
@mock_sagemaker
def test_create_notebook_instance_params(sagemaker_client):
fake_direct_internet_access_param = "Enabled"
volume_size_in_gb_param = 7
@ -121,7 +120,6 @@ def test_create_notebook_instance_params(sagemaker_client):
assert resp["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
instance_type = "undefined_instance_type"
args = {
@ -139,7 +137,6 @@ def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
assert expected_message in ex.value.response["Error"]["Message"]
@mock_sagemaker
def test_notebook_instance_lifecycle(sagemaker_client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
@ -193,14 +190,12 @@ def test_notebook_instance_lifecycle(sagemaker_client):
assert ex.value.response["Error"]["Message"] == "RecordNotFound"
@mock_sagemaker
def test_describe_nonexistent_model(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.describe_model(ModelName="Nonexistent")
assert e.value.response["Error"]["Message"].startswith("Could not find model")
@mock_sagemaker
def test_notebook_instance_lifecycle_config(sagemaker_client):
name = "MyLifeCycleConfig"
on_create = [{"Content": "Create Script Line 1"}]
@ -252,7 +247,6 @@ def test_notebook_instance_lifecycle_config(sagemaker_client):
)
@mock_sagemaker
def test_add_tags_to_notebook(sagemaker_client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
@ -272,7 +266,6 @@ def test_add_tags_to_notebook(sagemaker_client):
assert response["Tags"] == tags
@mock_sagemaker
def test_delete_tags_from_notebook(sagemaker_client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,

View File

@ -12,9 +12,10 @@ FAKE_CONTAINER = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
TEST_REGION_NAME = "us-east-1"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MyProcessingJobModel(object):
@ -103,7 +104,6 @@ class MyProcessingJobModel(object):
return sagemaker_client.create_processing_job(**params)
@mock_sagemaker
def test_create_processing_job(sagemaker_client):
bucket = "my-bucket"
prefix = "my-prefix"
@ -150,7 +150,6 @@ def test_create_processing_job(sagemaker_client):
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
@mock_sagemaker
def test_list_processing_jobs(sagemaker_client):
test_processing_job = MyProcessingJobModel(
processing_job_name=FAKE_PROCESSING_JOB_NAME, role_arn=FAKE_ROLE_ARN
@ -170,7 +169,6 @@ def test_list_processing_jobs(sagemaker_client):
assert processing_jobs.get("NextToken") is None
@mock_sagemaker
def test_list_processing_jobs_multiple(sagemaker_client):
name_job_1 = "blah"
arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
@ -193,13 +191,11 @@ def test_list_processing_jobs_multiple(sagemaker_client):
assert processing_jobs.get("NextToken").should.be.none
@mock_sagemaker
def test_list_processing_jobs_none(sagemaker_client):
processing_jobs = sagemaker_client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0)
@mock_sagemaker
def test_list_processing_jobs_should_validate_input(sagemaker_client):
junk_status_equals = "blah"
with pytest.raises(ClientError) as ex:
@ -218,7 +214,6 @@ def test_list_processing_jobs_should_validate_input(sagemaker_client):
)
@mock_sagemaker
def test_list_processing_jobs_with_name_filters(sagemaker_client):
for i in range(5):
name = "xgboost-{}".format(i)
@ -243,7 +238,6 @@ def test_list_processing_jobs_with_name_filters(sagemaker_client):
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
@mock_sagemaker
def test_list_processing_jobs_paginated(sagemaker_client):
for i in range(5):
name = "xgboost-{}".format(i)
@ -273,7 +267,6 @@ def test_list_processing_jobs_paginated(sagemaker_client):
assert xgboost_processing_job_next.get("NextToken").should_not.be.none
@mock_sagemaker
def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
for i in range(5):
name = "xgboost-{}".format(i)
@ -316,7 +309,6 @@ def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
assert vgg_processing_job_10.get("NextToken").should.be.none
@mock_sagemaker
def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client):
for i in range(5):
name = "xgboost-{}".format(i)
@ -357,7 +349,6 @@ def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client
assert processing_jobs_with_2_next_next.get("NextToken").should.be.none
@mock_sagemaker
def test_add_and_delete_tags_in_training_job(sagemaker_client):
processing_job_name = "MyProcessingJob"
role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)

View File

@ -8,12 +8,12 @@ from moto import mock_sagemaker
TEST_REGION_NAME = "us-east-1"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@mock_sagemaker
def test_search(sagemaker_client):
experiment_name = "experiment_name"
trial_component_name = "trial_component_name"
@ -60,7 +60,6 @@ def test_search(sagemaker_client):
assert resp["Results"][0]["Trial"]["TrialName"] == trial_name
@mock_sagemaker
def test_search_trial_component_with_experiment_name(sagemaker_client):
experiment_name = "experiment_name"
trial_component_name = "trial_component_name"

View File

@ -888,10 +888,10 @@ def test_state_machine_get_execution_history_contains_expected_success_events_wh
execution_history["events"].should.equal(expected_events)
@pytest.mark.parametrize("region", ["us-west-2", "cn-northwest-1"])
@pytest.mark.parametrize("test_region", ["us-west-2", "cn-northwest-1"])
@mock_stepfunctions
def test_stepfunction_regions(region):
client = boto3.client("stepfunctions", region_name=region)
def test_stepfunction_regions(test_region):
client = boto3.client("stepfunctions", region_name=test_region)
resp = client.list_state_machines()
resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)