diff --git a/moto/core/common_models.py b/moto/core/common_models.py index 613c9f3ad..8848f1c29 100644 --- a/moto/core/common_models.py +++ b/moto/core/common_models.py @@ -108,7 +108,7 @@ class ConfigQueryModel: backend_region: Optional[str] = None, resource_region: Optional[str] = None, aggregator: Optional[Dict[str, Any]] = None, - ) -> Tuple[List[Dict[str, Any]], str]: + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: """For AWS Config. This will list all of the resources of the given type and optional resource name and region. This supports both aggregated and non-aggregated listing. The following notes the difference: @@ -165,7 +165,7 @@ class ConfigQueryModel: resource_name: Optional[str] = None, backend_region: Optional[str] = None, resource_region: Optional[str] = None, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: """For AWS Config. This will query the backend for the specific resource type configuration. This supports both aggregated, and non-aggregated fetching -- for batched fetching -- the Config batching requests @@ -194,7 +194,7 @@ class ConfigQueryModel: raise NotImplementedError() -class CloudWatchMetricProvider(object): +class CloudWatchMetricProvider: @staticmethod @abstractmethod def get_cloudwatch_metrics(account_id: str) -> Any: # type: ignore[misc] diff --git a/moto/core/utils.py b/moto/core/utils.py index 833359fbd..b6bf065dd 100644 --- a/moto/core/utils.py +++ b/moto/core/utils.py @@ -154,7 +154,9 @@ def iso_8601_datetime_with_nanoseconds(value: datetime.datetime) -> str: return value.strftime("%Y-%m-%dT%H:%M:%S.%f000Z") -def iso_8601_datetime_without_milliseconds(value: datetime.datetime) -> Optional[str]: +def iso_8601_datetime_without_milliseconds( + value: Optional[datetime.datetime], +) -> Optional[str]: return value.strftime("%Y-%m-%dT%H:%M:%SZ") if value else None diff --git a/moto/iam/access_control.py b/moto/iam/access_control.py index 8cdd39137..059268463 100644 --- a/moto/iam/access_control.py +++ b/moto/iam/access_control.py @@ -17,6 +17,7 @@ import logging import re from abc import abstractmethod, ABCMeta from enum import Enum +from typing import Any, Dict, Optional, Match, List, Union from botocore.auth import SigV4Auth, S3SigV4Auth from botocore.awsrequest import AWSRequest @@ -39,12 +40,14 @@ from moto.s3.exceptions import ( S3SignatureDoesNotMatchError, ) from moto.sts.models import sts_backends -from .models import iam_backends, Policy +from .models import iam_backends, Policy, IAMBackend log = logging.getLogger(__name__) -def create_access_key(account_id, access_key_id, headers): +def create_access_key( + account_id: str, access_key_id: str, headers: Dict[str, str] +) -> Union["IAMUserAccessKey", "AssumedRoleAccessKey"]: if access_key_id.startswith("AKIA") or "X-Amz-Security-Token" not in headers: return IAMUserAccessKey(account_id, access_key_id, headers) else: @@ -53,10 +56,10 @@ def create_access_key(account_id, access_key_id, headers): class IAMUserAccessKey: @property - def backend(self): + def backend(self) -> IAMBackend: return iam_backends[self.account_id]["global"] - def __init__(self, account_id, access_key_id, headers): + def __init__(self, account_id: str, access_key_id: str, headers: Dict[str, str]): self.account_id = account_id iam_users = self.backend.list_users("/", None, None) @@ -72,13 +75,13 @@ class IAMUserAccessKey: raise CreateAccessKeyFailure(reason="InvalidId") @property - def arn(self): + def arn(self) -> str: return f"arn:aws:iam::{self.account_id}:user/{self._owner_user_name}" - def create_credentials(self): + def create_credentials(self) -> Credentials: return Credentials(self._access_key_id, self._secret_access_key) - def collect_policies(self): + def collect_policies(self) -> List[Dict[str, str]]: user_policies = [] inline_policy_names = self.backend.list_user_policies(self._owner_user_name) @@ -112,12 +115,12 @@ class IAMUserAccessKey: return user_policies -class AssumedRoleAccessKey(object): +class AssumedRoleAccessKey: @property - def backend(self): + def backend(self) -> IAMBackend: # type: ignore[misc] return iam_backends[self.account_id]["global"] - def __init__(self, account_id, access_key_id, headers): + def __init__(self, account_id: str, access_key_id: str, headers: Dict[str, str]): self.account_id = account_id for assumed_role in sts_backends[account_id]["global"].assumed_roles: if assumed_role.access_key_id == access_key_id: @@ -132,15 +135,15 @@ class AssumedRoleAccessKey(object): raise CreateAccessKeyFailure(reason="InvalidId") @property - def arn(self): + def arn(self) -> str: return f"arn:aws:sts::{self.account_id}:assumed-role/{self._owner_role_name}/{self._session_name}" - def create_credentials(self): + def create_credentials(self) -> Credentials: return Credentials( self._access_key_id, self._secret_access_key, self._session_token ) - def collect_policies(self): + def collect_policies(self) -> List[str]: role_policies = [] inline_policy_names = self.backend.list_role_policies(self._owner_role_name) @@ -153,19 +156,26 @@ class AssumedRoleAccessKey(object): attached_policies, _ = self.backend.list_attached_role_policies( self._owner_role_name ) - role_policies += attached_policies + role_policies += attached_policies # type: ignore[arg-type] return role_policies class CreateAccessKeyFailure(Exception): - def __init__(self, reason, *args): - super().__init__(*args) + def __init__(self, reason: str): + super().__init__() self.reason = reason class IAMRequestBase(object, metaclass=ABCMeta): - def __init__(self, account_id, method, path, data, headers): + def __init__( + self, + account_id: str, + method: str, + path: str, + data: Dict[str, str], + headers: Dict[str, str], + ): log.debug( f"Creating {self.__class__.__name__} with method={method}, path={path}, data={data}, headers={headers}" ) @@ -198,7 +208,7 @@ class IAMRequestBase(object, metaclass=ABCMeta): except CreateAccessKeyFailure as e: self._raise_invalid_access_key(e.reason) - def check_signature(self): + def check_signature(self) -> None: original_signature = self._get_string_between( "Signature=", ",", self._headers["Authorization"] ) @@ -206,11 +216,11 @@ class IAMRequestBase(object, metaclass=ABCMeta): if original_signature != calculated_signature: self._raise_signature_does_not_match() - def check_action_permitted(self): + def check_action_permitted(self) -> None: if ( self._action == "sts:GetCallerIdentity" ): # always allowed, even if there's an explicit Deny for it - return True + return policies = self._access_key.collect_policies() permitted = False @@ -226,30 +236,32 @@ class IAMRequestBase(object, metaclass=ABCMeta): self._raise_access_denied() @abstractmethod - def _raise_signature_does_not_match(self): + def _raise_signature_does_not_match(self) -> None: raise NotImplementedError() @abstractmethod - def _raise_access_denied(self): + def _raise_access_denied(self) -> None: raise NotImplementedError() @abstractmethod - def _raise_invalid_access_key(self, reason): + def _raise_invalid_access_key(self, reason: str) -> None: raise NotImplementedError() @abstractmethod - def _create_auth(self, credentials): + def _create_auth(self, credentials: Credentials) -> SigV4Auth: # type: ignore[misc] raise NotImplementedError() @staticmethod - def _create_headers_for_aws_request(signed_headers, original_headers): + def _create_headers_for_aws_request( + signed_headers: List[str], original_headers: Dict[str, str] + ) -> Dict[str, str]: headers = {} for key, value in original_headers.items(): if key.lower() in signed_headers: headers[key] = value return headers - def _create_aws_request(self): + def _create_aws_request(self) -> AWSRequest: signed_headers = self._get_string_between( "SignedHeaders=", ",", self._headers["Authorization"] ).split(";") @@ -261,7 +273,7 @@ class IAMRequestBase(object, metaclass=ABCMeta): return request - def _calculate_signature(self): + def _calculate_signature(self) -> str: credentials = self._access_key.create_credentials() auth = self._create_auth(credentials) request = self._create_aws_request() @@ -270,38 +282,40 @@ class IAMRequestBase(object, metaclass=ABCMeta): return auth.signature(string_to_sign, request) @staticmethod - def _get_string_between(first_separator, second_separator, string): + def _get_string_between( + first_separator: str, second_separator: str, string: str + ) -> str: return string.partition(first_separator)[2].partition(second_separator)[0] class IAMRequest(IAMRequestBase): - def _raise_signature_does_not_match(self): + def _raise_signature_does_not_match(self) -> None: if self._service == "ec2": raise AuthFailureError() else: raise SignatureDoesNotMatchError() - def _raise_invalid_access_key(self, _): + def _raise_invalid_access_key(self, _: str) -> None: if self._service == "ec2": raise AuthFailureError() else: raise InvalidClientTokenIdError() - def _create_auth(self, credentials): + def _create_auth(self, credentials: Any) -> SigV4Auth: return SigV4Auth(credentials, self._service, self._region) - def _raise_access_denied(self): + def _raise_access_denied(self) -> None: raise AccessDeniedError(user_arn=self._access_key.arn, action=self._action) class S3IAMRequest(IAMRequestBase): - def _raise_signature_does_not_match(self): + def _raise_signature_does_not_match(self) -> None: if "BucketName" in self._data: raise BucketSignatureDoesNotMatchError(bucket=self._data["BucketName"]) else: raise S3SignatureDoesNotMatchError() - def _raise_invalid_access_key(self, reason): + def _raise_invalid_access_key(self, reason: str) -> None: if reason == "InvalidToken": if "BucketName" in self._data: raise BucketInvalidTokenError(bucket=self._data["BucketName"]) @@ -313,18 +327,18 @@ class S3IAMRequest(IAMRequestBase): else: raise S3InvalidAccessKeyIdError() - def _create_auth(self, credentials): + def _create_auth(self, credentials: Any) -> S3SigV4Auth: return S3SigV4Auth(credentials, self._service, self._region) - def _raise_access_denied(self): + def _raise_access_denied(self) -> None: if "BucketName" in self._data: raise BucketAccessDeniedError(bucket=self._data["BucketName"]) else: raise S3AccessDeniedError() -class IAMPolicy(object): - def __init__(self, policy): +class IAMPolicy: + def __init__(self, policy: Any): if isinstance(policy, Policy): default_version = next( policy_version @@ -337,9 +351,11 @@ class IAMPolicy(object): else: policy_document = policy["policy_document"] - self._policy_json = json.loads(policy_document) + self._policy_json = json.loads(policy_document) # type: ignore[arg-type] - def is_action_permitted(self, action, resource="*"): + def is_action_permitted( + self, action: str, resource: str = "*" + ) -> "PermissionResult": permitted = False if isinstance(self._policy_json["Statement"], list): for policy_statement in self._policy_json["Statement"]: @@ -361,11 +377,13 @@ class IAMPolicy(object): return PermissionResult.NEUTRAL -class IAMPolicyStatement(object): - def __init__(self, statement): +class IAMPolicyStatement: + def __init__(self, statement: Any): self._statement = statement - def is_action_permitted(self, action, resource="*"): + def is_action_permitted( + self, action: str, resource: str = "*" + ) -> "PermissionResult": is_action_concerned = False if "NotAction" in self._statement: @@ -386,7 +404,7 @@ class IAMPolicyStatement(object): else: return PermissionResult.NEUTRAL - def is_unknown_principal(self, principal) -> bool: + def is_unknown_principal(self, principal: Optional[str]) -> bool: # https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-bucket-user-policy-specifying-principal-intro.html # For now, Moto only verifies principal == * # 'Unknown' principals are not verified @@ -401,17 +419,17 @@ class IAMPolicyStatement(object): return True return False - def _check_element_matches(self, statement_element, value): + def _check_element_matches(self, statement_element: Any, value: str) -> bool: if isinstance(self._statement[statement_element], list): for statement_element_value in self._statement[statement_element]: if self._match(statement_element_value, value): return True return False else: # string - return self._match(self._statement[statement_element], value) + return self._match(self._statement[statement_element], value) is not None @staticmethod - def _match(pattern, string): + def _match(pattern: str, string: str) -> Optional[Match[str]]: pattern = pattern.replace("*", ".*") pattern = f"^{pattern}$" return re.match(pattern, string) diff --git a/moto/iam/config.py b/moto/iam/config.py index 131117e67..654ea9397 100644 --- a/moto/iam/config.py +++ b/moto/iam/config.py @@ -1,5 +1,6 @@ import json import boto3 +from typing import Any, Dict, List, Optional, Tuple from moto.core.exceptions import InvalidNextTokenException from moto.core.common_models import ConfigQueryModel from moto.iam import iam_backends @@ -8,15 +9,15 @@ from moto.iam import iam_backends class RoleConfigQuery(ConfigQueryModel): def list_config_service_resources( self, - account_id, - resource_ids, - resource_name, - limit, - next_token, - backend_region=None, - resource_region=None, - aggregator=None, - ): + account_id: str, + resource_ids: Optional[List[str]], + resource_name: Optional[str], + limit: int, + next_token: Optional[str], + backend_region: Optional[str] = None, + resource_region: Optional[str] = None, + aggregator: Optional[Dict[str, Any]] = None, + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: # IAM roles are "global" and aren't assigned into any availability zone # The resource ID is a AWS-assigned random string like "AROA0BSVNSZKXVHS00SBJ" # The resource name is a user-assigned string like "MyDevelopmentAdminRole" @@ -43,7 +44,7 @@ class RoleConfigQuery(ConfigQueryModel): return [], None else: for role in role_list: - if role.id in resource_ids: + if role.id in resource_ids: # type: ignore[operator] filtered_roles.append(role) # Filtered roles are now the subject for the listing @@ -60,7 +61,7 @@ class RoleConfigQuery(ConfigQueryModel): aggregator_sources = aggregator.get( "account_aggregation_sources" ) or aggregator.get("organization_aggregation_source") - for source in aggregator_sources: + for source in aggregator_sources: # type: ignore[union-attr] source_dict = source.__dict__ if source_dict.get("all_aws_regions", False): aggregated_regions = boto3.Session().get_available_regions("config") @@ -86,7 +87,7 @@ class RoleConfigQuery(ConfigQueryModel): else: # Non-aggregated queries are in the else block, and we can treat these like a normal config resource # Pagination logic, sort by role id - sorted_roles = sorted(role_list, key=lambda role: role.id) + sorted_roles = sorted(role_list, key=lambda role: role.id) # type: ignore[attr-defined] new_token = None @@ -99,7 +100,7 @@ class RoleConfigQuery(ConfigQueryModel): start = next( index for (index, r) in enumerate(sorted_roles) - if next_token == (r["_id"] if aggregator else r.id) + if next_token == (r["_id"] if aggregator else r.id) # type: ignore[attr-defined] ) except StopIteration: raise InvalidNextTokenException() @@ -109,14 +110,14 @@ class RoleConfigQuery(ConfigQueryModel): if len(sorted_roles) > (start + limit): record = sorted_roles[start + limit] - new_token = record["_id"] if aggregator else record.id + new_token = record["_id"] if aggregator else record.id # type: ignore[attr-defined] return ( [ { "type": "AWS::IAM::Role", - "id": role["id"] if aggregator else role.id, - "name": role["name"] if aggregator else role.name, + "id": role["id"] if aggregator else role.id, # type: ignore[attr-defined] + "name": role["name"] if aggregator else role.name, # type: ignore[attr-defined] "region": role["region"] if aggregator else "global", } for role in role_list @@ -126,20 +127,20 @@ class RoleConfigQuery(ConfigQueryModel): def get_config_resource( self, - account_id, - resource_id, - resource_name=None, - backend_region=None, - resource_region=None, - ): + account_id: str, + resource_id: str, + resource_name: Optional[str] = None, + backend_region: Optional[str] = None, + resource_region: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: role = self.backends[account_id]["global"].roles.get(resource_id, {}) if not role: - return + return None if resource_name and role.name != resource_name: - return + return None # Format the role to the AWS Config format: config_data = role.to_config_dict() @@ -158,15 +159,15 @@ class RoleConfigQuery(ConfigQueryModel): class PolicyConfigQuery(ConfigQueryModel): def list_config_service_resources( self, - account_id, - resource_ids, - resource_name, - limit, - next_token, - backend_region=None, - resource_region=None, - aggregator=None, - ): + account_id: str, + resource_ids: Optional[List[str]], + resource_name: Optional[str], + limit: int, + next_token: Optional[str], + backend_region: Optional[str] = None, + resource_region: Optional[str] = None, + aggregator: Optional[Dict[str, Any]] = None, + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: # IAM policies are "global" and aren't assigned into any availability zone # The resource ID is a AWS-assigned random string like "ANPA0BSVNSZK00SJSPVUJ" # The resource name is a user-assigned string like "my-development-policy" @@ -206,7 +207,7 @@ class PolicyConfigQuery(ConfigQueryModel): else: for policy in policy_list: - if policy.id in resource_ids: + if policy.id in resource_ids: # type: ignore[operator] filtered_policies.append(policy) # Filtered roles are now the subject for the listing @@ -223,7 +224,7 @@ class PolicyConfigQuery(ConfigQueryModel): aggregator_sources = aggregator.get( "account_aggregation_sources" ) or aggregator.get("organization_aggregation_source") - for source in aggregator_sources: + for source in aggregator_sources: # type: ignore[union-attr] source_dict = source.__dict__ if source_dict.get("all_aws_regions", False): aggregated_regions = boto3.Session().get_available_regions("config") @@ -252,7 +253,7 @@ class PolicyConfigQuery(ConfigQueryModel): else: # Non-aggregated queries are in the else block, and we can treat these like a normal config resource # Pagination logic, sort by role id - sorted_policies = sorted(policy_list, key=lambda role: role.id) + sorted_policies = sorted(policy_list, key=lambda role: role.id) # type: ignore[attr-defined] new_token = None @@ -265,7 +266,7 @@ class PolicyConfigQuery(ConfigQueryModel): start = next( index for (index, p) in enumerate(sorted_policies) - if next_token == (p["_id"] if aggregator else p.id) + if next_token == (p["_id"] if aggregator else p.id) # type: ignore[attr-defined] ) except StopIteration: raise InvalidNextTokenException() @@ -275,14 +276,14 @@ class PolicyConfigQuery(ConfigQueryModel): if len(sorted_policies) > (start + limit): record = sorted_policies[start + limit] - new_token = record["_id"] if aggregator else record.id + new_token = record["_id"] if aggregator else record.id # type: ignore[attr-defined] return ( [ { "type": "AWS::IAM::Policy", - "id": policy["id"] if aggregator else policy.id, - "name": policy["name"] if aggregator else policy.name, + "id": policy["id"] if aggregator else policy.id, # type: ignore[attr-defined] + "name": policy["name"] if aggregator else policy.name, # type: ignore[attr-defined] "region": policy["region"] if aggregator else "global", } for policy in policy_list @@ -292,12 +293,12 @@ class PolicyConfigQuery(ConfigQueryModel): def get_config_resource( self, - account_id, - resource_id, - resource_name=None, - backend_region=None, - resource_region=None, - ): + account_id: str, + resource_id: str, + resource_name: Optional[str] = None, + backend_region: Optional[str] = None, + resource_region: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: # policies are listed in the backend as arns, but we have to accept the PolicyID as the resource_id # we'll make a really crude search for it policy = None @@ -308,10 +309,10 @@ class PolicyConfigQuery(ConfigQueryModel): break if not policy: - return + return None if resource_name and policy.name != resource_name: - return + return None # Format the policy to the AWS Config format: config_data = policy.to_config_dict() diff --git a/moto/iam/exceptions.py b/moto/iam/exceptions.py index 545a9a008..aebc3b0c8 100644 --- a/moto/iam/exceptions.py +++ b/moto/iam/exceptions.py @@ -1,3 +1,4 @@ +from typing import Any from moto.core.exceptions import RESTError XMLNS_IAM = "https://iam.amazonaws.com/doc/2010-05-08/" @@ -15,28 +16,28 @@ class IAMNotFoundException(RESTError): class IAMConflictException(RESTError): code = 409 - def __init__(self, code="Conflict", message=""): + def __init__(self, code: str = "Conflict", message: str = ""): super().__init__(code, message) class IAMReportNotPresentException(RESTError): code = 410 - def __init__(self, message): + def __init__(self, message: str): super().__init__("ReportNotPresent", message) class IAMLimitExceededException(RESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("LimitExceeded", message) class MalformedCertificate(RESTError): code = 400 - def __init__(self, cert): + def __init__(self, cert: str): super().__init__("MalformedCertificate", f"Certificate {cert} is malformed") @@ -55,7 +56,7 @@ class MalformedPolicyDocument(RESTError): class DuplicateTags(RESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidInput", "Duplicate tag keys found. Please note that Tag keys are case insensitive.", @@ -65,7 +66,7 @@ class DuplicateTags(RESTError): class TagKeyTooBig(RESTError): code = 400 - def __init__(self, tag, param="tags.X.member.key"): + def __init__(self, tag: str, param: str = "tags.X.member.key"): super().__init__( "ValidationError", f"1 validation error detected: Value '{tag}' at '{param}' failed to satisfy " @@ -76,7 +77,7 @@ class TagKeyTooBig(RESTError): class TagValueTooBig(RESTError): code = 400 - def __init__(self, tag): + def __init__(self, tag: str): super().__init__( "ValidationError", f"1 validation error detected: Value '{tag}' at 'tags.X.member.value' failed to satisfy " @@ -87,7 +88,7 @@ class TagValueTooBig(RESTError): class InvalidTagCharacters(RESTError): code = 400 - def __init__(self, tag, param="tags.X.member.key"): + def __init__(self, tag: str, param: str = "tags.X.member.key"): message = f"1 validation error detected: Value '{tag}' at '{param}' failed to satisfy constraint: Member must satisfy regular expression pattern: [\\p{{L}}\\p{{Z}}\\p{{N}}_.:/=+\\-@]+" super().__init__("ValidationError", message) @@ -96,7 +97,7 @@ class InvalidTagCharacters(RESTError): class TooManyTags(RESTError): code = 400 - def __init__(self, tags, param="tags"): + def __init__(self, tags: Any, param: str = "tags"): super().__init__( "ValidationError", f"1 validation error detected: Value '{tags}' at '{param}' failed to satisfy " @@ -107,28 +108,28 @@ class TooManyTags(RESTError): class EntityAlreadyExists(RESTError): code = 409 - def __init__(self, message): + def __init__(self, message: str): super().__init__("EntityAlreadyExists", message) class ValidationError(RESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("ValidationError", message) class InvalidInput(RESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidInput", message) class NoSuchEntity(RESTError): code = 404 - def __init__(self, message): + def __init__(self, message: str): super().__init__( "NoSuchEntity", message, xmlns=XMLNS_IAM, template="wrapped_single_error" ) diff --git a/moto/iam/models.py b/moto/iam/models.py index 1e61d7111..a84ed1af0 100644 --- a/moto/iam/models.py +++ b/moto/iam/models.py @@ -9,7 +9,8 @@ from cryptography import x509 from cryptography.hazmat.backends import default_backend from jinja2 import Template -from typing import List, Mapping +from typing import Any, Dict, Optional, Tuple, Union +from typing import List, Iterable from urllib import parse from moto.core.exceptions import RESTError from moto.core import ( @@ -92,22 +93,24 @@ def mark_account_as_visited( LIMIT_KEYS_PER_USER = 2 -class MFADevice(object): +class MFADevice: """MFA Device class.""" - def __init__(self, serial_number, authentication_code_1, authentication_code_2): + def __init__( + self, serial_number: str, authentication_code_1: str, authentication_code_2: str + ): self.enable_date = datetime.utcnow() self.serial_number = serial_number self.authentication_code_1 = authentication_code_1 self.authentication_code_2 = authentication_code_2 @property - def enabled_iso_8601(self): - return iso_8601_datetime_without_milliseconds(self.enable_date) + def enabled_iso_8601(self) -> str: + return iso_8601_datetime_without_milliseconds(self.enable_date) # type: ignore[return-value] -class VirtualMfaDevice(object): - def __init__(self, account_id, device_name): +class VirtualMfaDevice: + def __init__(self, account_id: str, device_name: str): self.serial_number = f"arn:aws:iam::{account_id}:mfa{device_name}" random_base32_string = "".join( @@ -120,13 +123,13 @@ class VirtualMfaDevice(object): "ascii" ) # this would be a generated PNG - self.enable_date = None - self.user_attribute = None - self.user = None + self.enable_date: Optional[datetime] = None + self.user_attribute: Optional[Dict[str, Any]] = None + self.user: Optional[User] = None @property - def enabled_iso_8601(self): - return iso_8601_datetime_without_milliseconds(self.enable_date) + def enabled_iso_8601(self) -> str: + return iso_8601_datetime_without_milliseconds(self.enable_date) # type: ignore[return-value] class Policy(CloudFormationModel): @@ -138,15 +141,15 @@ class Policy(CloudFormationModel): def __init__( self, - name, - account_id, - default_version_id=None, - description=None, - document=None, - path=None, - create_date=None, - update_date=None, - tags=None, + name: str, + account_id: str, + default_version_id: Optional[str] = None, + description: Optional[str] = None, + document: Optional[str] = None, + path: Optional[str] = None, + create_date: Optional[datetime] = None, + update_date: Optional[datetime] = None, + tags: Optional[Dict[str, Dict[str, str]]] = None, ): self.name = name self.account_id = account_id @@ -154,7 +157,7 @@ class Policy(CloudFormationModel): self.description = description or "" self.id = random_policy_id() self.path = path or "/" - self.tags = tags + self.tags = tags or {} if default_version_id: self.default_version_id = default_version_id @@ -164,14 +167,14 @@ class Policy(CloudFormationModel): self.next_version_num = 2 self.versions = [ PolicyVersion( - self.arn, document, True, self.default_version_id, update_date + self.arn, document, True, self.default_version_id, update_date # type: ignore ) ] - self.create_date = create_date if create_date is not None else datetime.utcnow() - self.update_date = update_date if update_date is not None else datetime.utcnow() + self.create_date = create_date or datetime.utcnow() + self.update_date = update_date or datetime.utcnow() - def update_default_version(self, new_default_version_id): + def update_default_version(self, new_default_version_id: str) -> None: for version in self.versions: if version.version_id == new_default_version_id: version.is_default = True @@ -180,33 +183,40 @@ class Policy(CloudFormationModel): self.default_version_id = new_default_version_id @property - def created_iso_8601(self): + def created_iso_8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.create_date) @property - def updated_iso_8601(self): + def updated_iso_8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.update_date) - def get_tags(self): + def get_tags(self) -> List[Dict[str, str]]: return [self.tags[tag] for tag in self.tags] class SAMLProvider(BaseModel): - def __init__(self, account_id, name, saml_metadata_document=None): + def __init__( + self, account_id: str, name: str, saml_metadata_document: Optional[str] = None + ): self.account_id = account_id self.name = name self.saml_metadata_document = saml_metadata_document @property - def arn(self): + def arn(self) -> str: return f"arn:aws:iam::{self.account_id}:saml-provider/{self.name}" class OpenIDConnectProvider(BaseModel): def __init__( - self, account_id, url, thumbprint_list, client_id_list=None, tags=None + self, + account_id: str, + url: str, + thumbprint_list: List[str], + client_id_list: List[str], + tags: Dict[str, Dict[str, str]], ): - self._errors = [] + self._errors: List[str] = [] self._validate(url, thumbprint_list, client_id_list) self.account_id = account_id @@ -215,17 +225,19 @@ class OpenIDConnectProvider(BaseModel): self.thumbprint_list = thumbprint_list self.client_id_list = client_id_list self.create_date = datetime.utcnow() - self.tags = tags + self.tags = tags or {} @property - def arn(self): + def arn(self) -> str: return f"arn:aws:iam::{self.account_id}:oidc-provider/{self.url}" @property - def created_iso_8601(self): - return iso_8601_datetime_without_milliseconds(self.create_date) + def created_iso_8601(self) -> str: + return iso_8601_datetime_without_milliseconds(self.create_date) # type: ignore[return-value] - def _validate(self, url, thumbprint_list, client_id_list): + def _validate( + self, url: str, thumbprint_list: List[str], client_id_list: List[str] + ) -> None: if any(len(client_id) > 255 for client_id in client_id_list): self._errors.append( self._format_error( @@ -271,10 +283,10 @@ class OpenIDConnectProvider(BaseModel): "Cannot exceed quota for ClientIdsPerOpenIdConnectProvider: 100" ) - def _format_error(self, key, value, constraint): + def _format_error(self, key: str, value: Any, constraint: str) -> str: return f'Value "{value}" at "{key}" failed to satisfy constraint: {constraint}' - def _raise_errors(self): + def _raise_errors(self) -> None: if self._errors: count = len(self._errors) plural = "s" if len(self._errors) > 1 else "" @@ -285,23 +297,28 @@ class OpenIDConnectProvider(BaseModel): f"{count} validation error{plural} detected: {errors}" ) - def get_tags(self): + def get_tags(self) -> List[Dict[str, str]]: return [self.tags[tag] for tag in self.tags] -class PolicyVersion(object): +class PolicyVersion: def __init__( - self, policy_arn, document, is_default=False, version_id="v1", create_date=None + self, + policy_arn: str, + document: str, + is_default: bool = False, + version_id: str = "v1", + create_date: Optional[datetime] = None, ): self.policy_arn = policy_arn - self.document = document or {} + self.document = document or "" self.is_default = is_default self.version_id = version_id - self.create_date = create_date if create_date is not None else datetime.utcnow() + self.create_date = create_date or datetime.utcnow() @property - def created_iso_8601(self): + def created_iso_8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.create_date) @@ -309,24 +326,24 @@ class ManagedPolicy(Policy, CloudFormationModel): """Managed policy.""" @property - def backend(self): + def backend(self) -> "IAMBackend": return iam_backends[self.account_id]["global"] is_attachable = True - def attach_to(self, obj): + def attach_to(self, obj: Union["Role", "Group", "User"]) -> None: self.attachment_count += 1 - obj.managed_policies[self.arn] = self + obj.managed_policies[self.arn] = self # type: ignore[assignment] - def detach_from(self, obj): + def detach_from(self, obj: Union["Role", "Group", "User"]) -> None: self.attachment_count -= 1 del obj.managed_policies[self.arn] @property - def arn(self): + def arn(self) -> str: return f"arn:aws:iam::{self.account_id}:policy{self.path}{self.name}" - def to_config_dict(self): + def to_config_dict(self) -> Dict[str, Any]: return { "version": "1.3", "configurationItemCaptureTime": str(self.create_date), @@ -374,17 +391,22 @@ class ManagedPolicy(Policy, CloudFormationModel): } @staticmethod - def cloudformation_name_type(): - return None # Resource never gets named after by template PolicyName! + def cloudformation_name_type() -> str: + return "" # Resource never gets named after by template PolicyName! @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::IAM::ManagedPolicy" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "ManagedPolicy": properties = cloudformation_json.get("Properties", {}) policy_document = json.dumps(properties.get("PolicyDocument")) name = properties.get("ManagedPolicyName", resource_name) @@ -417,7 +439,7 @@ class ManagedPolicy(Policy, CloudFormationModel): return policy @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.arn @@ -425,7 +447,7 @@ class AWSManagedPolicy(ManagedPolicy): """AWS-managed policy.""" @classmethod - def from_data(cls, name, account_id, data): + def from_data(cls, name: str, account_id: str, data: Dict[str, Any]) -> "AWSManagedPolicy": # type: ignore[misc] return cls( name, account_id=account_id, @@ -433,15 +455,15 @@ class AWSManagedPolicy(ManagedPolicy): path=data.get("Path"), document=json.dumps(data.get("Document")), create_date=datetime.strptime( - data.get("CreateDate"), "%Y-%m-%dT%H:%M:%S+00:00" + data.get("CreateDate"), "%Y-%m-%dT%H:%M:%S+00:00" # type: ignore[arg-type] ), update_date=datetime.strptime( - data.get("UpdateDate"), "%Y-%m-%dT%H:%M:%S+00:00" + data.get("UpdateDate"), "%Y-%m-%dT%H:%M:%S+00:00" # type: ignore[arg-type] ), ) @property - def arn(self): + def arn(self) -> str: return f"arn:aws:iam::aws:policy{self.path}{self.name}" @@ -449,22 +471,29 @@ class InlinePolicy(CloudFormationModel): # Represents an Inline Policy created by CloudFormation def __init__( self, - resource_name, - policy_name, - policy_document, - group_names, - role_names, - user_names, + resource_name: str, + policy_name: str, + policy_document: str, + group_names: List[str], + role_names: List[str], + user_names: List[str], ): self.name = resource_name - self.policy_name = None - self.policy_document = None - self.group_names = None - self.role_names = None - self.user_names = None + self.policy_name = policy_name + self.policy_document = policy_document + self.group_names = group_names + self.role_names = role_names + self.user_names = user_names self.update(policy_name, policy_document, group_names, role_names, user_names) - def update(self, policy_name, policy_document, group_names, role_names, user_names): + def update( + self, + policy_name: str, + policy_document: str, + group_names: List[str], + role_names: List[str], + user_names: List[str], + ) -> None: self.policy_name = policy_name self.policy_document = ( json.dumps(policy_document) @@ -476,17 +505,22 @@ class InlinePolicy(CloudFormationModel): self.user_names = user_names @staticmethod - def cloudformation_name_type(): - return None # Resource never gets named after by template PolicyName! + def cloudformation_name_type() -> str: + return "" # Resource never gets named after by template PolicyName! @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::IAM::Policy" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "InlinePolicy": properties = cloudformation_json.get("Properties", {}) policy_document = properties.get("PolicyDocument") policy_name = properties.get("PolicyName") @@ -504,14 +538,14 @@ class InlinePolicy(CloudFormationModel): ) @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "InlinePolicy": properties = cloudformation_json["Properties"] if cls.is_replacement_update(properties): @@ -548,14 +582,18 @@ class InlinePolicy(CloudFormationModel): ) @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: iam_backends[account_id]["global"].delete_inline_policy(resource_name) @staticmethod - def is_replacement_update(properties): - properties_requiring_replacement_update = [] + def is_replacement_update(properties: List[str]) -> bool: + properties_requiring_replacement_update: List[str] = [] return any( [ property_requiring_replacement in properties @@ -564,10 +602,10 @@ class InlinePolicy(CloudFormationModel): ) @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.name - def apply_policy(self, backend): + def apply_policy(self, backend: "IAMBackend") -> None: if self.user_names: for user_name in self.user_names: backend.put_user_policy( @@ -584,7 +622,7 @@ class InlinePolicy(CloudFormationModel): group_name, self.policy_name, self.policy_document ) - def unapply_policy(self, backend): + def unapply_policy(self, backend: "IAMBackend") -> None: if self.user_names: for user_name in self.user_names: backend.delete_user_policy(user_name, self.policy_name) @@ -599,55 +637,61 @@ class InlinePolicy(CloudFormationModel): class Role(CloudFormationModel): def __init__( self, - account_id, - role_id, - name, - assume_role_policy_document, - path, - permissions_boundary, - description, - tags, - max_session_duration, - linked_service=None, + account_id: str, + role_id: str, + name: str, + assume_role_policy_document: str, + path: str, + permissions_boundary: Optional[str], + description: str, + tags: Dict[str, Dict[str, str]], + max_session_duration: Optional[str], + linked_service: Optional[str] = None, ): self.account_id = account_id self.id = role_id self.name = name self.assume_role_policy_document = assume_role_policy_document self.path = path or "/" - self.policies = {} - self.managed_policies = {} + self.policies: Dict[str, str] = {} + self.managed_policies: Dict[str, ManagedPolicy] = {} self.create_date = datetime.utcnow() self.tags = tags self.last_used = None self.last_used_region = None self.description = description - self.permissions_boundary = permissions_boundary + self.permissions_boundary: Optional[str] = permissions_boundary self.max_session_duration = max_session_duration self._linked_service = linked_service @property - def created_iso_8601(self): + def created_iso_8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.create_date) @property - def last_used_iso_8601(self): + def last_used_iso_8601(self) -> Optional[str]: if self.last_used: return iso_8601_datetime_with_milliseconds(self.last_used) + return None @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "RoleName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-iam-role.html return "AWS::IAM::Role" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Role": properties = cloudformation_json["Properties"] role_name = properties.get("RoleName", resource_name) @@ -671,9 +715,13 @@ class Role(CloudFormationModel): return role @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: backend = iam_backends[account_id]["global"] for profile in backend.instance_profiles.values(): profile.delete_role(role_name=resource_name) @@ -685,12 +733,12 @@ class Role(CloudFormationModel): backend.delete_role(resource_name) @property - def arn(self): + def arn(self) -> str: if self._linked_service: return f"arn:aws:iam::{self.account_id}:role/aws-service-role/{self._linked_service}/{self.name}" return f"arn:aws:iam::{self.account_id}:role{self.path}{self.name}" - def to_config_dict(self): + def to_config_dict(self) -> Dict[str, Any]: _managed_policies = [] for key in self.managed_policies.keys(): _managed_policies.append( @@ -758,10 +806,10 @@ class Role(CloudFormationModel): } return config_dict - def put_policy(self, policy_name, policy_json): + def put_policy(self, policy_name: str, policy_json: str) -> None: self.policies[policy_name] = policy_json - def delete_policy(self, policy_name): + def delete_policy(self, policy_name: str) -> None: try: del self.policies[policy_name] except KeyError: @@ -770,30 +818,30 @@ class Role(CloudFormationModel): ) @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.name @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Arn"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() - def get_tags(self): - return [self.tags[tag] for tag in self.tags] + def get_tags(self) -> List[str]: + return [self.tags[tag] for tag in self.tags] # type: ignore @property - def description_escaped(self): + def description_escaped(self) -> str: import html return html.escape(self.description or "") - def to_xml(self): + def to_xml(self) -> str: template = Template( """ {{ role.path }} @@ -838,7 +886,15 @@ class Role(CloudFormationModel): class InstanceProfile(CloudFormationModel): - def __init__(self, account_id, instance_profile_id, name, path, roles, tags=None): + def __init__( + self, + account_id: str, + instance_profile_id: str, + name: str, + path: str, + roles: List[Role], + tags: Optional[List[Dict[str, str]]] = None, + ): self.id = instance_profile_id self.account_id = account_id self.name = name @@ -848,22 +904,27 @@ class InstanceProfile(CloudFormationModel): self.tags = {tag["Key"]: tag["Value"] for tag in tags or []} @property - def created_iso_8601(self): + def created_iso_8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.create_date) @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "InstanceProfileName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-iam-instanceprofile.html return "AWS::IAM::InstanceProfile" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "InstanceProfile": properties = cloudformation_json["Properties"] role_names = properties["Roles"] @@ -874,34 +935,38 @@ class InstanceProfile(CloudFormationModel): ) @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: iam_backends[account_id]["global"].delete_instance_profile(resource_name) - def delete_role(self, role_name): + def delete_role(self, role_name: str) -> None: self.roles = [role for role in self.roles if role.name != role_name] @property - def arn(self): + def arn(self) -> str: return f"arn:aws:iam::{self.account_id}:instance-profile{self.path}{self.name}" @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.name @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Arn"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() - def to_embedded_config_dict(self): + def to_embedded_config_dict(self) -> Dict[str, Any]: # Instance Profiles aren't a config item itself, but they are returned in IAM roles with # a "config like" json structure It's also different than Role.to_config_dict() roles = [] @@ -941,7 +1006,13 @@ class InstanceProfile(CloudFormationModel): class Certificate(BaseModel): def __init__( - self, account_id, cert_name, cert_body, private_key, cert_chain=None, path=None + self, + account_id: str, + cert_name: str, + cert_body: str, + private_key: str, + cert_chain: Optional[str] = None, + path: Optional[str] = None, ): self.account_id = account_id self.cert_name = cert_name @@ -953,16 +1024,16 @@ class Certificate(BaseModel): self.cert_chain = cert_chain @property - def physical_resource_id(self): - return self.name + def physical_resource_id(self) -> str: + return self.cert_name @property - def arn(self): + def arn(self) -> str: return f"arn:aws:iam::{self.account_id}:server-certificate{self.path}{self.cert_name}" class SigningCertificate(BaseModel): - def __init__(self, certificate_id, user_name, body): + def __init__(self, certificate_id: str, user_name: str, body: str): self.id = certificate_id self.user_name = user_name self.body = body @@ -970,23 +1041,29 @@ class SigningCertificate(BaseModel): self.status = "Active" @property - def uploaded_iso_8601(self): - return iso_8601_datetime_without_milliseconds(self.upload_date) + def uploaded_iso_8601(self) -> str: + return iso_8601_datetime_without_milliseconds(self.upload_date) # type: ignore class AccessKeyLastUsed: - def __init__(self, timestamp, service, region): + def __init__(self, timestamp: datetime, service: str, region: str): self._timestamp = timestamp self.service = service self.region = region @property - def timestamp(self): - return iso_8601_datetime_without_milliseconds(self._timestamp) + def timestamp(self) -> str: + return iso_8601_datetime_without_milliseconds(self._timestamp) # type: ignore class AccessKey(CloudFormationModel): - def __init__(self, user_name, prefix, account_id, status="Active"): + def __init__( + self, + user_name: Optional[str], + prefix: str, + account_id: str, + status: str = "Active", + ): self.user_name = user_name self.access_key_id = generate_access_key_id_from_account_id( account_id, prefix=prefix, total_length=20 @@ -994,17 +1071,17 @@ class AccessKey(CloudFormationModel): self.secret_access_key = random_alphanumeric(40) self.status = status self.create_date = datetime.utcnow() - self.last_used: AccessKeyLastUsed = None + self.last_used: Optional[datetime] = None @property - def created_iso_8601(self): - return iso_8601_datetime_without_milliseconds(self.create_date) + def created_iso_8601(self) -> str: + return iso_8601_datetime_without_milliseconds(self.create_date) # type: ignore @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["SecretAccessKey"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "SecretAccessKey": @@ -1012,17 +1089,22 @@ class AccessKey(CloudFormationModel): raise UnformattedGetAttTemplateException() @staticmethod - def cloudformation_name_type(): - return None # Resource never gets named after by template PolicyName! + def cloudformation_name_type() -> str: + return "" # Resource never gets named after by template PolicyName! @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::IAM::AccessKey" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "AccessKey": properties = cloudformation_json.get("Properties", {}) user_name = properties.get("UserName") status = properties.get("Status", "Active") @@ -1032,14 +1114,14 @@ class AccessKey(CloudFormationModel): ) @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "AccessKey": properties = cloudformation_json["Properties"] if cls.is_replacement_update(properties): @@ -1062,13 +1144,17 @@ class AccessKey(CloudFormationModel): ) @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: iam_backends[account_id]["global"].delete_access_key_by_name(resource_name) @staticmethod - def is_replacement_update(properties): + def is_replacement_update(properties: List[str]) -> bool: properties_requiring_replacement_update = ["Serial", "UserName"] return any( [ @@ -1078,12 +1164,12 @@ class AccessKey(CloudFormationModel): ) @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.access_key_id class SshPublicKey(BaseModel): - def __init__(self, user_name, ssh_public_key_body): + def __init__(self, user_name: str, ssh_public_key_body: str): self.user_name = user_name self.ssh_public_key_body = ssh_public_key_body self.ssh_public_key_id = "APKA" + random_access_key() @@ -1092,31 +1178,31 @@ class SshPublicKey(BaseModel): self.upload_date = datetime.utcnow() @property - def uploaded_iso_8601(self): - return iso_8601_datetime_without_milliseconds(self.upload_date) + def uploaded_iso_8601(self) -> str: + return iso_8601_datetime_without_milliseconds(self.upload_date) # type: ignore class Group(BaseModel): - def __init__(self, account_id, name, path="/"): + def __init__(self, account_id: str, name: str, path: str = "/"): self.account_id = account_id self.name = name self.id = random_resource_id() self.path = path self.create_date = datetime.utcnow() - self.users = [] - self.managed_policies = {} - self.policies = {} + self.users: List[User] = [] + self.managed_policies: Dict[str, str] = {} + self.policies: Dict[str, str] = {} @property - def created_iso_8601(self): + def created_iso_8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.create_date) @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Arn"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> None: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": @@ -1124,14 +1210,14 @@ class Group(BaseModel): raise UnformattedGetAttTemplateException() @property - def arn(self): + def arn(self) -> str: if self.path == "/": return f"arn:aws:iam::{self.account_id}:group/{self.name}" else: # The path must by definition end and start with a forward slash. So we don't have to add more slashes to the ARN return f"arn:aws:iam::{self.account_id}:group{self.path}{self.name}" - def get_policy(self, policy_name): + def get_policy(self, policy_name: str) -> Dict[str, str]: try: policy_json = self.policies[policy_name] except KeyError: @@ -1143,13 +1229,13 @@ class Group(BaseModel): "group_name": self.name, } - def put_policy(self, policy_name, policy_json): + def put_policy(self, policy_name: str, policy_json: str) -> None: self.policies[policy_name] = policy_json - def list_policies(self): - return self.policies.keys() + def list_policies(self) -> List[str]: + return list(self.policies.keys()) - def delete_policy(self, policy_name): + def delete_policy(self, policy_name: str) -> None: if policy_name not in self.policies: raise IAMNotFoundException(f"Policy {policy_name} not found") @@ -1157,39 +1243,38 @@ class Group(BaseModel): class User(CloudFormationModel): - def __init__(self, account_id, name, path=None): + def __init__(self, account_id: str, name: str, path: Optional[str] = None): self.account_id = account_id self.name = name self.id = random_resource_id() self.path = path if path else "/" self.create_date = datetime.utcnow() - self.mfa_devices = {} - self.policies = {} - self.managed_policies = {} - self.access_keys: Mapping[str, AccessKey] = [] - self.ssh_public_keys = [] - self.password = None + self.mfa_devices: Dict[str, MFADevice] = {} + self.policies: Dict[str, str] = {} + self.managed_policies: Dict[str, Dict[str, str]] = {} + self.access_keys: List[AccessKey] = [] + self.ssh_public_keys: List[SshPublicKey] = [] + self.password: Optional[str] = None self.password_last_used = None self.password_reset_required = False - self.signing_certificates = {} + self.signing_certificates: Dict[str, SigningCertificate] = {} @property - def arn(self): + def arn(self) -> str: return f"arn:aws:iam::{self.account_id}:user{self.path}{self.name}" @property - def created_iso_8601(self): + def created_iso_8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.create_date) @property - def password_last_used_iso_8601(self): + def password_last_used_iso_8601(self) -> Optional[str]: if self.password_last_used is not None: return iso_8601_datetime_with_milliseconds(self.password_last_used) else: return None - def get_policy(self, policy_name): - policy_json = None + def get_policy(self, policy_name: str) -> Dict[str, str]: try: policy_json = self.policies[policy_name] except KeyError: @@ -1201,19 +1286,19 @@ class User(CloudFormationModel): "user_name": self.name, } - def put_policy(self, policy_name, policy_json): + def put_policy(self, policy_name: str, policy_json: str) -> None: self.policies[policy_name] = policy_json - def deactivate_mfa_device(self, serial_number): + def deactivate_mfa_device(self, serial_number: str) -> None: self.mfa_devices.pop(serial_number) - def delete_policy(self, policy_name): + def delete_policy(self, policy_name: str) -> None: if policy_name not in self.policies: raise IAMNotFoundException(f"Policy {policy_name} not found") del self.policies[policy_name] - def create_access_key(self, prefix, status="Active") -> AccessKey: + def create_access_key(self, prefix: str, status: str = "Active") -> AccessKey: access_key = AccessKey( self.name, prefix=prefix, status=status, account_id=self.account_id ) @@ -1221,26 +1306,28 @@ class User(CloudFormationModel): return access_key def enable_mfa_device( - self, serial_number, authentication_code_1, authentication_code_2 - ): + self, serial_number: str, authentication_code_1: str, authentication_code_2: str + ) -> None: self.mfa_devices[serial_number] = MFADevice( serial_number, authentication_code_1, authentication_code_2 ) - def get_all_access_keys(self): + def get_all_access_keys(self) -> List[AccessKey]: return self.access_keys - def delete_access_key(self, access_key_id): + def delete_access_key(self, access_key_id: str) -> None: key = self.get_access_key_by_id(access_key_id) self.access_keys.remove(key) - def update_access_key(self, access_key_id, status=None): + def update_access_key( + self, access_key_id: str, status: Optional[str] = None + ) -> AccessKey: key = self.get_access_key_by_id(access_key_id) if status is not None: key.status = status return key - def get_access_key_by_id(self, access_key_id): + def get_access_key_by_id(self, access_key_id: str) -> AccessKey: for key in self.access_keys: if key.access_key_id == access_key_id: return key @@ -1249,7 +1336,7 @@ class User(CloudFormationModel): f"The Access Key with id {access_key_id} cannot be found" ) - def has_access_key(self, access_key_id): + def has_access_key(self, access_key_id: str) -> bool: return any( [ access_key @@ -1258,12 +1345,12 @@ class User(CloudFormationModel): ] ) - def upload_ssh_public_key(self, ssh_public_key_body): + def upload_ssh_public_key(self, ssh_public_key_body: str) -> SshPublicKey: pubkey = SshPublicKey(self.name, ssh_public_key_body) self.ssh_public_keys.append(pubkey) return pubkey - def get_ssh_public_key(self, ssh_public_key_id): + def get_ssh_public_key(self, ssh_public_key_id: str) -> SshPublicKey: for key in self.ssh_public_keys: if key.ssh_public_key_id == ssh_public_key_id: return key @@ -1272,29 +1359,29 @@ class User(CloudFormationModel): f"The SSH Public Key with id {ssh_public_key_id} cannot be found" ) - def get_all_ssh_public_keys(self): + def get_all_ssh_public_keys(self) -> List[SshPublicKey]: return self.ssh_public_keys - def update_ssh_public_key(self, ssh_public_key_id, status): + def update_ssh_public_key(self, ssh_public_key_id: str, status: str) -> None: key = self.get_ssh_public_key(ssh_public_key_id) key.status = status - def delete_ssh_public_key(self, ssh_public_key_id): + def delete_ssh_public_key(self, ssh_public_key_id: str) -> None: key = self.get_ssh_public_key(ssh_public_key_id) self.ssh_public_keys.remove(key) @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Arn"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() - def to_csv(self): + def to_csv(self) -> str: date_format = "%Y-%m-%dT%H:%M:%S+00:00" date_created = self.create_date # aagrawal,arn:aws:iam::509284790694:user/aagrawal,2014-09-01T22:28:48+00:00,true,2014-11-12T23:36:49+00:00,2014-09-03T18:59:00+00:00,N/A,false,true,2014-09-01T22:28:48+00:00,false,N/A,false,N/A,false,N/A @@ -1380,31 +1467,36 @@ class User(CloudFormationModel): return ",".join(fields) + "\n" @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "UserName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::IAM::User" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "User": properties = cloudformation_json.get("Properties", {}) path = properties.get("Path") user, _ = iam_backends[account_id]["global"].create_user(resource_name, path) return user @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "User": properties = cloudformation_json["Properties"] if cls.is_replacement_update(properties): @@ -1429,13 +1521,17 @@ class User(CloudFormationModel): return original_resource @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: iam_backends[account_id]["global"].delete_user(resource_name) @staticmethod - def is_replacement_update(properties): + def is_replacement_update(properties: List[str]) -> bool: properties_requiring_replacement_update = ["UserName"] return any( [ @@ -1445,24 +1541,24 @@ class User(CloudFormationModel): ) @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.name class AccountPasswordPolicy(BaseModel): def __init__( self, - allow_change_password, - hard_expiry, - max_password_age, - minimum_password_length, - password_reuse_prevention, - require_lowercase_characters, - require_numbers, - require_symbols, - require_uppercase_characters, + allow_change_password: bool, + hard_expiry: int, + max_password_age: int, + minimum_password_length: int, + password_reuse_prevention: int, + require_lowercase_characters: bool, + require_numbers: bool, + require_symbols: bool, + require_uppercase_characters: bool, ): - self._errors = [] + self._errors: List[str] = [] self._validate( max_password_age, minimum_password_length, password_reuse_prevention ) @@ -1478,12 +1574,15 @@ class AccountPasswordPolicy(BaseModel): self.require_uppercase_characters = require_uppercase_characters @property - def expire_passwords(self): + def expire_passwords(self) -> bool: return True if self.max_password_age and self.max_password_age > 0 else False def _validate( - self, max_password_age, minimum_password_length, password_reuse_prevention - ): + self, + max_password_age: int, + minimum_password_length: int, + password_reuse_prevention: int, + ) -> None: if minimum_password_length > 128: self._errors.append( self._format_error( @@ -1513,10 +1612,10 @@ class AccountPasswordPolicy(BaseModel): self._raise_errors() - def _format_error(self, key, value, constraint): + def _format_error(self, key: str, value: Union[str, int], constraint: str) -> str: return f'Value "{value}" at "{key}" failed to satisfy constraint: {constraint}' - def _raise_errors(self): + def _raise_errors(self) -> None: if self._errors: count = len(self._errors) plural = "s" if len(self._errors) > 1 else "" @@ -1529,7 +1628,7 @@ class AccountPasswordPolicy(BaseModel): class AccountSummary(BaseModel): - def __init__(self, iam_backend): + def __init__(self, iam_backend: "IAMBackend"): self._iam_backend = iam_backend self._group_policy_size_quota = 5120 @@ -1559,7 +1658,7 @@ class AccountSummary(BaseModel): self._groups_quota = 300 @property - def summary_map(self): + def summary_map(self) -> Dict[str, Any]: # type: ignore[misc] return { "GroupPolicySizeQuota": self._group_policy_size_quota, "InstanceProfilesQuota": self._instance_profiles_quota, @@ -1597,20 +1696,20 @@ class AccountSummary(BaseModel): } @property - def _groups(self): + def _groups(self) -> int: return len(self._iam_backend.groups) @property - def _instance_profiles(self): + def _instance_profiles(self) -> int: return len(self._iam_backend.instance_profiles) @property - def _mfa_devices(self): + def _mfa_devices(self) -> int: # Don't know, if hardware devices are also counted here return len(self._iam_backend.virtual_mfa_devices) @property - def _mfa_devices_in_use(self): + def _mfa_devices_in_use(self) -> int: devices = 0 for user in self._iam_backend.users.values(): @@ -1619,7 +1718,7 @@ class AccountSummary(BaseModel): return devices @property - def _policies(self): + def _policies(self) -> int: customer_policies = [ policy for policy in self._iam_backend.managed_policies @@ -1628,7 +1727,7 @@ class AccountSummary(BaseModel): return len(customer_policies) @property - def _policy_versions_in_use(self): + def _policy_versions_in_use(self) -> int: attachments = 0 for policy in self._iam_backend.managed_policies.values(): @@ -1637,53 +1736,59 @@ class AccountSummary(BaseModel): return attachments @property - def _providers(self): - providers = len(self._iam_backend.saml_providers) + len( + def _providers(self) -> int: + return len(self._iam_backend.saml_providers) + len( self._iam_backend.open_id_providers ) - return providers @property - def _roles(self): + def _roles(self) -> int: return len(self._iam_backend.roles) @property - def _server_certificates(self): + def _server_certificates(self) -> int: return len(self._iam_backend.certificates) @property - def _users(self): + def _users(self) -> int: return len(self._iam_backend.users) -def filter_items_with_path_prefix(path_prefix, items): +def filter_items_with_path_prefix( + path_prefix: str, items: Iterable[Any] +) -> Iterable[Any]: return [role for role in items if role.path.startswith(path_prefix)] class IAMBackend(BaseBackend): - def __init__(self, region_name, account_id=None, aws_policies=None): - super().__init__(region_name=region_name, account_id=account_id) - self.instance_profiles = {} - self.roles = {} - self.certificates = {} - self.groups = {} - self.users = {} - self.credential_report = None + def __init__( + self, + region_name: str, + account_id: Optional[str] = None, + aws_policies: Optional[List[ManagedPolicy]] = None, + ): + super().__init__(region_name=region_name, account_id=account_id) # type: ignore + self.instance_profiles: Dict[str, InstanceProfile] = {} + self.roles: Dict[str, Role] = {} + self.certificates: Dict[str, Certificate] = {} + self.groups: Dict[str, Group] = {} + self.users: Dict[str, User] = {} + self.credential_report: Optional[bool] = None self.aws_managed_policies = aws_policies or self._init_aws_policies() self.managed_policies = self._init_managed_policies() - self.account_aliases = [] - self.saml_providers = {} - self.open_id_providers = {} + self.account_aliases: List[str] = [] + self.saml_providers: Dict[str, SAMLProvider] = {} + self.open_id_providers: Dict[str, OpenIDConnectProvider] = {} self.policy_arn_regex = re.compile(r"^arn:aws:iam::(aws|[0-9]*):policy/.*$") - self.virtual_mfa_devices = {} - self.account_password_policy = None + self.virtual_mfa_devices: Dict[str, VirtualMfaDevice] = {} + self.account_password_policy: Optional[AccountPasswordPolicy] = None self.account_summary = AccountSummary(self) - self.inline_policies = {} - self.access_keys = {} + self.inline_policies: Dict[str, InlinePolicy] = {} + self.access_keys: Dict[str, AccessKey] = {} self.tagger = TaggingService() - def _init_aws_policies(self): + def _init_aws_policies(self) -> List[ManagedPolicy]: # AWS defines some of its own managed policies and we periodically # import them via `make aws_managed_policies` aws_managed_policies_data_parsed = json.loads(aws_managed_policies_data) @@ -1692,34 +1797,38 @@ class IAMBackend(BaseBackend): for name, d in aws_managed_policies_data_parsed.items() ] - def _init_managed_policies(self): + def _init_managed_policies(self) -> Dict[str, ManagedPolicy]: return dict((p.arn, p) for p in self.aws_managed_policies) - def reset(self): + def reset(self) -> None: region_name = self.region_name account_id = self.account_id # Do not reset these policies, as they take a long time to load aws_policies = self.aws_managed_policies self.__dict__ = {} - self.__init__(region_name, account_id, aws_policies) + self.__init__(region_name, account_id, aws_policies) # type: ignore[misc] - def attach_role_policy(self, policy_arn, role_name): + def attach_role_policy(self, policy_arn: str, role_name: str) -> None: arns = dict((p.arn, p) for p in self.managed_policies.values()) policy = arns[policy_arn] policy.attach_to(self.get_role(role_name)) - def update_role_description(self, role_name, role_description): + def update_role_description(self, role_name: str, role_description: str) -> Role: role = self.get_role(role_name) role.description = role_description return role - def update_role(self, role_name, role_description, max_session_duration): + def update_role( + self, role_name: str, role_description: str, max_session_duration: str + ) -> Role: role = self.get_role(role_name) role.description = role_description role.max_session_duration = max_session_duration return role - def put_role_permissions_boundary(self, role_name, permissions_boundary): + def put_role_permissions_boundary( + self, role_name: str, permissions_boundary: str + ) -> None: if permissions_boundary and not self.policy_arn_regex.match( permissions_boundary ): @@ -1730,11 +1839,11 @@ class IAMBackend(BaseBackend): role = self.get_role(role_name) role.permissions_boundary = permissions_boundary - def delete_role_permissions_boundary(self, role_name): + def delete_role_permissions_boundary(self, role_name: str) -> None: role = self.get_role(role_name) role.permissions_boundary = None - def detach_role_policy(self, policy_arn, role_name): + def detach_role_policy(self, policy_arn: str, role_name: str) -> None: arns = dict((p.arn, p) for p in self.managed_policies.values()) try: policy = arns[policy_arn] @@ -1744,7 +1853,7 @@ class IAMBackend(BaseBackend): raise IAMNotFoundException(f"Policy {policy_arn} was not found.") policy.detach_from(self.get_role(role_name)) - def attach_group_policy(self, policy_arn, group_name): + def attach_group_policy(self, policy_arn: str, group_name: str) -> None: arns = dict((p.arn, p) for p in self.managed_policies.values()) try: policy = arns[policy_arn] @@ -1754,7 +1863,7 @@ class IAMBackend(BaseBackend): return policy.attach_to(self.get_group(group_name)) - def detach_group_policy(self, policy_arn, group_name): + def detach_group_policy(self, policy_arn: str, group_name: str) -> None: arns = dict((p.arn, p) for p in self.managed_policies.values()) try: policy = arns[policy_arn] @@ -1764,7 +1873,7 @@ class IAMBackend(BaseBackend): raise IAMNotFoundException(f"Policy {policy_arn} was not found.") policy.detach_from(self.get_group(group_name)) - def attach_user_policy(self, policy_arn, user_name): + def attach_user_policy(self, policy_arn: str, user_name: str) -> None: arns = dict((p.arn, p) for p in self.managed_policies.values()) try: policy = arns[policy_arn] @@ -1772,7 +1881,7 @@ class IAMBackend(BaseBackend): raise IAMNotFoundException(f"Policy {policy_arn} was not found.") policy.attach_to(self.get_user(user_name)) - def detach_user_policy(self, policy_arn, user_name): + def detach_user_policy(self, policy_arn: str, user_name: str) -> None: arns = dict((p.arn, p) for p in self.managed_policies.values()) try: policy = arns[policy_arn] @@ -1782,7 +1891,14 @@ class IAMBackend(BaseBackend): raise IAMNotFoundException(f"Policy {policy_arn} was not found.") policy.detach_from(self.get_user(user_name)) - def create_policy(self, description, path, policy_document, policy_name, tags): + def create_policy( + self, + description: str, + path: str, + policy_document: str, + policy_name: str, + tags: List[Dict[str, str]], + ) -> ManagedPolicy: iam_policy_document_validator = IAMPolicyDocumentValidator(policy_document) iam_policy_document_validator.validate() @@ -1802,31 +1918,50 @@ class IAMBackend(BaseBackend): self.managed_policies[policy.arn] = policy return policy - def get_policy(self, policy_arn): + def get_policy(self, policy_arn: str) -> ManagedPolicy: if policy_arn not in self.managed_policies: raise IAMNotFoundException(f"Policy {policy_arn} not found") - return self.managed_policies.get(policy_arn) + return self.managed_policies[policy_arn] def list_attached_role_policies( - self, role_name, marker=None, max_items=100, path_prefix="/" - ): + self, + role_name: str, + marker: Optional[str] = None, + max_items: int = 100, + path_prefix: str = "/", + ) -> Tuple[Iterable[ManagedPolicy], Optional[str]]: policies = self.get_role(role_name).managed_policies.values() return self._filter_attached_policies(policies, marker, max_items, path_prefix) def list_attached_group_policies( - self, group_name, marker=None, max_items=100, path_prefix="/" - ): + self, + group_name: str, + marker: Optional[str] = None, + max_items: int = 100, + path_prefix: str = "/", + ) -> Tuple[Iterable[Dict[str, str]], Optional[str]]: policies = self.get_group(group_name).managed_policies.values() return self._filter_attached_policies(policies, marker, max_items, path_prefix) def list_attached_user_policies( - self, user_name, marker=None, max_items=100, path_prefix="/" - ): + self, + user_name: str, + marker: Optional[str] = None, + max_items: int = 100, + path_prefix: str = "/", + ) -> Tuple[Iterable[Dict[str, str]], Optional[str]]: policies = self.get_user(user_name).managed_policies.values() return self._filter_attached_policies(policies, marker, max_items, path_prefix) - def list_policies(self, marker, max_items, only_attached, path_prefix, scope): - policies = self.managed_policies.values() + def list_policies( + self, + marker: Optional[str], + max_items: int, + only_attached: bool, + path_prefix: str, + scope: str, + ) -> Tuple[Iterable[ManagedPolicy], Optional[str]]: + policies = list(self.managed_policies.values()) if only_attached: policies = [p for p in policies if p.attachment_count > 0] @@ -1838,7 +1973,7 @@ class IAMBackend(BaseBackend): return self._filter_attached_policies(policies, marker, max_items, path_prefix) - def set_default_policy_version(self, policy_arn, version_id): + def set_default_policy_version(self, policy_arn: str, version_id: str) -> bool: if re.match(r"v[1-9][0-9]*(\.[A-Za-z0-9-]*)?", version_id) is None: raise ValidationError( f"Value '{version_id}' at 'versionId' failed to satisfy constraint: Member must satisfy regular expression pattern: v[1-9][0-9]*(\\.[A-Za-z0-9-]*)?" @@ -1855,7 +1990,13 @@ class IAMBackend(BaseBackend): f"Policy {policy_arn} version {version_id} does not exist or is not attachable." ) - def _filter_attached_policies(self, policies, marker, max_items, path_prefix): + def _filter_attached_policies( + self, + policies: Iterable[Any], + marker: Optional[str], + max_items: int, + path_prefix: str, + ) -> Tuple[Iterable[Any], Optional[str]]: if path_prefix: policies = [p for p in policies if p.path.startswith(path_prefix)] @@ -1873,15 +2014,15 @@ class IAMBackend(BaseBackend): def create_role( self, - role_name, - assume_role_policy_document, - path, - permissions_boundary, - description, - tags, - max_session_duration, - linked_service=None, - ): + role_name: str, + assume_role_policy_document: str, + path: str, + permissions_boundary: Optional[str], + description: str, + tags: List[Dict[str, str]], + max_session_duration: Optional[str], + linked_service: Optional[str] = None, + ) -> Role: role_id = random_role_id(self.account_id) if permissions_boundary and not self.policy_arn_regex.match( permissions_boundary @@ -1909,10 +2050,10 @@ class IAMBackend(BaseBackend): self.roles[role_id] = role return role - def get_role_by_id(self, role_id): + def get_role_by_id(self, role_id: str) -> Optional[Role]: return self.roles.get(role_id) - def get_role(self, role_name): + def get_role(self, role_name: str) -> Role: for role in self.get_roles(): if role.name == role_name: return role @@ -1924,7 +2065,7 @@ class IAMBackend(BaseBackend): return role raise IAMNotFoundException(f"Role {arn} not found") - def delete_role(self, role_name): + def delete_role(self, role_name: str) -> None: role = self.get_role(role_name) for instance_profile in self.get_instance_profiles(): for profile_role in instance_profile.roles: @@ -1945,27 +2086,29 @@ class IAMBackend(BaseBackend): ) del self.roles[role.id] - def get_roles(self): + def get_roles(self) -> Iterable[Role]: return self.roles.values() - def update_assume_role_policy(self, role_name, policy_document): + def update_assume_role_policy(self, role_name: str, policy_document: str) -> None: role = self.get_role(role_name) iam_policy_document_validator = IAMTrustPolicyDocumentValidator(policy_document) iam_policy_document_validator.validate() role.assume_role_policy_document = policy_document - def put_role_policy(self, role_name, policy_name, policy_json): + def put_role_policy( + self, role_name: str, policy_name: str, policy_json: str + ) -> None: role = self.get_role(role_name) iam_policy_document_validator = IAMPolicyDocumentValidator(policy_json) iam_policy_document_validator.validate() role.put_policy(policy_name, policy_json) - def delete_role_policy(self, role_name, policy_name): + def delete_role_policy(self, role_name: str, policy_name: str) -> None: role = self.get_role(role_name) role.delete_policy(policy_name) - def get_role_policy(self, role_name, policy_name): + def get_role_policy(self, role_name: str, policy_name: str) -> Tuple[str, str]: role = self.get_role(role_name) for p, d in role.policies.items(): if p == policy_name: @@ -1974,15 +2117,17 @@ class IAMBackend(BaseBackend): f"Policy Document {policy_name} not attached to role {role_name}" ) - def list_role_policies(self, role_name): + def list_role_policies(self, role_name: str) -> List[str]: role = self.get_role(role_name) - return role.policies.keys() + return list(role.policies.keys()) - def _tag_verification(self, tags): + def _tag_verification( + self, tags: List[Dict[str, str]] + ) -> Dict[str, Dict[str, str]]: if len(tags) > 50: raise TooManyTags(tags) - tag_keys = {} + tag_keys: Dict[str, Dict[str, str]] = {} for tag in tags: # Need to index by the lowercase tag key since the keys are case insensitive, but their case is retained. ref_key = tag["Key"].lower() @@ -1995,7 +2140,9 @@ class IAMBackend(BaseBackend): return tag_keys - def _validate_tag_key(self, tag_key, exception_param="tags.X.member.key"): + def _validate_tag_key( + self, tag_key: str, exception_param: str = "tags.X.member.key" + ) -> None: """Validates the tag key. :param tag_key: The tag key to check against. @@ -2014,7 +2161,9 @@ class IAMBackend(BaseBackend): if not len(match) or len(match[0]) < len(tag_key): raise InvalidTagCharacters(tag_key, param=exception_param) - def _check_tag_duplicate(self, all_tags, tag_key): + def _check_tag_duplicate( + self, all_tags: Dict[str, Dict[str, str]], tag_key: str + ) -> None: """Validates that a tag key is not a duplicate :param all_tags: Dict to check if there is a duplicate tag. @@ -2024,7 +2173,9 @@ class IAMBackend(BaseBackend): if tag_key in all_tags: raise DuplicateTags() - def list_role_tags(self, role_name, marker, max_items=100): + def list_role_tags( + self, role_name: str, marker: Optional[str], max_items: int = 100 + ) -> Tuple[List[Dict[str, str]], Optional[str]]: role = self.get_role(role_name) max_items = int(max_items) @@ -2043,12 +2194,12 @@ class IAMBackend(BaseBackend): return tags, marker - def tag_role(self, role_name, tags): + def tag_role(self, role_name: str, tags: List[Dict[str, str]]) -> None: clean_tags = self._tag_verification(tags) role = self.get_role(role_name) role.tags.update(clean_tags) - def untag_role(self, role_name, tag_keys): + def untag_role(self, role_name: str, tag_keys: List[str]) -> None: if len(tag_keys) > 50: raise TooManyTags(tag_keys, param="tagKeys") @@ -2060,7 +2211,9 @@ class IAMBackend(BaseBackend): role.tags.pop(ref_key, None) - def list_policy_tags(self, policy_arn, marker, max_items=100): + def list_policy_tags( + self, policy_arn: str, marker: Optional[str], max_items: int = 100 + ) -> Tuple[List[Dict[str, str]], Optional[str]]: policy = self.get_policy(policy_arn) max_items = int(max_items) @@ -2079,12 +2232,12 @@ class IAMBackend(BaseBackend): return tags, marker - def tag_policy(self, policy_arn, tags): + def tag_policy(self, policy_arn: str, tags: List[Dict[str, str]]) -> None: clean_tags = self._tag_verification(tags) policy = self.get_policy(policy_arn) policy.tags.update(clean_tags) - def untag_policy(self, policy_arn, tag_keys): + def untag_policy(self, policy_arn: str, tag_keys: List[str]) -> None: if len(tag_keys) > 50: raise TooManyTags(tag_keys, param="tagKeys") @@ -2094,9 +2247,11 @@ class IAMBackend(BaseBackend): ref_key = key.lower() self._validate_tag_key(key, exception_param="tagKeys") - policy.tags.pop(ref_key, None) + policy.tags.pop(ref_key, None) # type: ignore[union-attr] - def create_policy_version(self, policy_arn, policy_document, set_as_default): + def create_policy_version( + self, policy_arn: str, policy_document: str, set_as_default: str + ) -> PolicyVersion: iam_policy_document_validator = IAMPolicyDocumentValidator(policy_document) iam_policy_document_validator.validate() @@ -2107,16 +2262,16 @@ class IAMBackend(BaseBackend): raise IAMLimitExceededException( "A managed policy can have up to 5 versions. Before you create a new version, you must delete an existing version." ) - set_as_default = set_as_default == "true" # convert it to python bool - version = PolicyVersion(policy_arn, policy_document, set_as_default) + _as_default = set_as_default == "true" # convert it to python bool + version = PolicyVersion(policy_arn, policy_document, _as_default) policy.versions.append(version) version.version_id = f"v{policy.next_version_num}" policy.next_version_num += 1 - if set_as_default: + if _as_default: policy.update_default_version(version.version_id) return version - def get_policy_version(self, policy_arn, version_id): + def get_policy_version(self, policy_arn: str, version_id: str) -> PolicyVersion: policy = self.get_policy(policy_arn) if not policy: raise IAMNotFoundException("Policy not found") @@ -2125,13 +2280,13 @@ class IAMBackend(BaseBackend): return version raise IAMNotFoundException("Policy version not found") - def list_policy_versions(self, policy_arn): + def list_policy_versions(self, policy_arn: str) -> List[PolicyVersion]: policy = self.get_policy(policy_arn) if not policy: raise IAMNotFoundException("Policy not found") return policy.versions - def delete_policy_version(self, policy_arn, version_id): + def delete_policy_version(self, policy_arn: str, version_id: str) -> None: policy = self.get_policy(policy_arn) if not policy: raise IAMNotFoundException("Policy not found") @@ -2146,7 +2301,13 @@ class IAMBackend(BaseBackend): return raise IAMNotFoundException("Policy not found") - def create_instance_profile(self, name, path, role_names, tags=None): + def create_instance_profile( + self, + name: str, + path: str, + role_names: List[str], + tags: Optional[List[Dict[str, str]]] = None, + ) -> InstanceProfile: if self.instance_profiles.get(name): raise IAMConflictException( code="EntityAlreadyExists", @@ -2162,7 +2323,7 @@ class IAMBackend(BaseBackend): self.instance_profiles[name] = instance_profile return instance_profile - def delete_instance_profile(self, name): + def delete_instance_profile(self, name: str) -> None: instance_profile = self.get_instance_profile(name) if len(instance_profile.roles) > 0: raise IAMConflictException( @@ -2171,24 +2332,24 @@ class IAMBackend(BaseBackend): ) del self.instance_profiles[name] - def get_instance_profile(self, profile_name): + def get_instance_profile(self, profile_name: str) -> InstanceProfile: for profile in self.get_instance_profiles(): if profile.name == profile_name: return profile raise IAMNotFoundException(f"Instance profile {profile_name} not found") - def get_instance_profile_by_arn(self, profile_arn): + def get_instance_profile_by_arn(self, profile_arn: str) -> InstanceProfile: for profile in self.get_instance_profiles(): if profile.arn == profile_arn: return profile raise IAMNotFoundException(f"Instance profile {profile_arn} not found") - def get_instance_profiles(self) -> List[InstanceProfile]: + def get_instance_profiles(self) -> Iterable[InstanceProfile]: return self.instance_profiles.values() - def get_instance_profiles_for_role(self, role_name): + def get_instance_profiles_for_role(self, role_name: str) -> List[InstanceProfile]: found_profiles = [] for profile in self.get_instance_profiles(): @@ -2198,25 +2359,32 @@ class IAMBackend(BaseBackend): return found_profiles - def add_role_to_instance_profile(self, profile_name, role_name): + def add_role_to_instance_profile(self, profile_name: str, role_name: str) -> None: profile = self.get_instance_profile(profile_name) role = self.get_role(role_name) profile.roles.append(role) - def remove_role_from_instance_profile(self, profile_name, role_name): + def remove_role_from_instance_profile( + self, profile_name: str, role_name: str + ) -> None: profile = self.get_instance_profile(profile_name) role = self.get_role(role_name) profile.roles.remove(role) - def list_server_certificates(self): + def list_server_certificates(self) -> Iterable[Certificate]: """ Pagination is not yet implemented """ return self.certificates.values() def upload_server_certificate( - self, cert_name, cert_body, private_key, cert_chain=None, path=None - ): + self, + cert_name: str, + cert_body: str, + private_key: str, + cert_chain: Optional[str] = None, + path: Optional[str] = None, + ) -> Certificate: certificate_id = random_resource_id() cert = Certificate( self.account_id, cert_name, cert_body, private_key, cert_chain, path @@ -2224,7 +2392,7 @@ class IAMBackend(BaseBackend): self.certificates[certificate_id] = cert return cert - def get_server_certificate(self, name): + def get_server_certificate(self, name: str) -> Certificate: for cert in self.certificates.values(): if name == cert.cert_name: return cert @@ -2233,13 +2401,13 @@ class IAMBackend(BaseBackend): f"The Server Certificate with name {name} cannot be found." ) - def get_certificate_by_arn(self, arn): + def get_certificate_by_arn(self, arn: str) -> Optional[Certificate]: for cert in self.certificates.values(): if arn == cert.arn: return cert return None - def delete_server_certificate(self, name): + def delete_server_certificate(self, name: str) -> None: cert_id = None for key, cert in self.certificates.items(): if name == cert.cert_name: @@ -2253,7 +2421,7 @@ class IAMBackend(BaseBackend): self.certificates.pop(cert_id, None) - def create_group(self, group_name, path="/"): + def create_group(self, group_name: str, path: str = "/") -> Group: if group_name in self.groups: raise IAMConflictException(f"Group {group_name} already exists") @@ -2261,7 +2429,7 @@ class IAMBackend(BaseBackend): self.groups[group_name] = group return group - def get_group(self, group_name): + def get_group(self, group_name: str) -> Group: """ Pagination is not yet implemented """ @@ -2270,10 +2438,10 @@ class IAMBackend(BaseBackend): except KeyError: raise IAMNotFoundException(f"Group {group_name} not found") - def list_groups(self): + def list_groups(self) -> Iterable[Group]: return self.groups.values() - def get_groups_for_user(self, user_name): + def get_groups_for_user(self, user_name: str) -> List[Group]: user = self.get_user(user_name) groups = [] for group in self.list_groups(): @@ -2282,29 +2450,31 @@ class IAMBackend(BaseBackend): return groups - def put_group_policy(self, group_name, policy_name, policy_json): + def put_group_policy( + self, group_name: str, policy_name: str, policy_json: str + ) -> None: group = self.get_group(group_name) iam_policy_document_validator = IAMPolicyDocumentValidator(policy_json) iam_policy_document_validator.validate() group.put_policy(policy_name, policy_json) - def list_group_policies(self, group_name): + def list_group_policies(self, group_name: str) -> List[str]: """ Pagination is not yet implemented """ group = self.get_group(group_name) return group.list_policies() - def delete_group_policy(self, group_name, policy_name): + def delete_group_policy(self, group_name: str, policy_name: str) -> None: group = self.get_group(group_name) group.delete_policy(policy_name) - def get_group_policy(self, group_name, policy_name): + def get_group_policy(self, group_name: str, policy_name: str) -> Dict[str, str]: group = self.get_group(group_name) return group.get_policy(policy_name) - def delete_group(self, group_name): + def delete_group(self, group_name: str) -> None: try: del self.groups[group_name] except KeyError: @@ -2312,7 +2482,9 @@ class IAMBackend(BaseBackend): f"The group with name {group_name} cannot be found." ) - def update_group(self, group_name, new_group_name, new_path): + def update_group( + self, group_name: str, new_group_name: Optional[str], new_path: Optional[str] + ) -> None: if new_group_name: if new_group_name in self.groups: raise IAMConflictException( @@ -2335,7 +2507,12 @@ class IAMBackend(BaseBackend): for policy_arn in existing_policies: self.attach_group_policy(policy_arn, new_group_name) - def create_user(self, user_name, path="/", tags=None): + def create_user( + self, + user_name: str, + path: str = "/", + tags: Optional[List[Dict[str, str]]] = None, + ) -> Tuple[User, Dict[str, List[Dict[str, str]]]]: if user_name in self.users: raise IAMConflictException( "EntityAlreadyExists", f"User {user_name} already exists" @@ -2346,7 +2523,7 @@ class IAMBackend(BaseBackend): self.users[user_name] = user return user, self.tagger.list_tags_for_resource(user.arn) - def get_user(self, name) -> User: + def get_user(self, name: str) -> User: user = self.users.get(name) if not user: @@ -2354,11 +2531,14 @@ class IAMBackend(BaseBackend): return user - def list_users(self, path_prefix, marker, max_items): - users = None + def list_users( + self, + path_prefix: Optional[str], + marker: Optional[str], + max_items: Optional[int], + ) -> Iterable[User]: try: - - users = self.users.values() + users: Iterable[User] = list(self.users.values()) if path_prefix: users = filter_items_with_path_prefix(path_prefix, users) @@ -2369,7 +2549,12 @@ class IAMBackend(BaseBackend): return users - def update_user(self, user_name, new_path=None, new_user_name=None): + def update_user( + self, + user_name: str, + new_path: Optional[str] = None, + new_user_name: Optional[str] = None, + ) -> None: try: user = self.users[user_name] except KeyError: @@ -2381,12 +2566,17 @@ class IAMBackend(BaseBackend): user.name = new_user_name self.users[new_user_name] = self.users.pop(user_name) - def list_roles(self, path_prefix=None, marker=None, max_items=None): + def list_roles( + self, + path_prefix: Optional[str] = None, + marker: Optional[str] = None, + max_items: Optional[int] = None, + ) -> Tuple[List[Role], Optional[str]]: path_prefix = path_prefix if path_prefix else "/" max_items = int(max_items) if max_items else 100 start_index = int(marker) if marker else 0 - roles = self.roles.values() + roles: Iterable[Role] = list(self.roles.values()) roles = filter_items_with_path_prefix(path_prefix, roles) sorted_roles = sorted(roles, key=lambda role: role.id) @@ -2399,7 +2589,9 @@ class IAMBackend(BaseBackend): return roles_to_return, marker - def upload_signing_certificate(self, user_name, body): + def upload_signing_certificate( + self, user_name: str, body: str + ) -> SigningCertificate: user = self.get_user(user_name) cert_id = random_resource_id(size=32) @@ -2418,7 +2610,7 @@ class IAMBackend(BaseBackend): return user.signing_certificates[cert_id] - def delete_signing_certificate(self, user_name, cert_id): + def delete_signing_certificate(self, user_name: str, cert_id: str) -> None: user = self.get_user(user_name) try: @@ -2428,12 +2620,14 @@ class IAMBackend(BaseBackend): f"The Certificate with id {cert_id} cannot be found." ) - def list_signing_certificates(self, user_name): + def list_signing_certificates(self, user_name: str) -> List[SigningCertificate]: user = self.get_user(user_name) return list(user.signing_certificates.values()) - def update_signing_certificate(self, user_name, cert_id, status): + def update_signing_certificate( + self, user_name: str, cert_id: str, status: str + ) -> None: user = self.get_user(user_name) try: @@ -2444,7 +2638,7 @@ class IAMBackend(BaseBackend): f"The Certificate with id {cert_id} cannot be found." ) - def create_login_profile(self, user_name, password): + def create_login_profile(self, user_name: str, password: str) -> User: # This does not currently deal with PasswordPolicyViolation. user = self.get_user(user_name) if user.password: @@ -2452,13 +2646,15 @@ class IAMBackend(BaseBackend): user.password = password return user - def get_login_profile(self, user_name): + def get_login_profile(self, user_name: str) -> User: user = self.get_user(user_name) if not user.password: raise IAMNotFoundException(f"Login profile for {user_name} not found") return user - def update_login_profile(self, user_name, password, password_reset_required): + def update_login_profile( + self, user_name: str, password: str, password_reset_required: bool + ) -> User: # This does not currently deal with PasswordPolicyViolation. user = self.get_user(user_name) if not user.password: @@ -2467,19 +2663,19 @@ class IAMBackend(BaseBackend): user.password_reset_required = password_reset_required return user - def delete_login_profile(self, user_name): + def delete_login_profile(self, user_name: str) -> None: user = self.get_user(user_name) if not user.password: raise IAMNotFoundException(f"Login profile for {user_name} not found") user.password = None - def add_user_to_group(self, group_name, user_name): + def add_user_to_group(self, group_name: str, user_name: str) -> None: user = self.get_user(user_name) group = self.get_group(group_name) if user not in group.users: group.users.append(user) - def remove_user_from_group(self, group_name, user_name): + def remove_user_from_group(self, group_name: str, user_name: str) -> None: group = self.get_group(group_name) user = self.get_user(user_name) try: @@ -2487,35 +2683,38 @@ class IAMBackend(BaseBackend): except ValueError: raise IAMNotFoundException(f"User {user_name} not in group {group_name}") - def get_user_policy(self, user_name, policy_name): + def get_user_policy(self, user_name: str, policy_name: str) -> Dict[str, str]: user = self.get_user(user_name) - policy = user.get_policy(policy_name) - return policy + return user.get_policy(policy_name) - def list_user_policies(self, user_name): + def list_user_policies(self, user_name: str) -> Iterable[str]: user = self.get_user(user_name) return user.policies.keys() - def list_user_tags(self, user_name): + def list_user_tags(self, user_name: str) -> Dict[str, List[Dict[str, str]]]: user = self.get_user(user_name) return self.tagger.list_tags_for_resource(user.arn) - def put_user_policy(self, user_name, policy_name, policy_json): + def put_user_policy( + self, user_name: str, policy_name: str, policy_json: str + ) -> None: user = self.get_user(user_name) iam_policy_document_validator = IAMPolicyDocumentValidator(policy_json) iam_policy_document_validator.validate() user.put_policy(policy_name, policy_json) - def delete_user_policy(self, user_name, policy_name): + def delete_user_policy(self, user_name: str, policy_name: str) -> None: user = self.get_user(user_name) user.delete_policy(policy_name) - def delete_policy(self, policy_arn): + def delete_policy(self, policy_arn: str) -> None: policy = self.get_policy(policy_arn) del self.managed_policies[policy.arn] - def create_access_key(self, user_name=None, prefix="AKIA", status="Active"): + def create_access_key( + self, user_name: str, prefix: str = "AKIA", status: str = "Active" + ) -> AccessKey: keys = self.list_access_keys(user_name) if len(keys) >= LIMIT_KEYS_PER_USER: raise IAMLimitExceededException( @@ -2526,18 +2725,20 @@ class IAMBackend(BaseBackend): self.access_keys[key.physical_resource_id] = key return key - def create_temp_access_key(self): + def create_temp_access_key(self) -> AccessKey: # Temporary access keys such as the ones returned by STS when assuming a role temporarily key = AccessKey(user_name=None, prefix="ASIA", account_id=self.account_id) self.access_keys[key.physical_resource_id] = key return key - def update_access_key(self, user_name, access_key_id, status=None): + def update_access_key( + self, user_name: str, access_key_id: str, status: Optional[str] = None + ) -> AccessKey: user = self.get_user(user_name) return user.update_access_key(access_key_id, status) - def get_access_key_last_used(self, access_key_id): + def get_access_key_last_used(self, access_key_id: str) -> Dict[str, Any]: access_keys_list = self.get_all_access_keys_for_all_users() for key in access_keys_list: if key.access_key_id == access_key_id: @@ -2547,58 +2748,67 @@ class IAMBackend(BaseBackend): f"The Access Key with id {access_key_id} cannot be found" ) - def get_all_access_keys_for_all_users(self): + def get_all_access_keys_for_all_users(self) -> List[AccessKey]: access_keys_list = [] for account in iam_backends.values(): for user_name in account["global"].users: access_keys_list += account["global"].list_access_keys(user_name) return access_keys_list - def list_access_keys(self, user_name): + def list_access_keys(self, user_name: str) -> List[AccessKey]: """ Pagination is not yet implemented """ user = self.get_user(user_name) - keys = user.get_all_access_keys() - return keys + return user.get_all_access_keys() - def delete_access_key(self, access_key_id, user_name): + def delete_access_key(self, access_key_id: str, user_name: str) -> None: user = self.get_user(user_name) access_key = user.get_access_key_by_id(access_key_id) self.delete_access_key_by_name(access_key.access_key_id) - def delete_access_key_by_name(self, name): + def delete_access_key_by_name(self, name: str) -> None: key = self.access_keys[name] try: # User may have been deleted before their access key... - user = self.get_user(key.user_name) + user = self.get_user(key.user_name) # type: ignore user.delete_access_key(key.access_key_id) except NoSuchEntity: pass del self.access_keys[name] - def upload_ssh_public_key(self, user_name, ssh_public_key_body): + def upload_ssh_public_key( + self, user_name: str, ssh_public_key_body: str + ) -> SshPublicKey: user = self.get_user(user_name) return user.upload_ssh_public_key(ssh_public_key_body) - def get_ssh_public_key(self, user_name, ssh_public_key_id): + def get_ssh_public_key( + self, user_name: str, ssh_public_key_id: str + ) -> SshPublicKey: user = self.get_user(user_name) return user.get_ssh_public_key(ssh_public_key_id) - def get_all_ssh_public_keys(self, user_name): + def get_all_ssh_public_keys(self, user_name: str) -> Iterable[SshPublicKey]: user = self.get_user(user_name) return user.get_all_ssh_public_keys() - def update_ssh_public_key(self, user_name, ssh_public_key_id, status): + def update_ssh_public_key( + self, user_name: str, ssh_public_key_id: str, status: str + ) -> None: user = self.get_user(user_name) - return user.update_ssh_public_key(ssh_public_key_id, status) + user.update_ssh_public_key(ssh_public_key_id, status) - def delete_ssh_public_key(self, user_name, ssh_public_key_id): + def delete_ssh_public_key(self, user_name: str, ssh_public_key_id: str) -> None: user = self.get_user(user_name) - return user.delete_ssh_public_key(ssh_public_key_id) + user.delete_ssh_public_key(ssh_public_key_id) def enable_mfa_device( - self, user_name, serial_number, authentication_code_1, authentication_code_2 - ): + self, + user_name: str, + serial_number: str, + authentication_code_1: str, + authentication_code_2: str, + ) -> None: """Enable MFA Device for user.""" user = self.get_user(user_name) if serial_number in user.mfa_devices: @@ -2625,7 +2835,7 @@ class IAMBackend(BaseBackend): serial_number, authentication_code_1, authentication_code_2 ) - def deactivate_mfa_device(self, user_name, serial_number): + def deactivate_mfa_device(self, user_name: str, serial_number: str) -> None: """Deactivate and detach MFA Device from user if device exists.""" user = self.get_user(user_name) if serial_number not in user.mfa_devices: @@ -2639,11 +2849,13 @@ class IAMBackend(BaseBackend): user.deactivate_mfa_device(serial_number) - def list_mfa_devices(self, user_name): + def list_mfa_devices(self, user_name: str) -> Iterable[MFADevice]: user = self.get_user(user_name) return user.mfa_devices.values() - def create_virtual_mfa_device(self, device_name, path): + def create_virtual_mfa_device( + self, device_name: str, path: str + ) -> VirtualMfaDevice: if not path: path = "/" @@ -2676,7 +2888,7 @@ class IAMBackend(BaseBackend): self.virtual_mfa_devices[device.serial_number] = device return device - def delete_virtual_mfa_device(self, serial_number): + def delete_virtual_mfa_device(self, serial_number: str) -> None: device = self.virtual_mfa_devices.pop(serial_number, None) if not device: @@ -2684,7 +2896,9 @@ class IAMBackend(BaseBackend): f"VirtualMFADevice with serial number {serial_number} doesn't exist." ) - def list_virtual_mfa_devices(self, assignment_status, marker, max_items): + def list_virtual_mfa_devices( + self, assignment_status: str, marker: Optional[str], max_items: int + ) -> Tuple[List[VirtualMfaDevice], Optional[str]]: devices = list(self.virtual_mfa_devices.values()) if assignment_status == "Assigned": @@ -2709,7 +2923,7 @@ class IAMBackend(BaseBackend): return devices, marker - def delete_user(self, user_name): + def delete_user(self, user_name: str) -> None: user = self.get_user(user_name) if user.managed_policies: raise IAMConflictException( @@ -2724,13 +2938,13 @@ class IAMBackend(BaseBackend): self.tagger.delete_all_tags_for_resource(user.arn) del self.users[user_name] - def report_generated(self): + def report_generated(self) -> Optional[bool]: return self.credential_report - def generate_report(self): + def generate_report(self) -> None: self.credential_report = True - def get_credential_report(self): + def get_credential_report(self) -> str: if not self.credential_report: raise IAMReportNotPresentException("Credential report not present") report = "user,arn,user_creation_time,password_enabled,password_last_used,password_last_changed,password_next_rotation,mfa_active,access_key_1_active,access_key_1_last_rotated,access_key_1_last_used_date,access_key_1_last_used_region,access_key_1_last_used_service,access_key_2_active,access_key_2_last_rotated,access_key_2_last_used_date,access_key_2_last_used_region,access_key_2_last_used_service,cert_1_active,cert_1_last_rotated,cert_2_active,cert_2_last_rotated\n" @@ -2738,17 +2952,19 @@ class IAMBackend(BaseBackend): report += self.users[user].to_csv() return base64.b64encode(report.encode("ascii")).decode("ascii") - def list_account_aliases(self): + def list_account_aliases(self) -> List[str]: return self.account_aliases - def create_account_alias(self, alias): + def create_account_alias(self, alias: str) -> None: # alias is force updated self.account_aliases = [alias] - def delete_account_alias(self): + def delete_account_alias(self) -> None: self.account_aliases = [] - def get_account_authorization_details(self, policy_filter): + def get_account_authorization_details( + self, policy_filter: List[str] + ) -> Dict[str, Any]: policies = self.managed_policies.values() local_policies = set(policies) - set(self.aws_managed_policies) returned_policies = [] @@ -2775,17 +2991,21 @@ class IAMBackend(BaseBackend): "managed_policies": returned_policies, } - def create_saml_provider(self, name, saml_metadata_document): + def create_saml_provider( + self, name: str, saml_metadata_document: str + ) -> SAMLProvider: saml_provider = SAMLProvider(self.account_id, name, saml_metadata_document) self.saml_providers[name] = saml_provider return saml_provider - def update_saml_provider(self, saml_provider_arn, saml_metadata_document): + def update_saml_provider( + self, saml_provider_arn: str, saml_metadata_document: str + ) -> SAMLProvider: saml_provider = self.get_saml_provider(saml_provider_arn) saml_provider.saml_metadata_document = saml_metadata_document return saml_provider - def delete_saml_provider(self, saml_provider_arn): + def delete_saml_provider(self, saml_provider_arn: str) -> None: try: for saml_provider in list(self.list_saml_providers()): if saml_provider.arn == saml_provider_arn: @@ -2793,16 +3013,16 @@ class IAMBackend(BaseBackend): except KeyError: raise IAMNotFoundException(f"SAMLProvider {saml_provider_arn} not found") - def list_saml_providers(self): + def list_saml_providers(self) -> Iterable[SAMLProvider]: return self.saml_providers.values() - def get_saml_provider(self, saml_provider_arn): + def get_saml_provider(self, saml_provider_arn: str) -> SAMLProvider: for saml_provider in self.list_saml_providers(): if saml_provider.arn == saml_provider_arn: return saml_provider raise IAMNotFoundException(f"SamlProvider {saml_provider_arn} not found") - def get_user_from_access_key_id(self, access_key_id): + def get_user_from_access_key_id(self, access_key_id: str) -> Optional[User]: for user_name, user in self.users.items(): access_keys = self.list_access_keys(user_name) for access_key in access_keys: @@ -2811,8 +3031,12 @@ class IAMBackend(BaseBackend): return None def create_open_id_connect_provider( - self, url, thumbprint_list, client_id_list, tags - ): + self, + url: str, + thumbprint_list: List[str], + client_id_list: List[str], + tags: List[Dict[str, str]], + ) -> OpenIDConnectProvider: clean_tags = self._tag_verification(tags) open_id_provider = OpenIDConnectProvider( self.account_id, url, thumbprint_list, client_id_list, clean_tags @@ -2824,16 +3048,20 @@ class IAMBackend(BaseBackend): self.open_id_providers[open_id_provider.arn] = open_id_provider return open_id_provider - def update_open_id_connect_provider_thumbprint(self, arn, thumbprint_list): + def update_open_id_connect_provider_thumbprint( + self, arn: str, thumbprint_list: List[str] + ) -> None: open_id_provider = self.get_open_id_connect_provider(arn) open_id_provider.thumbprint_list = thumbprint_list - def tag_open_id_connect_provider(self, arn, tags): + def tag_open_id_connect_provider( + self, arn: str, tags: List[Dict[str, str]] + ) -> None: open_id_provider = self.get_open_id_connect_provider(arn) clean_tags = self._tag_verification(tags) open_id_provider.tags.update(clean_tags) - def untag_open_id_connect_provider(self, arn, tag_keys): + def untag_open_id_connect_provider(self, arn: str, tag_keys: List[str]) -> None: open_id_provider = self.get_open_id_connect_provider(arn) for key in tag_keys: @@ -2841,7 +3069,9 @@ class IAMBackend(BaseBackend): self._validate_tag_key(key, exception_param="tagKeys") open_id_provider.tags.pop(ref_key, None) - def list_open_id_connect_provider_tags(self, arn, marker, max_items=100): + def list_open_id_connect_provider_tags( + self, arn: str, marker: Optional[str], max_items: int = 100 + ) -> Tuple[List[Dict[str, str]], Optional[str]]: open_id_provider = self.get_open_id_connect_provider(arn) max_items = int(max_items) @@ -2858,10 +3088,10 @@ class IAMBackend(BaseBackend): tags = [open_id_provider.tags[tag] for tag in tag_index] return tags, marker - def delete_open_id_connect_provider(self, arn): + def delete_open_id_connect_provider(self, arn: str) -> None: self.open_id_providers.pop(arn, None) - def get_open_id_connect_provider(self, arn): + def get_open_id_connect_provider(self, arn: str) -> OpenIDConnectProvider: open_id_provider = self.open_id_providers.get(arn) if not open_id_provider: @@ -2871,21 +3101,21 @@ class IAMBackend(BaseBackend): return open_id_provider - def list_open_id_connect_providers(self): + def list_open_id_connect_providers(self) -> List[str]: return list(self.open_id_providers.keys()) def update_account_password_policy( self, - allow_change_password, - hard_expiry, - max_password_age, - minimum_password_length, - password_reuse_prevention, - require_lowercase_characters, - require_numbers, - require_symbols, - require_uppercase_characters, - ): + allow_change_password: bool, + hard_expiry: int, + max_password_age: int, + minimum_password_length: int, + password_reuse_prevention: int, + require_lowercase_characters: bool, + require_numbers: bool, + require_symbols: bool, + require_uppercase_characters: bool, + ) -> None: self.account_password_policy = AccountPasswordPolicy( allow_change_password, hard_expiry, @@ -2898,7 +3128,7 @@ class IAMBackend(BaseBackend): require_uppercase_characters, ) - def get_account_password_policy(self): + def get_account_password_policy(self) -> AccountPasswordPolicy: if not self.account_password_policy: raise NoSuchEntity( f"The Password Policy with domain name {self.account_id} cannot be found." @@ -2906,7 +3136,7 @@ class IAMBackend(BaseBackend): return self.account_password_policy - def delete_account_password_policy(self): + def delete_account_password_policy(self) -> None: if not self.account_password_policy: raise NoSuchEntity( "The account policy with name PasswordPolicy cannot be found." @@ -2914,18 +3144,18 @@ class IAMBackend(BaseBackend): self.account_password_policy = None - def get_account_summary(self): + def get_account_summary(self) -> AccountSummary: return self.account_summary def create_inline_policy( self, - resource_name, - policy_name, - policy_document, - group_names, - role_names, - user_names, - ): + resource_name: str, + policy_name: str, + policy_document: str, + group_names: List[str], + role_names: List[str], + user_names: List[str], + ) -> InlinePolicy: if resource_name in self.inline_policies: raise IAMConflictException( "EntityAlreadyExists", f"Inline Policy {resource_name} already exists" @@ -2943,7 +3173,7 @@ class IAMBackend(BaseBackend): inline_policy.apply_policy(self) return inline_policy - def get_inline_policy(self, policy_id): + def get_inline_policy(self, policy_id: str) -> InlinePolicy: try: return self.inline_policies[policy_id] except KeyError: @@ -2951,13 +3181,13 @@ class IAMBackend(BaseBackend): def update_inline_policy( self, - resource_name, - policy_name, - policy_document, - group_names, - role_names, - user_names, - ): + resource_name: str, + policy_name: str, + policy_document: str, + group_names: List[str], + role_names: List[str], + user_names: List[str], + ) -> InlinePolicy: inline_policy = self.get_inline_policy(resource_name) inline_policy.unapply_policy(self) inline_policy.update( @@ -2966,22 +3196,24 @@ class IAMBackend(BaseBackend): inline_policy.apply_policy(self) return inline_policy - def delete_inline_policy(self, policy_id): + def delete_inline_policy(self, policy_id: str) -> None: inline_policy = self.get_inline_policy(policy_id) inline_policy.unapply_policy(self) del self.inline_policies[policy_id] - def tag_user(self, name, tags): + def tag_user(self, name: str, tags: List[Dict[str, str]]) -> None: user = self.get_user(name) self.tagger.tag_resource(user.arn, tags) - def untag_user(self, name, tag_keys): + def untag_user(self, name: str, tag_keys: List[str]) -> None: user = self.get_user(name) self.tagger.untag_resource_using_names(user.arn, tag_keys) - def create_service_linked_role(self, service_name, description, suffix): + def create_service_linked_role( + self, service_name: str, description: str, suffix: str + ) -> Role: # service.amazonaws.com -> Service # some-thing.service.amazonaws.com -> Service_SomeThing service = service_name.split(".")[-3] @@ -3016,12 +3248,12 @@ class IAMBackend(BaseBackend): linked_service=service_name, ) - def delete_service_linked_role(self, role_name): + def delete_service_linked_role(self, role_name: str) -> str: self.delete_role(role_name) deletion_task_id = str(random.uuid4()) return deletion_task_id - def get_service_linked_role_deletion_status(self): + def get_service_linked_role_deletion_status(self) -> bool: """ This method always succeeds for now - we do not yet keep track of deletions """ diff --git a/moto/iam/policy_validation.py b/moto/iam/policy_validation.py index 90e09c863..eff39e3a4 100644 --- a/moto/iam/policy_validation.py +++ b/moto/iam/policy_validation.py @@ -1,6 +1,6 @@ import json import re - +from typing import Any, Dict, List from moto.iam.exceptions import MalformedPolicyDocument @@ -61,7 +61,7 @@ SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS = { "s3": "Resource {resource} can not contain region information.", } -VALID_RESOURCE_PATH_STARTING_VALUES = { +VALID_RESOURCE_PATH_STARTING_VALUES: Dict[str, Any] = { "iam": { "values": [ "user/", @@ -84,13 +84,13 @@ VALID_RESOURCE_PATH_STARTING_VALUES = { class BaseIAMPolicyValidator: - def __init__(self, policy_document): + def __init__(self, policy_document: str): self._policy_document = policy_document - self._policy_json = {} - self._statements = [] + self._policy_json: Dict[str, Any] = {} + self._statements: List[Dict[str, Any]] = [] self._resource_error = "" # the first resource error found that does not generate a legacy parsing error - def validate(self): + def validate(self) -> None: try: self._validate_syntax() except Exception: @@ -124,7 +124,7 @@ class BaseIAMPolicyValidator: self._validate_actions_for_prefixes() self._validate_not_actions_for_prefixes() - def _validate_syntax(self): + def _validate_syntax(self) -> None: self._policy_json = json.loads(self._policy_document) assert isinstance(self._policy_json, dict) self._validate_top_elements() @@ -132,19 +132,19 @@ class BaseIAMPolicyValidator: self._validate_id_syntax() self._validate_statements_syntax() - def _validate_top_elements(self): + def _validate_top_elements(self) -> None: top_elements = self._policy_json.keys() for element in top_elements: assert element in VALID_TOP_ELEMENTS - def _validate_version_syntax(self): + def _validate_version_syntax(self) -> None: if "Version" in self._policy_json: assert self._policy_json["Version"] in VALID_VERSIONS - def _validate_version(self): + def _validate_version(self) -> None: assert self._policy_json["Version"] == "2012-10-17" - def _validate_sid_uniqueness(self): + def _validate_sid_uniqueness(self) -> None: sids = [] for statement in self._statements: if "Sid" in statement: @@ -153,7 +153,7 @@ class BaseIAMPolicyValidator: assert statementId not in sids sids.append(statementId) - def _validate_statements_syntax(self): + def _validate_statements_syntax(self) -> None: assert "Statement" in self._policy_json assert isinstance(self._policy_json["Statement"], (dict, list)) @@ -167,7 +167,7 @@ class BaseIAMPolicyValidator: self._validate_statement_syntax(statement) @staticmethod - def _validate_statement_syntax(statement): + def _validate_statement_syntax(statement: Dict[str, Any]) -> None: # type: ignore[misc] assert isinstance(statement, dict) for statement_element in statement.keys(): assert statement_element in VALID_STATEMENT_ELEMENTS @@ -184,7 +184,7 @@ class BaseIAMPolicyValidator: IAMPolicyDocumentValidator._validate_sid_syntax(statement) @staticmethod - def _validate_effect_syntax(statement): + def _validate_effect_syntax(statement: Dict[str, Any]) -> None: # type: ignore[misc] assert "Effect" in statement assert isinstance(statement["Effect"], str) assert statement["Effect"].lower() in [ @@ -192,31 +192,31 @@ class BaseIAMPolicyValidator: ] @staticmethod - def _validate_action_syntax(statement): + def _validate_action_syntax(statement: Dict[str, Any]) -> None: # type: ignore[misc] IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( statement, "Action" ) @staticmethod - def _validate_not_action_syntax(statement): + def _validate_not_action_syntax(statement: Dict[str, Any]) -> None: # type: ignore[misc] IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( statement, "NotAction" ) @staticmethod - def _validate_resource_syntax(statement): + def _validate_resource_syntax(statement: Dict[str, Any]) -> None: # type: ignore[misc] IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( statement, "Resource" ) @staticmethod - def _validate_not_resource_syntax(statement): + def _validate_not_resource_syntax(statement: Dict[str, Any]) -> None: # type: ignore[misc] IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( statement, "NotResource" ) @staticmethod - def _validate_string_or_list_of_strings_syntax(statement, key): + def _validate_string_or_list_of_strings_syntax(statement: Dict[str, Any], key: str) -> None: # type: ignore[misc] if key in statement: assert isinstance(statement[key], (str, list)) if isinstance(statement[key], list): @@ -224,7 +224,7 @@ class BaseIAMPolicyValidator: assert isinstance(resource, str) @staticmethod - def _validate_condition_syntax(statement): + def _validate_condition_syntax(statement: Dict[str, Any]) -> None: # type: ignore[misc] if "Condition" in statement: assert isinstance(statement["Condition"], dict) for condition_key, condition_value in statement["Condition"].items(): @@ -239,7 +239,7 @@ class BaseIAMPolicyValidator: assert not condition_value # empty dict @staticmethod - def _strip_condition_key(condition_key): + def _strip_condition_key(condition_key: str) -> str: for valid_prefix in VALID_CONDITION_PREFIXES: if condition_key.startswith(valid_prefix): condition_key = condition_key[len(valid_prefix) :] @@ -253,15 +253,15 @@ class BaseIAMPolicyValidator: return condition_key @staticmethod - def _validate_sid_syntax(statement): + def _validate_sid_syntax(statement: Dict[str, Any]) -> None: # type: ignore[misc] if "Sid" in statement: assert isinstance(statement["Sid"], str) - def _validate_id_syntax(self): + def _validate_id_syntax(self) -> None: if "Id" in self._policy_json: assert isinstance(self._policy_json["Id"], str) - def _validate_resource_exist(self): + def _validate_resource_exist(self) -> None: for statement in self._statements: assert "Resource" in statement or "NotResource" in statement if "Resource" in statement and isinstance(statement["Resource"], list): @@ -271,7 +271,7 @@ class BaseIAMPolicyValidator: ): assert statement["NotResource"] - def _validate_action_like_exist(self): + def _validate_action_like_exist(self) -> None: for statement in self._statements: assert "Action" in statement or "NotAction" in statement if "Action" in statement and isinstance(statement["Action"], list): @@ -279,13 +279,13 @@ class BaseIAMPolicyValidator: elif "NotAction" in statement and isinstance(statement["NotAction"], list): assert statement["NotAction"] - def _validate_actions_for_prefixes(self): + def _validate_actions_for_prefixes(self) -> None: self._validate_action_like_for_prefixes("Action") - def _validate_not_actions_for_prefixes(self): + def _validate_not_actions_for_prefixes(self) -> None: self._validate_action_like_for_prefixes("NotAction") - def _validate_action_like_for_prefixes(self, key): + def _validate_action_like_for_prefixes(self, key: str) -> None: for statement in self._statements: if key in statement: if isinstance(statement[key], str): @@ -295,7 +295,7 @@ class BaseIAMPolicyValidator: self._validate_action_prefix(action) @staticmethod - def _validate_action_prefix(action): + def _validate_action_prefix(action: str) -> None: action_parts = action.split(":") if len(action_parts) == 1 and action_parts[0] != "*": raise MalformedPolicyDocument( @@ -310,13 +310,13 @@ class BaseIAMPolicyValidator: if action_parts[0] != "*" and vendor_pattern.search(action_parts[0]): raise MalformedPolicyDocument(f"Vendor {action_parts[0]} is not valid") - def _validate_resources_for_formats(self): + def _validate_resources_for_formats(self) -> None: self._validate_resource_like_for_formats("Resource") - def _validate_not_resources_for_formats(self): + def _validate_not_resources_for_formats(self) -> None: self._validate_resource_like_for_formats("NotResource") - def _validate_resource_like_for_formats(self, key): + def _validate_resource_like_for_formats(self, key: str) -> None: for statement in self._statements: if key in statement: if isinstance(statement[key], str): @@ -329,7 +329,7 @@ class BaseIAMPolicyValidator: statement, key ) - def _validate_resource_format(self, resource): + def _validate_resource_format(self, resource: str) -> None: if resource != "*": resource_partitions = resource.partition(":") @@ -407,13 +407,13 @@ class BaseIAMPolicyValidator: ) ) - def _perform_first_legacy_parsing(self): + def _perform_first_legacy_parsing(self) -> None: """This method excludes legacy parsing resources, since that have to be done later.""" for statement in self._statements: self._legacy_parse_statement(statement) @staticmethod - def _legacy_parse_statement(statement): + def _legacy_parse_statement(statement: Dict[str, Any]) -> None: # type: ignore[misc] assert statement["Effect"] in VALID_EFFECTS # case-sensitive matching if "Condition" in statement: for condition_key, condition_value in statement["Condition"].items(): @@ -422,7 +422,7 @@ class BaseIAMPolicyValidator: ) @staticmethod - def _legacy_parse_resource_like(statement, key): + def _legacy_parse_resource_like(statement: Dict[str, Any], key: str) -> None: # type: ignore[misc] if isinstance(statement[key], str): if statement[key] != "*": assert statement[key].count(":") >= 5 or "::" not in statement[key] @@ -434,7 +434,7 @@ class BaseIAMPolicyValidator: assert resource[2] != "" @staticmethod - def _legacy_parse_condition(condition_key, condition_value): + def _legacy_parse_condition(condition_key: str, condition_value: Dict[str, Any]) -> None: # type: ignore[misc] stripped_condition_key = IAMPolicyDocumentValidator._strip_condition_key( condition_key ) @@ -452,7 +452,7 @@ class BaseIAMPolicyValidator: ) @staticmethod - def _legacy_parse_date_condition_value(date_condition_value): + def _legacy_parse_date_condition_value(date_condition_value: str) -> None: if "t" in date_condition_value.lower() or "-" in date_condition_value: IAMPolicyDocumentValidator._validate_iso_8601_datetime( date_condition_value.lower() @@ -461,7 +461,7 @@ class BaseIAMPolicyValidator: assert 0 <= int(date_condition_value) <= 9223372036854775807 @staticmethod - def _validate_iso_8601_datetime(datetime): + def _validate_iso_8601_datetime(datetime: str) -> None: datetime_parts = datetime.partition("t") negative_year = datetime_parts[0].startswith("-") date_parts = ( @@ -525,10 +525,10 @@ class IAMPolicyDocumentValidator(BaseIAMPolicyValidator): class IAMTrustPolicyDocumentValidator(BaseIAMPolicyValidator): - def __init__(self, policy_document): + def __init__(self, policy_document: str): super().__init__(policy_document) - def validate(self): + def validate(self) -> None: super().validate() try: for statement in self._statements: @@ -551,12 +551,12 @@ class IAMTrustPolicyDocumentValidator(BaseIAMPolicyValidator): except Exception: raise MalformedPolicyDocument("Has prohibited field Resource.") - def _validate_resource_not_exist(self): + def _validate_resource_not_exist(self) -> None: for statement in self._statements: assert "Resource" not in statement and "NotResource" not in statement @staticmethod - def _validate_trust_policy_action(action): + def _validate_trust_policy_action(action: str) -> None: # https://docs.aws.amazon.com/service-authorization/latest/reference/list_awssecuritytokenservice.html assert action in ( "sts:AssumeRole", diff --git a/moto/iam/responses.py b/moto/iam/responses.py index 3a75e0367..ac321c912 100644 --- a/moto/iam/responses.py +++ b/moto/iam/responses.py @@ -1,59 +1,59 @@ from moto.core.responses import BaseResponse -from .models import iam_backends, User +from .models import iam_backends, IAMBackend, User class IamResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="iam") @property - def backend(self): + def backend(self) -> IAMBackend: return iam_backends[self.current_account]["global"] - def attach_role_policy(self): + def attach_role_policy(self) -> str: policy_arn = self._get_param("PolicyArn") role_name = self._get_param("RoleName") self.backend.attach_role_policy(policy_arn, role_name) template = self.response_template(ATTACH_ROLE_POLICY_TEMPLATE) return template.render() - def detach_role_policy(self): + def detach_role_policy(self) -> str: role_name = self._get_param("RoleName") policy_arn = self._get_param("PolicyArn") self.backend.detach_role_policy(policy_arn, role_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DetachRolePolicy") - def attach_group_policy(self): + def attach_group_policy(self) -> str: policy_arn = self._get_param("PolicyArn") group_name = self._get_param("GroupName") self.backend.attach_group_policy(policy_arn, group_name) template = self.response_template(ATTACH_GROUP_POLICY_TEMPLATE) return template.render() - def detach_group_policy(self): + def detach_group_policy(self) -> str: policy_arn = self._get_param("PolicyArn") group_name = self._get_param("GroupName") self.backend.detach_group_policy(policy_arn, group_name) template = self.response_template(DETACH_GROUP_POLICY_TEMPLATE) return template.render() - def attach_user_policy(self): + def attach_user_policy(self) -> str: policy_arn = self._get_param("PolicyArn") user_name = self._get_param("UserName") self.backend.attach_user_policy(policy_arn, user_name) template = self.response_template(ATTACH_USER_POLICY_TEMPLATE) return template.render() - def detach_user_policy(self): + def detach_user_policy(self) -> str: policy_arn = self._get_param("PolicyArn") user_name = self._get_param("UserName") self.backend.detach_user_policy(policy_arn, user_name) template = self.response_template(DETACH_USER_POLICY_TEMPLATE) return template.render() - def create_policy(self): + def create_policy(self) -> str: description = self._get_param("Description") path = self._get_param("Path") policy_document = self._get_param("PolicyDocument") @@ -65,13 +65,13 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_POLICY_TEMPLATE) return template.render(policy=policy) - def get_policy(self): + def get_policy(self) -> str: policy_arn = self._get_param("PolicyArn") policy = self.backend.get_policy(policy_arn) template = self.response_template(GET_POLICY_TEMPLATE) return template.render(policy=policy) - def list_attached_role_policies(self): + def list_attached_role_policies(self) -> str: marker = self._get_param("Marker") max_items = self._get_int_param("MaxItems", 100) path_prefix = self._get_param("PathPrefix", "/") @@ -82,7 +82,7 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_ATTACHED_ROLE_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) - def list_attached_group_policies(self): + def list_attached_group_policies(self) -> str: marker = self._get_param("Marker") max_items = self._get_int_param("MaxItems", 100) path_prefix = self._get_param("PathPrefix", "/") @@ -93,7 +93,7 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_ATTACHED_GROUP_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) - def list_attached_user_policies(self): + def list_attached_user_policies(self) -> str: marker = self._get_param("Marker") max_items = self._get_int_param("MaxItems", 100) path_prefix = self._get_param("PathPrefix", "/") @@ -104,7 +104,7 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_ATTACHED_USER_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) - def list_policies(self): + def list_policies(self) -> str: marker = self._get_param("Marker") max_items = self._get_int_param("MaxItems", 100) only_attached = self._get_bool_param("OnlyAttached", False) @@ -116,7 +116,7 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) - def list_entities_for_policy(self): + def list_entities_for_policy(self) -> str: policy_arn = self._get_param("PolicyArn") # Options 'User'|'Role'|'Group'|'LocalManagedPolicy'|'AWSManagedPolicy @@ -181,14 +181,14 @@ class IamResponse(BaseResponse): roles=entity_roles, users=entity_users, groups=entity_groups ) - def set_default_policy_version(self): + def set_default_policy_version(self) -> str: policy_arn = self._get_param("PolicyArn") version_id = self._get_param("VersionId") self.backend.set_default_policy_version(policy_arn, version_id) template = self.response_template(SET_DEFAULT_POLICY_VERSION_TEMPLATE) return template.render() - def create_role(self): + def create_role(self) -> str: role_name = self._get_param("RoleName") path = self._get_param("Path") assume_role_policy_document = self._get_param("AssumeRolePolicyDocument") @@ -209,26 +209,26 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_ROLE_TEMPLATE) return template.render(role=role) - def get_role(self): + def get_role(self) -> str: role_name = self._get_param("RoleName") role = self.backend.get_role(role_name) template = self.response_template(GET_ROLE_TEMPLATE) return template.render(role=role) - def delete_role(self): + def delete_role(self) -> str: role_name = self._get_param("RoleName") self.backend.delete_role(role_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteRole") - def list_role_policies(self): + def list_role_policies(self) -> str: role_name = self._get_param("RoleName") role_policies_names = self.backend.list_role_policies(role_name) template = self.response_template(LIST_ROLE_POLICIES) return template.render(role_policies=role_policies_names) - def put_role_policy(self): + def put_role_policy(self) -> str: role_name = self._get_param("RoleName") policy_name = self._get_param("PolicyName") policy_document = self._get_param("PolicyDocument") @@ -236,14 +236,14 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutRolePolicy") - def delete_role_policy(self): + def delete_role_policy(self) -> str: role_name = self._get_param("RoleName") policy_name = self._get_param("PolicyName") self.backend.delete_role_policy(role_name, policy_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteRolePolicy") - def get_role_policy(self): + def get_role_policy(self) -> str: role_name = self._get_param("RoleName") policy_name = self._get_param("PolicyName") policy_name, policy_document = self.backend.get_role_policy( @@ -256,21 +256,21 @@ class IamResponse(BaseResponse): policy_document=policy_document, ) - def update_assume_role_policy(self): + def update_assume_role_policy(self) -> str: role_name = self._get_param("RoleName") policy_document = self._get_param("PolicyDocument") self.backend.update_assume_role_policy(role_name, policy_document) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="UpdateAssumeRolePolicy") - def update_role_description(self): + def update_role_description(self) -> str: role_name = self._get_param("RoleName") description = self._get_param("Description") role = self.backend.update_role_description(role_name, description) template = self.response_template(UPDATE_ROLE_DESCRIPTION_TEMPLATE) return template.render(role=role) - def update_role(self): + def update_role(self) -> str: role_name = self._get_param("RoleName") description = self._get_param("Description") max_session_duration = self._get_param("MaxSessionDuration", 3600) @@ -278,20 +278,20 @@ class IamResponse(BaseResponse): template = self.response_template(UPDATE_ROLE_TEMPLATE) return template.render(role=role) - def put_role_permissions_boundary(self): + def put_role_permissions_boundary(self) -> str: permissions_boundary = self._get_param("PermissionsBoundary") role_name = self._get_param("RoleName") self.backend.put_role_permissions_boundary(role_name, permissions_boundary) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutRolePermissionsBoundary") - def delete_role_permissions_boundary(self): + def delete_role_permissions_boundary(self) -> str: role_name = self._get_param("RoleName") self.backend.delete_role_permissions_boundary(role_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteRolePermissionsBoundary") - def create_policy_version(self): + def create_policy_version(self) -> str: policy_arn = self._get_param("PolicyArn") policy_document = self._get_param("PolicyDocument") set_as_default = self._get_param("SetAsDefault") @@ -301,21 +301,21 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_POLICY_VERSION_TEMPLATE) return template.render(policy_version=policy_version) - def get_policy_version(self): + def get_policy_version(self) -> str: policy_arn = self._get_param("PolicyArn") version_id = self._get_param("VersionId") policy_version = self.backend.get_policy_version(policy_arn, version_id) template = self.response_template(GET_POLICY_VERSION_TEMPLATE) return template.render(policy_version=policy_version) - def list_policy_versions(self): + def list_policy_versions(self) -> str: policy_arn = self._get_param("PolicyArn") policy_versions = self.backend.list_policy_versions(policy_arn) template = self.response_template(LIST_POLICY_VERSIONS_TEMPLATE) return template.render(policy_versions=policy_versions) - def list_policy_tags(self): + def list_policy_tags(self) -> str: policy_arn = self._get_param("PolicyArn") marker = self._get_param("Marker") max_items = self._get_param("MaxItems", 100) @@ -325,7 +325,7 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_POLICY_TAG_TEMPLATE) return template.render(tags=tags, marker=marker) - def tag_policy(self): + def tag_policy(self) -> str: policy_arn = self._get_param("PolicyArn") tags = self._get_multi_param("Tags.member") @@ -334,7 +334,7 @@ class IamResponse(BaseResponse): template = self.response_template(TAG_POLICY_TEMPLATE) return template.render() - def untag_policy(self): + def untag_policy(self) -> str: policy_arn = self._get_param("PolicyArn") tag_keys = self._get_multi_param("TagKeys.member") @@ -343,7 +343,7 @@ class IamResponse(BaseResponse): template = self.response_template(UNTAG_POLICY_TEMPLATE) return template.render() - def delete_policy_version(self): + def delete_policy_version(self) -> str: policy_arn = self._get_param("PolicyArn") version_id = self._get_param("VersionId") @@ -351,7 +351,7 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeletePolicyVersion") - def create_instance_profile(self): + def create_instance_profile(self) -> str: profile_name = self._get_param("InstanceProfileName") path = self._get_param("Path", "/") tags = self._get_multi_param("Tags.member") @@ -362,21 +362,21 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_INSTANCE_PROFILE_TEMPLATE) return template.render(profile=profile) - def delete_instance_profile(self): + def delete_instance_profile(self) -> str: profile_name = self._get_param("InstanceProfileName") - profile = self.backend.delete_instance_profile(profile_name) + self.backend.delete_instance_profile(profile_name) template = self.response_template(DELETE_INSTANCE_PROFILE_TEMPLATE) - return template.render(profile=profile) + return template.render() - def get_instance_profile(self): + def get_instance_profile(self) -> str: profile_name = self._get_param("InstanceProfileName") profile = self.backend.get_instance_profile(profile_name) template = self.response_template(GET_INSTANCE_PROFILE_TEMPLATE) return template.render(profile=profile) - def add_role_to_instance_profile(self): + def add_role_to_instance_profile(self) -> str: profile_name = self._get_param("InstanceProfileName") role_name = self._get_param("RoleName") @@ -384,7 +384,7 @@ class IamResponse(BaseResponse): template = self.response_template(ADD_ROLE_TO_INSTANCE_PROFILE_TEMPLATE) return template.render() - def remove_role_from_instance_profile(self): + def remove_role_from_instance_profile(self) -> str: profile_name = self._get_param("InstanceProfileName") role_name = self._get_param("RoleName") @@ -392,7 +392,7 @@ class IamResponse(BaseResponse): template = self.response_template(REMOVE_ROLE_FROM_INSTANCE_PROFILE_TEMPLATE) return template.render() - def list_roles(self): + def list_roles(self) -> str: path_prefix = self._get_param("PathPrefix", "/") marker = self._get_param("Marker", "0") max_items = self._get_param("MaxItems", 100) @@ -401,20 +401,20 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_ROLES_TEMPLATE) return template.render(roles=roles, marker=marker) - def list_instance_profiles(self): + def list_instance_profiles(self) -> str: profiles = self.backend.get_instance_profiles() template = self.response_template(LIST_INSTANCE_PROFILES_TEMPLATE) return template.render(instance_profiles=profiles) - def list_instance_profiles_for_role(self): + def list_instance_profiles_for_role(self) -> str: role_name = self._get_param("RoleName") profiles = self.backend.get_instance_profiles_for_role(role_name=role_name) template = self.response_template(LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE) return template.render(instance_profiles=profiles) - def upload_server_certificate(self): + def upload_server_certificate(self) -> str: cert_name = self._get_param("ServerCertificateName") cert_body = self._get_param("CertificateBody") path = self._get_param("Path") @@ -427,24 +427,24 @@ class IamResponse(BaseResponse): template = self.response_template(UPLOAD_CERT_TEMPLATE) return template.render(certificate=cert) - def list_server_certificates(self): + def list_server_certificates(self) -> str: certs = self.backend.list_server_certificates() template = self.response_template(LIST_SERVER_CERTIFICATES_TEMPLATE) return template.render(server_certificates=certs) - def get_server_certificate(self): + def get_server_certificate(self) -> str: cert_name = self._get_param("ServerCertificateName") cert = self.backend.get_server_certificate(cert_name) template = self.response_template(GET_SERVER_CERTIFICATE_TEMPLATE) return template.render(certificate=cert) - def delete_server_certificate(self): + def delete_server_certificate(self) -> str: cert_name = self._get_param("ServerCertificateName") self.backend.delete_server_certificate(cert_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteServerCertificate") - def create_group(self): + def create_group(self) -> str: group_name = self._get_param("GroupName") path = self._get_param("Path", "/") @@ -452,26 +452,26 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_GROUP_TEMPLATE) return template.render(group=group) - def get_group(self): + def get_group(self) -> str: group_name = self._get_param("GroupName") group = self.backend.get_group(group_name) template = self.response_template(GET_GROUP_TEMPLATE) return template.render(group=group) - def list_groups(self): + def list_groups(self) -> str: groups = self.backend.list_groups() template = self.response_template(LIST_GROUPS_TEMPLATE) return template.render(groups=groups) - def list_groups_for_user(self): + def list_groups_for_user(self) -> str: user_name = self._get_param("UserName") groups = self.backend.get_groups_for_user(user_name) template = self.response_template(LIST_GROUPS_FOR_USER_TEMPLATE) return template.render(groups=groups) - def put_group_policy(self): + def put_group_policy(self) -> str: group_name = self._get_param("GroupName") policy_name = self._get_param("PolicyName") policy_document = self._get_param("PolicyDocument") @@ -479,7 +479,7 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutGroupPolicy") - def list_group_policies(self): + def list_group_policies(self) -> str: group_name = self._get_param("GroupName") marker = self._get_param("Marker") policies = self.backend.list_group_policies(group_name) @@ -488,27 +488,27 @@ class IamResponse(BaseResponse): name="ListGroupPoliciesResponse", policies=policies, marker=marker ) - def get_group_policy(self): + def get_group_policy(self) -> str: group_name = self._get_param("GroupName") policy_name = self._get_param("PolicyName") policy_result = self.backend.get_group_policy(group_name, policy_name) template = self.response_template(GET_GROUP_POLICY_TEMPLATE) return template.render(name="GetGroupPolicyResponse", **policy_result) - def delete_group_policy(self): + def delete_group_policy(self) -> str: group_name = self._get_param("GroupName") policy_name = self._get_param("PolicyName") self.backend.delete_group_policy(group_name, policy_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteGroupPolicy") - def delete_group(self): + def delete_group(self) -> str: group_name = self._get_param("GroupName") self.backend.delete_group(group_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteGroup") - def update_group(self): + def update_group(self) -> str: group_name = self._get_param("GroupName") new_group_name = self._get_param("NewGroupName") new_path = self._get_param("NewPath") @@ -516,7 +516,7 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="UpdateGroup") - def create_user(self): + def create_user(self) -> str: user_name = self._get_param("UserName") path = self._get_param("Path") tags = self._get_multi_param("Tags.member") @@ -524,7 +524,7 @@ class IamResponse(BaseResponse): template = self.response_template(USER_TEMPLATE) return template.render(action="Create", user=user, tags=user_tags["Tags"]) - def get_user(self): + def get_user(self) -> str: user_name = self._get_param("UserName") if not user_name: access_key_id = self.get_access_key() @@ -537,7 +537,7 @@ class IamResponse(BaseResponse): template = self.response_template(USER_TEMPLATE) return template.render(action="Get", user=user, tags=tags) - def list_users(self): + def list_users(self) -> str: path_prefix = self._get_param("PathPrefix") marker = self._get_param("Marker") max_items = self._get_param("MaxItems") @@ -545,7 +545,7 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_USERS_TEMPLATE) return template.render(action="List", users=users, isTruncated=False) - def update_user(self): + def update_user(self) -> str: user_name = self._get_param("UserName") new_path = self._get_param("NewPath") new_user_name = self._get_param("NewUserName") @@ -557,7 +557,7 @@ class IamResponse(BaseResponse): template = self.response_template(USER_TEMPLATE) return template.render(action="Update", user=user) - def create_login_profile(self): + def create_login_profile(self) -> str: user_name = self._get_param("UserName") password = self._get_param("Password") user = self.backend.create_login_profile(user_name, password) @@ -565,14 +565,14 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) - def get_login_profile(self): + def get_login_profile(self) -> str: user_name = self._get_param("UserName") user = self.backend.get_login_profile(user_name) template = self.response_template(GET_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) - def update_login_profile(self): + def update_login_profile(self) -> str: user_name = self._get_param("UserName") password = self._get_param("Password") password_reset_required = self._get_param("PasswordResetRequired") @@ -583,7 +583,7 @@ class IamResponse(BaseResponse): template = self.response_template(UPDATE_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) - def add_user_to_group(self): + def add_user_to_group(self) -> str: group_name = self._get_param("GroupName") user_name = self._get_param("UserName") @@ -591,7 +591,7 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="AddUserToGroup") - def remove_user_from_group(self): + def remove_user_from_group(self) -> str: group_name = self._get_param("GroupName") user_name = self._get_param("UserName") @@ -599,7 +599,7 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="RemoveUserFromGroup") - def get_user_policy(self): + def get_user_policy(self) -> str: user_name = self._get_param("UserName") policy_name = self._get_param("PolicyName") @@ -611,19 +611,19 @@ class IamResponse(BaseResponse): policy_document=policy_document.get("policy_document"), ) - def list_user_policies(self): + def list_user_policies(self) -> str: user_name = self._get_param("UserName") policies = self.backend.list_user_policies(user_name) template = self.response_template(LIST_USER_POLICIES_TEMPLATE) return template.render(policies=policies) - def list_user_tags(self): + def list_user_tags(self) -> str: user_name = self._get_param("UserName") tags = self.backend.list_user_tags(user_name) template = self.response_template(LIST_USER_TAGS_TEMPLATE) return template.render(user_tags=tags["Tags"]) - def put_user_policy(self): + def put_user_policy(self) -> str: user_name = self._get_param("UserName") policy_name = self._get_param("PolicyName") policy_document = self._get_param("PolicyDocument") @@ -632,7 +632,7 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutUserPolicy") - def delete_user_policy(self): + def delete_user_policy(self) -> str: user_name = self._get_param("UserName") policy_name = self._get_param("PolicyName") @@ -640,7 +640,7 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteUserPolicy") - def create_access_key(self): + def create_access_key(self) -> str: user_name = self._get_param("UserName") if not user_name: access_key_id = self.get_access_key() @@ -651,7 +651,7 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_ACCESS_KEY_TEMPLATE) return template.render(key=key) - def update_access_key(self): + def update_access_key(self) -> str: user_name = self._get_param("UserName") access_key_id = self._get_param("AccessKeyId") status = self._get_param("Status") @@ -663,7 +663,7 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="UpdateAccessKey") - def get_access_key_last_used(self): + def get_access_key_last_used(self) -> str: access_key_id = self._get_param("AccessKeyId") last_used_response = self.backend.get_access_key_last_used(access_key_id) template = self.response_template(GET_ACCESS_KEY_LAST_USED_TEMPLATE) @@ -672,7 +672,7 @@ class IamResponse(BaseResponse): last_used=last_used_response["last_used"], ) - def list_access_keys(self): + def list_access_keys(self) -> str: user_name = self._get_param("UserName") if not user_name: access_key_id = self.get_access_key() @@ -683,7 +683,7 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_ACCESS_KEYS_TEMPLATE) return template.render(user_name=user_name, keys=keys) - def delete_access_key(self): + def delete_access_key(self) -> str: user_name = self._get_param("UserName") access_key_id = self._get_param("AccessKeyId") if not user_name: @@ -694,7 +694,7 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteAccessKey") - def upload_ssh_public_key(self): + def upload_ssh_public_key(self) -> str: user_name = self._get_param("UserName") ssh_public_key_body = self._get_param("SSHPublicKeyBody") @@ -702,7 +702,7 @@ class IamResponse(BaseResponse): template = self.response_template(UPLOAD_SSH_PUBLIC_KEY_TEMPLATE) return template.render(key=key) - def get_ssh_public_key(self): + def get_ssh_public_key(self) -> str: user_name = self._get_param("UserName") ssh_public_key_id = self._get_param("SSHPublicKeyId") @@ -710,14 +710,14 @@ class IamResponse(BaseResponse): template = self.response_template(GET_SSH_PUBLIC_KEY_TEMPLATE) return template.render(key=key) - def list_ssh_public_keys(self): + def list_ssh_public_keys(self) -> str: user_name = self._get_param("UserName") keys = self.backend.get_all_ssh_public_keys(user_name) template = self.response_template(LIST_SSH_PUBLIC_KEYS_TEMPLATE) return template.render(keys=keys) - def update_ssh_public_key(self): + def update_ssh_public_key(self) -> str: user_name = self._get_param("UserName") ssh_public_key_id = self._get_param("SSHPublicKeyId") status = self._get_param("Status") @@ -726,7 +726,7 @@ class IamResponse(BaseResponse): template = self.response_template(UPDATE_SSH_PUBLIC_KEY_TEMPLATE) return template.render() - def delete_ssh_public_key(self): + def delete_ssh_public_key(self) -> str: user_name = self._get_param("UserName") ssh_public_key_id = self._get_param("SSHPublicKeyId") @@ -734,7 +734,7 @@ class IamResponse(BaseResponse): template = self.response_template(DELETE_SSH_PUBLIC_KEY_TEMPLATE) return template.render() - def deactivate_mfa_device(self): + def deactivate_mfa_device(self) -> str: user_name = self._get_param("UserName") serial_number = self._get_param("SerialNumber") @@ -742,7 +742,7 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeactivateMFADevice") - def enable_mfa_device(self): + def enable_mfa_device(self) -> str: user_name = self._get_param("UserName") serial_number = self._get_param("SerialNumber") authentication_code_1 = self._get_param("AuthenticationCode1") @@ -754,13 +754,13 @@ class IamResponse(BaseResponse): template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="EnableMFADevice") - def list_mfa_devices(self): + def list_mfa_devices(self) -> str: user_name = self._get_param("UserName") devices = self.backend.list_mfa_devices(user_name) template = self.response_template(LIST_MFA_DEVICES_TEMPLATE) return template.render(user_name=user_name, devices=devices) - def create_virtual_mfa_device(self): + def create_virtual_mfa_device(self) -> str: path = self._get_param("Path") virtual_mfa_device_name = self._get_param("VirtualMFADeviceName") @@ -771,7 +771,7 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_VIRTUAL_MFA_DEVICE_TEMPLATE) return template.render(device=virtual_mfa_device) - def delete_virtual_mfa_device(self): + def delete_virtual_mfa_device(self) -> str: serial_number = self._get_param("SerialNumber") self.backend.delete_virtual_mfa_device(serial_number) @@ -779,7 +779,7 @@ class IamResponse(BaseResponse): template = self.response_template(DELETE_VIRTUAL_MFA_DEVICE_TEMPLATE) return template.render() - def list_virtual_mfa_devices(self): + def list_virtual_mfa_devices(self) -> str: assignment_status = self._get_param("AssignmentStatus", "Any") marker = self._get_param("Marker") max_items = self._get_param("MaxItems", 100) @@ -791,25 +791,25 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_VIRTUAL_MFA_DEVICES_TEMPLATE) return template.render(devices=devices, marker=marker) - def delete_user(self): + def delete_user(self) -> str: user_name = self._get_param("UserName") self.backend.delete_user(user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteUser") - def delete_policy(self): + def delete_policy(self) -> str: policy_arn = self._get_param("PolicyArn") self.backend.delete_policy(policy_arn) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeletePolicy") - def delete_login_profile(self): + def delete_login_profile(self) -> str: user_name = self._get_param("UserName") self.backend.delete_login_profile(user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteLoginProfile") - def generate_credential_report(self): + def generate_credential_report(self) -> str: if self.backend.report_generated(): template = self.response_template(CREDENTIAL_REPORT_GENERATED) else: @@ -817,28 +817,28 @@ class IamResponse(BaseResponse): self.backend.generate_report() return template.render() - def get_credential_report(self): + def get_credential_report(self) -> str: report = self.backend.get_credential_report() template = self.response_template(CREDENTIAL_REPORT) return template.render(report=report) - def list_account_aliases(self): + def list_account_aliases(self) -> str: aliases = self.backend.list_account_aliases() template = self.response_template(LIST_ACCOUNT_ALIASES_TEMPLATE) return template.render(aliases=aliases) - def create_account_alias(self): + def create_account_alias(self) -> str: alias = self._get_param("AccountAlias") self.backend.create_account_alias(alias) template = self.response_template(CREATE_ACCOUNT_ALIAS_TEMPLATE) return template.render() - def delete_account_alias(self): + def delete_account_alias(self) -> str: self.backend.delete_account_alias() template = self.response_template(DELETE_ACCOUNT_ALIAS_TEMPLATE) return template.render() - def get_account_authorization_details(self): + def get_account_authorization_details(self) -> str: filter_param = self._get_multi_param("Filter.member") account_details = self.backend.get_account_authorization_details(filter_param) template = self.response_template(GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE) @@ -852,7 +852,7 @@ class IamResponse(BaseResponse): list_tags_for_user=self.backend.list_user_tags, ) - def create_saml_provider(self): + def create_saml_provider(self) -> str: saml_provider_name = self._get_param("Name") saml_metadata_document = self._get_param("SAMLMetadataDocument") saml_provider = self.backend.create_saml_provider( @@ -862,7 +862,7 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_SAML_PROVIDER_TEMPLATE) return template.render(saml_provider=saml_provider) - def update_saml_provider(self): + def update_saml_provider(self) -> str: saml_provider_arn = self._get_param("SAMLProviderArn") saml_metadata_document = self._get_param("SAMLMetadataDocument") saml_provider = self.backend.update_saml_provider( @@ -872,27 +872,27 @@ class IamResponse(BaseResponse): template = self.response_template(UPDATE_SAML_PROVIDER_TEMPLATE) return template.render(saml_provider=saml_provider) - def delete_saml_provider(self): + def delete_saml_provider(self) -> str: saml_provider_arn = self._get_param("SAMLProviderArn") self.backend.delete_saml_provider(saml_provider_arn) template = self.response_template(DELETE_SAML_PROVIDER_TEMPLATE) return template.render() - def list_saml_providers(self): + def list_saml_providers(self) -> str: saml_providers = self.backend.list_saml_providers() template = self.response_template(LIST_SAML_PROVIDERS_TEMPLATE) return template.render(saml_providers=saml_providers) - def get_saml_provider(self): + def get_saml_provider(self) -> str: saml_provider_arn = self._get_param("SAMLProviderArn") saml_provider = self.backend.get_saml_provider(saml_provider_arn) template = self.response_template(GET_SAML_PROVIDER_TEMPLATE) return template.render(saml_provider=saml_provider) - def upload_signing_certificate(self): + def upload_signing_certificate(self) -> str: user_name = self._get_param("UserName") cert_body = self._get_param("CertificateBody") @@ -900,7 +900,7 @@ class IamResponse(BaseResponse): template = self.response_template(UPLOAD_SIGNING_CERTIFICATE_TEMPLATE) return template.render(cert=cert) - def update_signing_certificate(self): + def update_signing_certificate(self) -> str: user_name = self._get_param("UserName") cert_id = self._get_param("CertificateId") status = self._get_param("Status") @@ -909,7 +909,7 @@ class IamResponse(BaseResponse): template = self.response_template(UPDATE_SIGNING_CERTIFICATE_TEMPLATE) return template.render() - def delete_signing_certificate(self): + def delete_signing_certificate(self) -> str: user_name = self._get_param("UserName") cert_id = self._get_param("CertificateId") @@ -917,14 +917,14 @@ class IamResponse(BaseResponse): template = self.response_template(DELETE_SIGNING_CERTIFICATE_TEMPLATE) return template.render() - def list_signing_certificates(self): + def list_signing_certificates(self) -> str: user_name = self._get_param("UserName") certs = self.backend.list_signing_certificates(user_name) template = self.response_template(LIST_SIGNING_CERTIFICATES_TEMPLATE) return template.render(user_name=user_name, certificates=certs) - def list_role_tags(self): + def list_role_tags(self) -> str: role_name = self._get_param("RoleName") marker = self._get_param("Marker") max_items = self._get_param("MaxItems", 100) @@ -934,7 +934,7 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_ROLE_TAG_TEMPLATE) return template.render(tags=tags, marker=marker) - def tag_role(self): + def tag_role(self) -> str: role_name = self._get_param("RoleName") tags = self._get_multi_param("Tags.member") @@ -943,7 +943,7 @@ class IamResponse(BaseResponse): template = self.response_template(TAG_ROLE_TEMPLATE) return template.render() - def untag_role(self): + def untag_role(self) -> str: role_name = self._get_param("RoleName") tag_keys = self._get_multi_param("TagKeys.member") @@ -952,7 +952,7 @@ class IamResponse(BaseResponse): template = self.response_template(UNTAG_ROLE_TEMPLATE) return template.render() - def create_open_id_connect_provider(self): + def create_open_id_connect_provider(self) -> str: open_id_provider_url = self._get_param("Url") thumbprint_list = self._get_multi_param("ThumbprintList.member") client_id_list = self._get_multi_param("ClientIDList.member") @@ -965,7 +965,7 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_OPEN_ID_CONNECT_PROVIDER_TEMPLATE) return template.render(open_id_provider=open_id_provider) - def update_open_id_connect_provider_thumbprint(self): + def update_open_id_connect_provider_thumbprint(self) -> str: open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") thumbprint_list = self._get_multi_param("ThumbprintList.member") @@ -976,7 +976,7 @@ class IamResponse(BaseResponse): template = self.response_template(UPDATE_OPEN_ID_CONNECT_PROVIDER_THUMBPRINT) return template.render() - def tag_open_id_connect_provider(self): + def tag_open_id_connect_provider(self) -> str: open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") tags = self._get_multi_param("Tags.member") @@ -985,7 +985,7 @@ class IamResponse(BaseResponse): template = self.response_template(TAG_OPEN_ID_CONNECT_PROVIDER) return template.render() - def untag_open_id_connect_provider(self): + def untag_open_id_connect_provider(self) -> str: open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") tag_keys = self._get_multi_param("TagKeys.member") @@ -994,7 +994,7 @@ class IamResponse(BaseResponse): template = self.response_template(UNTAG_OPEN_ID_CONNECT_PROVIDER) return template.render() - def list_open_id_connect_provider_tags(self): + def list_open_id_connect_provider_tags(self) -> str: open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") marker = self._get_param("Marker") max_items = self._get_param("MaxItems", 100) @@ -1004,7 +1004,7 @@ class IamResponse(BaseResponse): template = self.response_template(LIST_OPEN_ID_CONNECT_PROVIDER_TAGS) return template.render(tags=tags, marker=marker) - def delete_open_id_connect_provider(self): + def delete_open_id_connect_provider(self) -> str: open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") self.backend.delete_open_id_connect_provider(open_id_provider_arn) @@ -1012,7 +1012,7 @@ class IamResponse(BaseResponse): template = self.response_template(DELETE_OPEN_ID_CONNECT_PROVIDER_TEMPLATE) return template.render() - def get_open_id_connect_provider(self): + def get_open_id_connect_provider(self) -> str: open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") open_id_provider = self.backend.get_open_id_connect_provider( @@ -1022,13 +1022,13 @@ class IamResponse(BaseResponse): template = self.response_template(GET_OPEN_ID_CONNECT_PROVIDER_TEMPLATE) return template.render(open_id_provider=open_id_provider) - def list_open_id_connect_providers(self): + def list_open_id_connect_providers(self) -> str: open_id_provider_arns = self.backend.list_open_id_connect_providers() template = self.response_template(LIST_OPEN_ID_CONNECT_PROVIDERS_TEMPLATE) return template.render(open_id_provider_arns=open_id_provider_arns) - def update_account_password_policy(self): + def update_account_password_policy(self) -> str: allow_change_password = self._get_bool_param( "AllowUsersToChangePassword", False ) @@ -1060,25 +1060,25 @@ class IamResponse(BaseResponse): template = self.response_template(UPDATE_ACCOUNT_PASSWORD_POLICY_TEMPLATE) return template.render() - def get_account_password_policy(self): + def get_account_password_policy(self) -> str: account_password_policy = self.backend.get_account_password_policy() template = self.response_template(GET_ACCOUNT_PASSWORD_POLICY_TEMPLATE) return template.render(password_policy=account_password_policy) - def delete_account_password_policy(self): + def delete_account_password_policy(self) -> str: self.backend.delete_account_password_policy() template = self.response_template(DELETE_ACCOUNT_PASSWORD_POLICY_TEMPLATE) return template.render() - def get_account_summary(self): + def get_account_summary(self) -> str: account_summary = self.backend.get_account_summary() template = self.response_template(GET_ACCOUNT_SUMMARY_TEMPLATE) return template.render(summary_map=account_summary.summary_map) - def tag_user(self): + def tag_user(self) -> str: name = self._get_param("UserName") tags = self._get_multi_param("Tags.member") @@ -1087,7 +1087,7 @@ class IamResponse(BaseResponse): template = self.response_template(TAG_USER_TEMPLATE) return template.render() - def untag_user(self): + def untag_user(self) -> str: name = self._get_param("UserName") tag_keys = self._get_multi_param("TagKeys.member") @@ -1096,7 +1096,7 @@ class IamResponse(BaseResponse): template = self.response_template(UNTAG_USER_TEMPLATE) return template.render() - def create_service_linked_role(self): + def create_service_linked_role(self) -> str: service_name = self._get_param("AWSServiceName") description = self._get_param("Description") suffix = self._get_param("CustomSuffix") @@ -1108,7 +1108,7 @@ class IamResponse(BaseResponse): template = self.response_template(CREATE_SERVICE_LINKED_ROLE_TEMPLATE) return template.render(role=role) - def delete_service_linked_role(self): + def delete_service_linked_role(self) -> str: role_name = self._get_param("RoleName") deletion_task_id = self.backend.delete_service_linked_role(role_name) @@ -1116,7 +1116,7 @@ class IamResponse(BaseResponse): template = self.response_template(DELETE_SERVICE_LINKED_ROLE_TEMPLATE) return template.render(deletion_task_id=deletion_task_id) - def get_service_linked_role_deletion_status(self): + def get_service_linked_role_deletion_status(self) -> str: self.backend.get_service_linked_role_deletion_status() template = self.response_template( diff --git a/moto/iam/utils.py b/moto/iam/utils.py index e467033e1..db19115ac 100644 --- a/moto/iam/utils.py +++ b/moto/iam/utils.py @@ -6,13 +6,13 @@ AWS_ROLE_ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" ACCOUNT_OFFSET = 549755813888 # int.from_bytes(base64.b32decode(b"QAAAAAAA"), byteorder="big"), start value -def _random_uppercase_or_digit_sequence(length): +def _random_uppercase_or_digit_sequence(length: int) -> str: return "".join(str(random.choice(AWS_ROLE_ALPHABET)) for _ in range(length)) def generate_access_key_id_from_account_id( account_id: str, prefix: str, total_length: int = 20 -): +) -> str: """ Generates a key id (e.g. access key id) for the given account id and prefix @@ -21,13 +21,13 @@ def generate_access_key_id_from_account_id( :param total_length: Total length of the access key (e.g. 20 for temp access keys, 21 for role ids) :return: Generated id """ - account_id = int(account_id) - id_with_offset = account_id // 2 + ACCOUNT_OFFSET + account_id_nr = int(account_id) + id_with_offset = account_id_nr // 2 + ACCOUNT_OFFSET account_bytes = int.to_bytes(id_with_offset, byteorder="big", length=5) account_part = base64.b32encode(account_bytes).decode("utf-8") middle_char = ( random.choice(AWS_ROLE_ALPHABET[16:]) - if account_id % 2 + if account_id_nr % 2 else random.choice(AWS_ROLE_ALPHABET[:16]) ) semi_fixed_part = prefix + account_part + middle_char @@ -36,14 +36,14 @@ def generate_access_key_id_from_account_id( ) -def random_alphanumeric(length): +def random_alphanumeric(length: int) -> str: return "".join( str(random.choice(string.ascii_letters + string.digits + "+" + "/")) for _ in range(length) ) -def random_resource_id(size=20): +def random_resource_id(size: int = 20) -> str: chars = list(range(10)) + list(string.ascii_lowercase) return "".join(str(random.choice(chars)) for x in range(size)) @@ -55,13 +55,13 @@ def random_role_id(account_id: str) -> str: ) -def random_access_key(): +def random_access_key() -> str: return "".join( str(random.choice(string.ascii_uppercase + string.digits)) for _ in range(16) ) -def random_policy_id(): +def random_policy_id() -> str: return "A" + "".join( random.choice(string.ascii_uppercase + string.digits) for _ in range(20) ) diff --git a/moto/s3/exceptions.py b/moto/s3/exceptions.py index 3fab7491a..7345f929f 100644 --- a/moto/s3/exceptions.py +++ b/moto/s3/exceptions.py @@ -55,12 +55,12 @@ class InvalidArgumentError(S3ClientError): class AccessForbidden(S3ClientError): code = 403 - def __init__(self, msg): + def __init__(self, msg: str): super().__init__("AccessForbidden", msg) class BucketError(S3ClientError): - def __init__(self, *args, **kwargs): + def __init__(self, *args: str, **kwargs: str): kwargs.setdefault("template", "bucket_error") self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME super().__init__(*args, **kwargs) @@ -335,45 +335,40 @@ class DuplicateTagKeys(S3ClientError): class S3AccessDeniedError(S3ClientError): code = 403 - def __init__(self, *args, **kwargs): + def __init__(self, *args: str, **kwargs: str): super().__init__("AccessDenied", "Access Denied", *args, **kwargs) class BucketAccessDeniedError(BucketError): code = 403 - def __init__(self, *args, **kwargs): + def __init__(self, *args: str, **kwargs: str): super().__init__("AccessDenied", "Access Denied", *args, **kwargs) class S3InvalidTokenError(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( - "InvalidToken", - "The provided token is malformed or otherwise invalid.", - *args, - **kwargs, + "InvalidToken", "The provided token is malformed or otherwise invalid." ) class S3AclAndGrantError(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "InvalidRequest", "Specifying both Canned ACLs and Header Grants is not allowed", - *args, - **kwargs, ) class BucketInvalidTokenError(BucketError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self, *args: str, **kwargs: str): super().__init__( "InvalidToken", "The provided token is malformed or otherwise invalid.", @@ -385,19 +380,17 @@ class BucketInvalidTokenError(BucketError): class S3InvalidAccessKeyIdError(S3ClientError): code = 403 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "InvalidAccessKeyId", "The AWS Access Key Id you provided does not exist in our records.", - *args, - **kwargs, ) class BucketInvalidAccessKeyIdError(S3ClientError): code = 403 - def __init__(self, *args, **kwargs): + def __init__(self, *args: str, **kwargs: str): super().__init__( "InvalidAccessKeyId", "The AWS Access Key Id you provided does not exist in our records.", @@ -409,19 +402,17 @@ class BucketInvalidAccessKeyIdError(S3ClientError): class S3SignatureDoesNotMatchError(S3ClientError): code = 403 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided. Check your key and signing method.", - *args, - **kwargs, ) class BucketSignatureDoesNotMatchError(S3ClientError): code = 403 - def __init__(self, *args, **kwargs): + def __init__(self, *args: str, **kwargs: str): super().__init__( "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided. Check your key and signing method.", diff --git a/moto/sts/models.py b/moto/sts/models.py index 5af678fdf..950f0df44 100644 --- a/moto/sts/models.py +++ b/moto/sts/models.py @@ -10,7 +10,6 @@ from moto.sts.utils import ( DEFAULT_STS_SESSION_DURATION, random_assumed_role_id, ) -from typing import Mapping class Token(BaseModel): @@ -185,6 +184,6 @@ class STSBackend(BaseBackend): return account_id, iam_backend.create_temp_access_key() -sts_backends: Mapping[str, STSBackend] = BackendDict( +sts_backends = BackendDict( STSBackend, "sts", use_boto3_regions=False, additional_regions=["global"] ) diff --git a/setup.cfg b/setup.cfg index 917dd704b..f4f1288a4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -230,7 +230,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/moto_api,moto/neptune +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/iam,moto/moto_api,moto/neptune show_column_numbers=True show_error_codes = True disable_error_code=abstract