diff --git a/moto/ssm/exceptions.py b/moto/ssm/exceptions.py index e573fff84..4c6b69ca6 100644 --- a/moto/ssm/exceptions.py +++ b/moto/ssm/exceptions.py @@ -4,138 +4,138 @@ from moto.core.exceptions import JsonRESTError class InvalidFilterKey(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidFilterKey", message) class InvalidFilterOption(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidFilterOption", message) class InvalidFilterValue(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidFilterValue", message) class InvalidResourceId(JsonRESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__("InvalidResourceId", "Invalid Resource Id") class InvalidResourceType(JsonRESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__("InvalidResourceType", "Invalid Resource Type") class ParameterNotFound(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("ParameterNotFound", message) class ParameterVersionNotFound(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("ParameterVersionNotFound", message) class ParameterVersionLabelLimitExceeded(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("ParameterVersionLabelLimitExceeded", message) class ValidationException(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("ValidationException", message) class DocumentAlreadyExists(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("DocumentAlreadyExists", message) class DocumentPermissionLimit(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("DocumentPermissionLimit", message) class InvalidPermissionType(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidPermissionType", message) class InvalidDocument(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidDocument", message) class InvalidDocumentOperation(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidDocumentOperation", message) class AccessDeniedException(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("AccessDeniedException", message) class InvalidDocumentContent(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidDocumentContent", message) class InvalidDocumentVersion(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidDocumentVersion", message) class DuplicateDocumentVersionName(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("DuplicateDocumentVersionName", message) class DuplicateDocumentContent(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("DuplicateDocumentContent", message) class ParameterMaxVersionLimitExceeded(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("ParameterMaxVersionLimitExceeded", message) diff --git a/moto/ssm/models.py b/moto/ssm/models.py index 3ae5247b6..120a497c3 100644 --- a/moto/ssm/models.py +++ b/moto/ssm/models.py @@ -1,6 +1,7 @@ import re from dataclasses import dataclass -from typing import Dict +from typing import Any, Dict, List, Iterator, Optional, Tuple +from typing import DefaultDict from collections import defaultdict @@ -42,9 +43,9 @@ from .exceptions import ( ) -class ParameterDict(defaultdict): - def __init__(self, account_id, region_name): - # each value is a list of all of the versions for a parameter +class ParameterDict(DefaultDict[str, List["Parameter"]]): + def __init__(self, account_id: str, region_name: str): + # each value is a list of all the versions for a parameter # to get the current value, grab the last item of the list super().__init__(list) self.latest_amis_loaded = False @@ -54,7 +55,7 @@ class ParameterDict(defaultdict): self.account_id = account_id self.region_name = region_name - def _check_loading_status(self, key): + def _check_loading_status(self, key: str) -> None: key = str(key or "") if key.startswith("/aws/service/ami-amazon-linux-latest"): if not self.latest_amis_loaded: @@ -75,7 +76,7 @@ class ParameterDict(defaultdict): ) self.latest_ecs_amis_loaded = True - def _load_latest_amis(self): + def _load_latest_amis(self) -> None: try: latest_amis_linux = load_resource( __name__, f"resources/ami-amazon-linux-latest/{self.region_name}.json" @@ -99,7 +100,7 @@ class ParameterDict(defaultdict): ) ) - def _load_tree_parameters(self, path: str): + def _load_tree_parameters(self, path: str) -> None: try: params = convert_to_params(load_resource(__name__, path)) except FileNotFoundError: @@ -127,7 +128,7 @@ class ParameterDict(defaultdict): ) ) - def _get_secretsmanager_parameter(self, secret_name): + def _get_secretsmanager_parameter(self, secret_name: str) -> List["Parameter"]: secrets_backend = secretsmanager_backends[self.account_id][self.region_name] secret = secrets_backend.describe_secret(secret_name) version_id_to_stage = secret["VersionIdsToStages"] @@ -161,13 +162,13 @@ class ParameterDict(defaultdict): for val in values ] - def __getitem__(self, item): + def __getitem__(self, item: str) -> List["Parameter"]: if item.startswith("/aws/reference/secretsmanager/"): return self._get_secretsmanager_parameter("/".join(item.split("/")[4:])) self._check_loading_status(item) return super().__getitem__(item) - def __contains__(self, k): + def __contains__(self, k: str) -> bool: # type: ignore[override] if k and k.startswith("/aws/reference/secretsmanager/"): try: param = self._get_secretsmanager_parameter("/".join(k.split("/")[4:])) @@ -179,7 +180,7 @@ class ParameterDict(defaultdict): self._check_loading_status(k) return super().__contains__(k) - def get_keys_beginning_with(self, path, recursive): + def get_keys_beginning_with(self, path: str, recursive: bool) -> Iterator[str]: self._check_loading_status(path) for param_name in self: if path != "/" and not param_name.startswith(path): @@ -196,19 +197,19 @@ PARAMETER_HISTORY_MAX_RESULTS = 50 class Parameter(CloudFormationModel): def __init__( self, - account_id, - name, - value, - parameter_type, - description, - allowed_pattern, - keyid, - last_modified_date, - version, - data_type, - tags=None, - labels=None, - source_result=None, + account_id: str, + name: str, + value: str, + parameter_type: str, + description: Optional[str], + allowed_pattern: Optional[str], + keyid: Optional[str], + last_modified_date: float, + version: int, + data_type: str, + tags: Optional[List[Dict[str, str]]] = None, + labels: Optional[List[str]] = None, + source_result: Optional[str] = None, ): self.account_id = account_id self.name = name @@ -231,19 +232,22 @@ class Parameter(CloudFormationModel): else: self.value = value - def encrypt(self, value): + def encrypt(self, value: str) -> str: return f"kms:{self.keyid}:" + value - def decrypt(self, value): + def decrypt(self, value: str) -> Optional[str]: if self.type != "SecureString": return value prefix = f"kms:{self.keyid or 'default'}:" if value.startswith(prefix): return value[len(prefix) :] + return None - def response_object(self, decrypt=False, region=None): - r = { + def response_object( + self, decrypt: bool = False, region: Optional[str] = None + ) -> Dict[str, Any]: + r: Dict[str, Any] = { "Name": self.name, "Type": self.type, "Value": self.decrypt(self.value) if decrypt else self.value, @@ -259,8 +263,10 @@ class Parameter(CloudFormationModel): return r - def describe_response_object(self, decrypt=False, include_labels=False): - r = self.response_object(decrypt) + def describe_response_object( + self, decrypt: bool = False, include_labels: bool = False + ) -> Dict[str, Any]: + r: Dict[str, Any] = self.response_object(decrypt) r["LastModifiedDate"] = round(self.last_modified_date, 3) r["LastModifiedUser"] = "N/A" @@ -279,18 +285,23 @@ class Parameter(CloudFormationModel): return r @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "Name" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ssm-parameter.html return "AWS::SSM::Parameter" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Parameter": ssm_backend = ssm_backends[account_id][region_name] properties = cloudformation_json["Properties"] @@ -310,14 +321,14 @@ class Parameter(CloudFormationModel): return parameter @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "Parameter": cls.delete_from_cloudformation_json( original_resource.name, cloudformation_json, account_id, region_name ) @@ -326,9 +337,13 @@ class Parameter(CloudFormationModel): ) @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: ssm_backend = ssm_backends[account_id][region_name] properties = cloudformation_json["Properties"] @@ -338,7 +353,9 @@ class Parameter(CloudFormationModel): MAX_TIMEOUT_SECONDS = 3600 -def generate_ssm_doc_param_list(parameters): +def generate_ssm_doc_param_list( + parameters: Dict[str, Any] +) -> Optional[List[Dict[str, Any]]]: if not parameters: return None param_list = [] @@ -370,24 +387,26 @@ def generate_ssm_doc_param_list(parameters): class AccountPermission: account_id: str version: str - created_at: datetime + created_at: datetime.datetime class Documents(BaseModel): - def __init__(self, ssm_document): + def __init__(self, ssm_document: "Document"): version = ssm_document.document_version self.versions = {version: ssm_document} self.default_version = version self.latest_version = version - self.permissions = {} # {AccountID: AccountPermission } + self.permissions: Dict[ + str, AccountPermission + ] = {} # {AccountID: AccountPermission } - def get_default_version(self): - return self.versions.get(self.default_version) + def get_default_version(self) -> "Document": + return self.versions[self.default_version] - def get_latest_version(self): - return self.versions.get(self.latest_version) + def get_latest_version(self) -> "Document": + return self.versions[self.latest_version] - def find_by_version_name(self, version_name): + def find_by_version_name(self, version_name: str) -> Optional["Document"]: return next( ( document @@ -397,10 +416,12 @@ class Documents(BaseModel): None, ) - def find_by_version(self, version): + def find_by_version(self, version: str) -> Optional["Document"]: return self.versions.get(version) - def find_by_version_and_version_name(self, version, version_name): + def find_by_version_and_version_name( + self, version: str, version_name: str + ) -> Optional["Document"]: return next( ( document @@ -410,10 +431,15 @@ class Documents(BaseModel): None, ) - def find(self, document_version=None, version_name=None, strict=True): + def find( + self, + document_version: Optional[str] = None, + version_name: Optional[str] = None, + strict: bool = True, + ) -> "Document": if document_version == "$LATEST": - ssm_document = self.get_latest_version() + ssm_document: Optional["Document"] = self.get_latest_version() elif version_name and document_version: ssm_document = self.find_by_version_and_version_name( document_version, version_name @@ -428,24 +454,26 @@ class Documents(BaseModel): if strict and not ssm_document: raise InvalidDocument("The specified document does not exist.") - return ssm_document + return ssm_document # type: ignore - def exists(self, document_version=None, version_name=None): + def exists( + self, document_version: Optional[str] = None, version_name: Optional[str] = None + ) -> bool: return self.find(document_version, version_name, strict=False) is not None - def add_new_version(self, new_document_version): + def add_new_version(self, new_document_version: "Document") -> None: version = new_document_version.document_version self.latest_version = version self.versions[version] = new_document_version - def update_default_version(self, version): + def update_default_version(self, version: str) -> "Document": ssm_document = self.find_by_version(version) if not ssm_document: raise InvalidDocument("The specified document does not exist.") self.default_version = version return ssm_document - def delete(self, *versions): + def delete(self, *versions: str) -> None: for version in versions: if version in self.versions: del self.versions[version] @@ -455,9 +483,14 @@ class Documents(BaseModel): new_latest_version = ordered_versions[-1] self.latest_version = new_latest_version - def describe(self, document_version=None, version_name=None, tags=None): + def describe( + self, + document_version: Optional[str] = None, + version_name: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + ) -> Dict[str, Any]: document = self.find(document_version, version_name) - base = { + base: Dict[str, Any] = { "Hash": document.hash, "HashType": "Sha256", "Name": document.name, @@ -483,7 +516,9 @@ class Documents(BaseModel): return base - def modify_permissions(self, accounts_to_add, accounts_to_remove, version): + def modify_permissions( + self, accounts_to_add: List[str], accounts_to_remove: List[str], version: str + ) -> None: version = version or "$DEFAULT" if accounts_to_add: if "all" in accounts_to_add: @@ -506,7 +541,7 @@ class Documents(BaseModel): for account_id in accounts_to_remove: self.permissions.pop(account_id, None) - def describe_permissions(self): + def describe_permissions(self) -> Dict[str, Any]: permissions_ordered_by_date = sorted( self.permissions.values(), key=lambda p: p.created_at @@ -520,23 +555,23 @@ class Documents(BaseModel): ], } - def is_shared(self): + def is_shared(self) -> bool: return len(self.permissions) > 0 class Document(BaseModel): def __init__( self, - account_id, - name, - version_name, - content, - document_type, - document_format, - requires, - attachments, - target_type, - document_version="1", + account_id: str, + name: str, + version_name: str, + content: str, + document_type: str, + document_format: str, + requires: List[Dict[str, str]], + attachments: List[Dict[str, Any]], + target_type: str, + document_version: str = "1", ): self.name = name self.version_name = version_name @@ -591,11 +626,13 @@ class Document(BaseModel): raise InvalidDocumentContent("The content for the document is not valid.") @property - def hash(self): + def hash(self) -> str: return hashlib.sha256(self.content.encode("utf-8")).hexdigest() - def list_describe(self, tags=None): - base = { + def list_describe( + self, tags: Optional[List[Dict[str, str]]] = None + ) -> Dict[str, Any]: + base: Dict[str, Any] = { "Name": self.name, "Owner": self.owner, "DocumentVersion": self.document_version, @@ -620,29 +657,26 @@ class Document(BaseModel): class Command(BaseModel): def __init__( self, - account_id, - comment="", - document_name="", - timeout_seconds=MAX_TIMEOUT_SECONDS, - instance_ids=None, - max_concurrency="", - max_errors="", - notification_config=None, - output_s3_bucket_name="", - output_s3_key_prefix="", - output_s3_region="", - parameters=None, - service_role_arn="", - targets=None, - backend_region="us-east-1", + account_id: str, + comment: str = "", + document_name: Optional[str] = "", + timeout_seconds: Optional[int] = MAX_TIMEOUT_SECONDS, + instance_ids: Optional[List[str]] = None, + max_concurrency: str = "", + max_errors: str = "", + notification_config: Optional[Dict[str, Any]] = None, + output_s3_bucket_name: str = "", + output_s3_key_prefix: str = "", + output_s3_region: str = "", + parameters: Optional[Dict[str, List[str]]] = None, + service_role_arn: str = "", + targets: Optional[List[Dict[str, Any]]] = None, + backend_region: str = "us-east-1", ): if instance_ids is None: instance_ids = [] - if notification_config is None: - notification_config = {} - if parameters is None: parameters = {} @@ -654,10 +688,11 @@ class Command(BaseModel): self.status_details = "Details placeholder" self.account_id = account_id + self.timeout_seconds = timeout_seconds or MAX_TIMEOUT_SECONDS self.requested_date_time = datetime.datetime.now() self.requested_date_time_iso = self.requested_date_time.isoformat() expires_after = self.requested_date_time + datetime.timedelta( - 0, timeout_seconds + 0, self.timeout_seconds ) self.expires_after = expires_after.isoformat() @@ -665,7 +700,11 @@ class Command(BaseModel): self.document_name = document_name self.max_concurrency = max_concurrency self.max_errors = max_errors - self.notification_config = notification_config + self.notification_config = notification_config or { + "NotificationArn": "string", + "NotificationEvents": ["Success"], + "NotificationType": "Command", + } self.output_s3_bucket_name = output_s3_bucket_name self.output_s3_key_prefix = output_s3_key_prefix self.output_s3_region = output_s3_region @@ -695,7 +734,7 @@ class Command(BaseModel): self.invocation_response(instance_id, "aws:runShellScript") ) - def _get_instance_ids_from_targets(self): + def _get_instance_ids_from_targets(self) -> List[str]: target_instance_ids = [] ec2_backend = ec2_backends[self.account_id][self.backend_region] ec2_filters = {target["Key"]: target["Values"] for target in self.targets} @@ -705,8 +744,8 @@ class Command(BaseModel): target_instance_ids.append(instance.id) return target_instance_ids - def response_object(self): - r = { + def response_object(self) -> Dict[str, Any]: + return { "CommandId": self.command_id, "Comment": self.comment, "CompletedCount": self.completed_count, @@ -728,11 +767,10 @@ class Command(BaseModel): "StatusDetails": self.status_details, "TargetCount": self.target_count, "Targets": self.targets, + "TimeoutSeconds": self.timeout_seconds, } - return r - - def invocation_response(self, instance_id, plugin_name): + def invocation_response(self, instance_id: str, plugin_name: str) -> Dict[str, Any]: # Calculate elapsed time from requested time and now. Use a hardcoded # elapsed time since there is no easy way to convert a timedelta to # an ISO 8601 duration string. @@ -740,7 +778,7 @@ class Command(BaseModel): elapsed_time_delta = datetime.timedelta(minutes=5) end_time = self.requested_date_time + elapsed_time_delta - r = { + return { "CommandId": self.command_id, "InstanceId": instance_id, "Comment": self.comment, @@ -757,9 +795,9 @@ class Command(BaseModel): "StandardErrorContent": "", } - return r - - def get_invocation(self, instance_id, plugin_name): + def get_invocation( + self, instance_id: str, plugin_name: Optional[str] + ) -> Dict[str, Any]: invocation = next( ( invocation @@ -784,13 +822,19 @@ class Command(BaseModel): return invocation -def _validate_document_format(document_format): +def _validate_document_format(document_format: str) -> None: aws_doc_formats = ["JSON", "YAML"] if document_format not in aws_doc_formats: raise ValidationException("Invalid document format " + str(document_format)) -def _validate_document_info(content, name, document_type, document_format, strict=True): +def _validate_document_info( + content: str, + name: str, + document_type: Optional[str], + document_format: str, + strict: bool = True, +) -> None: aws_ssm_name_regex = r"^[a-zA-Z0-9_\-.]{3,128}$" aws_name_reject_list = ["aws-", "amazon", "amzn"] aws_doc_types = [ @@ -821,21 +865,27 @@ def _validate_document_info(content, name, document_type, document_format, stric raise ValidationException("Invalid document type " + str(document_type)) -def _document_filter_equal_comparator(keyed_value, _filter): +def _document_filter_equal_comparator( + keyed_value: str, _filter: Dict[str, Any] +) -> bool: for v in _filter["Values"]: if keyed_value == v: return True return False -def _document_filter_list_includes_comparator(keyed_value_list, _filter): +def _document_filter_list_includes_comparator( + keyed_value_list: List[str], _filter: Dict[str, Any] +) -> bool: for v in _filter["Values"]: if v in keyed_value_list: return True return False -def _document_filter_match(account_id, filters, ssm_doc): +def _document_filter_match( + account_id: str, filters: List[Dict[str, Any]], ssm_doc: Document +) -> bool: for _filter in filters: if _filter["Key"] == "Name" and not _document_filter_equal_comparator( ssm_doc.name, _filter @@ -871,7 +921,7 @@ def _document_filter_match(account_id, filters, ssm_doc): return True -def _valid_parameter_type(type_): +def _valid_parameter_type(type_: str) -> bool: """ Parameter Type field only allows `SecureString`, `StringList` and `String` (not `str`) values @@ -879,7 +929,7 @@ def _valid_parameter_type(type_): return type_ in ("SecureString", "StringList", "String") -def _valid_parameter_data_type(data_type): +def _valid_parameter_data_type(data_type: str) -> bool: """ Parameter DataType field allows only `text` and `aws:ec2:image` values @@ -890,21 +940,19 @@ def _valid_parameter_data_type(data_type): class FakeMaintenanceWindow: def __init__( self, - name, - description, - enabled, - duration, - cutoff, - schedule, - schedule_timezone, - schedule_offset, - start_date, - end_date, + name: str, + description: str, + duration: int, + cutoff: int, + schedule: str, + schedule_timezone: str, + schedule_offset: int, + start_date: str, + end_date: str, ): self.id = FakeMaintenanceWindow.generate_id() self.name = name self.description = description - self.enabled = enabled self.duration = duration self.cutoff = cutoff self.schedule = schedule @@ -913,12 +961,12 @@ class FakeMaintenanceWindow: self.start_date = start_date self.end_date = end_date - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "WindowId": self.id, "Name": self.name, "Description": self.description, - "Enabled": self.enabled, + "Enabled": True, "Duration": self.duration, "Cutoff": self.cutoff, "Schedule": self.schedule, @@ -929,7 +977,7 @@ class FakeMaintenanceWindow: } @staticmethod - def generate_id(): + def generate_id() -> str: chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"] return "mw-" + "".join(str(random.choice(chars)) for _ in range(17)) @@ -946,19 +994,23 @@ class SimpleSystemManagerBackend(BaseBackend): Integration with SecretsManager is also supported. """ - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self._parameters = ParameterDict(account_id, region_name) - self._resource_tags = defaultdict(lambda: defaultdict(dict)) - self._commands = [] - self._errors = [] + self._resource_tags: DefaultDict[ + str, DefaultDict[str, Dict[str, str]] + ] = defaultdict(lambda: defaultdict(dict)) + self._commands: List[Command] = [] + self._errors: List[str] = [] self._documents: Dict[str, Documents] = {} self.windows: Dict[str, FakeMaintenanceWindow] = dict() @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint services.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "ssm" @@ -966,9 +1018,11 @@ class SimpleSystemManagerBackend(BaseBackend): service_region, zones, "ssmmessages" ) - def _generate_document_information(self, ssm_document, document_format): + def _generate_document_information( + self, ssm_document: Document, document_format: str + ) -> Dict[str, Any]: content = self._get_document_content(document_format, ssm_document) - base = { + base: Dict[str, Any] = { "Name": ssm_document.name, "DocumentVersion": ssm_document.document_version, "Status": ssm_document.status, @@ -987,7 +1041,7 @@ class SimpleSystemManagerBackend(BaseBackend): return base @staticmethod - def _get_document_content(document_format, ssm_document): + def _get_document_content(document_format: str, ssm_document: Document) -> str: if document_format == ssm_document.document_format: content = ssm_document.content elif document_format == "JSON": @@ -998,13 +1052,13 @@ class SimpleSystemManagerBackend(BaseBackend): raise ValidationException("Invalid document format " + str(document_format)) return content - def _get_documents(self, name): + def _get_documents(self, name: str) -> Documents: documents = self._documents.get(name) if not documents: raise InvalidDocument("The specified document does not exist.") return documents - def _get_documents_tags(self, name): + def _get_documents_tags(self, name: str) -> List[Dict[str, str]]: docs_tags = self._resource_tags.get("Document") if docs_tags: document_tags = docs_tags.get(name, {}) @@ -1015,16 +1069,16 @@ class SimpleSystemManagerBackend(BaseBackend): def create_document( self, - content, - requires, - attachments, - name, - version_name, - document_type, - document_format, - target_type, - tags, - ): + content: str, + requires: List[Dict[str, str]], + attachments: List[Dict[str, Any]], + name: str, + version_name: str, + document_type: str, + document_format: str, + target_type: str, + tags: List[Dict[str, str]], + ) -> Dict[str, Any]: ssm_document = Document( account_id=self.account_id, name=name, @@ -1056,7 +1110,9 @@ class SimpleSystemManagerBackend(BaseBackend): return documents.describe(tags=tags) - def delete_document(self, name, document_version, version_name, force): + def delete_document( + self, name: str, document_version: str, version_name: str, force: bool + ) -> None: documents = self._get_documents(name) if documents.is_shared(): @@ -1105,10 +1161,12 @@ class SimpleSystemManagerBackend(BaseBackend): documents.delete(*keys_to_delete) if len(documents.versions) == 0: - self._resource_tags.get("Document", {}).pop(name, None) + self._resource_tags.get("Document", {}).pop(name, None) # type: ignore del self._documents[name] - def get_document(self, name, document_version, version_name, document_format): + def get_document( + self, name: str, document_version: str, version_name: str, document_format: str + ) -> Dict[str, Any]: documents = self._get_documents(name) ssm_document = documents.find(document_version, version_name) @@ -1120,11 +1178,13 @@ class SimpleSystemManagerBackend(BaseBackend): return self._generate_document_information(ssm_document, document_format) - def update_document_default_version(self, name, document_version): + def update_document_default_version( + self, name: str, document_version: str + ) -> Dict[str, Any]: documents = self._get_documents(name) ssm_document = documents.update_default_version(document_version) - result = { + result: Dict[str, Any] = { "Name": ssm_document.name, "DefaultVersion": document_version, } @@ -1136,14 +1196,14 @@ class SimpleSystemManagerBackend(BaseBackend): def update_document( self, - content, - attachments, - name, - version_name, - document_version, - document_format, - target_type, - ): + content: str, + attachments: List[Dict[str, Any]], + name: str, + version_name: str, + document_version: str, + document_format: str, + target_type: str, + ) -> Dict[str, Any]: _validate_document_info( content=content, name=name, @@ -1198,21 +1258,27 @@ class SimpleSystemManagerBackend(BaseBackend): tags = self._get_documents_tags(name) return documents.describe(document_version=new_version, tags=tags) - def describe_document(self, name, document_version, version_name): + def describe_document( + self, name: str, document_version: str, version_name: str + ) -> Dict[str, Any]: documents = self._get_documents(name) tags = self._get_documents_tags(name) return documents.describe(document_version, version_name, tags=tags) def list_documents( - self, document_filter_list, filters, max_results=10, next_token="0" - ): + self, + document_filter_list: Any, + filters: List[Dict[str, Any]], + max_results: int = 10, + token: str = "0", + ) -> Tuple[List[Dict[str, Any]], str]: if document_filter_list: raise ValidationException( "DocumentFilterList is deprecated. Instead use Filters." ) - next_token = int(next_token) - results = [] + next_token = int(token or "0") + results: List[Dict[str, Any]] = [] dummy_token_tracker = 0 # Sort to maintain next token adjacency for _, documents in sorted(self._documents.items()): @@ -1235,10 +1301,10 @@ class SimpleSystemManagerBackend(BaseBackend): doc_describe = ssm_doc.list_describe(tags=tags) results.append(doc_describe) - # If we've fallen out of the loop, theres no more documents. No next token. + # If we've fallen out of the loop, there are no more documents. No next token. return results, "" - def describe_document_permission(self, name): + def describe_document_permission(self, name: str) -> Dict[str, Any]: """ Parameters max_results, permission_type, and next_token not yet implemented """ @@ -1247,12 +1313,12 @@ class SimpleSystemManagerBackend(BaseBackend): def modify_document_permission( self, - name, - account_ids_to_add, - account_ids_to_remove, - shared_document_version, - permission_type, - ): + name: str, + account_ids_to_add: List[str], + account_ids_to_remove: List[str], + shared_document_version: str, + permission_type: str, + ) -> None: account_id_regex = re.compile(r"^(all|[0-9]{12})$", re.IGNORECASE) version_regex = re.compile(r"^([$]LATEST|[$]DEFAULT|[$]ALL)$") @@ -1301,22 +1367,24 @@ class SimpleSystemManagerBackend(BaseBackend): account_ids_to_add, account_ids_to_remove, shared_document_version ) - def delete_parameter(self, name): - self._resource_tags.get("Parameter", {}).pop(name, None) - return self._parameters.pop(name, None) + def delete_parameter(self, name: str) -> Optional[Parameter]: + self._resource_tags.get("Parameter", {}).pop(name, None) # type: ignore + return self._parameters.pop(name, None) # type: ignore - def delete_parameters(self, names): + def delete_parameters(self, names: List[str]) -> List[str]: result = [] for name in names: try: del self._parameters[name] result.append(name) - self._resource_tags.get("Parameter", {}).pop(name, None) + self._resource_tags.get("Parameter", {}).pop(name, None) # type: ignore except KeyError: pass return result - def describe_parameters(self, filters, parameter_filters): + def describe_parameters( + self, filters: List[Dict[str, Any]], parameter_filters: List[Dict[str, Any]] + ) -> List[Parameter]: if filters and parameter_filters: raise ValidationException( "You can use either Filters or ParameterFilters in a single request." @@ -1326,29 +1394,27 @@ class SimpleSystemManagerBackend(BaseBackend): result = [] for param_name in self._parameters: - ssm_parameter = self.get_parameter(param_name) + ssm_parameter: Parameter = self.get_parameter(param_name) # type: ignore[assignment] if not self._match_filters(ssm_parameter, parameter_filters): continue if filters: for _filter in filters: if _filter["Key"] == "Name": - k = ssm_parameter.name for v in _filter["Values"]: - if k.startswith(v): + if ssm_parameter.name.startswith(v): result.append(ssm_parameter) break elif _filter["Key"] == "Type": - k = ssm_parameter.type for v in _filter["Values"]: - if k == v: + if ssm_parameter.type == v: result.append(ssm_parameter) break elif _filter["Key"] == "KeyId": - k = ssm_parameter.keyid - if k: + keyid = ssm_parameter.keyid + if keyid: for v in _filter["Values"]: - if k == v: + if keyid == v: result.append(ssm_parameter) break continue @@ -1357,7 +1423,9 @@ class SimpleSystemManagerBackend(BaseBackend): return result - def _validate_parameter_filters(self, parameter_filters, by_path): + def _validate_parameter_filters( + self, parameter_filters: Optional[List[Dict[str, Any]]], by_path: bool + ) -> None: for index, filter_obj in enumerate(parameter_filters or []): key = filter_obj["Key"] values = filter_obj.get("Values", []) @@ -1497,10 +1565,10 @@ class SimpleSystemManagerBackend(BaseBackend): filter_keys.append(key) - def _format_error(self, key, value, constraint): + def _format_error(self, key: str, value: Any, constraint: str) -> str: return f'Value "{value}" at "{key}" failed to satisfy constraint: {constraint}' - def _raise_errors(self): + def _raise_errors(self) -> None: if self._errors: count = len(self._errors) plural = "s" if len(self._errors) > 1 else "" @@ -1511,13 +1579,7 @@ class SimpleSystemManagerBackend(BaseBackend): f"{count} validation error{plural} detected: {errors}" ) - def get_all_parameters(self): - result = [] - for k, _ in self._parameters.items(): - result.append(self._parameters[k]) - return result - - def get_parameters(self, names): + def get_parameters(self, names: List[str]) -> Dict[str, Parameter]: result = {} if len(names) > 10: @@ -1540,41 +1602,45 @@ class SimpleSystemManagerBackend(BaseBackend): def get_parameters_by_path( self, - path, - recursive, - filters=None, - next_token=None, - max_results=10, - ): + path: str, + recursive: bool, + filters: Optional[List[Dict[str, Any]]] = None, + next_token: Optional[str] = None, + max_results: int = 10, + ) -> Tuple[List[Parameter], Optional[str]]: """Implement the get-parameters-by-path-API in the backend.""" self._validate_parameter_filters(filters, by_path=True) - result = [] + result: List[Parameter] = [] # path could be with or without a trailing /. we handle this # difference here. path = path.rstrip("/") + "/" for param_name in self._parameters.get_keys_beginning_with(path, recursive): - parameter = self.get_parameter(param_name) + parameter: Parameter = self.get_parameter(param_name) # type: ignore[assignment] if not self._match_filters(parameter, filters): continue result.append(parameter) return self._get_values_nexttoken(result, max_results, next_token) - def _get_values_nexttoken(self, values_list, max_results, next_token=None): - if next_token is None: - next_token = 0 - next_token = int(next_token) + def _get_values_nexttoken( + self, + values_list: List[Parameter], + max_results: int, + token: Optional[str] = None, + ) -> Tuple[List[Parameter], Optional[str]]: + next_token = int(token or "0") max_results = int(max_results) values = values_list[next_token : next_token + max_results] - if len(values) == max_results: - next_token = str(next_token + max_results) - else: - next_token = None - return values, next_token + return ( + values, + str(next_token + max_results) if len(values) == max_results else None, + ) - def get_parameter_history(self, name, next_token, max_results=50): + def get_parameter_history( + self, name: str, next_token: Optional[str], max_results: int = 50 + ) -> Tuple[Optional[List[Parameter]], Optional[str]]: if max_results > PARAMETER_HISTORY_MAX_RESULTS: raise ValidationException( @@ -1589,10 +1655,10 @@ class SimpleSystemManagerBackend(BaseBackend): return None, None - def _get_history_nexttoken(self, history, next_token, max_results): - if next_token is None: - next_token = 0 - next_token = int(next_token) + def _get_history_nexttoken( + self, history: List[Parameter], token: Optional[str], max_results: int + ) -> Tuple[List[Parameter], Optional[str]]: + next_token = int(token or "0") max_results = int(max_results) history_to_return = history[next_token : next_token + max_results] if ( @@ -1603,7 +1669,9 @@ class SimpleSystemManagerBackend(BaseBackend): return history_to_return, str(new_next_token) return history_to_return, None - def _match_filters(self, parameter, filters=None): + def _match_filters( + self, parameter: Parameter, filters: Optional[List[Dict[str, Any]]] = None + ) -> bool: """Return True if the given parameter matches all the filters""" for filter_obj in filters or []: key = filter_obj["Key"] @@ -1614,7 +1682,7 @@ class SimpleSystemManagerBackend(BaseBackend): else: option = filter_obj.get("Option", "Equals") - what = None + what: Any = None if key == "KeyId": what = parameter.keyid elif key == "Name": @@ -1679,7 +1747,7 @@ class SimpleSystemManagerBackend(BaseBackend): # True if no false match (or no filters at all) return True - def get_parameter(self, name): + def get_parameter(self, name: str) -> Optional[Parameter]: name_parts = name.split(":") name_prefix = name_parts[0] @@ -1714,7 +1782,9 @@ class SimpleSystemManagerBackend(BaseBackend): return None - def label_parameter_version(self, name, version, labels): + def label_parameter_version( + self, name: str, version: int, labels: List[str] + ) -> Tuple[List[str], int]: previous_parameter_versions = self._parameters[name] if not previous_parameter_versions: raise ParameterNotFound(f"Parameter {name} not found.") @@ -1770,9 +1840,9 @@ class SimpleSystemManagerBackend(BaseBackend): for label in parameter.labels[:]: if label in labels_needing_removal: parameter.labels.remove(label) - return [invalid_labels, version] + return (invalid_labels, version) - def _check_for_parameter_version_limit_exception(self, name): + def _check_for_parameter_version_limit_exception(self, name: str) -> None: # https://docs.aws.amazon.com/systems-manager/latest/userguide/sysman-paramstore-versions.html parameter_versions = self._parameters[name] oldest_parameter = parameter_versions[0] @@ -1786,16 +1856,16 @@ class SimpleSystemManagerBackend(BaseBackend): def put_parameter( self, - name, - description, - value, - parameter_type, - allowed_pattern, - keyid, - overwrite, - tags, - data_type, - ): + name: str, + description: str, + value: str, + parameter_type: str, + allowed_pattern: str, + keyid: str, + overwrite: bool, + tags: List[Dict[str, str]], + data_type: str, + ) -> Optional[int]: if not value: raise ValidationException( "1 validation error detected: Value '' at 'value' failed to satisfy" @@ -1847,7 +1917,7 @@ class SimpleSystemManagerBackend(BaseBackend): version = previous_parameter.version + 1 if not overwrite: - return + return None if len(previous_parameter_versions) >= PARAMETER_VERSION_LIMIT: self._check_for_parameter_version_limit_exception(name) @@ -1871,28 +1941,36 @@ class SimpleSystemManagerBackend(BaseBackend): ) if tags: - tags = {t["Key"]: t["Value"] for t in tags} - self.add_tags_to_resource("Parameter", name, tags) + tag_dict = {t["Key"]: t["Value"] for t in tags} + self.add_tags_to_resource("Parameter", name, tag_dict) return version - def add_tags_to_resource(self, resource_type, resource_id, tags): + def add_tags_to_resource( + self, resource_type: str, resource_id: str, tags: Dict[str, str] + ) -> None: self._validate_resource_type_and_id(resource_type, resource_id) for key, value in tags.items(): self._resource_tags[resource_type][resource_id][key] = value - def remove_tags_from_resource(self, resource_type, resource_id, keys): + def remove_tags_from_resource( + self, resource_type: str, resource_id: str, keys: List[str] + ) -> None: self._validate_resource_type_and_id(resource_type, resource_id) tags = self._resource_tags[resource_type][resource_id] for key in keys: if key in tags: del tags[key] - def list_tags_for_resource(self, resource_type, resource_id): + def list_tags_for_resource( + self, resource_type: str, resource_id: str + ) -> Dict[str, str]: self._validate_resource_type_and_id(resource_type, resource_id) return self._resource_tags[resource_type][resource_id] - def _validate_resource_type_and_id(self, resource_type, resource_id): + def _validate_resource_type_and_id( + self, resource_type: str, resource_id: str + ) -> None: if resource_type == "Parameter": if resource_id not in self._parameters: raise InvalidResourceId() @@ -1915,51 +1993,57 @@ class SimpleSystemManagerBackend(BaseBackend): else: raise InvalidResourceId() - def send_command(self, **kwargs): + def send_command( + self, + comment: str, + document_name: Optional[str], + timeout_seconds: int, + instance_ids: List[str], + max_concurrency: str, + max_errors: str, + notification_config: Optional[Dict[str, Any]], + output_s3_bucket_name: str, + output_s3_key_prefix: str, + output_s3_region: str, + parameters: Dict[str, List[str]], + service_role_arn: str, + targets: List[Dict[str, Any]], + ) -> Command: command = Command( account_id=self.account_id, - comment=kwargs.get("Comment", ""), - document_name=kwargs.get("DocumentName"), - timeout_seconds=kwargs.get("TimeoutSeconds", 3600), - instance_ids=kwargs.get("InstanceIds", []), - max_concurrency=kwargs.get("MaxConcurrency", "50"), - max_errors=kwargs.get("MaxErrors", "0"), - notification_config=kwargs.get( - "NotificationConfig", - { - "NotificationArn": "string", - "NotificationEvents": ["Success"], - "NotificationType": "Command", - }, - ), - output_s3_bucket_name=kwargs.get("OutputS3BucketName", ""), - output_s3_key_prefix=kwargs.get("OutputS3KeyPrefix", ""), - output_s3_region=kwargs.get("OutputS3Region", ""), - parameters=kwargs.get("Parameters", {}), - service_role_arn=kwargs.get("ServiceRoleArn", ""), - targets=kwargs.get("Targets", []), + comment=comment, + document_name=document_name, + timeout_seconds=timeout_seconds, + instance_ids=instance_ids, + max_concurrency=max_concurrency, + max_errors=max_errors, + notification_config=notification_config, + output_s3_bucket_name=output_s3_bucket_name, + output_s3_key_prefix=output_s3_key_prefix, + output_s3_region=output_s3_region, + parameters=parameters, + service_role_arn=service_role_arn, + targets=targets, backend_region=self.region_name, ) self._commands.append(command) - return {"Command": command.response_object()} + return command - def list_commands(self, **kwargs): + def list_commands( + self, command_id: Optional[str], instance_id: Optional[str] + ) -> List[Command]: """ - https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_ListCommands.html + Pagination and the Filters-parameter is not yet implemented """ - commands = self._commands - - command_id = kwargs.get("CommandId", None) if command_id: - commands = [self.get_command_by_id(command_id)] - instance_id = kwargs.get("InstanceId", None) + return [self.get_command_by_id(command_id)] if instance_id: - commands = self.get_commands_by_instance_id(instance_id) + return self.get_commands_by_instance_id(instance_id) - return {"Commands": [command.response_object() for command in commands]} + return self._commands - def get_command_by_id(self, command_id): + def get_command_by_id(self, command_id: str) -> Command: command = next( (command for command in self._commands if command.command_id == command_id), None, @@ -1970,43 +2054,35 @@ class SimpleSystemManagerBackend(BaseBackend): return command - def get_commands_by_instance_id(self, instance_id): + def get_commands_by_instance_id(self, instance_id: str) -> List[Command]: return [ command for command in self._commands if instance_id in command.instance_ids ] - def get_command_invocation(self, **kwargs): - """ - https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_GetCommandInvocation.html - """ - - command_id = kwargs.get("CommandId") - instance_id = kwargs.get("InstanceId") - plugin_name = kwargs.get("PluginName", None) - + def get_command_invocation( + self, command_id: str, instance_id: str, plugin_name: Optional[str] + ) -> Dict[str, Any]: command = self.get_command_by_id(command_id) return command.get_invocation(instance_id, plugin_name) def create_maintenance_window( self, - name, - description, - enabled, - duration, - cutoff, - schedule, - schedule_timezone, - schedule_offset, - start_date, - end_date, - ): + name: str, + description: str, + duration: int, + cutoff: int, + schedule: str, + schedule_timezone: str, + schedule_offset: int, + start_date: str, + end_date: str, + ) -> str: """ Creates a maintenance window. No error handling or input validation has been implemented yet. """ window = FakeMaintenanceWindow( name, description, - enabled, duration, cutoff, schedule, @@ -2018,14 +2094,16 @@ class SimpleSystemManagerBackend(BaseBackend): self.windows[window.id] = window return window.id - def get_maintenance_window(self, window_id): + def get_maintenance_window(self, window_id: str) -> FakeMaintenanceWindow: """ The window is assumed to exist - no error handling has been implemented yet. The NextExecutionTime-field is not returned. """ return self.windows[window_id] - def describe_maintenance_windows(self, filters): + def describe_maintenance_windows( + self, filters: Optional[List[Dict[str, Any]]] + ) -> List[FakeMaintenanceWindow]: """ Returns all windows. No pagination has been implemented yet. Only filtering for Name is supported. The NextExecutionTime-field is not returned. @@ -2038,7 +2116,7 @@ class SimpleSystemManagerBackend(BaseBackend): res = [w for w in res if w.name in f["Values"]] return res - def delete_maintenance_window(self, window_id): + def delete_maintenance_window(self, window_id: str) -> None: """ Assumes the provided WindowId exists. No error handling has been implemented yet. """ diff --git a/moto/ssm/responses.py b/moto/ssm/responses.py index 1469c4a51..204da6fec 100644 --- a/moto/ssm/responses.py +++ b/moto/ssm/responses.py @@ -1,4 +1,5 @@ import json +from typing import Any, Dict, Tuple, Union from moto.core.responses import BaseResponse from .exceptions import ValidationException @@ -13,14 +14,7 @@ class SimpleSystemManagerResponse(BaseResponse): def ssm_backend(self) -> SimpleSystemManagerBackend: return ssm_backends[self.current_account][self.region] - @property - def request_params(self): - try: - return json.loads(self.body) - except ValueError: - return {} - - def create_document(self): + def create_document(self) -> str: content = self._get_param("Content") requires = self._get_param("Requires") attachments = self._get_param("Attachments") @@ -45,7 +39,7 @@ class SimpleSystemManagerResponse(BaseResponse): return json.dumps({"DocumentDescription": result}) - def delete_document(self): + def delete_document(self) -> str: name = self._get_param("Name") document_version = self._get_param("DocumentVersion") version_name = self._get_param("VersionName") @@ -59,7 +53,7 @@ class SimpleSystemManagerResponse(BaseResponse): return json.dumps({}) - def get_document(self): + def get_document(self) -> str: name = self._get_param("Name") version_name = self._get_param("VersionName") document_version = self._get_param("DocumentVersion") @@ -74,7 +68,7 @@ class SimpleSystemManagerResponse(BaseResponse): return json.dumps(document) - def describe_document(self): + def describe_document(self) -> str: name = self._get_param("Name") document_version = self._get_param("DocumentVersion") version_name = self._get_param("VersionName") @@ -85,7 +79,7 @@ class SimpleSystemManagerResponse(BaseResponse): return json.dumps({"Document": result}) - def update_document(self): + def update_document(self) -> str: content = self._get_param("Content") attachments = self._get_param("Attachments") name = self._get_param("Name") @@ -106,7 +100,7 @@ class SimpleSystemManagerResponse(BaseResponse): return json.dumps({"DocumentDescription": result}) - def update_document_default_version(self): + def update_document_default_version(self) -> str: name = self._get_param("Name") document_version = self._get_param("DocumentVersion") @@ -115,7 +109,7 @@ class SimpleSystemManagerResponse(BaseResponse): ) return json.dumps({"Description": result}) - def list_documents(self): + def list_documents(self) -> str: document_filter_list = self._get_param("DocumentFilterList") filters = self._get_param("Filters") max_results = self._get_param("MaxResults", 10) @@ -125,18 +119,18 @@ class SimpleSystemManagerResponse(BaseResponse): document_filter_list=document_filter_list, filters=filters, max_results=max_results, - next_token=next_token, + token=next_token, ) return json.dumps({"DocumentIdentifiers": documents, "NextToken": token}) - def describe_document_permission(self): + def describe_document_permission(self) -> str: name = self._get_param("Name") result = self.ssm_backend.describe_document_permission(name=name) return json.dumps(result) - def modify_document_permission(self): + def modify_document_permission(self) -> str: account_ids_to_add = self._get_param("AccountIdsToAdd") account_ids_to_remove = self._get_param("AccountIdsToRemove") name = self._get_param("Name") @@ -150,11 +144,9 @@ class SimpleSystemManagerResponse(BaseResponse): shared_document_version=shared_document_version, permission_type=permission_type, ) + return "{}" - def _get_param(self, param_name, if_none=None): - return self.request_params.get(param_name, if_none) - - def delete_parameter(self): + def delete_parameter(self) -> Union[str, Tuple[str, Dict[str, int]]]: name = self._get_param("Name") result = self.ssm_backend.delete_parameter(name) if result is None: @@ -165,11 +157,11 @@ class SimpleSystemManagerResponse(BaseResponse): return json.dumps(error), dict(status=400) return json.dumps({}) - def delete_parameters(self): + def delete_parameters(self) -> str: names = self._get_param("Names") result = self.ssm_backend.delete_parameters(names) - response = {"DeletedParameters": [], "InvalidParameters": []} + response: Dict[str, Any] = {"DeletedParameters": [], "InvalidParameters": []} for name in names: if name in result: @@ -178,7 +170,7 @@ class SimpleSystemManagerResponse(BaseResponse): response["InvalidParameters"].append(name) return json.dumps(response) - def get_parameter(self): + def get_parameter(self) -> Union[str, Tuple[str, Dict[str, int]]]: name = self._get_param("Name") with_decryption = self._get_param("WithDecryption") @@ -202,13 +194,13 @@ class SimpleSystemManagerResponse(BaseResponse): response = {"Parameter": result.response_object(with_decryption, self.region)} return json.dumps(response) - def get_parameters(self): + def get_parameters(self) -> str: names = self._get_param("Names") with_decryption = self._get_param("WithDecryption") result = self.ssm_backend.get_parameters(names) - response = {"Parameters": [], "InvalidParameters": []} + response: Dict[str, Any] = {"Parameters": [], "InvalidParameters": []} for name, parameter in result.items(): param_data = parameter.response_object(with_decryption, self.region) @@ -220,7 +212,7 @@ class SimpleSystemManagerResponse(BaseResponse): response["InvalidParameters"].append(name) return json.dumps(response) - def get_parameters_by_path(self): + def get_parameters_by_path(self) -> str: path = self._get_param("Path") with_decryption = self._get_param("WithDecryption") recursive = self._get_param("Recursive", False) @@ -236,7 +228,7 @@ class SimpleSystemManagerResponse(BaseResponse): max_results=max_results, ) - response = {"Parameters": [], "NextToken": next_token} + response: Dict[str, Any] = {"Parameters": [], "NextToken": next_token} for parameter in result: param_data = parameter.response_object(with_decryption, self.region) @@ -244,7 +236,7 @@ class SimpleSystemManagerResponse(BaseResponse): return json.dumps(response) - def describe_parameters(self): + def describe_parameters(self) -> str: page_size = 10 filters = self._get_param("Filters") parameter_filters = self._get_param("ParameterFilters") @@ -257,7 +249,7 @@ class SimpleSystemManagerResponse(BaseResponse): result = self.ssm_backend.describe_parameters(filters, parameter_filters) - response = {"Parameters": []} + response: Dict[str, Any] = {"Parameters": []} end = token + page_size for parameter in result[token:]: @@ -270,7 +262,7 @@ class SimpleSystemManagerResponse(BaseResponse): return json.dumps(response) - def put_parameter(self): + def put_parameter(self) -> Union[str, Tuple[str, Dict[str, int]]]: name = self._get_param("Name") description = self._get_param("Description") value = self._get_param("Value") @@ -303,7 +295,7 @@ class SimpleSystemManagerResponse(BaseResponse): response = {"Version": result} return json.dumps(response) - def get_parameter_history(self): + def get_parameter_history(self) -> Union[str, Tuple[str, Dict[str, int]]]: name = self._get_param("Name") with_decryption = self._get_param("WithDecryption") next_token = self._get_param("NextToken") @@ -320,19 +312,19 @@ class SimpleSystemManagerResponse(BaseResponse): } return json.dumps(error), dict(status=400) - response = {"Parameters": []} - for parameter_version in result: - param_data = parameter_version.describe_response_object( - decrypt=with_decryption, include_labels=True - ) - response["Parameters"].append(param_data) - - if new_next_token is not None: - response["NextToken"] = new_next_token + response = { + "Parameters": [ + p_v.describe_response_object( + decrypt=with_decryption, include_labels=True + ) + for p_v in result + ], + "NextToken": new_next_token, + } return json.dumps(response) - def label_parameter_version(self): + def label_parameter_version(self) -> str: name = self._get_param("Name") version = self._get_param("ParameterVersion") labels = self._get_param("Labels") @@ -344,7 +336,7 @@ class SimpleSystemManagerResponse(BaseResponse): response = {"InvalidLabels": invalid_labels, "ParameterVersion": version} return json.dumps(response) - def add_tags_to_resource(self): + def add_tags_to_resource(self) -> str: resource_id = self._get_param("ResourceId") resource_type = self._get_param("ResourceType") tags = {t["Key"]: t["Value"] for t in self._get_param("Tags")} @@ -353,7 +345,7 @@ class SimpleSystemManagerResponse(BaseResponse): ) return json.dumps({}) - def remove_tags_from_resource(self): + def remove_tags_from_resource(self) -> str: resource_id = self._get_param("ResourceId") resource_type = self._get_param("ResourceType") keys = self._get_param("TagKeys") @@ -362,7 +354,7 @@ class SimpleSystemManagerResponse(BaseResponse): ) return json.dumps({}) - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> str: resource_id = self._get_param("ResourceId") resource_type = self._get_param("ResourceType") tags = self.ssm_backend.list_tags_for_resource( @@ -372,21 +364,56 @@ class SimpleSystemManagerResponse(BaseResponse): response = {"TagList": tag_list} return json.dumps(response) - def send_command(self): - return json.dumps(self.ssm_backend.send_command(**self.request_params)) - - def list_commands(self): - return json.dumps(self.ssm_backend.list_commands(**self.request_params)) - - def get_command_invocation(self): - return json.dumps( - self.ssm_backend.get_command_invocation(**self.request_params) + def send_command(self) -> str: + comment = self._get_param("Comment", "") + document_name = self._get_param("DocumentName") + timeout_seconds = self._get_int_param("TimeoutSeconds") + instance_ids = self._get_param("InstanceIds", []) + max_concurrency = self._get_param("MaxConcurrency", "50") + max_errors = self._get_param("MaxErrors", "0") + notification_config = self._get_param("NotificationConfig") + output_s3_bucket_name = self._get_param("OutputS3BucketName", "") + output_s3_key_prefix = self._get_param("OutputS3KeyPrefix", "") + output_s3_region = self._get_param("OutputS3Region", "") + parameters = self._get_param("Parameters", {}) + service_role_arn = self._get_param("ServiceRoleArn", "") + targets = self._get_param("Targets", []) + command = self.ssm_backend.send_command( + comment=comment, + document_name=document_name, + timeout_seconds=timeout_seconds, + instance_ids=instance_ids, + max_concurrency=max_concurrency, + max_errors=max_errors, + notification_config=notification_config, + output_s3_bucket_name=output_s3_bucket_name, + output_s3_key_prefix=output_s3_key_prefix, + output_s3_region=output_s3_region, + parameters=parameters, + service_role_arn=service_role_arn, + targets=targets, ) + return json.dumps({"Command": command.response_object()}) - def create_maintenance_window(self): + def list_commands(self) -> str: + command_id = self._get_param("CommandId") + instance_id = self._get_param("InstanceId") + commands = self.ssm_backend.list_commands(command_id, instance_id) + response = {"Commands": [command.response_object() for command in commands]} + return json.dumps(response) + + def get_command_invocation(self) -> str: + command_id = self._get_param("CommandId") + instance_id = self._get_param("InstanceId") + plugin_name = self._get_param("PluginName") + response = self.ssm_backend.get_command_invocation( + command_id, instance_id, plugin_name + ) + return json.dumps(response) + + def create_maintenance_window(self) -> str: name = self._get_param("Name") desc = self._get_param("Description", None) - enabled = self._get_bool_param("Enabled", True) duration = self._get_int_param("Duration") cutoff = self._get_int_param("Cutoff") schedule = self._get_param("Schedule") @@ -397,7 +424,6 @@ class SimpleSystemManagerResponse(BaseResponse): window_id = self.ssm_backend.create_maintenance_window( name=name, description=desc, - enabled=enabled, duration=duration, cutoff=cutoff, schedule=schedule, @@ -408,12 +434,12 @@ class SimpleSystemManagerResponse(BaseResponse): ) return json.dumps({"WindowId": window_id}) - def get_maintenance_window(self): + def get_maintenance_window(self) -> str: window_id = self._get_param("WindowId") window = self.ssm_backend.get_maintenance_window(window_id) return json.dumps(window.to_json()) - def describe_maintenance_windows(self): + def describe_maintenance_windows(self) -> str: filters = self._get_param("Filters", None) windows = [ window.to_json() @@ -421,7 +447,7 @@ class SimpleSystemManagerResponse(BaseResponse): ] return json.dumps({"WindowIdentities": windows}) - def delete_maintenance_window(self): + def delete_maintenance_window(self) -> str: window_id = self._get_param("WindowId") self.ssm_backend.delete_maintenance_window(window_id) return "{}" diff --git a/moto/ssm/utils.py b/moto/ssm/utils.py index 121ce019c..b5dc590e1 100644 --- a/moto/ssm/utils.py +++ b/moto/ssm/utils.py @@ -1,16 +1,19 @@ -def parameter_arn(account_id, region, parameter_name): +from typing import Any, Dict, List + + +def parameter_arn(account_id: str, region: str, parameter_name: str) -> str: if parameter_name[0] == "/": parameter_name = parameter_name[1:] return f"arn:aws:ssm:{region}:{account_id}:parameter/{parameter_name}" -def convert_to_tree(parameters): +def convert_to_tree(parameters: List[Dict[str, Any]]) -> Dict[str, Any]: """ Convert input into a smaller, less redundant data set in tree form Input: [{"Name": "/a/b/c", "Value": "af-south-1", ...}, ..] Output: {"a": {"b": {"c": {"Value": af-south-1}, ..}, ..}, ..} """ - tree_dict = {} + tree_dict: Dict[str, Any] = {} for p in parameters: current_level = tree_dict for path in p["Name"].split("/"): @@ -23,18 +26,20 @@ def convert_to_tree(parameters): return tree_dict -def convert_to_params(tree): +def convert_to_params(tree: Dict[str, Any]) -> List[Dict[str, Any]]: """ Inverse of 'convert_to_tree' """ - def m(tree, params, current_path=""): + def m( + tree: Dict[str, Any], params: List[Dict[str, Any]], current_path: str = "" + ) -> None: for key, value in tree.items(): if key == "Value": params.append({"Name": current_path, "Value": value}) else: m(value, params, current_path + "/" + key) - params = [] + params: List[Dict[str, Any]] = [] m(tree, params) return params diff --git a/setup.cfg b/setup.cfg index c3ef08ec2..1679a8068 100644 --- a/setup.cfg +++ b/setup.cfg @@ -239,7 +239,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/scheduler +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ssm,moto/scheduler show_column_numbers=True show_error_codes = True disable_error_code=abstract diff --git a/tests/test_ssm/test_ssm_boto3.py b/tests/test_ssm/test_ssm_boto3.py index ad2074440..daf6a73a9 100644 --- a/tests/test_ssm/test_ssm_boto3.py +++ b/tests/test_ssm/test_ssm_boto3.py @@ -1739,8 +1739,12 @@ def test_send_command(): before = datetime.datetime.now() response = client.send_command( + Comment="some comment", InstanceIds=["i-123456"], DocumentName=ssm_document, + TimeoutSeconds=42, + MaxConcurrency="360", + MaxErrors="2", Parameters=params, OutputS3Region="us-east-2", OutputS3BucketName="the-bucket", @@ -1749,6 +1753,7 @@ def test_send_command(): cmd = response["Command"] cmd["CommandId"].should_not.be(None) + assert cmd["Comment"] == "some comment" cmd["DocumentName"].should.equal(ssm_document) cmd["Parameters"].should.equal(params) @@ -1759,6 +1764,10 @@ def test_send_command(): cmd["ExpiresAfter"].should.be.greater_than(before) cmd["DeliveryTimedOutCount"].should.equal(0) + assert cmd["TimeoutSeconds"] == 42 + assert cmd["MaxConcurrency"] == "360" + assert cmd["MaxErrors"] == "2" + # test sending a command without any optional parameters response = client.send_command(DocumentName=ssm_document)