Techdebt: Improve some type annotations (#7072)

This commit is contained in:
Bert Blommers 2023-11-28 21:01:56 -01:00 committed by GitHub
parent bfac8a8a07
commit d6377ff905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 230 additions and 196 deletions

View File

@ -30,6 +30,10 @@ DEFAULT_COOLDOWN = 300
ASG_NAME_TAG = "aws:autoscaling:groupName"
def make_int(value: Union[None, str, int]) -> Optional[int]:
return int(value) if value is not None else value
class InstanceState:
def __init__(
self,
@ -406,13 +410,13 @@ class FakeAutoScalingGroup(CloudFormationModel):
min_size: Optional[int],
launch_config_name: str,
launch_template: Dict[str, Any],
vpc_zone_identifier: str,
vpc_zone_identifier: Optional[str],
default_cooldown: Optional[int],
health_check_period: Optional[int],
health_check_type: Optional[str],
load_balancers: List[str],
target_group_arns: List[str],
placement_group: str,
placement_group: Optional[str],
termination_policies: List[str],
autoscaling_backend: "AutoScalingBackend",
ec2_backend: EC2Backend,
@ -941,9 +945,6 @@ class AutoScalingBackend(BaseBackend):
def delete_launch_configuration(self, launch_configuration_name: str) -> None:
self.launch_configurations.pop(launch_configuration_name, None)
def make_int(self, value: Union[None, str, int]) -> Optional[int]:
return int(value) if value is not None else value
def put_scheduled_update_group_action(
self,
name: str,
@ -955,9 +956,9 @@ class AutoScalingBackend(BaseBackend):
end_time: str,
recurrence: str,
) -> FakeScheduledAction:
max_size = self.make_int(max_size)
min_size = self.make_int(min_size)
desired_capacity = self.make_int(desired_capacity)
max_size = make_int(max_size)
min_size = make_int(min_size)
desired_capacity = make_int(desired_capacity)
scheduled_action = FakeScheduledAction(
name=name,
@ -1010,13 +1011,13 @@ class AutoScalingBackend(BaseBackend):
min_size: Union[None, str, int],
launch_config_name: str,
launch_template: Dict[str, Any],
vpc_zone_identifier: str,
vpc_zone_identifier: Optional[str],
default_cooldown: Optional[int],
health_check_period: Union[None, str, int],
health_check_type: Optional[str],
load_balancers: List[str],
target_group_arns: List[str],
placement_group: str,
placement_group: Optional[str],
termination_policies: List[str],
tags: List[Dict[str, str]],
capacity_rebalance: bool = False,
@ -1024,10 +1025,10 @@ class AutoScalingBackend(BaseBackend):
instance_id: Optional[str] = None,
mixed_instance_policy: Optional[Dict[str, Any]] = None,
) -> FakeAutoScalingGroup:
max_size = self.make_int(max_size)
min_size = self.make_int(min_size)
desired_capacity = self.make_int(desired_capacity)
default_cooldown = self.make_int(default_cooldown)
max_size = make_int(max_size)
min_size = make_int(min_size)
desired_capacity = make_int(desired_capacity)
default_cooldown = make_int(default_cooldown)
# Verify only a single launch config-like parameter is provided.
params = [
@ -1064,9 +1065,7 @@ class AutoScalingBackend(BaseBackend):
launch_template=launch_template,
vpc_zone_identifier=vpc_zone_identifier,
default_cooldown=default_cooldown,
health_check_period=self.make_int(health_check_period)
if health_check_period
else 300,
health_check_period=make_int(health_check_period or 300),
health_check_type=health_check_type,
load_balancers=load_balancers,
target_group_arns=target_group_arns,

View File

@ -539,7 +539,7 @@ class LambdaFunction(CloudFormationModel, DockerModel):
self.run_time = spec.get("Runtime")
self.logs_backend = logs_backends[account_id][self.region]
self.environment_vars = spec.get("Environment", {}).get("Variables", {})
self.policy: Optional[Policy] = None
self.policy = Policy(self)
self.url_config: Optional[FunctionUrlConfig] = None
self.state = "Active"
self.reserved_concurrency = spec.get("ReservedConcurrentExecutions", None)
@ -582,7 +582,7 @@ class LambdaFunction(CloudFormationModel, DockerModel):
self._set_function_code(self.code)
self.function_arn = make_function_arn(
self.function_arn: str = make_function_arn(
self.region, self.account_id, self.function_name
)
@ -659,7 +659,11 @@ class LambdaFunction(CloudFormationModel, DockerModel):
]
if not all(layer_versions):
raise UnknownLayerVersionException(layers_versions_arns)
return [{"Arn": lv.arn, "CodeSize": lv.code_size} for lv in layer_versions]
# The `if lv` part is not necessary - we know there are no None's, because of the `all()`-check earlier
# But MyPy does not seem to understand this
# The `type: ignore` is because `code_size` is an int, and we're returning Dict[str, str]
# We should convert the return-type into a TypedDict the moment we drop Py3.7 support
return [{"Arn": lv.arn, "CodeSize": lv.code_size} for lv in layer_versions if lv] # type: ignore
def get_code_signing_config(self) -> Dict[str, Any]:
return {
@ -1577,8 +1581,6 @@ class LambdaStorage(object):
self._functions[fn.function_name]["latest"] = fn
else:
self._functions[fn.function_name] = {"latest": fn, "versions": []}
# instantiate a new policy for this version of the lambda
fn.policy = Policy(fn)
self._arns[fn.function_arn] = fn
def publish_function(

View File

@ -5,14 +5,17 @@ from moto.awslambda.exceptions import (
UnknownPolicyException,
)
from moto.moto_api._internal import mock_random
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Callable, Dict, List, Optional, TypeVar, TYPE_CHECKING
if TYPE_CHECKING:
from .models import LambdaFunction
TYPE_IDENTITY = TypeVar("TYPE_IDENTITY")
class Policy:
def __init__(self, parent: Any): # Parent should be a LambdaFunction
def __init__(self, parent: "LambdaFunction"):
self.revision = str(mock_random.uuid4())
self.statements: List[Dict[str, Any]] = []
self.parent = parent
@ -72,8 +75,6 @@ class Policy:
# converts AddPermission request to PolicyStatement
# https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
def decode_policy(self, obj: Dict[str, Any]) -> "Policy":
# import pydevd
# pydevd.settrace("localhost", port=5678)
policy = Policy(self.parent)
policy.revision = obj.get("RevisionId", "")

View File

@ -114,16 +114,13 @@ class ComputeEnvironment(CloudFormationModel):
backend = batch_backends[account_id][region_name]
properties = cloudformation_json["Properties"]
env = backend.create_compute_environment(
return backend.create_compute_environment(
resource_name,
properties["Type"],
properties.get("State", "ENABLED"),
lowercase_first_key(properties["ComputeResources"]),
properties["ServiceRole"],
)
arn = env[1]
return backend.get_compute_environment_by_arn(arn)
class JobQueue(CloudFormationModel):
@ -209,16 +206,13 @@ class JobQueue(CloudFormationModel):
for dict_item in properties["ComputeEnvironmentOrder"]
]
queue = backend.create_job_queue(
return backend.create_job_queue(
queue_name=resource_name,
priority=properties["Priority"],
state=properties.get("State", "ENABLED"),
compute_env_order=compute_envs,
schedule_policy={},
)
arn = queue[1]
return backend.get_job_queue_by_arn(arn)
class JobDefinition(CloudFormationModel):
@ -447,7 +441,7 @@ class JobDefinition(CloudFormationModel):
) -> "JobDefinition":
backend = batch_backends[account_id][region_name]
properties = cloudformation_json["Properties"]
res = backend.register_job_definition(
return backend.register_job_definition(
def_name=resource_name,
parameters=lowercase_first_key(properties.get("Parameters", {})),
_type="container",
@ -467,9 +461,6 @@ class JobDefinition(CloudFormationModel):
platform_capabilities=None,
propagate_tags=None,
)
arn = res[1]
return backend.get_job_definition_by_arn(arn)
class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
@ -1219,7 +1210,7 @@ class BatchBackend(BaseBackend):
state: str,
compute_resources: Dict[str, Any],
service_role: str,
) -> Tuple[str, str]:
) -> ComputeEnvironment:
# Validate
if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None:
raise InvalidParameterValueException(
@ -1304,7 +1295,7 @@ class BatchBackend(BaseBackend):
ecs_cluster = self.ecs_backend.create_cluster(cluster_name)
new_comp_env.set_ecs(ecs_cluster.arn, cluster_name)
return compute_environment_name, new_comp_env.arn
return new_comp_env
def _validate_compute_resources(self, cr: Dict[str, Any]) -> None:
"""
@ -1504,7 +1495,7 @@ class BatchBackend(BaseBackend):
state: str,
compute_env_order: List[Dict[str, str]],
tags: Optional[Dict[str, str]] = None,
) -> Tuple[str, str]:
) -> JobQueue:
for variable, var_name in (
(queue_name, "jobQueueName"),
(priority, "priority"),
@ -1550,7 +1541,7 @@ class BatchBackend(BaseBackend):
)
self._job_queues[queue.arn] = queue
return queue_name, queue.arn
return queue
def describe_job_queues(
self, job_queues: Optional[List[str]] = None
@ -1644,7 +1635,7 @@ class BatchBackend(BaseBackend):
timeout: Dict[str, int],
platform_capabilities: List[str],
propagate_tags: bool,
) -> Tuple[str, str, int]:
) -> JobDefinition:
if def_name is None:
raise ClientException("jobDefinitionName must be provided")
@ -1683,7 +1674,7 @@ class BatchBackend(BaseBackend):
self._job_definitions[job_def.arn] = job_def
return def_name, job_def.arn, job_def.revision
return job_def
def deregister_job_definition(self, def_name: str) -> None:
job_def = self.get_job_definition_by_arn(def_name)

View File

@ -31,7 +31,7 @@ class BatchResponse(BaseResponse):
state = self._get_param("state")
_type = self._get_param("type")
name, arn = self.batch_backend.create_compute_environment(
env = self.batch_backend.create_compute_environment(
compute_environment_name=compute_env_name,
_type=_type,
state=state,
@ -39,7 +39,7 @@ class BatchResponse(BaseResponse):
service_role=service_role,
)
result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
result = {"computeEnvironmentArn": env.arn, "computeEnvironmentName": env.name}
return json.dumps(result)
@ -91,7 +91,7 @@ class BatchResponse(BaseResponse):
state = self._get_param("state")
tags = self._get_param("tags")
name, arn = self.batch_backend.create_job_queue(
queue = self.batch_backend.create_job_queue(
queue_name=queue_name,
priority=priority,
schedule_policy=schedule_policy,
@ -100,7 +100,7 @@ class BatchResponse(BaseResponse):
tags=tags,
)
result = {"jobQueueArn": arn, "jobQueueName": name}
result = {"jobQueueArn": queue.arn, "jobQueueName": queue.name}
return json.dumps(result)
@ -157,7 +157,7 @@ class BatchResponse(BaseResponse):
timeout = self._get_param("timeout")
platform_capabilities = self._get_param("platformCapabilities")
propagate_tags = self._get_param("propagateTags")
name, arn, revision = self.batch_backend.register_job_definition(
job_def = self.batch_backend.register_job_definition(
def_name=def_name,
parameters=parameters,
_type=_type,
@ -171,9 +171,9 @@ class BatchResponse(BaseResponse):
)
result = {
"jobDefinitionArn": arn,
"jobDefinitionName": name,
"revision": revision,
"jobDefinitionArn": job_def.arn,
"jobDefinitionName": job_def.name,
"revision": job_def.revision,
}
return json.dumps(result)

View File

@ -140,7 +140,7 @@ class Trail(BaseModel):
)
def check_topic_exists(self) -> None:
if self.sns_topic_name:
if self.topic_arn:
from moto.sns import sns_backends
sns_backend = sns_backends[self.account_id][self.region_name]

View File

@ -69,7 +69,7 @@ class Repository(BaseObject, CloudFormationModel):
account_id: str,
region_name: str,
repository_name: str,
registry_id: str,
registry_id: Optional[str],
encryption_config: Optional[Dict[str, str]],
image_scan_config: str,
image_tag_mutablility: str,
@ -467,7 +467,7 @@ class ECRBackend(BaseBackend):
def create_repository(
self,
repository_name: str,
registry_id: str,
registry_id: Optional[str],
encryption_config: Dict[str, str],
image_scan_config: Any,
image_tag_mutablility: str,

View File

@ -693,7 +693,9 @@ class Service(BaseObject, CloudFormationModel):
):
# TODO: LoadBalancers
# TODO: Role
ecs_backend.delete_service(cluster_name, service_name)
ecs_backend.delete_service(
original_resource.cluster_name, service_name, force=True
)
return ecs_backend.create_service(
cluster_name,
new_resource_name,

View File

@ -169,16 +169,22 @@ class FakeLoadBalancer(CloudFormationModel):
port_policies: Dict[str, Any] = {}
for policy in policies:
policy_name = policy["PolicyName"]
other_policy = OtherPolicy(policy_name, "", [])
elb_backend.create_lb_other_policy(new_elb.name, other_policy)
policy_type_name = policy["PolicyType"]
policy_attrs = policy["Attributes"]
elb_backend.create_load_balancer_policy(
load_balancer_name=new_elb.name,
policy_name=policy_name,
policy_type_name=policy_type_name,
policy_attrs=policy_attrs,
)
for port in policy.get("InstancePorts", []):
policies_for_port: Any = port_policies.get(port, set())
policies_for_port.add(policy_name)
port_policies[port] = policies_for_port
for port, policies in port_policies.items():
elb_backend.set_load_balancer_policies_of_backend_server(
new_elb.name, port, list(policies)
elb_backend.set_load_balancer_policies_for_backend_server(
new_elb.name, int(port), list(policies)
)
health_check = properties.get("HealthCheck")
@ -552,7 +558,7 @@ class ELBBackend(BaseBackend):
if access_log:
load_balancer.attributes["access_log"] = access_log
def create_lb_other_policy(
def create_load_balancer_policy(
self,
load_balancer_name: str,
policy_name: str,
@ -586,7 +592,7 @@ class ELBBackend(BaseBackend):
load_balancer.policies.append(policy)
return load_balancer
def set_load_balancer_policies_of_backend_server(
def set_load_balancer_policies_for_backend_server(
self, load_balancer_name: str, instance_port: int, policies: List[str]
) -> FakeLoadBalancer:
load_balancer = self.get_load_balancer(load_balancer_name)

View File

@ -200,7 +200,7 @@ class ELBResponse(BaseResponse):
policy_type_name = self._get_param("PolicyTypeName")
policy_attrs = self._get_multi_param("PolicyAttributes.member.")
self.elb_backend.create_lb_other_policy(
self.elb_backend.create_load_balancer_policy(
load_balancer_name, policy_name, policy_type_name, policy_attrs
)
@ -269,7 +269,7 @@ class ELBResponse(BaseResponse):
]
if mb_backend:
policies = self._get_multi_param("PolicyNames.member")
self.elb_backend.set_load_balancer_policies_of_backend_server(
self.elb_backend.set_load_balancer_policies_for_backend_server(
load_balancer_name, instance_port, policies
)
# else: explode?

View File

@ -1075,6 +1075,9 @@ class AccessKeyLastUsed:
def timestamp(self) -> str:
return iso_8601_datetime_without_milliseconds(self._timestamp) # type: ignore
def strftime(self, date_format: str) -> str:
return self._timestamp.strftime(date_format)
class AccessKey(CloudFormationModel):
def __init__(
@ -1091,7 +1094,15 @@ class AccessKey(CloudFormationModel):
self.secret_access_key = random_alphanumeric(40)
self.status = status
self.create_date = utcnow()
self.last_used: Optional[datetime] = None
# Some users will set this field manually
# And they will be setting this value to a `datetime`
# https://github.com/getmoto/moto/issues/5927#issuecomment-1738188283
#
# The `to_csv` method calls `last_used.strptime`, which currently works on both AccessKeyLastUsed and datetime
# In the next major release we should communicate that this only accepts AccessKeyLastUsed
# (And rework to_csv accordingly)
self.last_used: Optional[AccessKeyLastUsed] = None
self.role_arn: Optional[str] = None
@property

View File

@ -1244,7 +1244,7 @@ class IoTBackend(BaseBackend):
cognito = cognitoidentity_backends[self.account_id][self.region_name]
identities = []
for identity_pool in cognito.identity_pools:
pool_identities = cognito.pools_identities.get(identity_pool, None)
pool_identities = cognito.pools_identities.get(identity_pool, [])
identities.extend(
[pi["IdentityId"] for pi in pool_identities.get("Identities", [])]
)

View File

@ -1271,13 +1271,16 @@ class LogsBackend(BaseBackend):
self.export_tasks[task_id].status["message"] = "Task is completed"
return task_id
def describe_export_tasks(self, taskId: str = "") -> Tuple[List[ExportTask], str]:
if taskId:
if taskId not in self.export_tasks:
def describe_export_tasks(self, task_id: str) -> List[ExportTask]:
"""
Pagination is not yet implemented
"""
if task_id:
if task_id not in self.export_tasks:
raise ResourceNotFoundException()
return [self.export_tasks[taskId]], ""
return [self.export_tasks[task_id]]
else:
return list(self.export_tasks.values()), ""
return list(self.export_tasks.values())
def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]:
return self.tagger.get_tag_dict_for_resource(resource_arn)

View File

@ -446,10 +446,8 @@ class LogsResponse(BaseResponse):
def describe_export_tasks(self) -> str:
task_id = self._get_param("taskId")
tasks, next_token = self.logs_backend.describe_export_tasks(taskId=task_id)
return json.dumps(
{"exportTasks": [task.to_json() for task in tasks], "nextToken": next_token}
)
tasks = self.logs_backend.describe_export_tasks(task_id=task_id)
return json.dumps({"exportTasks": [t.to_json() for t in tasks]})
def list_tags_for_resource(self) -> str:
resource_arn = self._get_param("resourceArn")

View File

@ -8,6 +8,7 @@ from dateutil.tz import tzutc
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.ec2 import ec2_backends
from moto.ec2.models.security_groups import SecurityGroup as EC2SecurityGroup
from moto.moto_api._internal import mock_random
from .exceptions import (
ClusterAlreadyExistsFaultError,
@ -242,7 +243,7 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
]
@property
def vpc_security_groups(self) -> List["SecurityGroup"]:
def vpc_security_groups(self) -> List["EC2SecurityGroup"]:
return [
security_group
for security_group in self.redshift_backend.ec2_backend.describe_security_groups()

View File

@ -230,9 +230,8 @@ class RecordSet(CloudFormationModel):
zone_name = properties.get("HostedZoneName")
backend = route53_backends[account_id]["global"]
if zone_name:
hosted_zone = backend.get_hosted_zone_by_name(zone_name)
else:
hosted_zone = backend.get_hosted_zone_by_name(zone_name) if zone_name else None
if hosted_zone is None:
hosted_zone = backend.get_hosted_zone(properties["HostedZoneId"])
record_set = hosted_zone.add_rrset(properties)
return record_set
@ -267,9 +266,8 @@ class RecordSet(CloudFormationModel):
zone_name = properties.get("HostedZoneName")
backend = route53_backends[account_id]["global"]
if zone_name:
hosted_zone = backend.get_hosted_zone_by_name(zone_name)
else:
hosted_zone = backend.get_hosted_zone_by_name(zone_name) if zone_name else None
if hosted_zone is None:
hosted_zone = backend.get_hosted_zone(properties["HostedZoneId"])
try:
@ -286,7 +284,11 @@ class RecordSet(CloudFormationModel):
) -> None:
"""Not exposed as part of the Route 53 API - used for CloudFormation"""
backend = route53_backends[account_id]["global"]
hosted_zone = backend.get_hosted_zone_by_name(self.hosted_zone_name)
hosted_zone = (
backend.get_hosted_zone_by_name(self.hosted_zone_name)
if self.hosted_zone_name
else None
)
if not hosted_zone:
hosted_zone = backend.get_hosted_zone(self.hosted_zone_id)
hosted_zone.delete_rrset({"Name": self.name, "Type": self.type_})
@ -460,9 +462,8 @@ class RecordSetGroup(CloudFormationModel):
zone_name = properties.get("HostedZoneName")
backend = route53_backends[account_id]["global"]
if zone_name:
hosted_zone = backend.get_hosted_zone_by_name(zone_name)
else:
hosted_zone = backend.get_hosted_zone_by_name(zone_name) if zone_name else None
if hosted_zone is None:
hosted_zone = backend.get_hosted_zone(properties["HostedZoneId"])
record_sets = properties["RecordSets"]
for record_set in record_sets:

View File

@ -41,6 +41,7 @@ from .exceptions import (
InvalidPermissionType,
InvalidResourceId,
InvalidResourceType,
ParameterAlreadyExists,
)
@ -211,6 +212,8 @@ class Parameter(CloudFormationModel):
tags: Optional[List[Dict[str, str]]] = None,
labels: Optional[List[str]] = None,
source_result: Optional[str] = None,
tier: Optional[str] = None,
policies: Optional[str] = None,
):
self.account_id = account_id
self.name = name
@ -224,6 +227,8 @@ class Parameter(CloudFormationModel):
self.tags = tags or []
self.labels = labels or []
self.source_result = source_result
self.tier = tier
self.policies = policies
if self.parameter_type == "SecureString":
if not self.keyid:
@ -255,7 +260,17 @@ class Parameter(CloudFormationModel):
"Version": self.version,
"LastModifiedDate": round(self.last_modified_date, 3),
"DataType": self.data_type,
"Tier": self.tier,
}
if self.policies:
try:
policy_list = json.loads(self.policies)
r["Policies"] = [
{"PolicyText": p, "PolicyType": p, "PolicyStatus": "Finished"}
for p in policy_list
]
except json.JSONDecodeError:
pass
if self.source_result:
r["SourceResult"] = self.source_result
@ -316,10 +331,10 @@ class Parameter(CloudFormationModel):
"overwrite": properties.get("Overwrite", False),
"tags": properties.get("Tags", None),
"data_type": properties.get("DataType", "text"),
"tier": properties.get("Tier"),
"policies": properties.get("Policies"),
}
ssm_backend.put_parameter(**parameter_args)
parameter = ssm_backend.get_parameter(properties.get("Name"))
return parameter
return ssm_backend.put_parameter(**parameter_args)
@classmethod
def update_from_cloudformation_json( # type: ignore[misc]
@ -442,7 +457,6 @@ class Documents(BaseModel):
version_name: Optional[str] = None,
strict: bool = True,
) -> "Document":
if document_version == "$LATEST":
ssm_document: Optional["Document"] = self.get_latest_version()
elif version_name and document_version:
@ -547,7 +561,6 @@ class Documents(BaseModel):
self.permissions.pop(account_id, None)
def describe_permissions(self) -> Dict[str, Any]:
permissions_ordered_by_date = sorted(
self.permissions.values(), key=lambda p: p.created_at
)
@ -678,7 +691,6 @@ class Command(BaseModel):
targets: Optional[List[Dict[str, Any]]] = None,
backend_region: str = "us-east-1",
):
if instance_ids is None:
instance_ids = []
@ -1336,7 +1348,6 @@ class SimpleSystemManagerBackend(BaseBackend):
def get_document(
self, name: str, document_version: str, version_name: str, document_format: str
) -> Dict[str, Any]:
documents = self._get_documents(name)
ssm_document = documents.find(document_version, version_name)
@ -1488,7 +1499,6 @@ class SimpleSystemManagerBackend(BaseBackend):
shared_document_version: str,
permission_type: str,
) -> None:
account_id_regex = re.compile(r"^(all|[0-9]{12})$", re.IGNORECASE)
version_regex = re.compile(r"^([$]LATEST|[$]DEFAULT|[$]ALL)$")
@ -1810,7 +1820,6 @@ class SimpleSystemManagerBackend(BaseBackend):
def get_parameter_history(
self, name: str, next_token: Optional[str], max_results: int = 50
) -> Tuple[Optional[List[Parameter]], Optional[str]]:
if max_results > PARAMETER_HISTORY_MAX_RESULTS:
raise ValidationException(
"1 validation error detected: "
@ -2034,7 +2043,9 @@ class SimpleSystemManagerBackend(BaseBackend):
overwrite: bool,
tags: List[Dict[str, str]],
data_type: str,
) -> Optional[int]:
tier: Optional[str],
policies: Optional[str],
) -> Parameter:
if not value:
raise ValidationException(
"1 validation error detected: Value '' at 'value' failed to satisfy"
@ -2097,7 +2108,7 @@ class SimpleSystemManagerBackend(BaseBackend):
version = previous_parameter.version + 1
if not overwrite:
return None
raise ParameterAlreadyExists
# overwriting a parameter, Type is not included in boto3 call
if not parameter_type and overwrite:
parameter_type = previous_parameter.parameter_type
@ -2123,29 +2134,32 @@ class SimpleSystemManagerBackend(BaseBackend):
data_type = (
data_type if data_type is not None else previous_parameter.data_type
)
tier = tier if tier is not None else previous_parameter.tier
policies = policies if policies is not None else previous_parameter.policies
last_modified_date = time.time()
self._parameters[name].append(
Parameter(
account_id=self.account_id,
name=name,
value=value,
parameter_type=parameter_type,
description=description,
allowed_pattern=allowed_pattern,
keyid=keyid,
last_modified_date=last_modified_date,
version=version,
tags=tags or [],
data_type=data_type,
)
new_param = Parameter(
account_id=self.account_id,
name=name,
value=value,
parameter_type=parameter_type,
description=description,
allowed_pattern=allowed_pattern,
keyid=keyid,
last_modified_date=last_modified_date,
version=version,
tags=tags or [],
data_type=data_type,
tier=tier,
policies=policies,
)
self._parameters[name].append(new_param)
if tags:
tag_dict = {t["Key"]: t["Value"] for t in tags}
self.add_tags_to_resource("Parameter", name, tag_dict)
return version
return new_param
def add_tags_to_resource(
self, resource_type: str, resource_id: str, tags: Dict[str, str]
@ -2460,7 +2474,6 @@ class SimpleSystemManagerBackend(BaseBackend):
cutoff_behavior: Optional[str],
alarm_configurations: Optional[Dict[str, Any]],
) -> str:
window = self.get_maintenance_window(window_id)
task = FakeMaintenanceWindowTask(
window_id,

View File

@ -1,9 +1,8 @@
import json
from typing import Any, Dict, Tuple, Union
import warnings
from moto.core.responses import BaseResponse
from .exceptions import ValidationException, ParameterAlreadyExists
from .exceptions import ValidationException
from .models import ssm_backends, SimpleSystemManagerBackend
@ -273,20 +272,10 @@ class SimpleSystemManagerResponse(BaseResponse):
overwrite = self._get_param("Overwrite", False)
tags = self._get_param("Tags")
data_type = self._get_param("DataType", "text")
# To be implemented arguments of put_parameter
tier = self._get_param("Tier")
if tier is not None:
warnings.warn(
"Tier configuration option is not yet implemented. Discarding."
)
policies = self._get_param("Policies")
if policies is not None:
warnings.warn(
"Policies configuration option is not yet implemented. Discarding."
)
result = self.ssm_backend.put_parameter(
param = self.ssm_backend.put_parameter(
name,
description,
value,
@ -296,12 +285,11 @@ class SimpleSystemManagerResponse(BaseResponse):
overwrite,
tags,
data_type,
tier=tier,
policies=policies,
)
if result is None:
raise ParameterAlreadyExists
response = {"Version": result}
response = {"Version": param.version}
return json.dumps(response)
def get_parameter_history(self) -> Union[str, Tuple[str, Dict[str, int]]]:

View File

@ -177,6 +177,7 @@ def test_update_service_through_cloudformation_should_trigger_replacement():
cfn_conn = boto3.client("cloudformation", region_name="us-west-1")
cfn_conn.create_stack(StackName="test_stack", TemplateBody=template_json1)
template2 = deepcopy(template1)
template2["Resources"]["testCluster"]["Properties"]["ClusterName"] = "updated name"
template2["Resources"]["testService"]["Properties"]["DesiredCount"] = 5
template2_json = json.dumps(template2)
cfn_conn.update_stack(StackName="test_stack", TemplateBody=template2_json)

View File

@ -112,7 +112,19 @@ def test_stack_elb_integration_with_update():
"Protocol": "HTTP",
}
],
"Policies": {"Ref": "AWS::NoValue"},
"Policies": [
{
"PolicyName": "My-SSLNegotiation-Policy",
"PolicyType": "SSLNegotiationPolicyType",
"InstancePorts": ["80"],
"Attributes": [
{
"Name": "Reference-Security-Policy",
"Value": "ELBSecurityPolicy-TLS-1-2-2017-01",
}
],
}
],
},
}
},
@ -127,6 +139,7 @@ def test_stack_elb_integration_with_update():
elb = boto3.client("elb", region_name="us-west-1")
load_balancer = elb.describe_load_balancers()["LoadBalancerDescriptions"][0]
assert load_balancer["AvailabilityZones"] == ["us-west-1a"]
assert load_balancer["Policies"]["OtherPolicies"] == ["My-SSLNegotiation-Policy"]
# when
elb_template["Resources"]["MyELB"]["Properties"]["AvailabilityZones"] = [

View File

@ -1,9 +1,12 @@
import datetime
import boto3
import csv
from moto import mock_ec2, mock_iam, mock_sts, settings
from moto.iam.models import iam_backends, IAMBackend
from dateutil.parser import parse
from tests import DEFAULT_ACCOUNT_ID
from unittest import SkipTest
@mock_ec2
@ -60,3 +63,35 @@ def test_mark_role_as_last_used():
if not settings.TEST_SERVER_MODE:
iam: IAMBackend = iam_backends[DEFAULT_ACCOUNT_ID]["global"]
assert iam.get_role(role_name).last_used is not None
@mock_ec2
@mock_iam
def test_get_credential_report_content__set_last_used_automatically():
if not settings.TEST_DECORATOR_MODE:
raise SkipTest("No point testing this in ServerMode")
# Ensure LAST_USED field is set
c_iam = boto3.client("iam", region_name="us-east-1")
c_iam.create_user(Path="my/path", UserName="fakeUser")
key = c_iam.create_access_key(UserName="fakeUser")
c_ec2 = boto3.client(
"ec2",
region_name="us-east-2",
aws_access_key_id=key["AccessKey"]["AccessKeyId"],
aws_secret_access_key=key["AccessKey"]["SecretAccessKey"],
)
c_ec2.describe_instances()
# VERIFY last_used can be retrieved
conn = boto3.client("iam", region_name="us-east-1")
result = conn.generate_credential_report()
while result["State"] != "COMPLETE":
result = conn.generate_credential_report()
result = conn.get_credential_report()
report = result["Content"].decode("utf-8")
report_dict = csv.DictReader(report.split("\n"))
user = next(report_dict)
assert parse(user["access_key_1_last_used_date"])

View File

@ -2,15 +2,13 @@ import datetime
import re
import string
import uuid
from unittest.mock import patch, Mock
from unittest import SkipTest
import boto3
import botocore.exceptions
from botocore.exceptions import ClientError
import pytest
from moto import mock_ec2, mock_ssm, settings
from moto import mock_ec2, mock_ssm
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from moto.ssm.models import PARAMETER_VERSION_LIMIT, PARAMETER_HISTORY_MAX_RESULTS
from tests import EXAMPLE_AMI_ID
@ -304,55 +302,6 @@ def test_put_parameter(name):
)
@mock_ssm
def test_put_parameter_unimplemented_parameters():
"""
Test to ensure coverage of unimplemented parameters. Remove for appropriate tests
once implemented
"""
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't test for warning logs in server mode")
mock_warn = Mock()
with patch("warnings.warn", mock_warn):
# Ensure that the ssm parameters are still working with the Mock
client = boto3.client("ssm", region_name="us-east-1")
name = "my-param"
response = client.put_parameter(
Name=name,
Description="A test parameter",
Value="value",
Type="String",
Tier="Advanced",
Policies="No way fam",
)
assert response["Version"] == 1
response = client.get_parameters(Names=[name], WithDecryption=False)
assert len(response["Parameters"]) == 1
assert response["Parameters"][0]["Name"] == name
assert response["Parameters"][0]["Value"] == "value"
assert response["Parameters"][0]["Type"] == "String"
assert response["Parameters"][0]["Version"] == 1
assert response["Parameters"][0]["DataType"] == "text"
assert isinstance(
response["Parameters"][0]["LastModifiedDate"], datetime.datetime
)
assert response["Parameters"][0]["ARN"] == (
f"arn:aws:ssm:us-east-1:{ACCOUNT_ID}:parameter/{name}"
)
# We got the argument warnings
mock_warn.assert_any_call(
"Tier configuration option is not yet implemented. Discarding."
)
mock_warn.assert_any_call(
"Policies configuration option is not yet implemented. Discarding."
)
@pytest.mark.parametrize("name", ["test", "my-cool-parameter"])
@mock_ssm
def test_put_parameter_overwrite_preserves_metadata(name):
@ -367,12 +316,11 @@ def test_put_parameter_overwrite_preserves_metadata(name):
Description=test_description,
Value="value",
Type="String",
Tags=[
{"Key": test_tag_key, "Value": test_tag_value},
],
Tags=[{"Key": test_tag_key, "Value": test_tag_value}],
AllowedPattern=test_pattern,
KeyId=test_key_id,
# TODO: add tier and policies support
Tier="Standard",
Policies='["Expiration"]',
)
assert response["Version"] == 1
@ -430,7 +378,6 @@ def test_put_parameter_overwrite_preserves_metadata(name):
assert response["Parameters"][0]["ARN"] == (
f"arn:aws:ssm:us-east-1:{ACCOUNT_ID}:parameter/{name}"
)
initial_modification_date = response["Parameters"][0]["LastModifiedDate"]
# Verify that tags are unchanged
response = client.list_tags_for_resource(ResourceType="Parameter", ResourceId=name)
@ -439,20 +386,42 @@ def test_put_parameter_overwrite_preserves_metadata(name):
assert response["TagList"][0]["Key"] == test_tag_key
assert response["TagList"][0]["Value"] == test_tag_value
# Verify description is unchanged
# Verify description/tier/policies is unchanged
response = client.describe_parameters(
ParameterFilters=[
{
"Key": "Name",
"Option": "Equals",
"Values": [name],
},
]
ParameterFilters=[{"Key": "Name", "Option": "Equals", "Values": [name]}]
)
assert len(response["Parameters"]) == 1
assert response["Parameters"][0]["Description"] == test_description
assert response["Parameters"][0]["AllowedPattern"] == test_pattern
assert response["Parameters"][0]["KeyId"] == test_key_id
assert response["Parameters"][0]["Tier"] == "Standard"
assert response["Parameters"][0]["Policies"] == [
{
"PolicyStatus": "Finished",
"PolicyText": "Expiration",
"PolicyType": "Expiration",
}
]
@mock_ssm
def test_put_parameter_with_invalid_policy():
name = "some_param"
test_description = "A test parameter"
client = boto3.client("ssm", region_name="us-east-1")
client.put_parameter(
Name=name,
Description=test_description,
Value="value",
Type="String",
Policies="invalid json",
)
# Verify that an invalid policy does not break anything
param = client.describe_parameters(
ParameterFilters=[{"Key": "Name", "Option": "Equals", "Values": [name]}]
)["Parameters"][0]
assert "Policies" not in param
@mock_ssm