From f1f4454b0fc0766bd5063c445e595196720e93a3 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sat, 11 Mar 2023 16:00:52 -0100 Subject: [PATCH] Techdebt: MyPy g-models (#6048) --- moto/glacier/models.py | 84 +++-- moto/glacier/responses.py | 67 ++-- moto/glacier/utils.py | 4 +- moto/glue/exceptions.py | 113 +++--- moto/glue/glue_schema_registry_utils.py | 141 ++++--- moto/glue/models.py | 479 ++++++++++++++---------- moto/glue/responses.py | 235 ++++++------ moto/glue/utils.py | 18 +- moto/greengrass/exceptions.py | 12 +- moto/greengrass/models.py | 413 +++++++++++++------- moto/greengrass/responses.py | 182 ++++----- moto/guardduty/exceptions.py | 17 +- moto/guardduty/models.py | 114 ++++-- moto/guardduty/responses.py | 42 ++- moto/utilities/tagging_service.py | 4 +- setup.cfg | 2 +- 16 files changed, 1130 insertions(+), 797 deletions(-) diff --git a/moto/glacier/models.py b/moto/glacier/models.py index 8d66e5f14..b9dfe696e 100644 --- a/moto/glacier/models.py +++ b/moto/glacier/models.py @@ -1,6 +1,6 @@ import hashlib - import datetime +from typing import Any, Dict, List, Optional, Union from moto.core import BaseBackend, BackendDict, BaseModel from moto.utilities.utils import md5_hash @@ -9,7 +9,7 @@ from .utils import get_job_id class Job(BaseModel): - def __init__(self, tier): + def __init__(self, tier: str): self.st = datetime.datetime.now() if tier.lower() == "expedited": @@ -20,16 +20,19 @@ class Job(BaseModel): # Standard self.et = self.st + datetime.timedelta(seconds=5) + def to_dict(self) -> Dict[str, Any]: + return {} + class ArchiveJob(Job): - def __init__(self, job_id, tier, arn, archive_id): + def __init__(self, job_id: str, tier: str, arn: str, archive_id: Optional[str]): self.job_id = job_id self.tier = tier self.arn = arn self.archive_id = archive_id Job.__init__(self, tier) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: d = { "Action": "ArchiveRetrieval", "ArchiveId": self.archive_id, @@ -57,13 +60,13 @@ class ArchiveJob(Job): class InventoryJob(Job): - def __init__(self, job_id, tier, arn): + def __init__(self, job_id: str, tier: str, arn: str): self.job_id = job_id self.tier = tier self.arn = arn Job.__init__(self, tier) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: d = { "Action": "InventoryRetrieval", "ArchiveSHA256TreeHash": None, @@ -89,15 +92,15 @@ class InventoryJob(Job): class Vault(BaseModel): - def __init__(self, vault_name, account_id, region): + def __init__(self, vault_name: str, account_id: str, region: str): self.st = datetime.datetime.now() self.vault_name = vault_name self.region = region - self.archives = {} - self.jobs = {} + self.archives: Dict[str, Dict[str, Any]] = {} + self.jobs: Dict[str, Job] = {} self.arn = f"arn:aws:glacier:{region}:{account_id}:vaults/{vault_name}" - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: archives_size = 0 for k in self.archives: archives_size += self.archives[k]["size"] @@ -111,7 +114,7 @@ class Vault(BaseModel): } return d - def create_archive(self, body, description): + def create_archive(self, body: bytes, description: str) -> Dict[str, Any]: archive_id = md5_hash(body).hexdigest() self.archives[archive_id] = {} self.archives[archive_id]["archive_id"] = archive_id @@ -124,10 +127,10 @@ class Vault(BaseModel): self.archives[archive_id]["description"] = description return self.archives[archive_id] - def get_archive_body(self, archive_id): + def get_archive_body(self, archive_id: str) -> str: return self.archives[archive_id]["body"] - def get_archive_list(self): + def get_archive_list(self) -> List[Dict[str, Any]]: archive_list = [] for a in self.archives: archive = self.archives[a] @@ -141,34 +144,33 @@ class Vault(BaseModel): archive_list.append(aobj) return archive_list - def delete_archive(self, archive_id): + def delete_archive(self, archive_id: str) -> Dict[str, Any]: return self.archives.pop(archive_id) - def initiate_job(self, job_type, tier, archive_id): + def initiate_job(self, job_type: str, tier: str, archive_id: Optional[str]) -> str: job_id = get_job_id() if job_type == "inventory-retrieval": - job = InventoryJob(job_id, tier, self.arn) + self.jobs[job_id] = InventoryJob(job_id, tier, self.arn) elif job_type == "archive-retrieval": - job = ArchiveJob(job_id, tier, self.arn, archive_id) + self.jobs[job_id] = ArchiveJob(job_id, tier, self.arn, archive_id) - self.jobs[job_id] = job return job_id - def list_jobs(self): - return self.jobs.values() + def list_jobs(self) -> List[Job]: + return list(self.jobs.values()) - def describe_job(self, job_id): + def describe_job(self, job_id: str) -> Optional[Job]: return self.jobs.get(job_id) - def job_ready(self, job_id): + def job_ready(self, job_id: str) -> str: job = self.describe_job(job_id) - jobj = job.to_dict() + jobj = job.to_dict() # type: ignore return jobj["Completed"] - def get_job_output(self, job_id): + def get_job_output(self, job_id: str) -> Union[str, Dict[str, Any]]: job = self.describe_job(job_id) - jobj = job.to_dict() + jobj = job.to_dict() # type: ignore if jobj["Action"] == "InventoryRetrieval": archives = self.get_archive_list() return { @@ -177,48 +179,54 @@ class Vault(BaseModel): "ArchiveList": archives, } else: - archive_body = self.get_archive_body(job.archive_id) + archive_body = self.get_archive_body(job.archive_id) # type: ignore return archive_body class GlacierBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.vaults = {} + self.vaults: Dict[str, Vault] = {} - def get_vault(self, vault_name): + def get_vault(self, vault_name: str) -> Vault: return self.vaults[vault_name] - def create_vault(self, vault_name): + def create_vault(self, vault_name: str) -> None: self.vaults[vault_name] = Vault(vault_name, self.account_id, self.region_name) - def list_vaults(self): - return self.vaults.values() + def list_vaults(self) -> List[Vault]: + return list(self.vaults.values()) - def delete_vault(self, vault_name): + def delete_vault(self, vault_name: str) -> None: self.vaults.pop(vault_name) - def initiate_job(self, vault_name, job_type, tier, archive_id): + def initiate_job( + self, vault_name: str, job_type: str, tier: str, archive_id: Optional[str] + ) -> str: vault = self.get_vault(vault_name) job_id = vault.initiate_job(job_type, tier, archive_id) return job_id - def describe_job(self, vault_name, archive_id): + def describe_job(self, vault_name: str, archive_id: str) -> Optional[Job]: vault = self.get_vault(vault_name) return vault.describe_job(archive_id) - def list_jobs(self, vault_name): + def list_jobs(self, vault_name: str) -> List[Job]: vault = self.get_vault(vault_name) return vault.list_jobs() - def get_job_output(self, vault_name, job_id): + def get_job_output( + self, vault_name: str, job_id: str + ) -> Union[str, Dict[str, Any], None]: vault = self.get_vault(vault_name) if vault.job_ready(job_id): return vault.get_job_output(job_id) else: return None - def upload_archive(self, vault_name, body, description): + def upload_archive( + self, vault_name: str, body: bytes, description: str + ) -> Dict[str, Any]: vault = self.get_vault(vault_name) return vault.create_archive(body, description) diff --git a/moto/glacier/responses.py b/moto/glacier/responses.py index b32f5e0c1..1909e09b9 100644 --- a/moto/glacier/responses.py +++ b/moto/glacier/responses.py @@ -1,23 +1,26 @@ import json - +from typing import Any, Dict +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse -from .models import glacier_backends +from .models import glacier_backends, GlacierBackend from .utils import vault_from_glacier_url class GlacierResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="glacier") @property - def glacier_backend(self): + def glacier_backend(self) -> GlacierBackend: return glacier_backends[self.current_account][self.region] - def all_vault_response(self, request, full_url, headers): + def all_vault_response( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) return self._all_vault_response(headers) - def _all_vault_response(self, headers): + def _all_vault_response(self, headers: Any) -> TYPE_RESPONSE: vaults = self.glacier_backend.list_vaults() response = json.dumps( {"Marker": None, "VaultList": [vault.to_dict() for vault in vaults]} @@ -26,11 +29,13 @@ class GlacierResponse(BaseResponse): headers["content-type"] = "application/json" return 200, headers, response - def vault_response(self, request, full_url, headers): + def vault_response( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) return self._vault_response(request, full_url, headers) - def _vault_response(self, request, full_url, headers): + def _vault_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] method = request.method vault_name = vault_from_glacier_url(full_url) @@ -41,23 +46,27 @@ class GlacierResponse(BaseResponse): elif method == "DELETE": return self._vault_response_delete(vault_name, headers) - def _vault_response_get(self, vault_name, headers): + def _vault_response_get(self, vault_name: str, headers: Any) -> TYPE_RESPONSE: vault = self.glacier_backend.get_vault(vault_name) headers["content-type"] = "application/json" return 200, headers, json.dumps(vault.to_dict()) - def _vault_response_put(self, vault_name, headers): + def _vault_response_put(self, vault_name: str, headers: Any) -> TYPE_RESPONSE: self.glacier_backend.create_vault(vault_name) return 201, headers, "" - def _vault_response_delete(self, vault_name, headers): + def _vault_response_delete(self, vault_name: str, headers: Any) -> TYPE_RESPONSE: self.glacier_backend.delete_vault(vault_name) return 204, headers, "" - def vault_archive_response(self, request, full_url, headers): + def vault_archive_response( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: return self._vault_archive_response(request, full_url, headers) - def _vault_archive_response(self, request, full_url, headers): + def _vault_archive_response( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: method = request.method if hasattr(request, "body"): body = request.body @@ -75,17 +84,21 @@ class GlacierResponse(BaseResponse): else: return 400, headers, "400 Bad Request" - def _vault_archive_response_post(self, vault_name, body, description, headers): + def _vault_archive_response_post( + self, vault_name: str, body: bytes, description: str, headers: Dict[str, Any] + ) -> TYPE_RESPONSE: vault = self.glacier_backend.upload_archive(vault_name, body, description) headers["x-amz-archive-id"] = vault["archive_id"] headers["x-amz-sha256-tree-hash"] = vault["sha256"] return 201, headers, "" - def vault_archive_individual_response(self, request, full_url, headers): + def vault_archive_individual_response( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) return self._vault_archive_individual_response(request, full_url, headers) - def _vault_archive_individual_response(self, request, full_url, headers): + def _vault_archive_individual_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] method = request.method vault_name = full_url.split("/")[-3] archive_id = full_url.split("/")[-1] @@ -95,11 +108,13 @@ class GlacierResponse(BaseResponse): vault.delete_archive(archive_id) return 204, headers, "" - def vault_jobs_response(self, request, full_url, headers): + def vault_jobs_response( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) return self._vault_jobs_response(request, full_url, headers) - def _vault_jobs_response(self, request, full_url, headers): + def _vault_jobs_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] method = request.method if hasattr(request, "body"): body = request.body @@ -135,22 +150,28 @@ class GlacierResponse(BaseResponse): headers["Location"] = f"/{account_id}/vaults/{vault_name}/jobs/{job_id}" return 202, headers, "" - def vault_jobs_individual_response(self, request, full_url, headers): + def vault_jobs_individual_response( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) return self._vault_jobs_individual_response(full_url, headers) - def _vault_jobs_individual_response(self, full_url, headers): + def _vault_jobs_individual_response( + self, full_url: str, headers: Any + ) -> TYPE_RESPONSE: vault_name = full_url.split("/")[-3] archive_id = full_url.split("/")[-1] job = self.glacier_backend.describe_job(vault_name, archive_id) - return 200, headers, json.dumps(job.to_dict()) + return 200, headers, json.dumps(job.to_dict()) # type: ignore - def vault_jobs_output_response(self, request, full_url, headers): + def vault_jobs_output_response( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) return self._vault_jobs_output_response(full_url, headers) - def _vault_jobs_output_response(self, full_url, headers): + def _vault_jobs_output_response(self, full_url: str, headers: Any) -> TYPE_RESPONSE: vault_name = full_url.split("/")[-4] job_id = full_url.split("/")[-2] output = self.glacier_backend.get_job_output(vault_name, job_id) diff --git a/moto/glacier/utils.py b/moto/glacier/utils.py index 813945b5f..067c1f844 100644 --- a/moto/glacier/utils.py +++ b/moto/glacier/utils.py @@ -2,11 +2,11 @@ from moto.moto_api._internal import mock_random as random import string -def vault_from_glacier_url(full_url): +def vault_from_glacier_url(full_url: str) -> str: return full_url.split("/")[-1] -def get_job_id(): +def get_job_id() -> str: return "".join( random.choice(string.ascii_uppercase + string.digits) for _ in range(92) ) diff --git a/moto/glue/exceptions.py b/moto/glue/exceptions.py index 2e9c9ca0d..2ed7262bf 100644 --- a/moto/glue/exceptions.py +++ b/moto/glue/exceptions.py @@ -1,3 +1,4 @@ +from typing import Optional from moto.core.exceptions import JsonRESTError @@ -6,72 +7,78 @@ class GlueClientError(JsonRESTError): class AlreadyExistsException(GlueClientError): - def __init__(self, typ): + def __init__(self, typ: str): super().__init__("AlreadyExistsException", f"{typ} already exists.") class DatabaseAlreadyExistsException(AlreadyExistsException): - def __init__(self): + def __init__(self) -> None: super().__init__("Database") class TableAlreadyExistsException(AlreadyExistsException): - def __init__(self): + def __init__(self) -> None: super().__init__("Table") class PartitionAlreadyExistsException(AlreadyExistsException): - def __init__(self): + def __init__(self) -> None: super().__init__("Partition") class CrawlerAlreadyExistsException(AlreadyExistsException): - def __init__(self): + def __init__(self) -> None: super().__init__("Crawler") class EntityNotFoundException(GlueClientError): - def __init__(self, msg): + def __init__(self, msg: str): super().__init__("EntityNotFoundException", msg) class DatabaseNotFoundException(EntityNotFoundException): - def __init__(self, db): + def __init__(self, db: str): super().__init__(f"Database {db} not found.") class TableNotFoundException(EntityNotFoundException): - def __init__(self, tbl): + def __init__(self, tbl: str): super().__init__(f"Table {tbl} not found.") class PartitionNotFoundException(EntityNotFoundException): - def __init__(self): + def __init__(self) -> None: super().__init__("Cannot find partition.") class CrawlerNotFoundException(EntityNotFoundException): - def __init__(self, crawler): + def __init__(self, crawler: str): super().__init__(f"Crawler {crawler} not found.") class JobNotFoundException(EntityNotFoundException): - def __init__(self, job): + def __init__(self, job: str): super().__init__(f"Job {job} not found.") class JobRunNotFoundException(EntityNotFoundException): - def __init__(self, job_run): + def __init__(self, job_run: str): super().__init__(f"Job run {job_run} not found.") class VersionNotFoundException(EntityNotFoundException): - def __init__(self): + def __init__(self) -> None: super().__init__("Version not found.") class SchemaNotFoundException(EntityNotFoundException): - def __init__(self, schema_name, registry_name, schema_arn, null="null"): + def __init__( + self, + schema_name: str, + registry_name: str, + schema_arn: Optional[str], + null: str = "null", + ): super().__init__( f"Schema is not found. RegistryName: {registry_name if registry_name else null}, SchemaName: {schema_name if schema_name else null}, SchemaArn: {schema_arn if schema_arn else null}", ) @@ -80,13 +87,13 @@ class SchemaNotFoundException(EntityNotFoundException): class SchemaVersionNotFoundFromSchemaIdException(EntityNotFoundException): def __init__( self, - registry_name, - schema_name, - schema_arn, - version_number, - latest_version, - null="null", - false="false", + registry_name: Optional[str], + schema_name: Optional[str], + schema_arn: Optional[str], + version_number: Optional[str], + latest_version: Optional[str], + null: str = "null", + false: str = "false", ): super().__init__( f"Schema version is not found. RegistryName: {registry_name if registry_name else null}, SchemaName: {schema_name if schema_name else null}, SchemaArn: {schema_arn if schema_arn else null}, VersionNumber: {version_number if version_number else null}, isLatestVersion: {latest_version if latest_version else false}", @@ -94,36 +101,36 @@ class SchemaVersionNotFoundFromSchemaIdException(EntityNotFoundException): class SchemaVersionNotFoundFromSchemaVersionIdException(EntityNotFoundException): - def __init__(self, schema_version_id): + def __init__(self, schema_version_id: str): super().__init__( f"Schema version is not found. SchemaVersionId: {schema_version_id}", ) class RegistryNotFoundException(EntityNotFoundException): - def __init__(self, resource, param_name, param_value): + def __init__(self, resource: str, param_name: str, param_value: Optional[str]): super().__init__( - resource + " is not found. " + param_name + ": " + param_value, + resource + " is not found. " + param_name + ": " + param_value, # type: ignore ) class CrawlerRunningException(GlueClientError): - def __init__(self, msg): + def __init__(self, msg: str): super().__init__("CrawlerRunningException", msg) class CrawlerNotRunningException(GlueClientError): - def __init__(self, msg): + def __init__(self, msg: str): super().__init__("CrawlerNotRunningException", msg) class ConcurrentRunsExceededException(GlueClientError): - def __init__(self, msg): + def __init__(self, msg: str): super().__init__("ConcurrentRunsExceededException", msg) class ResourceNumberLimitExceededException(GlueClientError): - def __init__(self, msg): + def __init__(self, msg: str): super().__init__( "ResourceNumberLimitExceededException", msg, @@ -131,7 +138,7 @@ class ResourceNumberLimitExceededException(GlueClientError): class GeneralResourceNumberLimitExceededException(ResourceNumberLimitExceededException): - def __init__(self, resource): + def __init__(self, resource: str): super().__init__( "More " + resource @@ -140,14 +147,14 @@ class GeneralResourceNumberLimitExceededException(ResourceNumberLimitExceededExc class SchemaVersionMetadataLimitExceededException(ResourceNumberLimitExceededException): - def __init__(self): + def __init__(self) -> None: super().__init__( "Your resource limits for Schema Version Metadata have been exceeded.", ) class GSRAlreadyExistsException(GlueClientError): - def __init__(self, msg): + def __init__(self, msg: str): super().__init__( "AlreadyExistsException", msg, @@ -155,21 +162,21 @@ class GSRAlreadyExistsException(GlueClientError): class SchemaVersionMetadataAlreadyExistsException(GSRAlreadyExistsException): - def __init__(self, schema_version_id, metadata_key, metadata_value): + def __init__(self, schema_version_id: str, metadata_key: str, metadata_value: str): super().__init__( f"Resource already exist for schema version id: {schema_version_id}, metadata key: {metadata_key}, metadata value: {metadata_value}", ) class GeneralGSRAlreadyExistsException(GSRAlreadyExistsException): - def __init__(self, resource, param_name, param_value): + def __init__(self, resource: str, param_name: str, param_value: str): super().__init__( resource + " already exists. " + param_name + ": " + param_value, ) class _InvalidOperationException(GlueClientError): - def __init__(self, error_type, op, msg): + def __init__(self, error_type: str, op: str, msg: str): super().__init__( error_type, "An error occurred (%s) when calling the %s operation: %s" @@ -178,22 +185,22 @@ class _InvalidOperationException(GlueClientError): class InvalidStateException(_InvalidOperationException): - def __init__(self, op, msg): + def __init__(self, op: str, msg: str): super().__init__("InvalidStateException", op, msg) class InvalidInputException(_InvalidOperationException): - def __init__(self, op, msg): + def __init__(self, op: str, msg: str): super().__init__("InvalidInputException", op, msg) class GSRInvalidInputException(GlueClientError): - def __init__(self, msg): + def __init__(self, msg: str): super().__init__("InvalidInputException", msg) class ResourceNameTooLongException(GSRInvalidInputException): - def __init__(self, param_name): + def __init__(self, param_name: str): super().__init__( "The resource name contains too many or too few characters. Parameter Name: " + param_name, @@ -201,7 +208,7 @@ class ResourceNameTooLongException(GSRInvalidInputException): class ParamValueContainsInvalidCharactersException(GSRInvalidInputException): - def __init__(self, param_name): + def __init__(self, param_name: str): super().__init__( "The parameter value contains one or more characters that are not valid. Parameter Name: " + param_name, @@ -209,28 +216,28 @@ class ParamValueContainsInvalidCharactersException(GSRInvalidInputException): class InvalidNumberOfTagsException(GSRInvalidInputException): - def __init__(self): + def __init__(self) -> None: super().__init__( "New Tags cannot be empty or more than 50", ) class InvalidDataFormatException(GSRInvalidInputException): - def __init__(self): + def __init__(self) -> None: super().__init__( "Data format is not valid.", ) class InvalidCompatibilityException(GSRInvalidInputException): - def __init__(self): + def __init__(self) -> None: super().__init__( "Compatibility is not valid.", ) class InvalidSchemaDefinitionException(GSRInvalidInputException): - def __init__(self, data_format_name, err): + def __init__(self, data_format_name: str, err: ValueError): super().__init__( "Schema definition of " + data_format_name @@ -240,45 +247,51 @@ class InvalidSchemaDefinitionException(GSRInvalidInputException): class InvalidRegistryIdBothParamsProvidedException(GSRInvalidInputException): - def __init__(self): + def __init__(self) -> None: super().__init__( "One of registryName or registryArn has to be provided, both cannot be provided.", ) class InvalidSchemaIdBothParamsProvidedException(GSRInvalidInputException): - def __init__(self): + def __init__(self) -> None: super().__init__( "One of (registryName and schemaName) or schemaArn has to be provided, both cannot be provided.", ) class InvalidSchemaIdNotProvidedException(GSRInvalidInputException): - def __init__(self): + def __init__(self) -> None: super().__init__( "At least one of (registryName and schemaName) or schemaArn has to be provided.", ) class InvalidSchemaVersionNumberBothParamsProvidedException(GSRInvalidInputException): - def __init__(self): + def __init__(self) -> None: super().__init__("Only one of VersionNumber or LatestVersion is required.") class InvalidSchemaVersionNumberNotProvidedException(GSRInvalidInputException): - def __init__(self): + def __init__(self) -> None: super().__init__("One of version number (or) latest version is required.") class InvalidSchemaVersionIdProvidedWithOtherParamsException(GSRInvalidInputException): - def __init__(self): + def __init__(self) -> None: super().__init__( "No other input parameters can be specified when fetching by SchemaVersionId." ) class DisabledCompatibilityVersioningException(GSRInvalidInputException): - def __init__(self, schema_name, registry_name, schema_arn, null="null"): + def __init__( + self, + schema_name: str, + registry_name: str, + schema_arn: Optional[str], + null: str = "null", + ): super().__init__( f"Compatibility DISABLED does not allow versioning. SchemaId: SchemaId(schemaArn={schema_arn if schema_arn else null}, schemaName={schema_name if schema_name else null}, registryName={registry_name if registry_name else null})" ) diff --git a/moto/glue/glue_schema_registry_utils.py b/moto/glue/glue_schema_registry_utils.py index 1ce5bf3c2..3a9bddd2b 100644 --- a/moto/glue/glue_schema_registry_utils.py +++ b/moto/glue/glue_schema_registry_utils.py @@ -1,5 +1,6 @@ import re import json +from typing import Any, Dict, Optional, Tuple, Pattern from .glue_schema_registry_constants import ( MAX_REGISTRY_NAME_LENGTH, @@ -53,7 +54,7 @@ from .exceptions import ( ) -def validate_registry_name_pattern_and_length(param_value): +def validate_registry_name_pattern_and_length(param_value: str) -> None: validate_param_pattern_and_length( param_value, param_name="registryName", @@ -62,7 +63,7 @@ def validate_registry_name_pattern_and_length(param_value): ) -def validate_arn_pattern_and_length(param_value): +def validate_arn_pattern_and_length(param_value: str) -> None: validate_param_pattern_and_length( param_value, param_name="registryArn", @@ -71,7 +72,7 @@ def validate_arn_pattern_and_length(param_value): ) -def validate_description_pattern_and_length(param_value): +def validate_description_pattern_and_length(param_value: str) -> None: validate_param_pattern_and_length( param_value, param_name="description", @@ -80,7 +81,7 @@ def validate_description_pattern_and_length(param_value): ) -def validate_schema_name_pattern_and_length(param_value): +def validate_schema_name_pattern_and_length(param_value: str) -> None: validate_param_pattern_and_length( param_value, param_name="schemaName", @@ -89,7 +90,7 @@ def validate_schema_name_pattern_and_length(param_value): ) -def validate_schema_version_metadata_key_pattern_and_length(param_value): +def validate_schema_version_metadata_key_pattern_and_length(param_value: str) -> None: validate_param_pattern_and_length( param_value, param_name="key", @@ -98,7 +99,7 @@ def validate_schema_version_metadata_key_pattern_and_length(param_value): ) -def validate_schema_version_metadata_value_pattern_and_length(param_value): +def validate_schema_version_metadata_value_pattern_and_length(param_value: str) -> None: validate_param_pattern_and_length( param_value, param_name="value", @@ -108,8 +109,8 @@ def validate_schema_version_metadata_value_pattern_and_length(param_value): def validate_param_pattern_and_length( - param_value, param_name, max_name_length, pattern -): + param_value: str, param_name: str, max_name_length: int, pattern: Pattern[str] +) -> None: if len(param_value.encode("utf-8")) > max_name_length: raise ResourceNameTooLongException(param_name) @@ -117,7 +118,7 @@ def validate_param_pattern_and_length( raise ParamValueContainsInvalidCharactersException(param_name) -def validate_schema_definition(schema_definition, data_format): +def validate_schema_definition(schema_definition: str, data_format: str) -> None: validate_schema_definition_length(schema_definition) if data_format in ["AVRO", "JSON"]: try: @@ -126,38 +127,39 @@ def validate_schema_definition(schema_definition, data_format): raise InvalidSchemaDefinitionException(data_format, err) -def validate_schema_definition_length(schema_definition): +def validate_schema_definition_length(schema_definition: str) -> None: if len(schema_definition) > MAX_SCHEMA_DEFINITION_LENGTH: param_name = SCHEMA_DEFINITION raise ResourceNameTooLongException(param_name) -def validate_schema_version_id_pattern(schema_version_id): +def validate_schema_version_id_pattern(schema_version_id: str) -> None: if re.match(SCHEMA_VERSION_ID_PATTERN, schema_version_id) is None: raise ParamValueContainsInvalidCharactersException(SCHEMA_VERSION_ID) -def validate_number_of_tags(tags): +def validate_number_of_tags(tags: Dict[str, str]) -> None: if len(tags) > MAX_TAGS_ALLOWED: raise InvalidNumberOfTagsException() -def validate_registry_id(registry_id, registries): +def validate_registry_id( + registry_id: Dict[str, Any], registries: Dict[str, Any] +) -> str: if not registry_id: - registry_name = DEFAULT_REGISTRY_NAME - return registry_name + return DEFAULT_REGISTRY_NAME if registry_id.get(REGISTRY_NAME) and registry_id.get(REGISTRY_ARN): raise InvalidRegistryIdBothParamsProvidedException() if registry_id.get(REGISTRY_NAME): registry_name = registry_id.get(REGISTRY_NAME) - validate_registry_name_pattern_and_length(registry_name) + validate_registry_name_pattern_and_length(registry_name) # type: ignore elif registry_id.get(REGISTRY_ARN): registry_arn = registry_id.get(REGISTRY_ARN) - validate_arn_pattern_and_length(registry_arn) - registry_name = registry_arn.split("/")[-1] + validate_arn_pattern_and_length(registry_arn) # type: ignore + registry_name = registry_arn.split("/")[-1] # type: ignore if registry_name != DEFAULT_REGISTRY_NAME and registry_name not in registries: if registry_id.get(REGISTRY_NAME): @@ -174,10 +176,15 @@ def validate_registry_id(registry_id, registries): param_value=registry_arn, ) - return registry_name + return registry_name # type: ignore -def validate_registry_params(registries, registry_name, description=None, tags=None): +def validate_registry_params( + registries: Any, + registry_name: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, +) -> None: validate_registry_name_pattern_and_length(registry_name) if description: @@ -197,7 +204,9 @@ def validate_registry_params(registries, registry_name, description=None, tags=N ) -def validate_schema_id(schema_id, registries): +def validate_schema_id( + schema_id: Dict[str, str], registries: Dict[str, Any] +) -> Tuple[str, str, Optional[str]]: schema_arn = schema_id.get(SCHEMA_ARN) registry_name = schema_id.get(REGISTRY_NAME) schema_name = schema_id.get(SCHEMA_NAME) @@ -225,15 +234,15 @@ def validate_schema_id(schema_id, registries): def validate_schema_params( - registry, - schema_name, - data_format, - compatibility, - schema_definition, - num_schemas, - description=None, - tags=None, -): + registry: Any, + schema_name: str, + data_format: str, + compatibility: str, + schema_definition: str, + num_schemas: int, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, +) -> None: validate_schema_name_pattern_and_length(schema_name) if data_format not in ["AVRO", "JSON", "PROTOBUF"]: @@ -271,14 +280,14 @@ def validate_schema_params( def validate_register_schema_version_params( - registry_name, - schema_name, - schema_arn, - num_schema_versions, - schema_definition, - compatibility, - data_format, -): + registry_name: str, + schema_name: str, + schema_arn: Optional[str], + num_schema_versions: int, + schema_definition: str, + compatibility: str, + data_format: str, +) -> None: if compatibility == "DISABLED": raise DisabledCompatibilityVersioningException( schema_name, registry_name, schema_arn @@ -290,9 +299,19 @@ def validate_register_schema_version_params( raise GeneralResourceNumberLimitExceededException(resource="schema versions") -def validate_schema_version_params( - registries, schema_id, schema_version_id, schema_version_number -): +def validate_schema_version_params( # type: ignore[return] + registries: Dict[str, Any], + schema_id: Optional[Dict[str, Any]], + schema_version_id: Optional[str], + schema_version_number: Optional[Dict[str, Any]], +) -> Tuple[ + Optional[str], + Optional[str], + Optional[str], + Optional[str], + Optional[str], + Optional[str], +]: if not schema_version_id and not schema_id and not schema_version_number: raise InvalidSchemaIdNotProvidedException() @@ -329,8 +348,11 @@ def validate_schema_version_params( def validate_schema_version_number( - registries, registry_name, schema_name, schema_version_number -): + registries: Dict[str, Any], + registry_name: str, + schema_name: str, + schema_version_number: Dict[str, str], +) -> Tuple[str, str]: latest_version = schema_version_number.get(LATEST_VERSION) version_number = schema_version_number.get(VERSION_NUMBER) schema = registries[registry_name].schemas[schema_name] @@ -339,20 +361,24 @@ def validate_schema_version_number( raise InvalidSchemaVersionNumberBothParamsProvidedException() return schema.latest_schema_version, latest_version - return version_number, latest_version + return version_number, latest_version # type: ignore -def validate_schema_version_metadata_pattern_and_length(metadata_key_value): +def validate_schema_version_metadata_pattern_and_length( + metadata_key_value: Dict[str, str] +) -> Tuple[str, str]: metadata_key = metadata_key_value.get(METADATA_KEY) metadata_value = metadata_key_value.get(METADATA_VALUE) - validate_schema_version_metadata_key_pattern_and_length(metadata_key) - validate_schema_version_metadata_value_pattern_and_length(metadata_value) + validate_schema_version_metadata_key_pattern_and_length(metadata_key) # type: ignore + validate_schema_version_metadata_value_pattern_and_length(metadata_value) # type: ignore - return metadata_key, metadata_value + return metadata_key, metadata_value # type: ignore[return-value] -def validate_number_of_schema_version_metadata_allowed(metadata): +def validate_number_of_schema_version_metadata_allowed( + metadata: Dict[str, Any] +) -> None: num_metadata_key_value_pairs = 0 for m in metadata.values(): num_metadata_key_value_pairs += len(m) @@ -362,8 +388,8 @@ def validate_number_of_schema_version_metadata_allowed(metadata): def get_schema_version_if_definition_exists( - schema_versions, data_format, schema_definition -): + schema_versions: Any, data_format: str, schema_definition: str +) -> Optional[Dict[str, Any]]: if data_format in ["AVRO", "JSON"]: for schema_version in schema_versions: if json.loads(schema_definition) == json.loads( @@ -378,9 +404,12 @@ def get_schema_version_if_definition_exists( def get_put_schema_version_metadata_response( - schema_id, schema_version_number, schema_version_id, metadata_key_value -): - put_schema_version_metadata_response_dict = {} + schema_id: Dict[str, Any], + schema_version_number: Optional[Dict[str, str]], + schema_version_id: str, + metadata_key_value: Dict[str, str], +) -> Dict[str, Any]: + put_schema_version_metadata_response_dict: Dict[str, Any] = {} if schema_version_id: put_schema_version_metadata_response_dict[SCHEMA_VERSION_ID] = schema_version_id if schema_id: @@ -416,7 +445,9 @@ def get_put_schema_version_metadata_response( return put_schema_version_metadata_response_dict -def delete_schema_response(schema_name, schema_arn, status): +def delete_schema_response( + schema_name: str, schema_arn: str, status: str +) -> Dict[str, Any]: return { "SchemaName": schema_name, "SchemaArn": schema_arn, diff --git a/moto/glue/models.py b/moto/glue/models.py index 6abbf4aa9..fb4508c68 100644 --- a/moto/glue/models.py +++ b/moto/glue/models.py @@ -3,7 +3,7 @@ import time from collections import OrderedDict from datetime import datetime import re -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from moto.core import BaseBackend, BackendDict, BaseModel from moto.moto_api import state_manager @@ -77,12 +77,12 @@ class GlueBackend(BaseBackend): }, } - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.databases = OrderedDict() - self.crawlers = OrderedDict() - self.jobs = OrderedDict() - self.job_runs = OrderedDict() + self.databases: Dict[str, FakeDatabase] = OrderedDict() + self.crawlers: Dict[str, FakeCrawler] = OrderedDict() + self.jobs: Dict[str, FakeJob] = OrderedDict() + self.job_runs: Dict[str, FakeJobRun] = OrderedDict() self.tagger = TaggingService() self.registries: Dict[str, FakeRegistry] = OrderedDict() self.num_schemas = 0 @@ -93,13 +93,17 @@ class GlueBackend(BaseBackend): ) @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "glue" ) - def create_database(self, database_name, database_input): + def create_database( + self, database_name: str, database_input: Dict[str, Any] + ) -> "FakeDatabase": if database_name in self.databases: raise DatabaseAlreadyExistsException() @@ -107,27 +111,31 @@ class GlueBackend(BaseBackend): self.databases[database_name] = database return database - def get_database(self, database_name): + def get_database(self, database_name: str) -> "FakeDatabase": try: return self.databases[database_name] except KeyError: raise DatabaseNotFoundException(database_name) - def update_database(self, database_name, database_input): + def update_database( + self, database_name: str, database_input: Dict[str, Any] + ) -> None: if database_name not in self.databases: raise DatabaseNotFoundException(database_name) self.databases[database_name].input = database_input - def get_databases(self): + def get_databases(self) -> List["FakeDatabase"]: return [self.databases[key] for key in self.databases] if self.databases else [] - def delete_database(self, database_name): + def delete_database(self, database_name: str) -> None: if database_name not in self.databases: raise DatabaseNotFoundException(database_name) del self.databases[database_name] - def create_table(self, database_name, table_name, table_input): + def create_table( + self, database_name: str, table_name: str, table_input: Dict[str, Any] + ) -> "FakeTable": database = self.get_database(database_name) if table_name in database.tables: @@ -144,7 +152,9 @@ class GlueBackend(BaseBackend): except KeyError: raise TableNotFoundException(table_name) - def get_tables(self, database_name, expression): + def get_tables( + self, database_name: str, expression: Optional[str] + ) -> List["FakeTable"]: database = self.get_database(database_name) if expression: # sanitise expression, * is treated as a glob-like wildcard @@ -164,15 +174,16 @@ class GlueBackend(BaseBackend): else: return [table for table_name, table in database.tables.items()] - def delete_table(self, database_name, table_name): + def delete_table(self, database_name: str, table_name: str) -> None: database = self.get_database(database_name) try: del database.tables[table_name] except KeyError: raise TableNotFoundException(table_name) - return {} - def update_table(self, database_name, table_name: str, table_input) -> None: + def update_table( + self, database_name: str, table_name: str, table_input: Dict[str, Any] + ) -> None: table = self.get_table(database_name, table_name) table.update(table_input) @@ -202,15 +213,21 @@ class GlueBackend(BaseBackend): table = self.get_table(database_name, table_name) table.delete_version(version_id) - def create_partition(self, database_name: str, table_name: str, part_input) -> None: + def create_partition( + self, database_name: str, table_name: str, part_input: Dict[str, Any] + ) -> None: table = self.get_table(database_name, table_name) table.create_partition(part_input) - def get_partition(self, database_name: str, table_name: str, values): + def get_partition( + self, database_name: str, table_name: str, values: str + ) -> "FakePartition": table = self.get_table(database_name, table_name) return table.get_partition(values) - def get_partitions(self, database_name, table_name, expression): + def get_partitions( + self, database_name: str, table_name: str, expression: str + ) -> List["FakePartition"]: """ See https://docs.aws.amazon.com/glue/latest/webapi/API_GetPartitions.html for supported expressions. @@ -226,34 +243,38 @@ class GlueBackend(BaseBackend): return table.get_partitions(expression) def update_partition( - self, database_name, table_name, part_input, part_to_update + self, + database_name: str, + table_name: str, + part_input: Dict[str, Any], + part_to_update: str, ) -> None: table = self.get_table(database_name, table_name) table.update_partition(part_to_update, part_input) def delete_partition( - self, database_name: str, table_name: str, part_to_delete + self, database_name: str, table_name: str, part_to_delete: int ) -> None: table = self.get_table(database_name, table_name) table.delete_partition(part_to_delete) def create_crawler( self, - name, - role, - database_name, - description, - targets, - schedule, - classifiers, - table_prefix, - schema_change_policy, - recrawl_policy, - lineage_configuration, - configuration, - crawler_security_configuration, - tags, - ): + name: str, + role: str, + database_name: str, + description: str, + targets: Dict[str, Any], + schedule: str, + classifiers: List[str], + table_prefix: str, + schema_change_policy: Dict[str, str], + recrawl_policy: Dict[str, str], + lineage_configuration: Dict[str, str], + configuration: str, + crawler_security_configuration: str, + tags: Dict[str, str], + ) -> None: if name in self.crawlers: raise CrawlerAlreadyExistsException() @@ -276,28 +297,28 @@ class GlueBackend(BaseBackend): ) self.crawlers[name] = crawler - def get_crawler(self, name): + def get_crawler(self, name: str) -> "FakeCrawler": try: return self.crawlers[name] except KeyError: raise CrawlerNotFoundException(name) - def get_crawlers(self): + def get_crawlers(self) -> List["FakeCrawler"]: return [self.crawlers[key] for key in self.crawlers] if self.crawlers else [] @paginate(pagination_model=PAGINATION_MODEL) - def list_crawlers(self): + def list_crawlers(self) -> List["FakeCrawler"]: # type: ignore[misc] return [crawler for _, crawler in self.crawlers.items()] - def start_crawler(self, name): + def start_crawler(self, name: str) -> None: crawler = self.get_crawler(name) crawler.start_crawler() - def stop_crawler(self, name): + def stop_crawler(self, name: str) -> None: crawler = self.get_crawler(name) crawler.stop_crawler() - def delete_crawler(self, name): + def delete_crawler(self, name: str) -> None: try: del self.crawlers[name] except KeyError: @@ -305,29 +326,29 @@ class GlueBackend(BaseBackend): def create_job( self, - name, - role, - command, - description, - log_uri, - execution_property, - default_arguments, - non_overridable_arguments, - connections, - max_retries, - allocated_capacity, - timeout, - max_capacity, - security_configuration, - tags, - notification_property, - glue_version, - number_of_workers, - worker_type, - code_gen_configuration_nodes, - execution_class, - source_control_details, - ): + name: str, + role: str, + command: str, + description: str, + log_uri: str, + execution_property: Dict[str, int], + default_arguments: Dict[str, str], + non_overridable_arguments: Dict[str, str], + connections: Dict[str, List[str]], + max_retries: int, + allocated_capacity: int, + timeout: int, + max_capacity: float, + security_configuration: str, + tags: Dict[str, str], + notification_property: Dict[str, int], + glue_version: str, + number_of_workers: int, + worker_type: str, + code_gen_configuration_nodes: Dict[str, Any], + execution_class: str, + source_control_details: Dict[str, str], + ) -> None: self.jobs[name] = FakeJob( name, role, @@ -353,46 +374,50 @@ class GlueBackend(BaseBackend): source_control_details, backend=self, ) - return name - def get_job(self, name): + def get_job(self, name: str) -> "FakeJob": try: return self.jobs[name] except KeyError: raise JobNotFoundException(name) @paginate(pagination_model=PAGINATION_MODEL) - def get_jobs(self): + def get_jobs(self) -> List["FakeJob"]: # type: ignore return [job for _, job in self.jobs.items()] - def start_job_run(self, name): + def start_job_run(self, name: str) -> str: job = self.get_job(name) return job.start_job_run() - def get_job_run(self, name, run_id): + def get_job_run(self, name: str, run_id: str) -> "FakeJobRun": job = self.get_job(name) return job.get_job_run(run_id) @paginate(pagination_model=PAGINATION_MODEL) - def list_jobs(self): + def list_jobs(self) -> List["FakeJob"]: # type: ignore return [job for _, job in self.jobs.items()] - def get_tags(self, resource_id): + def get_tags(self, resource_id: str) -> Dict[str, str]: return self.tagger.get_tag_dict_for_resource(resource_id) - def tag_resource(self, resource_arn, tags): - tags = TaggingService.convert_dict_to_tags_input(tags or {}) - self.tagger.tag_resource(resource_arn, tags) + def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None: + tag_list = TaggingService.convert_dict_to_tags_input(tags or {}) + self.tagger.tag_resource(resource_arn, tag_list) - def untag_resource(self, resource_arn, tag_keys): + def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: self.tagger.untag_resource_using_names(resource_arn, tag_keys) - def create_registry(self, registry_name, description=None, tags=None): + def create_registry( + self, + registry_name: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: # If registry name id default-registry, create default-registry if registry_name == DEFAULT_REGISTRY_NAME: registry = FakeRegistry(self, registry_name, description, tags) self.registries[registry_name] = registry - return registry + return registry # type: ignore # Validate Registry Parameters validate_registry_params(self.registries, registry_name, description, tags) @@ -401,27 +426,27 @@ class GlueBackend(BaseBackend): self.registries[registry_name] = registry return registry.as_dict() - def delete_registry(self, registry_id): + def delete_registry(self, registry_id: Dict[str, Any]) -> Dict[str, Any]: registry_name = validate_registry_id(registry_id, self.registries) return self.registries.pop(registry_name).as_dict() - def get_registry(self, registry_id): + def get_registry(self, registry_id: Dict[str, Any]) -> Dict[str, Any]: registry_name = validate_registry_id(registry_id, self.registries) return self.registries[registry_name].as_dict() - def list_registries(self): + def list_registries(self) -> List[Dict[str, Any]]: return [reg.as_dict() for reg in self.registries.values()] def create_schema( self, - registry_id, - schema_name, - data_format, - compatibility, - schema_definition, - description=None, - tags=None, - ): + registry_id: Dict[str, Any], + schema_name: str, + data_format: str, + compatibility: str, + schema_definition: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: """ The following parameters/features are not yet implemented: Glue Schema Registry: compatibility checks NONE | BACKWARD | BACKWARD_ALL | FORWARD | FORWARD_ALL | FULL | FULL_ALL and Data format parsing and syntax validation. """ @@ -479,7 +504,9 @@ class GlueBackend(BaseBackend): resp.update({"Tags": tags}) return resp - def register_schema_version(self, schema_id, schema_definition): + def register_schema_version( + self, schema_id: Dict[str, Any], schema_definition: str + ) -> Dict[str, Any]: # Validate Schema Id registry_name, schema_name, schema_arn = validate_schema_id( schema_id, self.registries @@ -538,8 +565,11 @@ class GlueBackend(BaseBackend): return schema_version.as_dict() def get_schema_version( - self, schema_id=None, schema_version_id=None, schema_version_number=None - ): + self, + schema_id: Optional[Dict[str, str]] = None, + schema_version_id: Optional[str] = None, + schema_version_number: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: # Validate Schema Parameters ( schema_version_id, @@ -571,10 +601,10 @@ class GlueBackend(BaseBackend): raise SchemaVersionNotFoundFromSchemaVersionIdException(schema_version_id) # GetSchemaVersion using VersionNumber - schema = self.registries[registry_name].schemas[schema_name] + schema = self.registries[registry_name].schemas[schema_name] # type: ignore for schema_version in schema.schema_versions.values(): if ( - version_number == schema_version.version_number + version_number == schema_version.version_number # type: ignore and schema_version.schema_version_status != DELETING_STATUS ): get_schema_version_dict = schema_version.get_schema_version_as_dict() @@ -584,7 +614,9 @@ class GlueBackend(BaseBackend): registry_name, schema_name, schema_arn, version_number, latest_version ) - def get_schema_by_definition(self, schema_id, schema_definition): + def get_schema_by_definition( + self, schema_id: Dict[str, str], schema_definition: str + ) -> Dict[str, Any]: # Validate SchemaId validate_schema_definition_length(schema_definition) registry_name, schema_name, schema_arn = validate_schema_id( @@ -606,8 +638,12 @@ class GlueBackend(BaseBackend): raise SchemaNotFoundException(schema_name, registry_name, schema_arn) def put_schema_version_metadata( - self, schema_id, schema_version_number, schema_version_id, metadata_key_value - ): + self, + schema_id: Dict[str, Any], + schema_version_number: Dict[str, str], + schema_version_id: str, + metadata_key_value: Dict[str, str], + ) -> Dict[str, Any]: # Validate metadata_key_value and schema version params ( metadata_key, @@ -620,7 +656,7 @@ class GlueBackend(BaseBackend): schema_arn, version_number, latest_version, - ) = validate_schema_version_params( + ) = validate_schema_version_params( # type: ignore self.registries, schema_id, schema_version_id, schema_version_number ) @@ -650,9 +686,9 @@ class GlueBackend(BaseBackend): raise SchemaVersionNotFoundFromSchemaVersionIdException(schema_version_id) # PutSchemaVersionMetadata using VersionNumber - schema = self.registries[registry_name].schemas[schema_name] + schema = self.registries[registry_name].schemas[schema_name] # type: ignore for schema_version in schema.schema_versions.values(): - if version_number == schema_version.version_number: + if version_number == schema_version.version_number: # type: ignore validate_number_of_schema_version_metadata_allowed( schema_version.metadata ) @@ -677,12 +713,12 @@ class GlueBackend(BaseBackend): registry_name, schema_name, schema_arn, version_number, latest_version ) - def get_schema(self, schema_id): + def get_schema(self, schema_id: Dict[str, str]) -> Dict[str, Any]: registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries) schema = self.registries[registry_name].schemas[schema_name] return schema.as_dict() - def delete_schema(self, schema_id): + def delete_schema(self, schema_id: Dict[str, str]) -> Dict[str, Any]: # Validate schema_id registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries) @@ -701,7 +737,9 @@ class GlueBackend(BaseBackend): return response - def update_schema(self, schema_id, compatibility, description): + def update_schema( + self, schema_id: Dict[str, str], compatibility: str, description: str + ) -> Dict[str, Any]: """ The SchemaVersionNumber-argument is not yet implemented """ @@ -715,7 +753,9 @@ class GlueBackend(BaseBackend): return schema.as_dict() - def batch_delete_table(self, database_name, tables): + def batch_delete_table( + self, database_name: str, tables: List[str] + ) -> List[Dict[str, Any]]: errors = [] for table_name in tables: try: @@ -732,7 +772,12 @@ class GlueBackend(BaseBackend): ) return errors - def batch_get_partition(self, database_name, table_name, partitions_to_get): + def batch_get_partition( + self, + database_name: str, + table_name: str, + partitions_to_get: List[Dict[str, str]], + ) -> List[Dict[str, Any]]: table = self.get_table(database_name, table_name) partitions = [] @@ -744,7 +789,9 @@ class GlueBackend(BaseBackend): continue return partitions - def batch_create_partition(self, database_name, table_name, partition_input): + def batch_create_partition( + self, database_name: str, table_name: str, partition_input: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: table = self.get_table(database_name, table_name) errors_output = [] @@ -763,7 +810,9 @@ class GlueBackend(BaseBackend): ) return errors_output - def batch_update_partition(self, database_name, table_name, entries): + def batch_update_partition( + self, database_name: str, table_name: str, entries: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: table = self.get_table(database_name, table_name) errors_output = [] @@ -785,14 +834,16 @@ class GlueBackend(BaseBackend): ) return errors_output - def batch_delete_partition(self, database_name, table_name, parts): + def batch_delete_partition( + self, database_name: str, table_name: str, parts: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: table = self.get_table(database_name, table_name) errors_output = [] for part_input in parts: values = part_input.get("Values") try: - table.delete_partition(values) + table.delete_partition(values) # type: ignore except PartitionNotFoundException: errors_output.append( { @@ -805,7 +856,7 @@ class GlueBackend(BaseBackend): ) return errors_output - def batch_get_crawlers(self, crawler_names): + def batch_get_crawlers(self, crawler_names: List[str]) -> List[Dict[str, Any]]: crawlers = [] for crawler in self.get_crawlers(): if crawler.as_dict()["Name"] in crawler_names: @@ -814,13 +865,13 @@ class GlueBackend(BaseBackend): class FakeDatabase(BaseModel): - def __init__(self, database_name, database_input): + def __init__(self, database_name: str, database_input: Dict[str, Any]): self.name = database_name self.input = database_input self.created_time = datetime.utcnow() - self.tables = OrderedDict() + self.tables: Dict[str, FakeTable] = OrderedDict() - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: return { "Name": self.name, "Description": self.input.get("Description"), @@ -836,23 +887,25 @@ class FakeDatabase(BaseModel): class FakeTable(BaseModel): - def __init__(self, database_name: str, table_name: str, table_input): + def __init__( + self, database_name: str, table_name: str, table_input: Dict[str, Any] + ): self.database_name = database_name self.name = table_name - self.partitions = OrderedDict() + self.partitions: Dict[str, FakePartition] = OrderedDict() self.created_time = datetime.utcnow() - self.updated_time = None + self.updated_time: Optional[datetime] = None self._current_version = 1 self.versions: Dict[str, Dict[str, Any]] = { str(self._current_version): table_input } - def update(self, table_input): + def update(self, table_input: Dict[str, Any]) -> None: self.versions[str(self._current_version + 1)] = table_input self._current_version += 1 self.updated_time = datetime.utcnow() - def get_version(self, ver): + def get_version(self, ver: str) -> Dict[str, Any]: try: int(ver) except ValueError as e: @@ -863,11 +916,11 @@ class FakeTable(BaseModel): except KeyError: raise VersionNotFoundException() - def delete_version(self, version_id): + def delete_version(self, version_id: str) -> None: self.versions.pop(version_id) - def as_dict(self, version=None): - version = version or self._current_version + def as_dict(self, version: Optional[str] = None) -> Dict[str, Any]: + version = version or self._current_version # type: ignore obj = { "DatabaseName": self.database_name, "Name": self.name, @@ -880,23 +933,23 @@ class FakeTable(BaseModel): obj["UpdateTime"] = unix_time(self.updated_time) return obj - def create_partition(self, partiton_input): + def create_partition(self, partiton_input: Dict[str, Any]) -> None: partition = FakePartition(self.database_name, self.name, partiton_input) key = str(partition.values) if key in self.partitions: raise PartitionAlreadyExistsException() self.partitions[str(partition.values)] = partition - def get_partitions(self, expression): + def get_partitions(self, expression: str) -> List["FakePartition"]: return list(filter(PartitionFilter(expression, self), self.partitions.values())) - def get_partition(self, values): + def get_partition(self, values: str) -> "FakePartition": try: return self.partitions[str(values)] except KeyError: raise PartitionNotFoundException() - def update_partition(self, old_values, partiton_input): + def update_partition(self, old_values: str, partiton_input: Dict[str, Any]) -> None: partition = FakePartition(self.database_name, self.name, partiton_input) key = str(partition.values) if old_values == partiton_input["Values"]: @@ -913,7 +966,7 @@ class FakeTable(BaseModel): raise PartitionAlreadyExistsException() self.partitions[key] = partition - def delete_partition(self, values): + def delete_partition(self, values: int) -> None: try: del self.partitions[str(values)] except KeyError: @@ -921,14 +974,16 @@ class FakeTable(BaseModel): class FakePartition(BaseModel): - def __init__(self, database_name, table_name, partiton_input): + def __init__( + self, database_name: str, table_name: str, partiton_input: Dict[str, Any] + ): self.creation_time = time.time() self.database_name = database_name self.table_name = table_name self.partition_input = partiton_input self.values = self.partition_input.get("Values", []) - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: obj = { "DatabaseName": self.database_name, "TableName": self.table_name, @@ -941,21 +996,21 @@ class FakePartition(BaseModel): class FakeCrawler(BaseModel): def __init__( self, - name, - role, - database_name, - description, - targets, - schedule, - classifiers, - table_prefix, - schema_change_policy, - recrawl_policy, - lineage_configuration, - configuration, - crawler_security_configuration, - tags, - backend, + name: str, + role: str, + database_name: str, + description: str, + targets: Dict[str, Any], + schedule: str, + classifiers: List[str], + table_prefix: str, + schema_change_policy: Dict[str, str], + recrawl_policy: Dict[str, str], + lineage_configuration: Dict[str, str], + configuration: str, + crawler_security_configuration: str, + tags: Dict[str, str], + backend: GlueBackend, ): self.name = name self.role = role @@ -980,11 +1035,11 @@ class FakeCrawler(BaseModel): self.backend = backend self.backend.tag_resource(self.arn, tags) - def get_name(self): + def get_name(self) -> str: return self.name - def as_dict(self): - last_crawl = self.last_crawl_info.as_dict() if self.last_crawl_info else None + def as_dict(self) -> Dict[str, Any]: + last_crawl = self.last_crawl_info.as_dict() if self.last_crawl_info else None # type: ignore data = { "Name": self.name, "Role": self.role, @@ -1017,14 +1072,14 @@ class FakeCrawler(BaseModel): return data - def start_crawler(self): + def start_crawler(self) -> None: if self.state == "RUNNING": raise CrawlerRunningException( f"Crawler with name {self.name} has already started" ) self.state = "RUNNING" - def stop_crawler(self): + def stop_crawler(self) -> None: if self.state != "RUNNING": raise CrawlerNotRunningException( f"Crawler with name {self.name} isn't running" @@ -1034,7 +1089,13 @@ class FakeCrawler(BaseModel): class LastCrawlInfo(BaseModel): def __init__( - self, error_message, log_group, log_stream, message_prefix, start_time, status + self, + error_message: str, + log_group: str, + log_stream: str, + message_prefix: str, + start_time: str, + status: str, ): self.error_message = error_message self.log_group = log_group @@ -1043,7 +1104,7 @@ class LastCrawlInfo(BaseModel): self.start_time = start_time self.status = status - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: return { "ErrorMessage": self.error_message, "LogGroup": self.log_group, @@ -1057,29 +1118,29 @@ class LastCrawlInfo(BaseModel): class FakeJob: def __init__( self, - name, - role, - command, - description=None, - log_uri=None, - execution_property=None, - default_arguments=None, - non_overridable_arguments=None, - connections=None, - max_retries=None, - allocated_capacity=None, - timeout=None, - max_capacity=None, - security_configuration=None, - tags=None, - notification_property=None, - glue_version=None, - number_of_workers=None, - worker_type=None, - code_gen_configuration_nodes=None, - execution_class=None, - source_control_details=None, - backend=None, + name: str, + role: str, + command: str, + description: str, + log_uri: str, + execution_property: Dict[str, int], + default_arguments: Dict[str, str], + non_overridable_arguments: Dict[str, str], + connections: Dict[str, List[str]], + max_retries: int, + allocated_capacity: int, + timeout: int, + max_capacity: float, + security_configuration: str, + tags: Dict[str, str], + notification_property: Dict[str, int], + glue_version: str, + number_of_workers: int, + worker_type: str, + code_gen_configuration_nodes: Dict[str, Any], + execution_class: str, + source_control_details: Dict[str, str], + backend: GlueBackend, ): self.name = name self.description = description @@ -1112,10 +1173,10 @@ class FakeJob: self.job_runs: List[FakeJobRun] = [] - def get_name(self): + def get_name(self) -> str: return self.name - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: return { "Name": self.name, "Description": self.description, @@ -1142,7 +1203,7 @@ class FakeJob: "SourceControlDetails": self.source_control_details, } - def start_job_run(self): + def start_job_run(self) -> str: running_jobs = len( [jr for jr in self.job_runs if jr.status in ["STARTING", "RUNNING"]] ) @@ -1154,7 +1215,7 @@ class FakeJob: self.job_runs.append(fake_job_run) return fake_job_run.job_run_id - def get_job_run(self, run_id): + def get_job_run(self, run_id: str) -> "FakeJobRun": for job_run in self.job_runs: if job_run.job_run_id == run_id: job_run.advance() @@ -1165,11 +1226,11 @@ class FakeJob: class FakeJobRun(ManagedState): def __init__( self, - job_name: int, + job_name: str, job_run_id: str = "01", - arguments: dict = None, - allocated_capacity: int = None, - timeout: int = None, + arguments: Optional[Dict[str, Any]] = None, + allocated_capacity: Optional[int] = None, + timeout: Optional[int] = None, worker_type: str = "Standard", ): ManagedState.__init__( @@ -1187,10 +1248,10 @@ class FakeJobRun(ManagedState): self.modified_on = datetime.utcnow() self.completed_on = datetime.utcnow() - def get_name(self): + def get_name(self) -> str: return self.job_name - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: return { "Id": self.job_run_id, "Attempt": 1, @@ -1220,7 +1281,13 @@ class FakeJobRun(ManagedState): class FakeRegistry(BaseModel): - def __init__(self, backend, registry_name, description=None, tags=None): + def __init__( + self, + backend: GlueBackend, + registry_name: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + ): self.name = registry_name self.description = description self.tags = tags @@ -1230,7 +1297,7 @@ class FakeRegistry(BaseModel): self.registry_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:registry/{self.name}" self.schemas: Dict[str, FakeSchema] = OrderedDict() - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: return { "RegistryArn": self.registry_arn, "RegistryName": self.name, @@ -1243,12 +1310,12 @@ class FakeSchema(BaseModel): def __init__( self, backend: GlueBackend, - registry_name, - schema_name, - data_format, - compatibility, - schema_version_id, - description=None, + registry_name: str, + schema_name: str, + data_format: str, + compatibility: str, + schema_version_id: str, + description: Optional[str] = None, ): self.registry_name = registry_name self.registry_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:registry/{self.registry_name}" @@ -1265,18 +1332,18 @@ class FakeSchema(BaseModel): self.schema_version_status = AVAILABLE_STATUS self.created_time = datetime.utcnow() self.updated_time = datetime.utcnow() - self.schema_versions = OrderedDict() + self.schema_versions: Dict[str, FakeSchemaVersion] = OrderedDict() - def update_next_schema_version(self): + def update_next_schema_version(self) -> None: self.next_schema_version += 1 - def update_latest_schema_version(self): + def update_latest_schema_version(self) -> None: self.latest_schema_version += 1 - def get_next_schema_version(self): + def get_next_schema_version(self) -> int: return self.next_schema_version - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: return { "RegistryArn": self.registry_arn, "RegistryName": self.registry_name, @@ -1298,10 +1365,10 @@ class FakeSchemaVersion(BaseModel): def __init__( self, backend: GlueBackend, - registry_name, - schema_name, - schema_definition, - version_number, + registry_name: str, + schema_name: str, + schema_definition: str, + version_number: int, ): self.registry_name = registry_name self.schema_name = schema_name @@ -1312,19 +1379,19 @@ class FakeSchemaVersion(BaseModel): self.schema_version_id = str(mock_random.uuid4()) self.created_time = datetime.utcnow() self.updated_time = datetime.utcnow() - self.metadata = OrderedDict() + self.metadata: Dict[str, Any] = {} - def get_schema_version_id(self): + def get_schema_version_id(self) -> str: return self.schema_version_id - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: return { "SchemaVersionId": self.schema_version_id, "VersionNumber": self.version_number, "Status": self.schema_version_status, } - def get_schema_version_as_dict(self): + def get_schema_version_as_dict(self) -> Dict[str, Any]: # add data_format for full return dictionary of get_schema_version return { "SchemaVersionId": self.schema_version_id, @@ -1335,7 +1402,7 @@ class FakeSchemaVersion(BaseModel): "CreatedTime": str(self.created_time), } - def get_schema_by_definition_as_dict(self): + def get_schema_by_definition_as_dict(self) -> Dict[str, Any]: # add data_format for full return dictionary of get_schema_by_definition return { "SchemaVersionId": self.schema_version_id, diff --git a/moto/glue/responses.py b/moto/glue/responses.py index 934b1ed76..5b7eee45b 100644 --- a/moto/glue/responses.py +++ b/moto/glue/responses.py @@ -1,11 +1,13 @@ import json +from typing import Any, Dict, List +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse -from .models import glue_backends, GlueBackend +from .models import glue_backends, GlueBackend, FakeJob, FakeCrawler class GlueResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="glue") @property @@ -13,66 +15,66 @@ class GlueResponse(BaseResponse): return glue_backends[self.current_account][self.region] @property - def parameters(self): + def parameters(self) -> Dict[str, Any]: # type: ignore[misc] return json.loads(self.body) - def create_database(self): + def create_database(self) -> str: database_input = self.parameters.get("DatabaseInput") - database_name = database_input.get("Name") + database_name = database_input.get("Name") # type: ignore if "CatalogId" in self.parameters: - database_input["CatalogId"] = self.parameters.get("CatalogId") - self.glue_backend.create_database(database_name, database_input) + database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore + self.glue_backend.create_database(database_name, database_input) # type: ignore[arg-type] return "" - def get_database(self): + def get_database(self) -> str: database_name = self.parameters.get("Name") - database = self.glue_backend.get_database(database_name) + database = self.glue_backend.get_database(database_name) # type: ignore[arg-type] return json.dumps({"Database": database.as_dict()}) - def get_databases(self): + def get_databases(self) -> str: database_list = self.glue_backend.get_databases() return json.dumps( {"DatabaseList": [database.as_dict() for database in database_list]} ) - def update_database(self): + def update_database(self) -> str: database_input = self.parameters.get("DatabaseInput") database_name = self.parameters.get("Name") if "CatalogId" in self.parameters: - database_input["CatalogId"] = self.parameters.get("CatalogId") - self.glue_backend.update_database(database_name, database_input) + database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore + self.glue_backend.update_database(database_name, database_input) # type: ignore[arg-type] return "" - def delete_database(self): + def delete_database(self) -> str: name = self.parameters.get("Name") - self.glue_backend.delete_database(name) + self.glue_backend.delete_database(name) # type: ignore[arg-type] return json.dumps({}) - def create_table(self): + def create_table(self) -> str: database_name = self.parameters.get("DatabaseName") table_input = self.parameters.get("TableInput") - table_name = table_input.get("Name") - self.glue_backend.create_table(database_name, table_name, table_input) + table_name = table_input.get("Name") # type: ignore + self.glue_backend.create_table(database_name, table_name, table_input) # type: ignore[arg-type] return "" - def get_table(self): + def get_table(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("Name") - table = self.glue_backend.get_table(database_name, table_name) + table = self.glue_backend.get_table(database_name, table_name) # type: ignore[arg-type] return json.dumps({"Table": table.as_dict()}) - def update_table(self): + def update_table(self) -> str: database_name = self.parameters.get("DatabaseName") table_input = self.parameters.get("TableInput") - table_name = table_input.get("Name") - self.glue_backend.update_table(database_name, table_name, table_input) + table_name = table_input.get("Name") # type: ignore + self.glue_backend.update_table(database_name, table_name, table_input) # type: ignore[arg-type] return "" - def get_table_versions(self): + def get_table_versions(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") - versions = self.glue_backend.get_table_versions(database_name, table_name) + versions = self.glue_backend.get_table_versions(database_name, table_name) # type: ignore[arg-type] return json.dumps( { "TableVersions": [ @@ -82,36 +84,36 @@ class GlueResponse(BaseResponse): } ) - def get_table_version(self): + def get_table_version(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") ver_id = self.parameters.get("VersionId") - return self.glue_backend.get_table_version(database_name, table_name, ver_id) + return self.glue_backend.get_table_version(database_name, table_name, ver_id) # type: ignore[arg-type] def delete_table_version(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") version_id = self.parameters.get("VersionId") - self.glue_backend.delete_table_version(database_name, table_name, version_id) + self.glue_backend.delete_table_version(database_name, table_name, version_id) # type: ignore[arg-type] return "{}" - def get_tables(self): + def get_tables(self) -> str: database_name = self.parameters.get("DatabaseName") expression = self.parameters.get("Expression") - tables = self.glue_backend.get_tables(database_name, expression) + tables = self.glue_backend.get_tables(database_name, expression) # type: ignore[arg-type] return json.dumps({"TableList": [table.as_dict() for table in tables]}) - def delete_table(self): + def delete_table(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("Name") - resp = self.glue_backend.delete_table(database_name, table_name) - return json.dumps(resp) + self.glue_backend.delete_table(database_name, table_name) # type: ignore[arg-type] + return "{}" - def batch_delete_table(self): + def batch_delete_table(self) -> str: database_name = self.parameters.get("DatabaseName") tables = self.parameters.get("TablesToDelete") - errors = self.glue_backend.batch_delete_table(database_name, tables) + errors = self.glue_backend.batch_delete_table(database_name, tables) # type: ignore[arg-type] out = {} if errors: @@ -119,50 +121,50 @@ class GlueResponse(BaseResponse): return json.dumps(out) - def get_partitions(self): + def get_partitions(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") expression = self.parameters.get("Expression") partitions = self.glue_backend.get_partitions( - database_name, table_name, expression + database_name, table_name, expression # type: ignore[arg-type] ) return json.dumps({"Partitions": [p.as_dict() for p in partitions]}) - def get_partition(self): + def get_partition(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") values = self.parameters.get("PartitionValues") - p = self.glue_backend.get_partition(database_name, table_name, values) + p = self.glue_backend.get_partition(database_name, table_name, values) # type: ignore[arg-type] return json.dumps({"Partition": p.as_dict()}) - def batch_get_partition(self): + def batch_get_partition(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") partitions_to_get = self.parameters.get("PartitionsToGet") partitions = self.glue_backend.batch_get_partition( - database_name, table_name, partitions_to_get + database_name, table_name, partitions_to_get # type: ignore[arg-type] ) return json.dumps({"Partitions": partitions}) - def create_partition(self): + def create_partition(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") part_input = self.parameters.get("PartitionInput") - self.glue_backend.create_partition(database_name, table_name, part_input) + self.glue_backend.create_partition(database_name, table_name, part_input) # type: ignore[arg-type] return "" - def batch_create_partition(self): + def batch_create_partition(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") partition_input = self.parameters.get("PartitionInputList") errors_output = self.glue_backend.batch_create_partition( - database_name, table_name, partition_input + database_name, table_name, partition_input # type: ignore[arg-type] ) out = {} @@ -171,24 +173,24 @@ class GlueResponse(BaseResponse): return json.dumps(out) - def update_partition(self): + def update_partition(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") part_input = self.parameters.get("PartitionInput") part_to_update = self.parameters.get("PartitionValueList") self.glue_backend.update_partition( - database_name, table_name, part_input, part_to_update + database_name, table_name, part_input, part_to_update # type: ignore[arg-type] ) return "" - def batch_update_partition(self): + def batch_update_partition(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") entries = self.parameters.get("Entries") errors_output = self.glue_backend.batch_update_partition( - database_name, table_name, entries + database_name, table_name, entries # type: ignore[arg-type] ) out = {} @@ -197,21 +199,21 @@ class GlueResponse(BaseResponse): return json.dumps(out) - def delete_partition(self): + def delete_partition(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") part_to_delete = self.parameters.get("PartitionValues") - self.glue_backend.delete_partition(database_name, table_name, part_to_delete) + self.glue_backend.delete_partition(database_name, table_name, part_to_delete) # type: ignore[arg-type] return "" - def batch_delete_partition(self): + def batch_delete_partition(self) -> str: database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") parts = self.parameters.get("PartitionsToDelete") errors_output = self.glue_backend.batch_delete_partition( - database_name, table_name, parts + database_name, table_name, parts # type: ignore[arg-type] ) out = {} @@ -220,37 +222,37 @@ class GlueResponse(BaseResponse): return json.dumps(out) - def create_crawler(self): + def create_crawler(self) -> str: self.glue_backend.create_crawler( - name=self.parameters.get("Name"), - role=self.parameters.get("Role"), - database_name=self.parameters.get("DatabaseName"), - description=self.parameters.get("Description"), - targets=self.parameters.get("Targets"), - schedule=self.parameters.get("Schedule"), - classifiers=self.parameters.get("Classifiers"), - table_prefix=self.parameters.get("TablePrefix"), - schema_change_policy=self.parameters.get("SchemaChangePolicy"), - recrawl_policy=self.parameters.get("RecrawlPolicy"), - lineage_configuration=self.parameters.get("LineageConfiguration"), - configuration=self.parameters.get("Configuration"), - crawler_security_configuration=self.parameters.get( + name=self.parameters.get("Name"), # type: ignore[arg-type] + role=self.parameters.get("Role"), # type: ignore[arg-type] + database_name=self.parameters.get("DatabaseName"), # type: ignore[arg-type] + description=self.parameters.get("Description"), # type: ignore[arg-type] + targets=self.parameters.get("Targets"), # type: ignore[arg-type] + schedule=self.parameters.get("Schedule"), # type: ignore[arg-type] + classifiers=self.parameters.get("Classifiers"), # type: ignore[arg-type] + table_prefix=self.parameters.get("TablePrefix"), # type: ignore[arg-type] + schema_change_policy=self.parameters.get("SchemaChangePolicy"), # type: ignore[arg-type] + recrawl_policy=self.parameters.get("RecrawlPolicy"), # type: ignore[arg-type] + lineage_configuration=self.parameters.get("LineageConfiguration"), # type: ignore[arg-type] + configuration=self.parameters.get("Configuration"), # type: ignore[arg-type] + crawler_security_configuration=self.parameters.get( # type: ignore[arg-type] "CrawlerSecurityConfiguration" ), - tags=self.parameters.get("Tags"), + tags=self.parameters.get("Tags"), # type: ignore[arg-type] ) return "" - def get_crawler(self): + def get_crawler(self) -> str: name = self.parameters.get("Name") - crawler = self.glue_backend.get_crawler(name) + crawler = self.glue_backend.get_crawler(name) # type: ignore[arg-type] return json.dumps({"Crawler": crawler.as_dict()}) - def get_crawlers(self): + def get_crawlers(self) -> str: crawlers = self.glue_backend.get_crawlers() return json.dumps({"Crawlers": [crawler.as_dict() for crawler in crawlers]}) - def list_crawlers(self): + def list_crawlers(self) -> str: next_token = self._get_param("NextToken") max_results = self._get_int_param("MaxResults") tags = self._get_param("Tags") @@ -265,31 +267,33 @@ class GlueResponse(BaseResponse): ) ) - def filter_crawlers_by_tags(self, crawlers, tags): + def filter_crawlers_by_tags( + self, crawlers: List[FakeCrawler], tags: Dict[str, str] + ) -> List[str]: if not tags: return [crawler.get_name() for crawler in crawlers] return [ crawler.get_name() for crawler in crawlers - if self.is_tags_match(self, crawler.arn, tags) + if self.is_tags_match(crawler.arn, tags) ] - def start_crawler(self): + def start_crawler(self) -> str: name = self.parameters.get("Name") - self.glue_backend.start_crawler(name) + self.glue_backend.start_crawler(name) # type: ignore[arg-type] return "" - def stop_crawler(self): + def stop_crawler(self) -> str: name = self.parameters.get("Name") - self.glue_backend.stop_crawler(name) + self.glue_backend.stop_crawler(name) # type: ignore[arg-type] return "" - def delete_crawler(self): + def delete_crawler(self) -> str: name = self.parameters.get("Name") - self.glue_backend.delete_crawler(name) + self.glue_backend.delete_crawler(name) # type: ignore[arg-type] return "" - def create_job(self): + def create_job(self) -> str: name = self._get_param("Name") description = self._get_param("Description") log_uri = self._get_param("LogUri") @@ -312,7 +316,7 @@ class GlueResponse(BaseResponse): code_gen_configuration_nodes = self._get_param("CodeGenConfigurationNodes") execution_class = self._get_param("ExecutionClass") source_control_details = self._get_param("SourceControlDetails") - name = self.glue_backend.create_job( + self.glue_backend.create_job( name=name, description=description, log_uri=log_uri, @@ -338,12 +342,12 @@ class GlueResponse(BaseResponse): ) return json.dumps(dict(Name=name)) - def get_job(self): + def get_job(self) -> str: name = self.parameters.get("JobName") - job = self.glue_backend.get_job(name) + job = self.glue_backend.get_job(name) # type: ignore[arg-type] return json.dumps({"Job": job.as_dict()}) - def get_jobs(self): + def get_jobs(self) -> str: next_token = self._get_param("NextToken") max_results = self._get_int_param("MaxResults") jobs, next_token = self.glue_backend.get_jobs( @@ -356,18 +360,18 @@ class GlueResponse(BaseResponse): ) ) - def start_job_run(self): + def start_job_run(self) -> str: name = self.parameters.get("JobName") - job_run_id = self.glue_backend.start_job_run(name) + job_run_id = self.glue_backend.start_job_run(name) # type: ignore[arg-type] return json.dumps(dict(JobRunId=job_run_id)) - def get_job_run(self): + def get_job_run(self) -> str: name = self.parameters.get("JobName") run_id = self.parameters.get("RunId") - job_run = self.glue_backend.get_job_run(name, run_id) + job_run = self.glue_backend.get_job_run(name, run_id) # type: ignore[arg-type] return json.dumps({"JobRun": job_run.as_dict()}) - def list_jobs(self): + def list_jobs(self) -> str: next_token = self._get_param("NextToken") max_results = self._get_int_param("MaxResults") tags = self._get_param("Tags") @@ -382,32 +386,31 @@ class GlueResponse(BaseResponse): ) ) - def get_tags(self): + def get_tags(self) -> TYPE_RESPONSE: resource_arn = self.parameters.get("ResourceArn") - tags = self.glue_backend.get_tags(resource_arn) + tags = self.glue_backend.get_tags(resource_arn) # type: ignore[arg-type] return 200, {}, json.dumps({"Tags": tags}) - def tag_resource(self): + def tag_resource(self) -> TYPE_RESPONSE: resource_arn = self.parameters.get("ResourceArn") tags = self.parameters.get("TagsToAdd", {}) - self.glue_backend.tag_resource(resource_arn, tags) + self.glue_backend.tag_resource(resource_arn, tags) # type: ignore[arg-type] return 201, {}, "{}" - def untag_resource(self): + def untag_resource(self) -> TYPE_RESPONSE: resource_arn = self._get_param("ResourceArn") tag_keys = self.parameters.get("TagsToRemove") - self.glue_backend.untag_resource(resource_arn, tag_keys) + self.glue_backend.untag_resource(resource_arn, tag_keys) # type: ignore[arg-type] return 200, {}, "{}" - def filter_jobs_by_tags(self, jobs, tags): + def filter_jobs_by_tags( + self, jobs: List[FakeJob], tags: Dict[str, str] + ) -> List[str]: if not tags: return [job.get_name() for job in jobs] - return [ - job.get_name() for job in jobs if self.is_tags_match(self, job.arn, tags) - ] + return [job.get_name() for job in jobs if self.is_tags_match(job.arn, tags)] - @staticmethod - def is_tags_match(self, resource_arn, tags): + def is_tags_match(self, resource_arn: str, tags: Dict[str, str]) -> bool: glue_resource_tags = self.glue_backend.get_tags(resource_arn) mutual_keys = set(glue_resource_tags).intersection(tags) for key in mutual_keys: @@ -415,28 +418,28 @@ class GlueResponse(BaseResponse): return True return False - def create_registry(self): + def create_registry(self) -> str: registry_name = self._get_param("RegistryName") description = self._get_param("Description") tags = self._get_param("Tags") registry = self.glue_backend.create_registry(registry_name, description, tags) return json.dumps(registry) - def delete_registry(self): + def delete_registry(self) -> str: registry_id = self._get_param("RegistryId") registry = self.glue_backend.delete_registry(registry_id) return json.dumps(registry) - def get_registry(self): + def get_registry(self) -> str: registry_id = self._get_param("RegistryId") registry = self.glue_backend.get_registry(registry_id) return json.dumps(registry) - def list_registries(self): + def list_registries(self) -> str: registries = self.glue_backend.list_registries() return json.dumps({"Registries": registries}) - def create_schema(self): + def create_schema(self) -> str: registry_id = self._get_param("RegistryId") schema_name = self._get_param("SchemaName") data_format = self._get_param("DataFormat") @@ -455,7 +458,7 @@ class GlueResponse(BaseResponse): ) return json.dumps(schema) - def register_schema_version(self): + def register_schema_version(self) -> str: schema_id = self._get_param("SchemaId") schema_definition = self._get_param("SchemaDefinition") schema_version = self.glue_backend.register_schema_version( @@ -463,7 +466,7 @@ class GlueResponse(BaseResponse): ) return json.dumps(schema_version) - def get_schema_version(self): + def get_schema_version(self) -> str: schema_id = self._get_param("SchemaId") schema_version_id = self._get_param("SchemaVersionId") schema_version_number = self._get_param("SchemaVersionNumber") @@ -473,7 +476,7 @@ class GlueResponse(BaseResponse): ) return json.dumps(schema_version) - def get_schema_by_definition(self): + def get_schema_by_definition(self) -> str: schema_id = self._get_param("SchemaId") schema_definition = self._get_param("SchemaDefinition") schema_version = self.glue_backend.get_schema_by_definition( @@ -481,7 +484,7 @@ class GlueResponse(BaseResponse): ) return json.dumps(schema_version) - def put_schema_version_metadata(self): + def put_schema_version_metadata(self) -> str: schema_id = self._get_param("SchemaId") schema_version_number = self._get_param("SchemaVersionNumber") schema_version_id = self._get_param("SchemaVersionId") @@ -491,24 +494,24 @@ class GlueResponse(BaseResponse): ) return json.dumps(schema_version) - def get_schema(self): + def get_schema(self) -> str: schema_id = self._get_param("SchemaId") schema = self.glue_backend.get_schema(schema_id) return json.dumps(schema) - def delete_schema(self): + def delete_schema(self) -> str: schema_id = self._get_param("SchemaId") schema = self.glue_backend.delete_schema(schema_id) return json.dumps(schema) - def update_schema(self): + def update_schema(self) -> str: schema_id = self._get_param("SchemaId") compatibility = self._get_param("Compatibility") description = self._get_param("Description") schema = self.glue_backend.update_schema(schema_id, compatibility, description) return json.dumps(schema) - def batch_get_crawlers(self): + def batch_get_crawlers(self) -> str: crawler_names = self._get_param("CrawlerNames") crawlers = self.glue_backend.batch_get_crawlers(crawler_names) crawlers_not_found = list( @@ -521,5 +524,5 @@ class GlueResponse(BaseResponse): } ) - def get_partition_indexes(self): + def get_partition_indexes(self) -> str: return json.dumps({"PartitionIndexDescriptorList": []}) diff --git a/moto/glue/utils.py b/moto/glue/utils.py index 9849968c4..11b063692 100644 --- a/moto/glue/utils.py +++ b/moto/glue/utils.py @@ -97,7 +97,7 @@ def _escape_regex(pattern: str) -> str: class _Expr(abc.ABC): @abc.abstractmethod - def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any: + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any: # type: ignore[misc] raise NotImplementedError() @@ -196,7 +196,7 @@ class _Like(_Expr): pattern = _cast("string", self.literal) # prepare SQL pattern for conversion to regex pattern - pattern = _escape_regex(pattern) + pattern = _escape_regex(pattern) # type: ignore # NOTE convert SQL wildcards to regex, no literal matches possible pattern = pattern.replace("_", ".").replace("%", ".*") @@ -265,19 +265,19 @@ class _BoolOr(_Expr): class _PartitionFilterExpressionCache: - def __init__(self): + def __init__(self) -> None: # build grammar according to Glue.Client.get_partitions(Expression) lpar, rpar = map(Suppress, "()") # NOTE these are AWS Athena column name best practices ident = Forward().set_name("ident") - ident <<= Word(alphanums + "._").set_parse_action(_Ident) | lpar + ident + rpar + ident <<= Word(alphanums + "._").set_parse_action(_Ident) | lpar + ident + rpar # type: ignore number = Forward().set_name("number") - number <<= pyparsing_common.number | lpar + number + rpar + number <<= pyparsing_common.number | lpar + number + rpar # type: ignore string = Forward().set_name("string") - string <<= QuotedString(quote_char="'", esc_quote="''") | lpar + string + rpar + string <<= QuotedString(quote_char="'", esc_quote="''") | lpar + string + rpar # type: ignore literal = (number | string).set_name("literal") literal_list = delimited_list(literal, min=1).set_name("list") @@ -293,7 +293,7 @@ class _PartitionFilterExpressionCache: in_, between, like, not_, is_, null = map( CaselessKeyword, "in between like not is null".split() ) - not_ = Suppress(not_) # only needed for matching + not_ = Suppress(not_) # type: ignore # only needed for matching cond = ( (ident + is_ + null).set_parse_action(_IsNull) @@ -343,11 +343,11 @@ _PARTITION_FILTER_EXPRESSION_CACHE = _PartitionFilterExpressionCache() class PartitionFilter: - def __init__(self, expression: Optional[str], fake_table): + def __init__(self, expression: Optional[str], fake_table: Any): self.expression = expression self.fake_table = fake_table - def __call__(self, fake_partition) -> bool: + def __call__(self, fake_partition: Any) -> bool: expression = _PARTITION_FILTER_EXPRESSION_CACHE.get(self.expression) if expression is None: return True diff --git a/moto/greengrass/exceptions.py b/moto/greengrass/exceptions.py index 271fa2130..a9ee3158b 100644 --- a/moto/greengrass/exceptions.py +++ b/moto/greengrass/exceptions.py @@ -6,36 +6,36 @@ class GreengrassClientError(JsonRESTError): class IdNotFoundException(GreengrassClientError): - def __init__(self, msg): + def __init__(self, msg: str): self.code = 404 super().__init__("IdNotFoundException", msg) class InvalidContainerDefinitionException(GreengrassClientError): - def __init__(self, msg): + def __init__(self, msg: str): self.code = 400 super().__init__("InvalidContainerDefinitionException", msg) class VersionNotFoundException(GreengrassClientError): - def __init__(self, msg): + def __init__(self, msg: str): self.code = 404 super().__init__("VersionNotFoundException", msg) class InvalidInputException(GreengrassClientError): - def __init__(self, msg): + def __init__(self, msg: str): self.code = 400 super().__init__("InvalidInputException", msg) class MissingCoreException(GreengrassClientError): - def __init__(self, msg): + def __init__(self, msg: str): self.code = 400 super().__init__("MissingCoreException", msg) class ResourceNotFoundException(GreengrassClientError): - def __init__(self, msg): + def __init__(self, msg: str): self.code = 404 super().__init__("ResourceNotFoundException", msg) diff --git a/moto/greengrass/models.py b/moto/greengrass/models.py index 0508d3c92..3eefdfa7b 100644 --- a/moto/greengrass/models.py +++ b/moto/greengrass/models.py @@ -1,6 +1,7 @@ import json from collections import OrderedDict from datetime import datetime +from typing import Any, Dict, List, Iterable, Optional import re from moto.core import BaseBackend, BackendDict, BaseModel @@ -18,7 +19,7 @@ from .exceptions import ( class FakeCoreDefinition(BaseModel): - def __init__(self, account_id, region_name, name): + def __init__(self, account_id: str, region_name: str, name: str): self.region_name = region_name self.name = name self.id = str(mock_random.uuid4()) @@ -27,7 +28,7 @@ class FakeCoreDefinition(BaseModel): self.latest_version = "" self.latest_version_arn = "" - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( @@ -44,7 +45,13 @@ class FakeCoreDefinition(BaseModel): class FakeCoreDefinitionVersion(BaseModel): - def __init__(self, account_id, region_name, core_definition_id, definition): + def __init__( + self, + account_id: str, + region_name: str, + core_definition_id: str, + definition: Dict[str, Any], + ): self.region_name = region_name self.core_definition_id = core_definition_id self.definition = definition @@ -52,8 +59,8 @@ class FakeCoreDefinitionVersion(BaseModel): self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/cores/{self.core_definition_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() - def to_dict(self, include_detail=False): - obj = { + def to_dict(self, include_detail: bool = False) -> Dict[str, Any]: + obj: Dict[str, Any] = { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( self.created_at_datetime @@ -69,7 +76,13 @@ class FakeCoreDefinitionVersion(BaseModel): class FakeDeviceDefinition(BaseModel): - def __init__(self, account_id, region_name, name, initial_version): + def __init__( + self, + account_id: str, + region_name: str, + name: str, + initial_version: Dict[str, Any], + ): self.region_name = region_name self.id = str(mock_random.uuid4()) self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/devices/{self.id}" @@ -80,7 +93,7 @@ class FakeDeviceDefinition(BaseModel): self.name = name self.initial_version = initial_version - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: res = { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( @@ -99,7 +112,13 @@ class FakeDeviceDefinition(BaseModel): class FakeDeviceDefinitionVersion(BaseModel): - def __init__(self, account_id, region_name, device_definition_id, devices): + def __init__( + self, + account_id: str, + region_name: str, + device_definition_id: str, + devices: List[Dict[str, Any]], + ): self.region_name = region_name self.device_definition_id = device_definition_id self.devices = devices @@ -107,8 +126,8 @@ class FakeDeviceDefinitionVersion(BaseModel): self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/devices/{self.device_definition_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() - def to_dict(self, include_detail=False): - obj = { + def to_dict(self, include_detail: bool = False) -> Dict[str, Any]: + obj: Dict[str, Any] = { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( self.created_at_datetime @@ -124,7 +143,13 @@ class FakeDeviceDefinitionVersion(BaseModel): class FakeResourceDefinition(BaseModel): - def __init__(self, account_id, region_name, name, initial_version): + def __init__( + self, + account_id: str, + region_name: str, + name: str, + initial_version: Dict[str, Any], + ): self.region_name = region_name self.id = str(mock_random.uuid4()) self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/resources/{self.id}" @@ -135,7 +160,7 @@ class FakeResourceDefinition(BaseModel): self.name = name self.initial_version = initial_version - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( @@ -152,7 +177,13 @@ class FakeResourceDefinition(BaseModel): class FakeResourceDefinitionVersion(BaseModel): - def __init__(self, account_id, region_name, resource_definition_id, resources): + def __init__( + self, + account_id: str, + region_name: str, + resource_definition_id: str, + resources: List[Dict[str, Any]], + ): self.region_name = region_name self.resource_definition_id = resource_definition_id self.resources = resources @@ -160,7 +191,7 @@ class FakeResourceDefinitionVersion(BaseModel): self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/resources/{self.resource_definition_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( @@ -173,7 +204,13 @@ class FakeResourceDefinitionVersion(BaseModel): class FakeFunctionDefinition(BaseModel): - def __init__(self, account_id, region_name, name, initial_version): + def __init__( + self, + account_id: str, + region_name: str, + name: str, + initial_version: Dict[str, Any], + ): self.region_name = region_name self.id = str(mock_random.uuid4()) self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/functions/{self.id}" @@ -184,7 +221,7 @@ class FakeFunctionDefinition(BaseModel): self.name = name self.initial_version = initial_version - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: res = { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( @@ -204,7 +241,12 @@ class FakeFunctionDefinition(BaseModel): class FakeFunctionDefinitionVersion(BaseModel): def __init__( - self, account_id, region_name, function_definition_id, functions, default_config + self, + account_id: str, + region_name: str, + function_definition_id: str, + functions: List[Dict[str, Any]], + default_config: Dict[str, Any], ): self.region_name = region_name self.function_definition_id = function_definition_id @@ -214,7 +256,7 @@ class FakeFunctionDefinitionVersion(BaseModel): self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/functions/{self.function_definition_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( @@ -227,7 +269,13 @@ class FakeFunctionDefinitionVersion(BaseModel): class FakeSubscriptionDefinition(BaseModel): - def __init__(self, account_id, region_name, name, initial_version): + def __init__( + self, + account_id: str, + region_name: str, + name: str, + initial_version: Dict[str, Any], + ): self.region_name = region_name self.id = str(mock_random.uuid4()) self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/subscriptions/{self.id}" @@ -238,7 +286,7 @@ class FakeSubscriptionDefinition(BaseModel): self.name = name self.initial_version = initial_version - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( @@ -256,7 +304,11 @@ class FakeSubscriptionDefinition(BaseModel): class FakeSubscriptionDefinitionVersion(BaseModel): def __init__( - self, account_id, region_name, subscription_definition_id, subscriptions + self, + account_id: str, + region_name: str, + subscription_definition_id: str, + subscriptions: List[Dict[str, Any]], ): self.region_name = region_name self.subscription_definition_id = subscription_definition_id @@ -265,7 +317,7 @@ class FakeSubscriptionDefinitionVersion(BaseModel): self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/subscriptions/{self.subscription_definition_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( @@ -278,7 +330,7 @@ class FakeSubscriptionDefinitionVersion(BaseModel): class FakeGroup(BaseModel): - def __init__(self, account_id, region_name, name): + def __init__(self, account_id: str, region_name: str, name: str): self.region_name = region_name self.group_id = str(mock_random.uuid4()) self.name = name @@ -288,7 +340,7 @@ class FakeGroup(BaseModel): self.latest_version = "" self.latest_version_arn = "" - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: obj = { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( @@ -308,14 +360,14 @@ class FakeGroup(BaseModel): class FakeGroupVersion(BaseModel): def __init__( self, - account_id, - region_name, - group_id, - core_definition_version_arn, - device_definition_version_arn, - function_definition_version_arn, - resource_definition_version_arn, - subscription_definition_version_arn, + account_id: str, + region_name: str, + group_id: str, + core_definition_version_arn: Optional[str], + device_definition_version_arn: Optional[str], + function_definition_version_arn: Optional[str], + resource_definition_version_arn: Optional[str], + subscription_definition_version_arn: Optional[str], ): self.region_name = region_name self.group_id = group_id @@ -328,7 +380,7 @@ class FakeGroupVersion(BaseModel): self.resource_definition_version_arn = resource_definition_version_arn self.subscription_definition_version_arn = subscription_definition_version_arn - def to_dict(self, include_detail=False): + def to_dict(self, include_detail: bool = False) -> Dict[str, Any]: definition = {} if self.core_definition_version_arn: @@ -354,7 +406,7 @@ class FakeGroupVersion(BaseModel): "SubscriptionDefinitionVersionArn" ] = self.subscription_definition_version_arn - obj = { + obj: Dict[str, Any] = { "Arn": self.arn, "CreationTimestamp": iso_8601_datetime_with_milliseconds( self.created_at_datetime @@ -370,7 +422,14 @@ class FakeGroupVersion(BaseModel): class FakeDeployment(BaseModel): - def __init__(self, account_id, region_name, group_id, group_arn, deployment_type): + def __init__( + self, + account_id: str, + region_name: str, + group_id: str, + group_arn: str, + deployment_type: str, + ): self.region_name = region_name self.id = str(mock_random.uuid4()) self.group_id = group_id @@ -381,7 +440,7 @@ class FakeDeployment(BaseModel): self.deployment_type = deployment_type self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:/greengrass/groups/{self.group_id}/deployments/{self.id}" - def to_dict(self, include_detail=False): + def to_dict(self, include_detail: bool = False) -> Dict[str, Any]: obj = {"DeploymentId": self.id, "DeploymentArn": self.arn} if include_detail: @@ -395,11 +454,11 @@ class FakeDeployment(BaseModel): class FakeAssociatedRole(BaseModel): - def __init__(self, role_arn): + def __init__(self, role_arn: str): self.role_arn = role_arn self.associated_at = datetime.utcnow() - def to_dict(self, include_detail=False): + def to_dict(self, include_detail: bool = False) -> Dict[str, Any]: obj = {"AssociatedAt": iso_8601_datetime_with_milliseconds(self.associated_at)} if include_detail: @@ -409,12 +468,17 @@ class FakeAssociatedRole(BaseModel): class FakeDeploymentStatus(BaseModel): - def __init__(self, deployment_type, updated_at, deployment_status="InProgress"): + def __init__( + self, + deployment_type: str, + updated_at: datetime, + deployment_status: str = "InProgress", + ): self.deployment_type = deployment_type self.update_at_datetime = updated_at self.deployment_status = deployment_status - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "DeploymentStatus": self.deployment_status, "DeploymentType": self.deployment_type, @@ -423,24 +487,38 @@ class FakeDeploymentStatus(BaseModel): class GreengrassBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.groups = OrderedDict() - self.group_role_associations = OrderedDict() - self.group_versions = OrderedDict() - self.core_definitions = OrderedDict() - self.core_definition_versions = OrderedDict() - self.device_definitions = OrderedDict() - self.device_definition_versions = OrderedDict() - self.function_definitions = OrderedDict() - self.function_definition_versions = OrderedDict() - self.resource_definitions = OrderedDict() - self.resource_definition_versions = OrderedDict() - self.subscription_definitions = OrderedDict() - self.subscription_definition_versions = OrderedDict() - self.deployments = OrderedDict() + self.groups: Dict[str, FakeGroup] = OrderedDict() + self.group_role_associations: Dict[str, FakeAssociatedRole] = OrderedDict() + self.group_versions: Dict[str, Dict[str, FakeGroupVersion]] = OrderedDict() + self.core_definitions: Dict[str, FakeCoreDefinition] = OrderedDict() + self.core_definition_versions: Dict[ + str, Dict[str, FakeCoreDefinitionVersion] + ] = OrderedDict() + self.device_definitions: Dict[str, FakeDeviceDefinition] = OrderedDict() + self.device_definition_versions: Dict[ + str, Dict[str, FakeDeviceDefinitionVersion] + ] = OrderedDict() + self.function_definitions: Dict[str, FakeFunctionDefinition] = OrderedDict() + self.function_definition_versions: Dict[ + str, Dict[str, FakeFunctionDefinitionVersion] + ] = OrderedDict() + self.resource_definitions: Dict[str, FakeResourceDefinition] = OrderedDict() + self.resource_definition_versions: Dict[ + str, Dict[str, FakeResourceDefinitionVersion] + ] = OrderedDict() + self.subscription_definitions: Dict[ + str, FakeSubscriptionDefinition + ] = OrderedDict() + self.subscription_definition_versions: Dict[ + str, Dict[str, FakeSubscriptionDefinitionVersion] + ] = OrderedDict() + self.deployments: Dict[str, FakeDeployment] = OrderedDict() - def create_core_definition(self, name, initial_version): + def create_core_definition( + self, name: str, initial_version: Dict[str, Any] + ) -> FakeCoreDefinition: core_definition = FakeCoreDefinition(self.account_id, self.region_name, name) self.core_definitions[core_definition.id] = core_definition @@ -449,22 +527,22 @@ class GreengrassBackend(BaseBackend): ) return core_definition - def list_core_definitions(self): + def list_core_definitions(self) -> Iterable[FakeCoreDefinition]: return self.core_definitions.values() - def get_core_definition(self, core_definition_id): + def get_core_definition(self, core_definition_id: str) -> FakeCoreDefinition: if core_definition_id not in self.core_definitions: raise IdNotFoundException("That Core List Definition does not exist") return self.core_definitions[core_definition_id] - def delete_core_definition(self, core_definition_id): + def delete_core_definition(self, core_definition_id: str) -> None: if core_definition_id not in self.core_definitions: raise IdNotFoundException("That cores definition does not exist.") del self.core_definitions[core_definition_id] del self.core_definition_versions[core_definition_id] - def update_core_definition(self, core_definition_id, name): + def update_core_definition(self, core_definition_id: str, name: str) -> None: if name == "": raise InvalidContainerDefinitionException( @@ -474,7 +552,9 @@ class GreengrassBackend(BaseBackend): raise IdNotFoundException("That cores definition does not exist.") self.core_definitions[core_definition_id].name = name - def create_core_definition_version(self, core_definition_id, cores): + def create_core_definition_version( + self, core_definition_id: str, cores: List[Dict[str, Any]] + ) -> FakeCoreDefinitionVersion: definition = {"Cores": cores} core_def_ver = FakeCoreDefinitionVersion( @@ -491,15 +571,17 @@ class GreengrassBackend(BaseBackend): return core_def_ver - def list_core_definition_versions(self, core_definition_id): + def list_core_definition_versions( + self, core_definition_id: str + ) -> Iterable[FakeCoreDefinitionVersion]: if core_definition_id not in self.core_definitions: raise IdNotFoundException("That cores definition does not exist.") return self.core_definition_versions[core_definition_id].values() def get_core_definition_version( - self, core_definition_id, core_definition_version_id - ): + self, core_definition_id: str, core_definition_version_id: str + ) -> FakeCoreDefinitionVersion: if core_definition_id not in self.core_definitions: raise IdNotFoundException("That cores definition does not exist.") @@ -516,7 +598,9 @@ class GreengrassBackend(BaseBackend): core_definition_version_id ] - def create_device_definition(self, name, initial_version): + def create_device_definition( + self, name: str, initial_version: Dict[str, Any] + ) -> FakeDeviceDefinition: device_def = FakeDeviceDefinition( self.account_id, self.region_name, name, initial_version ) @@ -527,10 +611,12 @@ class GreengrassBackend(BaseBackend): return device_def - def list_device_definitions(self): + def list_device_definitions(self) -> Iterable[FakeDeviceDefinition]: return self.device_definitions.values() - def create_device_definition_version(self, device_definition_id, devices): + def create_device_definition_version( + self, device_definition_id: str, devices: List[Dict[str, Any]] + ) -> FakeDeviceDefinitionVersion: if device_definition_id not in self.device_definitions: raise IdNotFoundException("That devices definition does not exist.") @@ -552,25 +638,27 @@ class GreengrassBackend(BaseBackend): return device_ver - def list_device_definition_versions(self, device_definition_id): + def list_device_definition_versions( + self, device_definition_id: str + ) -> Iterable[FakeDeviceDefinitionVersion]: if device_definition_id not in self.device_definitions: raise IdNotFoundException("That devices definition does not exist.") return self.device_definition_versions[device_definition_id].values() - def get_device_definition(self, device_definition_id): + def get_device_definition(self, device_definition_id: str) -> FakeDeviceDefinition: if device_definition_id not in self.device_definitions: raise IdNotFoundException("That Device List Definition does not exist.") return self.device_definitions[device_definition_id] - def delete_device_definition(self, device_definition_id): + def delete_device_definition(self, device_definition_id: str) -> None: if device_definition_id not in self.device_definitions: raise IdNotFoundException("That devices definition does not exist.") del self.device_definitions[device_definition_id] del self.device_definition_versions[device_definition_id] - def update_device_definition(self, device_definition_id, name): + def update_device_definition(self, device_definition_id: str, name: str) -> None: if name == "": raise InvalidContainerDefinitionException( @@ -581,8 +669,8 @@ class GreengrassBackend(BaseBackend): self.device_definitions[device_definition_id].name = name def get_device_definition_version( - self, device_definition_id, device_definition_version_id - ): + self, device_definition_id: str, device_definition_version_id: str + ) -> FakeDeviceDefinitionVersion: if device_definition_id not in self.device_definitions: raise IdNotFoundException("That devices definition does not exist.") @@ -599,7 +687,9 @@ class GreengrassBackend(BaseBackend): device_definition_version_id ] - def create_resource_definition(self, name, initial_version): + def create_resource_definition( + self, name: str, initial_version: Dict[str, Any] + ) -> FakeResourceDefinition: resources = initial_version.get("Resources", []) GreengrassBackend._validate_resources(resources) @@ -614,22 +704,26 @@ class GreengrassBackend(BaseBackend): return resource_def - def list_resource_definitions(self): - return self.resource_definitions + def list_resource_definitions(self) -> Iterable[FakeResourceDefinition]: + return self.resource_definitions.values() - def get_resource_definition(self, resource_definition_id): + def get_resource_definition( + self, resource_definition_id: str + ) -> FakeResourceDefinition: if resource_definition_id not in self.resource_definitions: raise IdNotFoundException("That Resource List Definition does not exist.") return self.resource_definitions[resource_definition_id] - def delete_resource_definition(self, resource_definition_id): + def delete_resource_definition(self, resource_definition_id: str) -> None: if resource_definition_id not in self.resource_definitions: raise IdNotFoundException("That resources definition does not exist.") del self.resource_definitions[resource_definition_id] del self.resource_definition_versions[resource_definition_id] - def update_resource_definition(self, resource_definition_id, name): + def update_resource_definition( + self, resource_definition_id: str, name: str + ) -> None: if name == "": raise InvalidInputException("Invalid resource name.") @@ -637,7 +731,9 @@ class GreengrassBackend(BaseBackend): raise IdNotFoundException("That resources definition does not exist.") self.resource_definitions[resource_definition_id].name = name - def create_resource_definition_version(self, resource_definition_id, resources): + def create_resource_definition_version( + self, resource_definition_id: str, resources: List[Dict[str, Any]] + ) -> FakeResourceDefinitionVersion: if resource_definition_id not in self.resource_definitions: raise IdNotFoundException("That resource definition does not exist.") @@ -666,7 +762,9 @@ class GreengrassBackend(BaseBackend): return resource_def_ver - def list_resource_definition_versions(self, resource_definition_id): + def list_resource_definition_versions( + self, resource_definition_id: str + ) -> Iterable[FakeResourceDefinitionVersion]: if resource_definition_id not in self.resource_definition_versions: raise IdNotFoundException("That resources definition does not exist.") @@ -674,8 +772,8 @@ class GreengrassBackend(BaseBackend): return self.resource_definition_versions[resource_definition_id].values() def get_resource_definition_version( - self, resource_definition_id, resource_definition_version_id - ): + self, resource_definition_id: str, resource_definition_version_id: str + ) -> FakeResourceDefinitionVersion: if resource_definition_id not in self.resource_definition_versions: raise IdNotFoundException("That resources definition does not exist.") @@ -693,7 +791,7 @@ class GreengrassBackend(BaseBackend): ] @staticmethod - def _validate_resources(resources): + def _validate_resources(resources: List[Dict[str, Any]]) -> None: # type: ignore[misc] for resource in resources: volume_source_path = ( resource.get("ResourceDataContainer", {}) @@ -719,7 +817,9 @@ class GreengrassBackend(BaseBackend): f", but got: {device_source_path}])", ) - def create_function_definition(self, name, initial_version): + def create_function_definition( + self, name: str, initial_version: Dict[str, Any] + ) -> FakeFunctionDefinition: func_def = FakeFunctionDefinition( self.account_id, self.region_name, name, initial_version ) @@ -731,22 +831,26 @@ class GreengrassBackend(BaseBackend): return func_def - def list_function_definitions(self): - return self.function_definitions.values() + def list_function_definitions(self) -> List[FakeFunctionDefinition]: + return list(self.function_definitions.values()) - def get_function_definition(self, function_definition_id): + def get_function_definition( + self, function_definition_id: str + ) -> FakeFunctionDefinition: if function_definition_id not in self.function_definitions: raise IdNotFoundException("That Lambda List Definition does not exist.") return self.function_definitions[function_definition_id] - def delete_function_definition(self, function_definition_id): + def delete_function_definition(self, function_definition_id: str) -> None: if function_definition_id not in self.function_definitions: raise IdNotFoundException("That lambdas definition does not exist.") del self.function_definitions[function_definition_id] del self.function_definition_versions[function_definition_id] - def update_function_definition(self, function_definition_id, name): + def update_function_definition( + self, function_definition_id: str, name: str + ) -> None: if name == "": raise InvalidContainerDefinitionException( @@ -757,8 +861,11 @@ class GreengrassBackend(BaseBackend): self.function_definitions[function_definition_id].name = name def create_function_definition_version( - self, function_definition_id, functions, default_config - ): + self, + function_definition_id: str, + functions: List[Dict[str, Any]], + default_config: Dict[str, Any], + ) -> FakeFunctionDefinitionVersion: if function_definition_id not in self.function_definitions: raise IdNotFoundException("That lambdas does not exist.") @@ -784,14 +891,16 @@ class GreengrassBackend(BaseBackend): return func_ver - def list_function_definition_versions(self, function_definition_id): + def list_function_definition_versions( + self, function_definition_id: str + ) -> Dict[str, FakeFunctionDefinitionVersion]: if function_definition_id not in self.function_definition_versions: raise IdNotFoundException("That lambdas definition does not exist.") return self.function_definition_versions[function_definition_id] def get_function_definition_version( - self, function_definition_id, function_definition_version_id - ): + self, function_definition_id: str, function_definition_version_id: str + ) -> FakeFunctionDefinitionVersion: if function_definition_id not in self.function_definition_versions: raise IdNotFoundException("That lambdas definition does not exist.") @@ -809,7 +918,7 @@ class GreengrassBackend(BaseBackend): ] @staticmethod - def _is_valid_subscription_target_or_source(target_or_source): + def _is_valid_subscription_target_or_source(target_or_source: str) -> bool: if target_or_source in ["cloud", "GGShadowService"]: return True @@ -829,10 +938,10 @@ class GreengrassBackend(BaseBackend): return False @staticmethod - def _validate_subscription_target_or_source(subscriptions): + def _validate_subscription_target_or_source(subscriptions: List[Dict[str, Any]]) -> None: # type: ignore[misc] - target_errors = [] - source_errors = [] + target_errors: List[str] = [] + source_errors: List[str] = [] for subscription in subscriptions: subscription_id = subscription["Id"] @@ -863,7 +972,9 @@ class GreengrassBackend(BaseBackend): f"The subscriptions definition is invalid or corrupted. (ErrorDetails: [{error_msg}])", ) - def create_subscription_definition(self, name, initial_version): + def create_subscription_definition( + self, name: str, initial_version: Dict[str, Any] + ) -> FakeSubscriptionDefinition: GreengrassBackend._validate_subscription_target_or_source( initial_version["Subscriptions"] @@ -883,10 +994,12 @@ class GreengrassBackend(BaseBackend): sub_def.latest_version_arn = sub_def_ver.arn return sub_def - def list_subscription_definitions(self): - return self.subscription_definitions.values() + def list_subscription_definitions(self) -> List[FakeSubscriptionDefinition]: + return list(self.subscription_definitions.values()) - def get_subscription_definition(self, subscription_definition_id): + def get_subscription_definition( + self, subscription_definition_id: str + ) -> FakeSubscriptionDefinition: if subscription_definition_id not in self.subscription_definitions: raise IdNotFoundException( @@ -894,13 +1007,15 @@ class GreengrassBackend(BaseBackend): ) return self.subscription_definitions[subscription_definition_id] - def delete_subscription_definition(self, subscription_definition_id): + def delete_subscription_definition(self, subscription_definition_id: str) -> None: if subscription_definition_id not in self.subscription_definitions: raise IdNotFoundException("That subscriptions definition does not exist.") del self.subscription_definitions[subscription_definition_id] del self.subscription_definition_versions[subscription_definition_id] - def update_subscription_definition(self, subscription_definition_id, name): + def update_subscription_definition( + self, subscription_definition_id: str, name: str + ) -> None: if name == "": raise InvalidContainerDefinitionException( @@ -911,8 +1026,8 @@ class GreengrassBackend(BaseBackend): self.subscription_definitions[subscription_definition_id].name = name def create_subscription_definition_version( - self, subscription_definition_id, subscriptions - ): + self, subscription_definition_id: str, subscriptions: List[Dict[str, Any]] + ) -> FakeSubscriptionDefinitionVersion: GreengrassBackend._validate_subscription_target_or_source(subscriptions) @@ -931,14 +1046,16 @@ class GreengrassBackend(BaseBackend): return sub_def_ver - def list_subscription_definition_versions(self, subscription_definition_id): + def list_subscription_definition_versions( + self, subscription_definition_id: str + ) -> Dict[str, FakeSubscriptionDefinitionVersion]: if subscription_definition_id not in self.subscription_definition_versions: raise IdNotFoundException("That subscriptions definition does not exist.") return self.subscription_definition_versions[subscription_definition_id] def get_subscription_definition_version( - self, subscription_definition_id, subscription_definition_version_id - ): + self, subscription_definition_id: str, subscription_definition_version_id: str + ) -> FakeSubscriptionDefinitionVersion: if subscription_definition_id not in self.subscription_definitions: raise IdNotFoundException("That subscriptions definition does not exist.") @@ -955,7 +1072,7 @@ class GreengrassBackend(BaseBackend): subscription_definition_version_id ] - def create_group(self, name, initial_version): + def create_group(self, name: str, initial_version: Dict[str, Any]) -> FakeGroup: group = FakeGroup(self.account_id, self.region_name, name) self.groups[group.group_id] = group @@ -983,22 +1100,22 @@ class GreengrassBackend(BaseBackend): return group - def list_groups(self): - return self.groups.values() + def list_groups(self) -> List[FakeGroup]: + return list(self.groups.values()) - def get_group(self, group_id): + def get_group(self, group_id: str) -> Optional[FakeGroup]: if group_id not in self.groups: raise IdNotFoundException("That Group Definition does not exist.") return self.groups.get(group_id) - def delete_group(self, group_id): + def delete_group(self, group_id: str) -> None: if group_id not in self.groups: # I don't know why, the error message is different between get_group and delete_group raise IdNotFoundException("That group definition does not exist.") del self.groups[group_id] del self.group_versions[group_id] - def update_group(self, group_id, name): + def update_group(self, group_id: str, name: str) -> None: if name == "": raise InvalidContainerDefinitionException( @@ -1010,13 +1127,13 @@ class GreengrassBackend(BaseBackend): def create_group_version( self, - group_id, - core_definition_version_arn, - device_definition_version_arn, - function_definition_version_arn, - resource_definition_version_arn, - subscription_definition_version_arn, - ): + group_id: str, + core_definition_version_arn: Optional[str], + device_definition_version_arn: Optional[str], + function_definition_version_arn: Optional[str], + resource_definition_version_arn: Optional[str], + subscription_definition_version_arn: Optional[str], + ) -> FakeGroupVersion: if group_id not in self.groups: raise IdNotFoundException("That group does not exist.") @@ -1048,19 +1165,21 @@ class GreengrassBackend(BaseBackend): def _validate_group_version_definitions( self, - core_definition_version_arn=None, - device_definition_version_arn=None, - function_definition_version_arn=None, - resource_definition_version_arn=None, - subscription_definition_version_arn=None, - ): - def _is_valid_def_ver_arn(definition_version_arn, kind="cores"): + core_definition_version_arn: Optional[str] = None, + device_definition_version_arn: Optional[str] = None, + function_definition_version_arn: Optional[str] = None, + resource_definition_version_arn: Optional[str] = None, + subscription_definition_version_arn: Optional[str] = None, + ) -> None: + def _is_valid_def_ver_arn( + definition_version_arn: Optional[str], kind: str = "cores" + ) -> bool: if definition_version_arn is None: return True if kind == "cores": - versions = self.core_definition_versions + versions: Any = self.core_definition_versions elif kind == "devices": versions = self.device_definition_versions elif kind == "functions": @@ -1124,12 +1243,14 @@ class GreengrassBackend(BaseBackend): f"The group is invalid or corrupted. (ErrorDetails: [{error_details}])", ) - def list_group_versions(self, group_id): + def list_group_versions(self, group_id: str) -> List[FakeGroupVersion]: if group_id not in self.group_versions: raise IdNotFoundException("That group definition does not exist.") - return self.group_versions[group_id].values() + return list(self.group_versions[group_id].values()) - def get_group_version(self, group_id, group_version_id): + def get_group_version( + self, group_id: str, group_version_id: str + ) -> FakeGroupVersion: if group_id not in self.group_versions: raise IdNotFoundException("That group definition does not exist.") @@ -1142,8 +1263,12 @@ class GreengrassBackend(BaseBackend): return self.group_versions[group_id][group_version_id] def create_deployment( - self, group_id, group_version_id, deployment_type, deployment_id=None - ): + self, + group_id: str, + group_version_id: str, + deployment_type: str, + deployment_id: Optional[str] = None, + ) -> FakeDeployment: deployment_types = ( "NewDeployment", @@ -1199,7 +1324,7 @@ class GreengrassBackend(BaseBackend): self.deployments[deployment.id] = deployment return deployment - def list_deployments(self, group_id): + def list_deployments(self, group_id: str) -> List[FakeDeployment]: # ListDeployments API does not check specified group is exists return [ @@ -1208,7 +1333,9 @@ class GreengrassBackend(BaseBackend): if deployment.group_id == group_id ] - def get_deployment_status(self, group_id, deployment_id): + def get_deployment_status( + self, group_id: str, deployment_id: str + ) -> FakeDeploymentStatus: if deployment_id not in self.deployments: raise InvalidInputException(f"Deployment '{deployment_id}' does not exist.") @@ -1224,7 +1351,7 @@ class GreengrassBackend(BaseBackend): deployment.deployment_status, ) - def reset_deployments(self, group_id, force=False): + def reset_deployments(self, group_id: str, force: bool = False) -> FakeDeployment: if group_id not in self.groups: raise ResourceNotFoundException("That Group Definition does not exist.") @@ -1248,7 +1375,9 @@ class GreengrassBackend(BaseBackend): self.deployments[deployment.id] = deployment return deployment - def associate_role_to_group(self, group_id, role_arn): + def associate_role_to_group( + self, group_id: str, role_arn: str + ) -> FakeAssociatedRole: # I don't know why, AssociateRoleToGroup does not check specified group is exists # So, this API allows any group id such as "a" @@ -1257,7 +1386,7 @@ class GreengrassBackend(BaseBackend): self.group_role_associations[group_id] = associated_role return associated_role - def get_associated_role(self, group_id): + def get_associated_role(self, group_id: str) -> FakeAssociatedRole: if group_id not in self.group_role_associations: raise GreengrassClientError( @@ -1266,7 +1395,7 @@ class GreengrassBackend(BaseBackend): return self.group_role_associations[group_id] - def disassociate_role_from_group(self, group_id): + def disassociate_role_from_group(self, group_id: str) -> None: if group_id not in self.group_role_associations: return del self.group_role_associations[group_id] diff --git a/moto/greengrass/responses.py b/moto/greengrass/responses.py index d36ad8516..8eb1d1c46 100644 --- a/moto/greengrass/responses.py +++ b/moto/greengrass/responses.py @@ -1,20 +1,22 @@ from datetime import datetime +from typing import Any import json +from moto.core.common_types import TYPE_RESPONSE from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.responses import BaseResponse -from .models import greengrass_backends +from .models import greengrass_backends, GreengrassBackend class GreengrassResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="greengrass") @property - def greengrass_backend(self): + def greengrass_backend(self) -> GreengrassBackend: return greengrass_backends[self.current_account][self.region] - def core_definitions(self, request, full_url, headers): + def core_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -23,7 +25,7 @@ class GreengrassResponse(BaseResponse): if self.method == "POST": return self.create_core_definition() - def list_core_definitions(self): + def list_core_definitions(self) -> TYPE_RESPONSE: res = self.greengrass_backend.list_core_definitions() return ( 200, @@ -33,7 +35,7 @@ class GreengrassResponse(BaseResponse): ), ) - def create_core_definition(self): + def create_core_definition(self) -> TYPE_RESPONSE: name = self._get_param("Name") initial_version = self._get_param("InitialVersion") res = self.greengrass_backend.create_core_definition( @@ -41,7 +43,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def core_definition(self, request, full_url, headers): + def core_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -53,21 +55,21 @@ class GreengrassResponse(BaseResponse): if self.method == "PUT": return self.update_core_definition() - def get_core_definition(self): + def get_core_definition(self) -> TYPE_RESPONSE: core_definition_id = self.path.split("/")[-1] res = self.greengrass_backend.get_core_definition( core_definition_id=core_definition_id ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def delete_core_definition(self): + def delete_core_definition(self) -> TYPE_RESPONSE: core_definition_id = self.path.split("/")[-1] self.greengrass_backend.delete_core_definition( core_definition_id=core_definition_id ) return 200, {"status": 200}, json.dumps({}) - def update_core_definition(self): + def update_core_definition(self) -> TYPE_RESPONSE: core_definition_id = self.path.split("/")[-1] name = self._get_param("Name") self.greengrass_backend.update_core_definition( @@ -75,7 +77,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps({}) - def core_definition_versions(self, request, full_url, headers): + def core_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -84,7 +86,7 @@ class GreengrassResponse(BaseResponse): if self.method == "POST": return self.create_core_definition_version() - def create_core_definition_version(self): + def create_core_definition_version(self) -> TYPE_RESPONSE: core_definition_id = self.path.split("/")[-2] cores = self._get_param("Cores") @@ -93,7 +95,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_core_definition_versions(self): + def list_core_definition_versions(self) -> TYPE_RESPONSE: core_definition_id = self.path.split("/")[-2] res = self.greengrass_backend.list_core_definition_versions(core_definition_id) return ( @@ -102,13 +104,13 @@ class GreengrassResponse(BaseResponse): json.dumps({"Versions": [core_def_ver.to_dict() for core_def_ver in res]}), ) - def core_definition_version(self, request, full_url, headers): + def core_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": return self.get_core_definition_version() - def get_core_definition_version(self): + def get_core_definition_version(self) -> TYPE_RESPONSE: core_definition_id = self.path.split("/")[-3] core_definition_version_id = self.path.split("/")[-1] res = self.greengrass_backend.get_core_definition_version( @@ -117,7 +119,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True)) - def device_definitions(self, request, full_url, headers): + def device_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -126,7 +128,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_device_definition() - def create_device_definition(self): + def create_device_definition(self) -> TYPE_RESPONSE: name = self._get_param("Name") initial_version = self._get_param("InitialVersion") @@ -135,7 +137,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_device_definition(self): + def list_device_definition(self) -> TYPE_RESPONSE: res = self.greengrass_backend.list_device_definitions() return ( 200, @@ -149,7 +151,7 @@ class GreengrassResponse(BaseResponse): ), ) - def device_definition_versions(self, request, full_url, headers): + def device_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -158,7 +160,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_device_definition_versions() - def create_device_definition_version(self): + def create_device_definition_version(self) -> TYPE_RESPONSE: device_definition_id = self.path.split("/")[-2] devices = self._get_param("Devices") @@ -168,7 +170,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_device_definition_versions(self): + def list_device_definition_versions(self) -> TYPE_RESPONSE: device_definition_id = self.path.split("/")[-2] res = self.greengrass_backend.list_device_definition_versions( @@ -182,7 +184,7 @@ class GreengrassResponse(BaseResponse): ), ) - def device_definition(self, request, full_url, headers): + def device_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -194,14 +196,14 @@ class GreengrassResponse(BaseResponse): if self.method == "PUT": return self.update_device_definition() - def get_device_definition(self): + def get_device_definition(self) -> TYPE_RESPONSE: device_definition_id = self.path.split("/")[-1] res = self.greengrass_backend.get_device_definition( device_definition_id=device_definition_id ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def delete_device_definition(self): + def delete_device_definition(self) -> TYPE_RESPONSE: device_definition_id = self.path.split("/")[-1] self.greengrass_backend.delete_device_definition( @@ -209,7 +211,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps({}) - def update_device_definition(self): + def update_device_definition(self) -> TYPE_RESPONSE: device_definition_id = self.path.split("/")[-1] name = self._get_param("Name") @@ -218,13 +220,13 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps({}) - def device_definition_version(self, request, full_url, headers): + def device_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": return self.get_device_definition_version() - def get_device_definition_version(self): + def get_device_definition_version(self) -> TYPE_RESPONSE: device_definition_id = self.path.split("/")[-3] device_definition_version_id = self.path.split("/")[-1] res = self.greengrass_backend.get_device_definition_version( @@ -233,7 +235,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True)) - def resource_definitions(self, request, full_url, headers): + def resource_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -242,7 +244,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_resource_definitions() - def create_resource_definition(self): + def create_resource_definition(self) -> TYPE_RESPONSE: initial_version = self._get_param("InitialVersion") name = self._get_param("Name") @@ -251,16 +253,16 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_resource_definitions(self): + def list_resource_definitions(self) -> TYPE_RESPONSE: res = self.greengrass_backend.list_resource_definitions() return ( 200, {"status": 200}, - json.dumps({"Definitions": [i.to_dict() for i in res.values()]}), + json.dumps({"Definitions": [i.to_dict() for i in res]}), ) - def resource_definition(self, request, full_url, headers): + def resource_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -272,14 +274,14 @@ class GreengrassResponse(BaseResponse): if self.method == "PUT": return self.update_resource_definition() - def get_resource_definition(self): + def get_resource_definition(self) -> TYPE_RESPONSE: resource_definition_id = self.path.split("/")[-1] res = self.greengrass_backend.get_resource_definition( resource_definition_id=resource_definition_id ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def delete_resource_definition(self): + def delete_resource_definition(self) -> TYPE_RESPONSE: resource_definition_id = self.path.split("/")[-1] self.greengrass_backend.delete_resource_definition( @@ -287,7 +289,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps({}) - def update_resource_definition(self): + def update_resource_definition(self) -> TYPE_RESPONSE: resource_definition_id = self.path.split("/")[-1] name = self._get_param("Name") @@ -296,7 +298,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps({}) - def resource_definition_versions(self, request, full_url, headers): + def resource_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -305,7 +307,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_resource_definition_versions() - def create_resource_definition_version(self): + def create_resource_definition_version(self) -> TYPE_RESPONSE: resource_definition_id = self.path.split("/")[-2] resources = self._get_param("Resources") @@ -315,7 +317,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_resource_definition_versions(self): + def list_resource_definition_versions(self) -> TYPE_RESPONSE: resource_device_definition_id = self.path.split("/")[-2] res = self.greengrass_backend.list_resource_definition_versions( @@ -330,13 +332,13 @@ class GreengrassResponse(BaseResponse): ), ) - def resource_definition_version(self, request, full_url, headers): + def resource_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": return self.get_resource_definition_version() - def get_resource_definition_version(self): + def get_resource_definition_version(self) -> TYPE_RESPONSE: resource_definition_id = self.path.split("/")[-3] resource_definition_version_id = self.path.split("/")[-1] res = self.greengrass_backend.get_resource_definition_version( @@ -345,7 +347,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def function_definitions(self, request, full_url, headers): + def function_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -354,7 +356,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_function_definitions() - def create_function_definition(self): + def create_function_definition(self) -> TYPE_RESPONSE: initial_version = self._get_param("InitialVersion") name = self._get_param("Name") @@ -363,7 +365,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_function_definitions(self): + def list_function_definitions(self) -> TYPE_RESPONSE: res = self.greengrass_backend.list_function_definitions() return ( 200, @@ -373,7 +375,7 @@ class GreengrassResponse(BaseResponse): ), ) - def function_definition(self, request, full_url, headers): + def function_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -385,21 +387,21 @@ class GreengrassResponse(BaseResponse): if self.method == "PUT": return self.update_function_definition() - def get_function_definition(self): + def get_function_definition(self) -> TYPE_RESPONSE: function_definition_id = self.path.split("/")[-1] res = self.greengrass_backend.get_function_definition( function_definition_id=function_definition_id, ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def delete_function_definition(self): + def delete_function_definition(self) -> TYPE_RESPONSE: function_definition_id = self.path.split("/")[-1] self.greengrass_backend.delete_function_definition( function_definition_id=function_definition_id, ) return 200, {"status": 200}, json.dumps({}) - def update_function_definition(self): + def update_function_definition(self) -> TYPE_RESPONSE: function_definition_id = self.path.split("/")[-1] name = self._get_param("Name") self.greengrass_backend.update_function_definition( @@ -407,7 +409,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps({}) - def function_definition_versions(self, request, full_url, headers): + def function_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -416,7 +418,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_function_definition_versions() - def create_function_definition_version(self): + def create_function_definition_version(self) -> TYPE_RESPONSE: default_config = self._get_param("DefaultConfig") function_definition_id = self.path.split("/")[-2] @@ -429,7 +431,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_function_definition_versions(self): + def list_function_definition_versions(self) -> TYPE_RESPONSE: function_definition_id = self.path.split("/")[-2] res = self.greengrass_backend.list_function_definition_versions( function_definition_id=function_definition_id @@ -437,13 +439,13 @@ class GreengrassResponse(BaseResponse): versions = [i.to_dict() for i in res.values()] return 200, {"status": 200}, json.dumps({"Versions": versions}) - def function_definition_version(self, request, full_url, headers): + def function_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": return self.get_function_definition_version() - def get_function_definition_version(self): + def get_function_definition_version(self) -> TYPE_RESPONSE: function_definition_id = self.path.split("/")[-3] function_definition_version_id = self.path.split("/")[-1] res = self.greengrass_backend.get_function_definition_version( @@ -452,7 +454,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def subscription_definitions(self, request, full_url, headers): + def subscription_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -461,7 +463,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_subscription_definitions() - def create_subscription_definition(self): + def create_subscription_definition(self) -> TYPE_RESPONSE: initial_version = self._get_param("InitialVersion") name = self._get_param("Name") @@ -470,7 +472,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_subscription_definitions(self): + def list_subscription_definitions(self) -> TYPE_RESPONSE: res = self.greengrass_backend.list_subscription_definitions() return ( @@ -486,7 +488,7 @@ class GreengrassResponse(BaseResponse): ), ) - def subscription_definition(self, request, full_url, headers): + def subscription_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -498,21 +500,21 @@ class GreengrassResponse(BaseResponse): if self.method == "PUT": return self.update_subscription_definition() - def get_subscription_definition(self): + def get_subscription_definition(self) -> TYPE_RESPONSE: subscription_definition_id = self.path.split("/")[-1] res = self.greengrass_backend.get_subscription_definition( subscription_definition_id=subscription_definition_id ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def delete_subscription_definition(self): + def delete_subscription_definition(self) -> TYPE_RESPONSE: subscription_definition_id = self.path.split("/")[-1] self.greengrass_backend.delete_subscription_definition( subscription_definition_id=subscription_definition_id ) return 200, {"status": 200}, json.dumps({}) - def update_subscription_definition(self): + def update_subscription_definition(self) -> TYPE_RESPONSE: subscription_definition_id = self.path.split("/")[-1] name = self._get_param("Name") self.greengrass_backend.update_subscription_definition( @@ -520,7 +522,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps({}) - def subscription_definition_versions(self, request, full_url, headers): + def subscription_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -529,7 +531,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_subscription_definition_versions() - def create_subscription_definition_version(self): + def create_subscription_definition_version(self) -> TYPE_RESPONSE: subscription_definition_id = self.path.split("/")[-2] subscriptions = self._get_param("Subscriptions") @@ -539,7 +541,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_subscription_definition_versions(self): + def list_subscription_definition_versions(self) -> TYPE_RESPONSE: subscription_definition_id = self.path.split("/")[-2] res = self.greengrass_backend.list_subscription_definition_versions( subscription_definition_id=subscription_definition_id @@ -547,13 +549,13 @@ class GreengrassResponse(BaseResponse): versions = [i.to_dict() for i in res.values()] return 200, {"status": 200}, json.dumps({"Versions": versions}) - def subscription_definition_version(self, request, full_url, headers): + def subscription_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": return self.get_subscription_definition_version() - def get_subscription_definition_version(self): + def get_subscription_definition_version(self) -> TYPE_RESPONSE: subscription_definition_id = self.path.split("/")[-3] subscription_definition_version_id = self.path.split("/")[-1] res = self.greengrass_backend.get_subscription_definition_version( @@ -562,7 +564,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def groups(self, request, full_url, headers): + def groups(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -571,7 +573,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_groups() - def create_group(self): + def create_group(self) -> TYPE_RESPONSE: initial_version = self._get_param("InitialVersion") name = self._get_param("Name") @@ -580,7 +582,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_groups(self): + def list_groups(self) -> TYPE_RESPONSE: res = self.greengrass_backend.list_groups() return ( @@ -589,7 +591,7 @@ class GreengrassResponse(BaseResponse): json.dumps({"Groups": [group.to_dict() for group in res]}), ) - def group(self, request, full_url, headers): + def group(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -601,27 +603,25 @@ class GreengrassResponse(BaseResponse): if self.method == "PUT": return self.update_group() - def get_group(self): + def get_group(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-1] - res = self.greengrass_backend.get_group( - group_id=group_id, - ) - return 200, {"status": 200}, json.dumps(res.to_dict()) + res = self.greengrass_backend.get_group(group_id=group_id) + return 200, {"status": 200}, json.dumps(res.to_dict()) # type: ignore - def delete_group(self): + def delete_group(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-1] self.greengrass_backend.delete_group( group_id=group_id, ) return 200, {"status": 200}, json.dumps({}) - def update_group(self): + def update_group(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-1] name = self._get_param("Name") self.greengrass_backend.update_group(group_id=group_id, name=name) return 200, {"status": 200}, json.dumps({}) - def group_versions(self, request, full_url, headers): + def group_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -630,7 +630,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_group_versions() - def create_group_version(self): + def create_group_version(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-2] @@ -656,7 +656,7 @@ class GreengrassResponse(BaseResponse): ) return 201, {"status": 201}, json.dumps(res.to_dict()) - def list_group_versions(self): + def list_group_versions(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-2] res = self.greengrass_backend.list_group_versions(group_id=group_id) return ( @@ -665,13 +665,13 @@ class GreengrassResponse(BaseResponse): json.dumps({"Versions": [group_ver.to_dict() for group_ver in res]}), ) - def group_version(self, request, full_url, headers): + def group_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": return self.get_group_version() - def get_group_version(self): + def get_group_version(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-3] group_version_id = self.path.split("/")[-1] @@ -681,7 +681,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True)) - def deployments(self, request, full_url, headers): + def deployments(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -690,7 +690,7 @@ class GreengrassResponse(BaseResponse): if self.method == "GET": return self.list_deployments() - def create_deployment(self): + def create_deployment(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-2] group_version_id = self._get_param("GroupVersionId") @@ -705,7 +705,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def list_deployments(self): + def list_deployments(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-2] res = self.greengrass_backend.list_deployments(group_id=group_id) @@ -721,13 +721,13 @@ class GreengrassResponse(BaseResponse): json.dumps({"Deployments": deployments}), ) - def deployment_satus(self, request, full_url, headers): + def deployment_satus(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": return self.get_deployment_status() - def get_deployment_status(self): + def get_deployment_status(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-4] deployment_id = self.path.split("/")[-2] @@ -737,13 +737,13 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def deployments_reset(self, request, full_url, headers): + def deployments_reset(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": return self.reset_deployments() - def reset_deployments(self): + def reset_deployments(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-3] res = self.greengrass_backend.reset_deployments( @@ -751,7 +751,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def role(self, request, full_url, headers): + def role(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "PUT": @@ -763,7 +763,7 @@ class GreengrassResponse(BaseResponse): if self.method == "DELETE": return self.disassociate_role_from_group() - def associate_role_to_group(self): + def associate_role_to_group(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-2] role_arn = self._get_param("RoleArn") @@ -773,7 +773,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict()) - def get_associated_role(self): + def get_associated_role(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-2] res = self.greengrass_backend.get_associated_role( @@ -781,7 +781,7 @@ class GreengrassResponse(BaseResponse): ) return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True)) - def disassociate_role_from_group(self): + def disassociate_role_from_group(self) -> TYPE_RESPONSE: group_id = self.path.split("/")[-2] self.greengrass_backend.disassociate_role_from_group( group_id=group_id, diff --git a/moto/guardduty/exceptions.py b/moto/guardduty/exceptions.py index d61b9aa17..59ebe6df5 100644 --- a/moto/guardduty/exceptions.py +++ b/moto/guardduty/exceptions.py @@ -1,3 +1,4 @@ +from typing import Any, List, Tuple from moto.core.exceptions import JsonRESTError @@ -8,24 +9,28 @@ class GuardDutyException(JsonRESTError): class DetectorNotFoundException(GuardDutyException): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidInputException", "The request is rejected because the input detectorId is not owned by the current account.", ) - def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument - return {"X-Amzn-ErrorType": "BadRequestException"} + def get_headers( + self, *args: Any, **kwargs: Any + ) -> List[Tuple[str, str]]: # pylint: disable=unused-argument + return [("X-Amzn-ErrorType", "BadRequestException")] class FilterNotFoundException(GuardDutyException): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidInputException", "The request is rejected since no such resource found.", ) - def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument - return {"X-Amzn-ErrorType": "BadRequestException"} + def get_headers( + self, *args: Any, **kwargs: Any + ) -> List[Tuple[str, str]]: # pylint: disable=unused-argument + return [("X-Amzn-ErrorType", "BadRequestException")] diff --git a/moto/guardduty/models.py b/moto/guardduty/models.py index 7fad4a360..d30ece1b3 100644 --- a/moto/guardduty/models.py +++ b/moto/guardduty/models.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, List, Optional from moto.core import BaseBackend, BackendDict, BaseModel from moto.moto_api._internal import mock_random from datetime import datetime @@ -6,12 +7,18 @@ from .exceptions import DetectorNotFoundException, FilterNotFoundException class GuardDutyBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.admin_account_ids = [] - self.detectors = {} + self.admin_account_ids: List[str] = [] + self.detectors: Dict[str, Detector] = {} - def create_detector(self, enable, finding_publishing_frequency, data_sources, tags): + def create_detector( + self, + enable: bool, + finding_publishing_frequency: str, + data_sources: Dict[str, Any], + tags: Dict[str, str], + ) -> str: if finding_publishing_frequency not in [ "FIFTEEN_MINUTES", "ONE_HOUR", @@ -31,29 +38,35 @@ class GuardDutyBackend(BaseBackend): return detector.id def create_filter( - self, detector_id, name, action, description, finding_criteria, rank - ): + self, + detector_id: str, + name: str, + action: str, + description: str, + finding_criteria: Dict[str, Any], + rank: int, + ) -> None: detector = self.get_detector(detector_id) _filter = Filter(name, action, description, finding_criteria, rank) detector.add_filter(_filter) - def delete_detector(self, detector_id): + def delete_detector(self, detector_id: str) -> None: self.detectors.pop(detector_id, None) - def delete_filter(self, detector_id, filter_name): + def delete_filter(self, detector_id: str, filter_name: str) -> None: detector = self.get_detector(detector_id) detector.delete_filter(filter_name) - def enable_organization_admin_account(self, admin_account_id): + def enable_organization_admin_account(self, admin_account_id: str) -> None: self.admin_account_ids.append(admin_account_id) - def list_organization_admin_accounts(self): + def list_organization_admin_accounts(self) -> List[str]: """ Pagination is not yet implemented """ return self.admin_account_ids - def list_detectors(self): + def list_detectors(self) -> List[str]: """ The MaxResults and NextToken-parameter have not yet been implemented. """ @@ -62,24 +75,34 @@ class GuardDutyBackend(BaseBackend): detectorids.append(self.detectors[detector].id) return detectorids - def get_detector(self, detector_id): + def get_detector(self, detector_id: str) -> "Detector": if detector_id not in self.detectors: raise DetectorNotFoundException return self.detectors[detector_id] - def get_filter(self, detector_id, filter_name): + def get_filter(self, detector_id: str, filter_name: str) -> "Filter": detector = self.get_detector(detector_id) return detector.get_filter(filter_name) def update_detector( - self, detector_id, enable, finding_publishing_frequency, data_sources - ): + self, + detector_id: str, + enable: bool, + finding_publishing_frequency: str, + data_sources: Dict[str, Any], + ) -> None: detector = self.get_detector(detector_id) detector.update(enable, finding_publishing_frequency, data_sources) def update_filter( - self, detector_id, filter_name, action, description, finding_criteria, rank - ): + self, + detector_id: str, + filter_name: str, + action: str, + description: str, + finding_criteria: Dict[str, Any], + rank: int, + ) -> None: detector = self.get_detector(detector_id) detector.update_filter( filter_name, @@ -91,14 +114,27 @@ class GuardDutyBackend(BaseBackend): class Filter(BaseModel): - def __init__(self, name, action, description, finding_criteria, rank): + def __init__( + self, + name: str, + action: str, + description: str, + finding_criteria: Dict[str, Any], + rank: int, + ): self.name = name self.action = action self.description = description self.finding_criteria = finding_criteria self.rank = rank or 1 - def update(self, action, description, finding_criteria, rank): + def update( + self, + action: Optional[str], + description: Optional[str], + finding_criteria: Optional[Dict[str, Any]], + rank: Optional[int], + ) -> None: if action is not None: self.action = action if description is not None: @@ -108,7 +144,7 @@ class Filter(BaseModel): if rank is not None: self.rank = rank - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "name": self.name, "action": self.action, @@ -121,12 +157,12 @@ class Filter(BaseModel): class Detector(BaseModel): def __init__( self, - account_id, - created_at, - finding_publish_freq, - enabled, - datasources, - tags, + account_id: str, + created_at: datetime, + finding_publish_freq: str, + enabled: bool, + datasources: Dict[str, Any], + tags: Dict[str, str], ): self.id = mock_random.get_random_hex(length=32) self.created_at = created_at @@ -137,20 +173,27 @@ class Detector(BaseModel): self.datasources = datasources or {} self.tags = tags or {} - self.filters = dict() + self.filters: Dict[str, Filter] = dict() - def add_filter(self, _filter: Filter): + def add_filter(self, _filter: Filter) -> None: self.filters[_filter.name] = _filter - def delete_filter(self, filter_name): + def delete_filter(self, filter_name: str) -> None: self.filters.pop(filter_name, None) - def get_filter(self, filter_name: str): + def get_filter(self, filter_name: str) -> Filter: if filter_name not in self.filters: raise FilterNotFoundException return self.filters[filter_name] - def update_filter(self, filter_name, action, description, finding_criteria, rank): + def update_filter( + self, + filter_name: str, + action: str, + description: str, + finding_criteria: Dict[str, Any], + rank: int, + ) -> None: _filter = self.get_filter(filter_name) _filter.update( action=action, @@ -159,7 +202,12 @@ class Detector(BaseModel): rank=rank, ) - def update(self, enable, finding_publishing_frequency, data_sources): + def update( + self, + enable: bool, + finding_publishing_frequency: str, + data_sources: Dict[str, Any], + ) -> None: if enable is not None: self.enabled = enable if finding_publishing_frequency is not None: @@ -167,7 +215,7 @@ class Detector(BaseModel): if data_sources is not None: self.datasources = data_sources - def to_json(self): + def to_json(self) -> Dict[str, Any]: data_sources = { "cloudTrail": {"status": "DISABLED"}, "dnsLogs": {"status": "DISABLED"}, diff --git a/moto/guardduty/responses.py b/moto/guardduty/responses.py index 7980bb1c3..03763e057 100644 --- a/moto/guardduty/responses.py +++ b/moto/guardduty/responses.py @@ -1,18 +1,20 @@ +from typing import Any +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse -from .models import guardduty_backends +from .models import guardduty_backends, GuardDutyBackend import json from urllib.parse import unquote class GuardDutyResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="guardduty") @property - def guardduty_backend(self): + def guardduty_backend(self) -> GuardDutyBackend: return guardduty_backends[self.current_account][self.region] - def filter(self, request, full_url, headers): + def filter(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self.get_filter() @@ -21,12 +23,12 @@ class GuardDutyResponse(BaseResponse): elif request.method == "POST": return self.update_filter() - def filters(self, request, full_url, headers): + def filters(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "POST": return self.create_filter() - def detectors(self, request, full_url, headers): + def detectors(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) if request.method == "POST": return self.create_detector() @@ -35,7 +37,7 @@ class GuardDutyResponse(BaseResponse): else: return 404, {}, "" - def detector(self, request, full_url, headers): + def detector(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self.get_detector() @@ -44,7 +46,7 @@ class GuardDutyResponse(BaseResponse): elif request.method == "POST": return self.update_detector() - def create_filter(self): + def create_filter(self) -> TYPE_RESPONSE: detector_id = self.path.split("/")[-2] name = self._get_param("name") action = self._get_param("action") @@ -57,7 +59,7 @@ class GuardDutyResponse(BaseResponse): ) return 200, {}, json.dumps({"name": name}) - def create_detector(self): + def create_detector(self) -> TYPE_RESPONSE: enable = self._get_param("enable") finding_publishing_frequency = self._get_param("findingPublishingFrequency") data_sources = self._get_param("dataSources") @@ -69,20 +71,22 @@ class GuardDutyResponse(BaseResponse): return 200, {}, json.dumps(dict(detectorId=detector_id)) - def delete_detector(self): + def delete_detector(self) -> TYPE_RESPONSE: detector_id = self.path.split("/")[-1] self.guardduty_backend.delete_detector(detector_id) return 200, {}, "{}" - def delete_filter(self): + def delete_filter(self) -> TYPE_RESPONSE: detector_id = self.path.split("/")[-3] filter_name = unquote(self.path.split("/")[-1]) self.guardduty_backend.delete_filter(detector_id, filter_name) return 200, {}, "{}" - def enable_organization_admin_account(self, request, full_url, headers): + def enable_organization_admin_account( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) admin_account = self._get_param("adminAccountId") @@ -90,7 +94,9 @@ class GuardDutyResponse(BaseResponse): return 200, {}, "{}" - def list_organization_admin_accounts(self, request, full_url, headers): + def list_organization_admin_accounts( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) account_ids = self.guardduty_backend.list_organization_admin_accounts() @@ -108,25 +114,25 @@ class GuardDutyResponse(BaseResponse): ), ) - def list_detectors(self): + def list_detectors(self) -> TYPE_RESPONSE: detector_ids = self.guardduty_backend.list_detectors() return 200, {}, json.dumps({"detectorIds": detector_ids}) - def get_detector(self): + def get_detector(self) -> TYPE_RESPONSE: detector_id = self.path.split("/")[-1] detector = self.guardduty_backend.get_detector(detector_id) return 200, {}, json.dumps(detector.to_json()) - def get_filter(self): + def get_filter(self) -> TYPE_RESPONSE: detector_id = self.path.split("/")[-3] filter_name = unquote(self.path.split("/")[-1]) _filter = self.guardduty_backend.get_filter(detector_id, filter_name) return 200, {}, json.dumps(_filter.to_json()) - def update_detector(self): + def update_detector(self) -> TYPE_RESPONSE: detector_id = self.path.split("/")[-1] enable = self._get_param("enable") finding_publishing_frequency = self._get_param("findingPublishingFrequency") @@ -137,7 +143,7 @@ class GuardDutyResponse(BaseResponse): ) return 200, {}, "{}" - def update_filter(self): + def update_filter(self) -> TYPE_RESPONSE: detector_id = self.path.split("/")[-3] filter_name = unquote(self.path.split("/")[-1]) action = self._get_param("action") diff --git a/moto/utilities/tagging_service.py b/moto/utilities/tagging_service.py index ed72a8a4a..970d1795d 100644 --- a/moto/utilities/tagging_service.py +++ b/moto/utilities/tagging_service.py @@ -171,7 +171,9 @@ class TaggingService: ) @staticmethod - def convert_dict_to_tags_input(tags: Dict[str, str]) -> List[Dict[str, str]]: + def convert_dict_to_tags_input( + tags: Optional[Dict[str, str]] + ) -> List[Dict[str, str]]: """Given a dictionary, return generic boto params for tags""" if not tags: return [] diff --git a/setup.cfg b/setup.cfg index f617f0fe1..c62229eee 100644 --- a/setup.cfg +++ b/setup.cfg @@ -229,7 +229,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/moto_api,moto/neptune +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/moto_api,moto/neptune show_column_numbers=True show_error_codes = True disable_error_code=abstract