Techdebt: MyPy SecretsManager (#6244)

This commit is contained in:
Bert Blommers 2023-04-22 15:39:48 +00:00 committed by GitHub
parent ce3234a6a9
commit f54f4a666f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 219 additions and 187 deletions

View File

@ -6,13 +6,13 @@ class SecretsManagerClientError(JsonRESTError):
class ResourceNotFoundException(SecretsManagerClientError): class ResourceNotFoundException(SecretsManagerClientError):
def __init__(self, message): def __init__(self, message: str):
self.code = 404 self.code = 404
super().__init__("ResourceNotFoundException", message) super().__init__("ResourceNotFoundException", message)
class SecretNotFoundException(SecretsManagerClientError): class SecretNotFoundException(SecretsManagerClientError):
def __init__(self): def __init__(self) -> None:
self.code = 404 self.code = 404
super().__init__( super().__init__(
"ResourceNotFoundException", "ResourceNotFoundException",
@ -21,7 +21,7 @@ class SecretNotFoundException(SecretsManagerClientError):
class SecretHasNoValueException(SecretsManagerClientError): class SecretHasNoValueException(SecretsManagerClientError):
def __init__(self, version_stage): def __init__(self, version_stage: str):
self.code = 404 self.code = 404
super().__init__( super().__init__(
"ResourceNotFoundException", "ResourceNotFoundException",
@ -31,7 +31,7 @@ class SecretHasNoValueException(SecretsManagerClientError):
class SecretStageVersionMismatchException(SecretsManagerClientError): class SecretStageVersionMismatchException(SecretsManagerClientError):
def __init__(self): def __init__(self) -> None:
self.code = 404 self.code = 404
super().__init__( super().__init__(
"InvalidRequestException", "InvalidRequestException",
@ -40,25 +40,25 @@ class SecretStageVersionMismatchException(SecretsManagerClientError):
class ClientError(SecretsManagerClientError): class ClientError(SecretsManagerClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidParameterValue", message) super().__init__("InvalidParameterValue", message)
class InvalidParameterException(SecretsManagerClientError): class InvalidParameterException(SecretsManagerClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidParameterException", message) super().__init__("InvalidParameterException", message)
class ResourceExistsException(SecretsManagerClientError): class ResourceExistsException(SecretsManagerClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("ResourceExistsException", message) super().__init__("ResourceExistsException", message)
class InvalidRequestException(SecretsManagerClientError): class InvalidRequestException(SecretsManagerClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidRequestException", message) super().__init__("InvalidRequestException", message)
class ValidationException(SecretsManagerClientError): class ValidationException(SecretsManagerClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("ValidationException", message) super().__init__("ValidationException", message)

View File

@ -1,30 +1,36 @@
def name_filter(secret, names): from typing import List, TYPE_CHECKING
if TYPE_CHECKING:
from ..models import FakeSecret
def name_filter(secret: "FakeSecret", names: List[str]) -> bool:
return _matcher(names, [secret.name]) return _matcher(names, [secret.name])
def description_filter(secret, descriptions): def description_filter(secret: "FakeSecret", descriptions: List[str]) -> bool:
return _matcher(descriptions, [secret.description]) return _matcher(descriptions, [secret.description]) # type: ignore
def tag_key(secret, tag_keys): def tag_key(secret: "FakeSecret", tag_keys: List[str]) -> bool:
return _matcher(tag_keys, [tag["Key"] for tag in secret.tags]) return _matcher(tag_keys, [tag["Key"] for tag in secret.tags])
def tag_value(secret, tag_values): def tag_value(secret: "FakeSecret", tag_values: List[str]) -> bool:
return _matcher(tag_values, [tag["Value"] for tag in secret.tags]) return _matcher(tag_values, [tag["Value"] for tag in secret.tags])
def filter_all(secret, values): def filter_all(secret: "FakeSecret", values: List[str]) -> bool:
attributes = ( attributes = (
[secret.name, secret.description] [secret.name, secret.description]
+ [tag["Key"] for tag in secret.tags] + [tag["Key"] for tag in secret.tags]
+ [tag["Value"] for tag in secret.tags] + [tag["Value"] for tag in secret.tags]
) )
return _matcher(values, attributes) return _matcher(values, attributes) # type: ignore
def _matcher(patterns, strings): def _matcher(patterns: List[str], strings: List[str]) -> bool:
for pattern in [p for p in patterns if p.startswith("!")]: for pattern in [p for p in patterns if p.startswith("!")]:
for string in strings: for string in strings:
if _match_pattern(pattern[1:], string): if _match_pattern(pattern[1:], string):
@ -37,7 +43,7 @@ def _matcher(patterns, strings):
return False return False
def _match_pattern(pattern, value): def _match_pattern(pattern: str, value: str) -> bool:
for word in pattern.split(" "): for word in pattern.split(" "):
if word not in value: if word not in value:
return False return False

View File

@ -2,7 +2,7 @@ import time
import json import json
import datetime import datetime
from typing import List, Tuple from typing import Any, Dict, List, Tuple, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
@ -35,41 +35,41 @@ _filter_functions = {
} }
def filter_keys(): def filter_keys() -> List[str]:
return list(_filter_functions.keys()) return list(_filter_functions.keys())
def _matches(secret, filters): def _matches(secret: "FakeSecret", filters: List[Dict[str, Any]]) -> bool:
is_match = True is_match = True
for f in filters: for f in filters:
# Filter names are pre-validated in the resource layer # Filter names are pre-validated in the resource layer
filter_function = _filter_functions.get(f["Key"]) filter_function = _filter_functions.get(f["Key"])
is_match = is_match and filter_function(secret, f["Values"]) is_match = is_match and filter_function(secret, f["Values"]) # type: ignore
return is_match return is_match
class SecretsManager(BaseModel): class SecretsManager(BaseModel):
def __init__(self, region_name): def __init__(self, region_name: str):
self.region = region_name self.region = region_name
class FakeSecret: class FakeSecret:
def __init__( def __init__(
self, self,
account_id, account_id: str,
region_name, region_name: str,
secret_id, secret_id: str,
secret_string=None, secret_string: Optional[str] = None,
secret_binary=None, secret_binary: Optional[str] = None,
description=None, description: Optional[str] = None,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
kms_key_id=None, kms_key_id: Optional[str] = None,
version_id=None, version_id: Optional[str] = None,
version_stages=None, version_stages: Optional[List[str]] = None,
last_changed_date=None, last_changed_date: Optional[int] = None,
created_date=None, created_date: Optional[int] = None,
): ):
self.secret_id = secret_id self.secret_id = secret_id
self.name = secret_id self.name = secret_id
@ -86,12 +86,16 @@ class FakeSecret:
self.rotation_enabled = False self.rotation_enabled = False
self.rotation_lambda_arn = "" self.rotation_lambda_arn = ""
self.auto_rotate_after_days = 0 self.auto_rotate_after_days = 0
self.deleted_date = None self.deleted_date: Optional[float] = None
self.policy = None self.policy: Optional[str] = None
def update( def update(
self, description=None, tags=None, kms_key_id=None, last_changed_date=None self,
): description: Optional[str] = None,
tags: Optional[List[Dict[str, str]]] = None,
kms_key_id: Optional[str] = None,
last_changed_date: Optional[int] = None,
) -> None:
self.description = description self.description = description
self.tags = tags or [] self.tags = tags or []
if last_changed_date is not None: if last_changed_date is not None:
@ -100,13 +104,15 @@ class FakeSecret:
if kms_key_id is not None: if kms_key_id is not None:
self.kms_key_id = kms_key_id self.kms_key_id = kms_key_id
def set_versions(self, versions): def set_versions(self, versions: Dict[str, Dict[str, Any]]) -> None:
self.versions = versions self.versions = versions
def set_default_version_id(self, version_id): def set_default_version_id(self, version_id: str) -> None:
self.default_version_id = version_id self.default_version_id = version_id
def reset_default_version(self, secret_version, version_id): def reset_default_version(
self, secret_version: Dict[str, Any], version_id: str
) -> None:
# remove all old AWSPREVIOUS stages # remove all old AWSPREVIOUS stages
for old_version in self.versions.values(): for old_version in self.versions.values():
if "AWSPREVIOUS" in old_version["version_stages"]: if "AWSPREVIOUS" in old_version["version_stages"]:
@ -119,22 +125,26 @@ class FakeSecret:
self.versions[version_id] = secret_version self.versions[version_id] = secret_version
self.default_version_id = version_id self.default_version_id = version_id
def remove_version_stages_from_old_versions(self, version_stages): def remove_version_stages_from_old_versions(
self, version_stages: List[str]
) -> None:
for version_stage in version_stages: for version_stage in version_stages:
for old_version in self.versions.values(): for old_version in self.versions.values():
if version_stage in old_version["version_stages"]: if version_stage in old_version["version_stages"]:
old_version["version_stages"].remove(version_stage) old_version["version_stages"].remove(version_stage)
def delete(self, deleted_date): def delete(self, deleted_date: float) -> None:
self.deleted_date = deleted_date self.deleted_date = deleted_date
def restore(self): def restore(self) -> None:
self.deleted_date = None self.deleted_date = None
def is_deleted(self): def is_deleted(self) -> bool:
return self.deleted_date is not None return self.deleted_date is not None
def to_short_dict(self, include_version_stages=False, version_id=None): def to_short_dict(
self, include_version_stages: bool = False, version_id: Optional[str] = None
) -> str:
if not version_id: if not version_id:
version_id = self.default_version_id version_id = self.default_version_id
dct = { dct = {
@ -146,7 +156,7 @@ class FakeSecret:
dct["VersionStages"] = self.versions[version_id]["version_stages"] dct["VersionStages"] = self.versions[version_id]["version_stages"]
return json.dumps(dct) return json.dumps(dct)
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
version_id_to_stages = self._form_version_ids_to_stages() version_id_to_stages = self._form_version_ids_to_stages()
return { return {
@ -167,7 +177,7 @@ class FakeSecret:
"CreatedDate": self.created_date, "CreatedDate": self.created_date,
} }
def _form_version_ids_to_stages(self): def _form_version_ids_to_stages(self) -> Dict[str, str]:
version_id_to_stages = {} version_id_to_stages = {}
for key, value in self.versions.items(): for key, value in self.versions.items():
version_id_to_stages[key] = value["version_stages"] version_id_to_stages[key] = value["version_stages"]
@ -175,86 +185,87 @@ class FakeSecret:
return version_id_to_stages return version_id_to_stages
class SecretsStore(dict): class SecretsStore(Dict[str, FakeSecret]):
# Parameters to this dictionary can be three possible values: # Parameters to this dictionary can be three possible values:
# names, full ARNs, and partial ARNs # names, full ARNs, and partial ARNs
# Every retrieval method should check which type of input it receives # Every retrieval method should check which type of input it receives
def __setitem__(self, key, value): def __setitem__(self, key: str, value: FakeSecret) -> None:
super().__setitem__(key, value) super().__setitem__(key, value)
def __getitem__(self, key): def __getitem__(self, key: str) -> FakeSecret:
for secret in dict.values(self): for secret in dict.values(self):
if secret.arn == key or secret.name == key: if secret.arn == key or secret.name == key:
return secret return secret
name = get_secret_name_from_partial_arn(key) name = get_secret_name_from_partial_arn(key)
return super().__getitem__(name) return super().__getitem__(name)
def __contains__(self, key): def __contains__(self, key: str) -> bool: # type: ignore
for secret in dict.values(self): for secret in dict.values(self):
if secret.arn == key or secret.name == key: if secret.arn == key or secret.name == key:
return True return True
name = get_secret_name_from_partial_arn(key) name = get_secret_name_from_partial_arn(key)
return dict.__contains__(self, name) return dict.__contains__(self, name) # type: ignore
def get(self, key, *args, **kwargs): def get(self, key: str) -> Optional[FakeSecret]: # type: ignore
for secret in dict.values(self): for secret in dict.values(self):
if secret.arn == key or secret.name == key: if secret.arn == key or secret.name == key:
return secret return secret
name = get_secret_name_from_partial_arn(key) name = get_secret_name_from_partial_arn(key)
return super().get(name, *args, **kwargs) return super().get(name)
def pop(self, key, *args, **kwargs): def pop(self, key: str) -> Optional[FakeSecret]: # type: ignore
for secret in dict.values(self): for secret in dict.values(self):
if secret.arn == key or secret.name == key: if secret.arn == key or secret.name == key:
key = secret.name key = secret.name
name = get_secret_name_from_partial_arn(key) name = get_secret_name_from_partial_arn(key)
return super().pop(name, *args, **kwargs) return super().pop(name, None)
class SecretsManagerBackend(BaseBackend): class SecretsManagerBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.secrets = SecretsStore() self.secrets = SecretsStore()
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint services.""" """Default VPC endpoint services."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "secretsmanager" service_region, zones, "secretsmanager"
) )
def _is_valid_identifier(self, identifier): def _is_valid_identifier(self, identifier: str) -> bool:
return identifier in self.secrets return identifier in self.secrets
def _unix_time_secs(self, dt): def _unix_time_secs(self, dt: datetime.datetime) -> float:
epoch = datetime.datetime.utcfromtimestamp(0) epoch = datetime.datetime.utcfromtimestamp(0)
return (dt - epoch).total_seconds() return (dt - epoch).total_seconds()
def _client_request_token_validator(self, client_request_token): def _client_request_token_validator(self, client_request_token: str) -> None:
token_length = len(client_request_token) token_length = len(client_request_token)
if token_length < 32 or token_length > 64: if token_length < 32 or token_length > 64:
msg = "ClientRequestToken must be 32-64 characters long." msg = "ClientRequestToken must be 32-64 characters long."
raise InvalidParameterException(msg) raise InvalidParameterException(msg)
def _from_client_request_token(self, client_request_token): def _from_client_request_token(self, client_request_token: Optional[str]) -> str:
version_id = client_request_token if client_request_token:
if version_id: self._client_request_token_validator(client_request_token)
self._client_request_token_validator(version_id) return client_request_token
else: else:
version_id = str(mock_random.uuid4()) return str(mock_random.uuid4())
return version_id
def cancel_rotate_secret(self, secret_id: str): def cancel_rotate_secret(self, secret_id: str) -> str:
if not self._is_valid_identifier(secret_id): if not self._is_valid_identifier(secret_id):
raise SecretNotFoundException() raise SecretNotFoundException()
if self.secrets[secret_id].is_deleted(): secret = self.secrets[secret_id]
if secret.is_deleted():
raise InvalidRequestException( raise InvalidRequestException(
"You tried to perform the operation on a secret that's currently marked deleted." "You tried to perform the operation on a secret that's currently marked deleted."
) )
secret = self.secrets.get(key=secret_id)
if not secret.rotation_lambda_arn: if not secret.rotation_lambda_arn:
# This response doesn't make much sense for `CancelRotateSecret`, but this is what AWS has documented ... # This response doesn't make much sense for `CancelRotateSecret`, but this is what AWS has documented ...
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_CancelRotateSecret.html # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_CancelRotateSecret.html
@ -268,7 +279,9 @@ class SecretsManagerBackend(BaseBackend):
secret.rotation_enabled = False secret.rotation_enabled = False
return secret.to_short_dict() return secret.to_short_dict()
def get_secret_value(self, secret_id, version_id, version_stage): def get_secret_value(
self, secret_id: str, version_id: str, version_stage: str
) -> Dict[str, Any]:
if not self._is_valid_identifier(secret_id): if not self._is_valid_identifier(secret_id):
raise SecretNotFoundException() raise SecretNotFoundException()
@ -331,12 +344,12 @@ class SecretsManagerBackend(BaseBackend):
def update_secret( def update_secret(
self, self,
secret_id, secret_id: str,
secret_string=None, secret_string: Optional[str] = None,
secret_binary=None, secret_binary: Optional[str] = None,
client_request_token=None, client_request_token: Optional[str] = None,
kms_key_id=None, kms_key_id: Optional[str] = None,
): ) -> str:
# error if secret does not exist # error if secret does not exist
if secret_id not in self.secrets: if secret_id not in self.secrets:
@ -366,14 +379,14 @@ class SecretsManagerBackend(BaseBackend):
def create_secret( def create_secret(
self, self,
name, name: str,
secret_string=None, secret_string: Optional[str] = None,
secret_binary=None, secret_binary: Optional[str] = None,
description=None, description: Optional[str] = None,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
kms_key_id=None, kms_key_id: Optional[str] = None,
client_request_token=None, client_request_token: Optional[str] = None,
): ) -> str:
# error if secret exists # error if secret exists
if name in self.secrets.keys(): if name in self.secrets.keys():
@ -395,15 +408,15 @@ class SecretsManagerBackend(BaseBackend):
def _add_secret( def _add_secret(
self, self,
secret_id, secret_id: str,
secret_string=None, secret_string: Optional[str] = None,
secret_binary=None, secret_binary: Optional[str] = None,
description=None, description: Optional[str] = None,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
kms_key_id=None, kms_key_id: Optional[str] = None,
version_id=None, version_id: Optional[str] = None,
version_stages=None, version_stages: Optional[List[str]] = None,
): ) -> FakeSecret:
if version_stages is None: if version_stages is None:
version_stages = ["AWSCURRENT"] version_stages = ["AWSCURRENT"]
@ -453,12 +466,12 @@ class SecretsManagerBackend(BaseBackend):
def put_secret_value( def put_secret_value(
self, self,
secret_id, secret_id: str,
secret_string, secret_string: str,
secret_binary, secret_binary: str,
client_request_token, client_request_token: str,
version_stages, version_stages: List[str],
): ) -> str:
if not self._is_valid_identifier(secret_id): if not self._is_valid_identifier(secret_id):
raise SecretNotFoundException() raise SecretNotFoundException()
@ -481,7 +494,7 @@ class SecretsManagerBackend(BaseBackend):
return secret.to_short_dict(include_version_stages=True, version_id=version_id) return secret.to_short_dict(include_version_stages=True, version_id=version_id)
def describe_secret(self, secret_id): def describe_secret(self, secret_id: str) -> Dict[str, Any]:
if not self._is_valid_identifier(secret_id): if not self._is_valid_identifier(secret_id):
raise SecretNotFoundException() raise SecretNotFoundException()
@ -491,11 +504,11 @@ class SecretsManagerBackend(BaseBackend):
def rotate_secret( def rotate_secret(
self, self,
secret_id, secret_id: str,
client_request_token=None, client_request_token: Optional[str] = None,
rotation_lambda_arn=None, rotation_lambda_arn: Optional[str] = None,
rotation_rules=None, rotation_rules: Optional[Dict[str, Any]] = None,
): ) -> str:
rotation_days = "AutomaticallyAfterDays" rotation_days = "AutomaticallyAfterDays"
@ -584,8 +597,8 @@ class SecretsManagerBackend(BaseBackend):
lambda_backend = lambda_backends[self.account_id][self.region_name] lambda_backend = lambda_backends[self.account_id][self.region_name]
request_headers = {} request_headers: Dict[str, Any] = {}
response_headers = {} response_headers: Dict[str, Any] = {}
try: try:
func = lambda_backend.get_function(secret.rotation_lambda_arn) func = lambda_backend.get_function(secret.rotation_lambda_arn)
@ -617,15 +630,15 @@ class SecretsManagerBackend(BaseBackend):
def get_random_password( def get_random_password(
self, self,
password_length, password_length: int,
exclude_characters, exclude_characters: str,
exclude_numbers, exclude_numbers: bool,
exclude_punctuation, exclude_punctuation: bool,
exclude_uppercase, exclude_uppercase: bool,
exclude_lowercase, exclude_lowercase: bool,
include_space, include_space: bool,
require_each_included_type, require_each_included_type: bool,
): ) -> str:
# password size must have value less than or equal to 4096 # password size must have value less than or equal to 4096
if password_length > 4096: if password_length > 4096:
raise ClientError( raise ClientError(
@ -639,7 +652,7 @@ class SecretsManagerBackend(BaseBackend):
when calling the GetRandomPassword operation: Password length is too short based on the required types." when calling the GetRandomPassword operation: Password length is too short based on the required types."
) )
response = json.dumps( return json.dumps(
{ {
"RandomPassword": random_password( "RandomPassword": random_password(
password_length, password_length,
@ -654,9 +667,7 @@ class SecretsManagerBackend(BaseBackend):
} }
) )
return response def list_secret_version_ids(self, secret_id: str) -> str:
def list_secret_version_ids(self, secret_id):
secret = self.secrets[secret_id] secret = self.secrets[secret_id]
version_list = [] version_list = []
@ -670,7 +681,7 @@ class SecretsManagerBackend(BaseBackend):
} }
) )
response = json.dumps( return json.dumps(
{ {
"ARN": secret.secret_id, "ARN": secret.secret_id,
"Name": secret.name, "Name": secret.name,
@ -679,11 +690,12 @@ class SecretsManagerBackend(BaseBackend):
} }
) )
return response
def list_secrets( def list_secrets(
self, filters: List, max_results: int = 100, next_token: str = None self,
) -> Tuple[List, str]: filters: List[Dict[str, Any]],
max_results: int = 100,
next_token: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
secret_list = [] secret_list = []
for secret in self.secrets.values(): for secret in self.secrets.values():
if _matches(secret, filters): if _matches(secret, filters):
@ -697,8 +709,11 @@ class SecretsManagerBackend(BaseBackend):
return secret_page, new_next_token return secret_page, new_next_token
def delete_secret( def delete_secret(
self, secret_id, recovery_window_in_days, force_delete_without_recovery self,
): secret_id: str,
recovery_window_in_days: int,
force_delete_without_recovery: bool,
) -> Tuple[str, str, float]:
if recovery_window_in_days and ( if recovery_window_in_days and (
recovery_window_in_days < 7 or recovery_window_in_days > 30 recovery_window_in_days < 7 or recovery_window_in_days > 30
@ -718,9 +733,11 @@ class SecretsManagerBackend(BaseBackend):
if not force_delete_without_recovery: if not force_delete_without_recovery:
raise SecretNotFoundException() raise SecretNotFoundException()
else: else:
secret = FakeSecret(self.account_id, self.region_name, secret_id) unknown_secret = FakeSecret(
arn = secret.arn self.account_id, self.region_name, secret_id
name = secret.name )
arn = unknown_secret.arn
name = unknown_secret.name
deletion_date = datetime.datetime.utcnow() deletion_date = datetime.datetime.utcnow()
return arn, name, self._unix_time_secs(deletion_date) return arn, name, self._unix_time_secs(deletion_date)
else: else:
@ -733,11 +750,11 @@ class SecretsManagerBackend(BaseBackend):
deletion_date = datetime.datetime.utcnow() deletion_date = datetime.datetime.utcnow()
if force_delete_without_recovery: if force_delete_without_recovery:
secret = self.secrets.pop(secret_id, None) secret = self.secrets.pop(secret_id)
else: else:
deletion_date += datetime.timedelta(days=recovery_window_in_days or 30) deletion_date += datetime.timedelta(days=recovery_window_in_days or 30)
self.secrets[secret_id].delete(self._unix_time_secs(deletion_date)) self.secrets[secret_id].delete(self._unix_time_secs(deletion_date))
secret = self.secrets.get(secret_id, None) secret = self.secrets.get(secret_id)
if not secret: if not secret:
raise SecretNotFoundException() raise SecretNotFoundException()
@ -747,7 +764,7 @@ class SecretsManagerBackend(BaseBackend):
return arn, name, self._unix_time_secs(deletion_date) return arn, name, self._unix_time_secs(deletion_date)
def restore_secret(self, secret_id): def restore_secret(self, secret_id: str) -> Tuple[str, str]:
if not self._is_valid_identifier(secret_id): if not self._is_valid_identifier(secret_id):
raise SecretNotFoundException() raise SecretNotFoundException()
@ -757,7 +774,7 @@ class SecretsManagerBackend(BaseBackend):
return secret.arn, secret.name return secret.arn, secret.name
def tag_resource(self, secret_id, tags): def tag_resource(self, secret_id: str, tags: List[Dict[str, str]]) -> None:
if secret_id not in self.secrets: if secret_id not in self.secrets:
raise SecretNotFoundException() raise SecretNotFoundException()
@ -778,9 +795,7 @@ class SecretsManagerBackend(BaseBackend):
old_tags.remove(existing_key_name) old_tags.remove(existing_key_name)
old_tags.append(tag) old_tags.append(tag)
return secret_id def untag_resource(self, secret_id: str, tag_keys: List[str]) -> None:
def untag_resource(self, secret_id, tag_keys):
if secret_id not in self.secrets: if secret_id not in self.secrets:
raise SecretNotFoundException() raise SecretNotFoundException()
@ -792,11 +807,13 @@ class SecretsManagerBackend(BaseBackend):
if tag["Key"] in tag_keys: if tag["Key"] in tag_keys:
tags.remove(tag) tags.remove(tag)
return secret_id
def update_secret_version_stage( def update_secret_version_stage(
self, secret_id, version_stage, remove_from_version_id, move_to_version_id self,
): secret_id: str,
version_stage: str,
remove_from_version_id: str,
move_to_version_id: str,
) -> Tuple[str, str]:
if secret_id not in self.secrets: if secret_id not in self.secrets:
raise SecretNotFoundException() raise SecretNotFoundException()
@ -839,9 +856,9 @@ class SecretsManagerBackend(BaseBackend):
if "AWSPREVIOUS" in stages: if "AWSPREVIOUS" in stages:
stages.remove("AWSPREVIOUS") stages.remove("AWSPREVIOUS")
return secret_id return secret.arn, secret.name
def put_resource_policy(self, secret_id: str, policy: str): def put_resource_policy(self, secret_id: str, policy: str) -> Tuple[str, str]:
""" """
The BlockPublicPolicy-parameter is not yet implemented The BlockPublicPolicy-parameter is not yet implemented
""" """
@ -852,7 +869,7 @@ class SecretsManagerBackend(BaseBackend):
secret.policy = policy secret.policy = policy
return secret.arn, secret.name return secret.arn, secret.name
def get_resource_policy(self, secret_id): def get_resource_policy(self, secret_id: str) -> str:
if not self._is_valid_identifier(secret_id): if not self._is_valid_identifier(secret_id):
raise SecretNotFoundException() raise SecretNotFoundException()
@ -865,7 +882,7 @@ class SecretsManagerBackend(BaseBackend):
resp["ResourcePolicy"] = secret.policy resp["ResourcePolicy"] = secret.policy
return json.dumps(resp) return json.dumps(resp)
def delete_resource_policy(self, secret_id): def delete_resource_policy(self, secret_id: str) -> Tuple[str, str]:
if not self._is_valid_identifier(secret_id): if not self._is_valid_identifier(secret_id):
raise SecretNotFoundException() raise SecretNotFoundException()

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.secretsmanager.exceptions import ( from moto.secretsmanager.exceptions import (
InvalidRequestException, InvalidRequestException,
@ -5,12 +6,12 @@ from moto.secretsmanager.exceptions import (
ValidationException, ValidationException,
) )
from .models import secretsmanager_backends, filter_keys from .models import secretsmanager_backends, filter_keys, SecretsManagerBackend
import json import json
def _validate_filters(filters): def _validate_filters(filters: List[Dict[str, Any]]) -> None:
for idx, f in enumerate(filters): for idx, f in enumerate(filters):
filter_key = f.get("Key", None) filter_key = f.get("Key", None)
filter_values = f.get("Values", None) filter_values = f.get("Values", None)
@ -28,18 +29,18 @@ def _validate_filters(filters):
class SecretsManagerResponse(BaseResponse): class SecretsManagerResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="secretsmanager") super().__init__(service_name="secretsmanager")
@property @property
def backend(self): def backend(self) -> SecretsManagerBackend:
return secretsmanager_backends[self.current_account][self.region] return secretsmanager_backends[self.current_account][self.region]
def cancel_rotate_secret(self): def cancel_rotate_secret(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
return self.backend.cancel_rotate_secret(secret_id=secret_id) return self.backend.cancel_rotate_secret(secret_id=secret_id)
def get_secret_value(self): def get_secret_value(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
version_id = self._get_param("VersionId") version_id = self._get_param("VersionId")
version_stage = self._get_param("VersionStage") version_stage = self._get_param("VersionStage")
@ -48,7 +49,7 @@ class SecretsManagerResponse(BaseResponse):
) )
return json.dumps(value) return json.dumps(value)
def create_secret(self): def create_secret(self) -> str:
name = self._get_param("Name") name = self._get_param("Name")
secret_string = self._get_param("SecretString") secret_string = self._get_param("SecretString")
secret_binary = self._get_param("SecretBinary") secret_binary = self._get_param("SecretBinary")
@ -66,7 +67,7 @@ class SecretsManagerResponse(BaseResponse):
client_request_token=client_request_token, client_request_token=client_request_token,
) )
def update_secret(self): def update_secret(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
secret_string = self._get_param("SecretString") secret_string = self._get_param("SecretString")
secret_binary = self._get_param("SecretBinary") secret_binary = self._get_param("SecretBinary")
@ -80,7 +81,7 @@ class SecretsManagerResponse(BaseResponse):
kms_key_id=kms_key_id, kms_key_id=kms_key_id,
) )
def get_random_password(self): def get_random_password(self) -> str:
password_length = self._get_param("PasswordLength", if_none=32) password_length = self._get_param("PasswordLength", if_none=32)
exclude_characters = self._get_param("ExcludeCharacters", if_none="") exclude_characters = self._get_param("ExcludeCharacters", if_none="")
exclude_numbers = self._get_param("ExcludeNumbers", if_none=False) exclude_numbers = self._get_param("ExcludeNumbers", if_none=False)
@ -102,12 +103,12 @@ class SecretsManagerResponse(BaseResponse):
require_each_included_type=require_each_included_type, require_each_included_type=require_each_included_type,
) )
def describe_secret(self): def describe_secret(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
secret = self.backend.describe_secret(secret_id=secret_id) secret = self.backend.describe_secret(secret_id=secret_id)
return json.dumps(secret) return json.dumps(secret)
def rotate_secret(self): def rotate_secret(self) -> str:
client_request_token = self._get_param("ClientRequestToken") client_request_token = self._get_param("ClientRequestToken")
rotation_lambda_arn = self._get_param("RotationLambdaARN") rotation_lambda_arn = self._get_param("RotationLambdaARN")
rotation_rules = self._get_param("RotationRules") rotation_rules = self._get_param("RotationRules")
@ -119,7 +120,7 @@ class SecretsManagerResponse(BaseResponse):
rotation_rules=rotation_rules, rotation_rules=rotation_rules,
) )
def put_secret_value(self): def put_secret_value(self) -> str:
secret_id = self._get_param("SecretId", if_none="") secret_id = self._get_param("SecretId", if_none="")
secret_string = self._get_param("SecretString") secret_string = self._get_param("SecretString")
secret_binary = self._get_param("SecretBinary") secret_binary = self._get_param("SecretBinary")
@ -140,11 +141,11 @@ class SecretsManagerResponse(BaseResponse):
client_request_token=client_request_token, client_request_token=client_request_token,
) )
def list_secret_version_ids(self): def list_secret_version_ids(self) -> str:
secret_id = self._get_param("SecretId", if_none="") secret_id = self._get_param("SecretId", if_none="")
return self.backend.list_secret_version_ids(secret_id=secret_id) return self.backend.list_secret_version_ids(secret_id=secret_id)
def list_secrets(self): def list_secrets(self) -> str:
filters = self._get_param("Filters", if_none=[]) filters = self._get_param("Filters", if_none=[])
_validate_filters(filters) _validate_filters(filters)
max_results = self._get_int_param("MaxResults") max_results = self._get_int_param("MaxResults")
@ -154,7 +155,7 @@ class SecretsManagerResponse(BaseResponse):
) )
return json.dumps(dict(SecretList=secret_list, NextToken=next_token)) return json.dumps(dict(SecretList=secret_list, NextToken=next_token))
def delete_secret(self): def delete_secret(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
recovery_window_in_days = self._get_param("RecoveryWindowInDays") recovery_window_in_days = self._get_param("RecoveryWindowInDays")
force_delete_without_recovery = self._get_param("ForceDeleteWithoutRecovery") force_delete_without_recovery = self._get_param("ForceDeleteWithoutRecovery")
@ -165,44 +166,47 @@ class SecretsManagerResponse(BaseResponse):
) )
return json.dumps(dict(ARN=arn, Name=name, DeletionDate=deletion_date)) return json.dumps(dict(ARN=arn, Name=name, DeletionDate=deletion_date))
def restore_secret(self): def restore_secret(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
arn, name = self.backend.restore_secret(secret_id=secret_id) arn, name = self.backend.restore_secret(secret_id=secret_id)
return json.dumps(dict(ARN=arn, Name=name)) return json.dumps(dict(ARN=arn, Name=name))
def get_resource_policy(self): def get_resource_policy(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
return self.backend.get_resource_policy(secret_id=secret_id) return self.backend.get_resource_policy(secret_id=secret_id)
def put_resource_policy(self): def put_resource_policy(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
policy = self._get_param("ResourcePolicy") policy = self._get_param("ResourcePolicy")
arn, name = self.backend.put_resource_policy(secret_id, policy) arn, name = self.backend.put_resource_policy(secret_id, policy)
return json.dumps(dict(ARN=arn, Name=name)) return json.dumps(dict(ARN=arn, Name=name))
def delete_resource_policy(self): def delete_resource_policy(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
arn, name = self.backend.delete_resource_policy(secret_id) arn, name = self.backend.delete_resource_policy(secret_id)
return json.dumps(dict(ARN=arn, Name=name)) return json.dumps(dict(ARN=arn, Name=name))
def tag_resource(self): def tag_resource(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
tags = self._get_param("Tags", if_none=[]) tags = self._get_param("Tags", if_none=[])
return self.backend.tag_resource(secret_id, tags) self.backend.tag_resource(secret_id, tags)
return "{}"
def untag_resource(self): def untag_resource(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
tag_keys = self._get_param("TagKeys", if_none=[]) tag_keys = self._get_param("TagKeys", if_none=[])
return self.backend.untag_resource(secret_id=secret_id, tag_keys=tag_keys) self.backend.untag_resource(secret_id=secret_id, tag_keys=tag_keys)
return "{}"
def update_secret_version_stage(self): def update_secret_version_stage(self) -> str:
secret_id = self._get_param("SecretId") secret_id = self._get_param("SecretId")
version_stage = self._get_param("VersionStage") version_stage = self._get_param("VersionStage")
remove_from_version_id = self._get_param("RemoveFromVersionId") remove_from_version_id = self._get_param("RemoveFromVersionId")
move_to_version_id = self._get_param("MoveToVersionId") move_to_version_id = self._get_param("MoveToVersionId")
return self.backend.update_secret_version_stage( arn, name = self.backend.update_secret_version_stage(
secret_id=secret_id, secret_id=secret_id,
version_stage=version_stage, version_stage=version_stage,
remove_from_version_id=remove_from_version_id, remove_from_version_id=remove_from_version_id,
move_to_version_id=move_to_version_id, move_to_version_id=move_to_version_id,
) )
return json.dumps({"ARN": arn, "Name": name})

View File

@ -4,15 +4,15 @@ from moto.moto_api._internal import mock_random as random
def random_password( def random_password(
password_length, password_length: int,
exclude_characters, exclude_characters: str,
exclude_numbers, exclude_numbers: bool,
exclude_punctuation, exclude_punctuation: bool,
exclude_uppercase, exclude_uppercase: bool,
exclude_lowercase, exclude_lowercase: bool,
include_space, include_space: bool,
require_each_included_type, require_each_included_type: bool,
): ) -> str:
password = "" password = ""
required_characters = "" required_characters = ""
@ -61,7 +61,7 @@ def random_password(
return password return password
def secret_arn(account_id, region, secret_id): def secret_arn(account_id: str, region: str, secret_id: str) -> str:
id_string = "".join(random.choice(string.ascii_letters) for _ in range(6)) id_string = "".join(random.choice(string.ascii_letters) for _ in range(6))
return ( return (
f"arn:aws:secretsmanager:{region}:{account_id}:secret:{secret_id}-{id_string}" f"arn:aws:secretsmanager:{region}:{account_id}:secret:{secret_id}-{id_string}"
@ -84,7 +84,7 @@ def get_secret_name_from_partial_arn(partial_arn: str) -> str:
return partial_arn return partial_arn
def _exclude_characters(password, exclude_characters): def _exclude_characters(password: str, exclude_characters: str) -> str:
for c in exclude_characters: for c in exclude_characters:
if c in string.punctuation: if c in string.punctuation:
# Escape punctuation regex usage # Escape punctuation regex usage
@ -93,7 +93,9 @@ def _exclude_characters(password, exclude_characters):
return password return password
def _add_password_require_each_included_type(password, required_characters): def _add_password_require_each_included_type(
password: str, required_characters: str
) -> str:
password_with_required_char = password[: -len(required_characters)] password_with_required_char = password[: -len(required_characters)]
password_with_required_char += required_characters password_with_required_char += required_characters

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/scheduler files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/scheduler
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract

View File

@ -791,7 +791,7 @@ def test_update_secret_version_stage(pass_arn):
assert stages[initial_version] == ["AWSCURRENT"] assert stages[initial_version] == ["AWSCURRENT"]
assert stages[new_version] == [custom_stage] assert stages[new_version] == [custom_stage]
test_client.post( resp = test_client.post(
"/", "/",
data={ data={
"SecretId": secret_id, "SecretId": secret_id,
@ -801,6 +801,9 @@ def test_update_secret_version_stage(pass_arn):
}, },
headers={"X-Amz-Target": "secretsmanager.UpdateSecretVersionStage"}, headers={"X-Amz-Target": "secretsmanager.UpdateSecretVersionStage"},
) )
resp = json.loads(resp.data.decode("utf-8"))
assert resp.get("ARN") == create_secret["ARN"]
assert resp.get("Name") == DEFAULT_SECRET_NAME
describe_secret = test_client.post( describe_secret = test_client.post(
"/", "/",