Techdebt: MyPy g-models (#6048)

This commit is contained in:
Bert Blommers 2023-03-11 16:00:52 -01:00 committed by GitHub
parent bb39b02098
commit f1f4454b0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1130 additions and 797 deletions

View File

@ -1,6 +1,6 @@
import hashlib
import datetime
from typing import Any, Dict, List, Optional, Union
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.utilities.utils import md5_hash
@ -9,7 +9,7 @@ from .utils import get_job_id
class Job(BaseModel):
def __init__(self, tier):
def __init__(self, tier: str):
self.st = datetime.datetime.now()
if tier.lower() == "expedited":
@ -20,16 +20,19 @@ class Job(BaseModel):
# Standard
self.et = self.st + datetime.timedelta(seconds=5)
def to_dict(self) -> Dict[str, Any]:
return {}
class ArchiveJob(Job):
def __init__(self, job_id, tier, arn, archive_id):
def __init__(self, job_id: str, tier: str, arn: str, archive_id: Optional[str]):
self.job_id = job_id
self.tier = tier
self.arn = arn
self.archive_id = archive_id
Job.__init__(self, tier)
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
d = {
"Action": "ArchiveRetrieval",
"ArchiveId": self.archive_id,
@ -57,13 +60,13 @@ class ArchiveJob(Job):
class InventoryJob(Job):
def __init__(self, job_id, tier, arn):
def __init__(self, job_id: str, tier: str, arn: str):
self.job_id = job_id
self.tier = tier
self.arn = arn
Job.__init__(self, tier)
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
d = {
"Action": "InventoryRetrieval",
"ArchiveSHA256TreeHash": None,
@ -89,15 +92,15 @@ class InventoryJob(Job):
class Vault(BaseModel):
def __init__(self, vault_name, account_id, region):
def __init__(self, vault_name: str, account_id: str, region: str):
self.st = datetime.datetime.now()
self.vault_name = vault_name
self.region = region
self.archives = {}
self.jobs = {}
self.archives: Dict[str, Dict[str, Any]] = {}
self.jobs: Dict[str, Job] = {}
self.arn = f"arn:aws:glacier:{region}:{account_id}:vaults/{vault_name}"
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
archives_size = 0
for k in self.archives:
archives_size += self.archives[k]["size"]
@ -111,7 +114,7 @@ class Vault(BaseModel):
}
return d
def create_archive(self, body, description):
def create_archive(self, body: bytes, description: str) -> Dict[str, Any]:
archive_id = md5_hash(body).hexdigest()
self.archives[archive_id] = {}
self.archives[archive_id]["archive_id"] = archive_id
@ -124,10 +127,10 @@ class Vault(BaseModel):
self.archives[archive_id]["description"] = description
return self.archives[archive_id]
def get_archive_body(self, archive_id):
def get_archive_body(self, archive_id: str) -> str:
return self.archives[archive_id]["body"]
def get_archive_list(self):
def get_archive_list(self) -> List[Dict[str, Any]]:
archive_list = []
for a in self.archives:
archive = self.archives[a]
@ -141,34 +144,33 @@ class Vault(BaseModel):
archive_list.append(aobj)
return archive_list
def delete_archive(self, archive_id):
def delete_archive(self, archive_id: str) -> Dict[str, Any]:
return self.archives.pop(archive_id)
def initiate_job(self, job_type, tier, archive_id):
def initiate_job(self, job_type: str, tier: str, archive_id: Optional[str]) -> str:
job_id = get_job_id()
if job_type == "inventory-retrieval":
job = InventoryJob(job_id, tier, self.arn)
self.jobs[job_id] = InventoryJob(job_id, tier, self.arn)
elif job_type == "archive-retrieval":
job = ArchiveJob(job_id, tier, self.arn, archive_id)
self.jobs[job_id] = ArchiveJob(job_id, tier, self.arn, archive_id)
self.jobs[job_id] = job
return job_id
def list_jobs(self):
return self.jobs.values()
def list_jobs(self) -> List[Job]:
return list(self.jobs.values())
def describe_job(self, job_id):
def describe_job(self, job_id: str) -> Optional[Job]:
return self.jobs.get(job_id)
def job_ready(self, job_id):
def job_ready(self, job_id: str) -> str:
job = self.describe_job(job_id)
jobj = job.to_dict()
jobj = job.to_dict() # type: ignore
return jobj["Completed"]
def get_job_output(self, job_id):
def get_job_output(self, job_id: str) -> Union[str, Dict[str, Any]]:
job = self.describe_job(job_id)
jobj = job.to_dict()
jobj = job.to_dict() # type: ignore
if jobj["Action"] == "InventoryRetrieval":
archives = self.get_archive_list()
return {
@ -177,48 +179,54 @@ class Vault(BaseModel):
"ArchiveList": archives,
}
else:
archive_body = self.get_archive_body(job.archive_id)
archive_body = self.get_archive_body(job.archive_id) # type: ignore
return archive_body
class GlacierBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.vaults = {}
self.vaults: Dict[str, Vault] = {}
def get_vault(self, vault_name):
def get_vault(self, vault_name: str) -> Vault:
return self.vaults[vault_name]
def create_vault(self, vault_name):
def create_vault(self, vault_name: str) -> None:
self.vaults[vault_name] = Vault(vault_name, self.account_id, self.region_name)
def list_vaults(self):
return self.vaults.values()
def list_vaults(self) -> List[Vault]:
return list(self.vaults.values())
def delete_vault(self, vault_name):
def delete_vault(self, vault_name: str) -> None:
self.vaults.pop(vault_name)
def initiate_job(self, vault_name, job_type, tier, archive_id):
def initiate_job(
self, vault_name: str, job_type: str, tier: str, archive_id: Optional[str]
) -> str:
vault = self.get_vault(vault_name)
job_id = vault.initiate_job(job_type, tier, archive_id)
return job_id
def describe_job(self, vault_name, archive_id):
def describe_job(self, vault_name: str, archive_id: str) -> Optional[Job]:
vault = self.get_vault(vault_name)
return vault.describe_job(archive_id)
def list_jobs(self, vault_name):
def list_jobs(self, vault_name: str) -> List[Job]:
vault = self.get_vault(vault_name)
return vault.list_jobs()
def get_job_output(self, vault_name, job_id):
def get_job_output(
self, vault_name: str, job_id: str
) -> Union[str, Dict[str, Any], None]:
vault = self.get_vault(vault_name)
if vault.job_ready(job_id):
return vault.get_job_output(job_id)
else:
return None
def upload_archive(self, vault_name, body, description):
def upload_archive(
self, vault_name: str, body: bytes, description: str
) -> Dict[str, Any]:
vault = self.get_vault(vault_name)
return vault.create_archive(body, description)

View File

@ -1,23 +1,26 @@
import json
from typing import Any, Dict
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from .models import glacier_backends
from .models import glacier_backends, GlacierBackend
from .utils import vault_from_glacier_url
class GlacierResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="glacier")
@property
def glacier_backend(self):
def glacier_backend(self) -> GlacierBackend:
return glacier_backends[self.current_account][self.region]
def all_vault_response(self, request, full_url, headers):
def all_vault_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
return self._all_vault_response(headers)
def _all_vault_response(self, headers):
def _all_vault_response(self, headers: Any) -> TYPE_RESPONSE:
vaults = self.glacier_backend.list_vaults()
response = json.dumps(
{"Marker": None, "VaultList": [vault.to_dict() for vault in vaults]}
@ -26,11 +29,13 @@ class GlacierResponse(BaseResponse):
headers["content-type"] = "application/json"
return 200, headers, response
def vault_response(self, request, full_url, headers):
def vault_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
return self._vault_response(request, full_url, headers)
def _vault_response(self, request, full_url, headers):
def _vault_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
method = request.method
vault_name = vault_from_glacier_url(full_url)
@ -41,23 +46,27 @@ class GlacierResponse(BaseResponse):
elif method == "DELETE":
return self._vault_response_delete(vault_name, headers)
def _vault_response_get(self, vault_name, headers):
def _vault_response_get(self, vault_name: str, headers: Any) -> TYPE_RESPONSE:
vault = self.glacier_backend.get_vault(vault_name)
headers["content-type"] = "application/json"
return 200, headers, json.dumps(vault.to_dict())
def _vault_response_put(self, vault_name, headers):
def _vault_response_put(self, vault_name: str, headers: Any) -> TYPE_RESPONSE:
self.glacier_backend.create_vault(vault_name)
return 201, headers, ""
def _vault_response_delete(self, vault_name, headers):
def _vault_response_delete(self, vault_name: str, headers: Any) -> TYPE_RESPONSE:
self.glacier_backend.delete_vault(vault_name)
return 204, headers, ""
def vault_archive_response(self, request, full_url, headers):
def vault_archive_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
return self._vault_archive_response(request, full_url, headers)
def _vault_archive_response(self, request, full_url, headers):
def _vault_archive_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
method = request.method
if hasattr(request, "body"):
body = request.body
@ -75,17 +84,21 @@ class GlacierResponse(BaseResponse):
else:
return 400, headers, "400 Bad Request"
def _vault_archive_response_post(self, vault_name, body, description, headers):
def _vault_archive_response_post(
self, vault_name: str, body: bytes, description: str, headers: Dict[str, Any]
) -> TYPE_RESPONSE:
vault = self.glacier_backend.upload_archive(vault_name, body, description)
headers["x-amz-archive-id"] = vault["archive_id"]
headers["x-amz-sha256-tree-hash"] = vault["sha256"]
return 201, headers, ""
def vault_archive_individual_response(self, request, full_url, headers):
def vault_archive_individual_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
return self._vault_archive_individual_response(request, full_url, headers)
def _vault_archive_individual_response(self, request, full_url, headers):
def _vault_archive_individual_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
method = request.method
vault_name = full_url.split("/")[-3]
archive_id = full_url.split("/")[-1]
@ -95,11 +108,13 @@ class GlacierResponse(BaseResponse):
vault.delete_archive(archive_id)
return 204, headers, ""
def vault_jobs_response(self, request, full_url, headers):
def vault_jobs_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
return self._vault_jobs_response(request, full_url, headers)
def _vault_jobs_response(self, request, full_url, headers):
def _vault_jobs_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
method = request.method
if hasattr(request, "body"):
body = request.body
@ -135,22 +150,28 @@ class GlacierResponse(BaseResponse):
headers["Location"] = f"/{account_id}/vaults/{vault_name}/jobs/{job_id}"
return 202, headers, ""
def vault_jobs_individual_response(self, request, full_url, headers):
def vault_jobs_individual_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
return self._vault_jobs_individual_response(full_url, headers)
def _vault_jobs_individual_response(self, full_url, headers):
def _vault_jobs_individual_response(
self, full_url: str, headers: Any
) -> TYPE_RESPONSE:
vault_name = full_url.split("/")[-3]
archive_id = full_url.split("/")[-1]
job = self.glacier_backend.describe_job(vault_name, archive_id)
return 200, headers, json.dumps(job.to_dict())
return 200, headers, json.dumps(job.to_dict()) # type: ignore
def vault_jobs_output_response(self, request, full_url, headers):
def vault_jobs_output_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
return self._vault_jobs_output_response(full_url, headers)
def _vault_jobs_output_response(self, full_url, headers):
def _vault_jobs_output_response(self, full_url: str, headers: Any) -> TYPE_RESPONSE:
vault_name = full_url.split("/")[-4]
job_id = full_url.split("/")[-2]
output = self.glacier_backend.get_job_output(vault_name, job_id)

View File

@ -2,11 +2,11 @@ from moto.moto_api._internal import mock_random as random
import string
def vault_from_glacier_url(full_url):
def vault_from_glacier_url(full_url: str) -> str:
return full_url.split("/")[-1]
def get_job_id():
def get_job_id() -> str:
return "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(92)
)

View File

@ -1,3 +1,4 @@
from typing import Optional
from moto.core.exceptions import JsonRESTError
@ -6,72 +7,78 @@ class GlueClientError(JsonRESTError):
class AlreadyExistsException(GlueClientError):
def __init__(self, typ):
def __init__(self, typ: str):
super().__init__("AlreadyExistsException", f"{typ} already exists.")
class DatabaseAlreadyExistsException(AlreadyExistsException):
def __init__(self):
def __init__(self) -> None:
super().__init__("Database")
class TableAlreadyExistsException(AlreadyExistsException):
def __init__(self):
def __init__(self) -> None:
super().__init__("Table")
class PartitionAlreadyExistsException(AlreadyExistsException):
def __init__(self):
def __init__(self) -> None:
super().__init__("Partition")
class CrawlerAlreadyExistsException(AlreadyExistsException):
def __init__(self):
def __init__(self) -> None:
super().__init__("Crawler")
class EntityNotFoundException(GlueClientError):
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__("EntityNotFoundException", msg)
class DatabaseNotFoundException(EntityNotFoundException):
def __init__(self, db):
def __init__(self, db: str):
super().__init__(f"Database {db} not found.")
class TableNotFoundException(EntityNotFoundException):
def __init__(self, tbl):
def __init__(self, tbl: str):
super().__init__(f"Table {tbl} not found.")
class PartitionNotFoundException(EntityNotFoundException):
def __init__(self):
def __init__(self) -> None:
super().__init__("Cannot find partition.")
class CrawlerNotFoundException(EntityNotFoundException):
def __init__(self, crawler):
def __init__(self, crawler: str):
super().__init__(f"Crawler {crawler} not found.")
class JobNotFoundException(EntityNotFoundException):
def __init__(self, job):
def __init__(self, job: str):
super().__init__(f"Job {job} not found.")
class JobRunNotFoundException(EntityNotFoundException):
def __init__(self, job_run):
def __init__(self, job_run: str):
super().__init__(f"Job run {job_run} not found.")
class VersionNotFoundException(EntityNotFoundException):
def __init__(self):
def __init__(self) -> None:
super().__init__("Version not found.")
class SchemaNotFoundException(EntityNotFoundException):
def __init__(self, schema_name, registry_name, schema_arn, null="null"):
def __init__(
self,
schema_name: str,
registry_name: str,
schema_arn: Optional[str],
null: str = "null",
):
super().__init__(
f"Schema is not found. RegistryName: {registry_name if registry_name else null}, SchemaName: {schema_name if schema_name else null}, SchemaArn: {schema_arn if schema_arn else null}",
)
@ -80,13 +87,13 @@ class SchemaNotFoundException(EntityNotFoundException):
class SchemaVersionNotFoundFromSchemaIdException(EntityNotFoundException):
def __init__(
self,
registry_name,
schema_name,
schema_arn,
version_number,
latest_version,
null="null",
false="false",
registry_name: Optional[str],
schema_name: Optional[str],
schema_arn: Optional[str],
version_number: Optional[str],
latest_version: Optional[str],
null: str = "null",
false: str = "false",
):
super().__init__(
f"Schema version is not found. RegistryName: {registry_name if registry_name else null}, SchemaName: {schema_name if schema_name else null}, SchemaArn: {schema_arn if schema_arn else null}, VersionNumber: {version_number if version_number else null}, isLatestVersion: {latest_version if latest_version else false}",
@ -94,36 +101,36 @@ class SchemaVersionNotFoundFromSchemaIdException(EntityNotFoundException):
class SchemaVersionNotFoundFromSchemaVersionIdException(EntityNotFoundException):
def __init__(self, schema_version_id):
def __init__(self, schema_version_id: str):
super().__init__(
f"Schema version is not found. SchemaVersionId: {schema_version_id}",
)
class RegistryNotFoundException(EntityNotFoundException):
def __init__(self, resource, param_name, param_value):
def __init__(self, resource: str, param_name: str, param_value: Optional[str]):
super().__init__(
resource + " is not found. " + param_name + ": " + param_value,
resource + " is not found. " + param_name + ": " + param_value, # type: ignore
)
class CrawlerRunningException(GlueClientError):
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__("CrawlerRunningException", msg)
class CrawlerNotRunningException(GlueClientError):
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__("CrawlerNotRunningException", msg)
class ConcurrentRunsExceededException(GlueClientError):
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__("ConcurrentRunsExceededException", msg)
class ResourceNumberLimitExceededException(GlueClientError):
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__(
"ResourceNumberLimitExceededException",
msg,
@ -131,7 +138,7 @@ class ResourceNumberLimitExceededException(GlueClientError):
class GeneralResourceNumberLimitExceededException(ResourceNumberLimitExceededException):
def __init__(self, resource):
def __init__(self, resource: str):
super().__init__(
"More "
+ resource
@ -140,14 +147,14 @@ class GeneralResourceNumberLimitExceededException(ResourceNumberLimitExceededExc
class SchemaVersionMetadataLimitExceededException(ResourceNumberLimitExceededException):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"Your resource limits for Schema Version Metadata have been exceeded.",
)
class GSRAlreadyExistsException(GlueClientError):
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__(
"AlreadyExistsException",
msg,
@ -155,21 +162,21 @@ class GSRAlreadyExistsException(GlueClientError):
class SchemaVersionMetadataAlreadyExistsException(GSRAlreadyExistsException):
def __init__(self, schema_version_id, metadata_key, metadata_value):
def __init__(self, schema_version_id: str, metadata_key: str, metadata_value: str):
super().__init__(
f"Resource already exist for schema version id: {schema_version_id}, metadata key: {metadata_key}, metadata value: {metadata_value}",
)
class GeneralGSRAlreadyExistsException(GSRAlreadyExistsException):
def __init__(self, resource, param_name, param_value):
def __init__(self, resource: str, param_name: str, param_value: str):
super().__init__(
resource + " already exists. " + param_name + ": " + param_value,
)
class _InvalidOperationException(GlueClientError):
def __init__(self, error_type, op, msg):
def __init__(self, error_type: str, op: str, msg: str):
super().__init__(
error_type,
"An error occurred (%s) when calling the %s operation: %s"
@ -178,22 +185,22 @@ class _InvalidOperationException(GlueClientError):
class InvalidStateException(_InvalidOperationException):
def __init__(self, op, msg):
def __init__(self, op: str, msg: str):
super().__init__("InvalidStateException", op, msg)
class InvalidInputException(_InvalidOperationException):
def __init__(self, op, msg):
def __init__(self, op: str, msg: str):
super().__init__("InvalidInputException", op, msg)
class GSRInvalidInputException(GlueClientError):
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__("InvalidInputException", msg)
class ResourceNameTooLongException(GSRInvalidInputException):
def __init__(self, param_name):
def __init__(self, param_name: str):
super().__init__(
"The resource name contains too many or too few characters. Parameter Name: "
+ param_name,
@ -201,7 +208,7 @@ class ResourceNameTooLongException(GSRInvalidInputException):
class ParamValueContainsInvalidCharactersException(GSRInvalidInputException):
def __init__(self, param_name):
def __init__(self, param_name: str):
super().__init__(
"The parameter value contains one or more characters that are not valid. Parameter Name: "
+ param_name,
@ -209,28 +216,28 @@ class ParamValueContainsInvalidCharactersException(GSRInvalidInputException):
class InvalidNumberOfTagsException(GSRInvalidInputException):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"New Tags cannot be empty or more than 50",
)
class InvalidDataFormatException(GSRInvalidInputException):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"Data format is not valid.",
)
class InvalidCompatibilityException(GSRInvalidInputException):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"Compatibility is not valid.",
)
class InvalidSchemaDefinitionException(GSRInvalidInputException):
def __init__(self, data_format_name, err):
def __init__(self, data_format_name: str, err: ValueError):
super().__init__(
"Schema definition of "
+ data_format_name
@ -240,45 +247,51 @@ class InvalidSchemaDefinitionException(GSRInvalidInputException):
class InvalidRegistryIdBothParamsProvidedException(GSRInvalidInputException):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"One of registryName or registryArn has to be provided, both cannot be provided.",
)
class InvalidSchemaIdBothParamsProvidedException(GSRInvalidInputException):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"One of (registryName and schemaName) or schemaArn has to be provided, both cannot be provided.",
)
class InvalidSchemaIdNotProvidedException(GSRInvalidInputException):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"At least one of (registryName and schemaName) or schemaArn has to be provided.",
)
class InvalidSchemaVersionNumberBothParamsProvidedException(GSRInvalidInputException):
def __init__(self):
def __init__(self) -> None:
super().__init__("Only one of VersionNumber or LatestVersion is required.")
class InvalidSchemaVersionNumberNotProvidedException(GSRInvalidInputException):
def __init__(self):
def __init__(self) -> None:
super().__init__("One of version number (or) latest version is required.")
class InvalidSchemaVersionIdProvidedWithOtherParamsException(GSRInvalidInputException):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"No other input parameters can be specified when fetching by SchemaVersionId."
)
class DisabledCompatibilityVersioningException(GSRInvalidInputException):
def __init__(self, schema_name, registry_name, schema_arn, null="null"):
def __init__(
self,
schema_name: str,
registry_name: str,
schema_arn: Optional[str],
null: str = "null",
):
super().__init__(
f"Compatibility DISABLED does not allow versioning. SchemaId: SchemaId(schemaArn={schema_arn if schema_arn else null}, schemaName={schema_name if schema_name else null}, registryName={registry_name if registry_name else null})"
)

View File

@ -1,5 +1,6 @@
import re
import json
from typing import Any, Dict, Optional, Tuple, Pattern
from .glue_schema_registry_constants import (
MAX_REGISTRY_NAME_LENGTH,
@ -53,7 +54,7 @@ from .exceptions import (
)
def validate_registry_name_pattern_and_length(param_value):
def validate_registry_name_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length(
param_value,
param_name="registryName",
@ -62,7 +63,7 @@ def validate_registry_name_pattern_and_length(param_value):
)
def validate_arn_pattern_and_length(param_value):
def validate_arn_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length(
param_value,
param_name="registryArn",
@ -71,7 +72,7 @@ def validate_arn_pattern_and_length(param_value):
)
def validate_description_pattern_and_length(param_value):
def validate_description_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length(
param_value,
param_name="description",
@ -80,7 +81,7 @@ def validate_description_pattern_and_length(param_value):
)
def validate_schema_name_pattern_and_length(param_value):
def validate_schema_name_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length(
param_value,
param_name="schemaName",
@ -89,7 +90,7 @@ def validate_schema_name_pattern_and_length(param_value):
)
def validate_schema_version_metadata_key_pattern_and_length(param_value):
def validate_schema_version_metadata_key_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length(
param_value,
param_name="key",
@ -98,7 +99,7 @@ def validate_schema_version_metadata_key_pattern_and_length(param_value):
)
def validate_schema_version_metadata_value_pattern_and_length(param_value):
def validate_schema_version_metadata_value_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length(
param_value,
param_name="value",
@ -108,8 +109,8 @@ def validate_schema_version_metadata_value_pattern_and_length(param_value):
def validate_param_pattern_and_length(
param_value, param_name, max_name_length, pattern
):
param_value: str, param_name: str, max_name_length: int, pattern: Pattern[str]
) -> None:
if len(param_value.encode("utf-8")) > max_name_length:
raise ResourceNameTooLongException(param_name)
@ -117,7 +118,7 @@ def validate_param_pattern_and_length(
raise ParamValueContainsInvalidCharactersException(param_name)
def validate_schema_definition(schema_definition, data_format):
def validate_schema_definition(schema_definition: str, data_format: str) -> None:
validate_schema_definition_length(schema_definition)
if data_format in ["AVRO", "JSON"]:
try:
@ -126,38 +127,39 @@ def validate_schema_definition(schema_definition, data_format):
raise InvalidSchemaDefinitionException(data_format, err)
def validate_schema_definition_length(schema_definition):
def validate_schema_definition_length(schema_definition: str) -> None:
if len(schema_definition) > MAX_SCHEMA_DEFINITION_LENGTH:
param_name = SCHEMA_DEFINITION
raise ResourceNameTooLongException(param_name)
def validate_schema_version_id_pattern(schema_version_id):
def validate_schema_version_id_pattern(schema_version_id: str) -> None:
if re.match(SCHEMA_VERSION_ID_PATTERN, schema_version_id) is None:
raise ParamValueContainsInvalidCharactersException(SCHEMA_VERSION_ID)
def validate_number_of_tags(tags):
def validate_number_of_tags(tags: Dict[str, str]) -> None:
if len(tags) > MAX_TAGS_ALLOWED:
raise InvalidNumberOfTagsException()
def validate_registry_id(registry_id, registries):
def validate_registry_id(
registry_id: Dict[str, Any], registries: Dict[str, Any]
) -> str:
if not registry_id:
registry_name = DEFAULT_REGISTRY_NAME
return registry_name
return DEFAULT_REGISTRY_NAME
if registry_id.get(REGISTRY_NAME) and registry_id.get(REGISTRY_ARN):
raise InvalidRegistryIdBothParamsProvidedException()
if registry_id.get(REGISTRY_NAME):
registry_name = registry_id.get(REGISTRY_NAME)
validate_registry_name_pattern_and_length(registry_name)
validate_registry_name_pattern_and_length(registry_name) # type: ignore
elif registry_id.get(REGISTRY_ARN):
registry_arn = registry_id.get(REGISTRY_ARN)
validate_arn_pattern_and_length(registry_arn)
registry_name = registry_arn.split("/")[-1]
validate_arn_pattern_and_length(registry_arn) # type: ignore
registry_name = registry_arn.split("/")[-1] # type: ignore
if registry_name != DEFAULT_REGISTRY_NAME and registry_name not in registries:
if registry_id.get(REGISTRY_NAME):
@ -174,10 +176,15 @@ def validate_registry_id(registry_id, registries):
param_value=registry_arn,
)
return registry_name
return registry_name # type: ignore
def validate_registry_params(registries, registry_name, description=None, tags=None):
def validate_registry_params(
registries: Any,
registry_name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> None:
validate_registry_name_pattern_and_length(registry_name)
if description:
@ -197,7 +204,9 @@ def validate_registry_params(registries, registry_name, description=None, tags=N
)
def validate_schema_id(schema_id, registries):
def validate_schema_id(
schema_id: Dict[str, str], registries: Dict[str, Any]
) -> Tuple[str, str, Optional[str]]:
schema_arn = schema_id.get(SCHEMA_ARN)
registry_name = schema_id.get(REGISTRY_NAME)
schema_name = schema_id.get(SCHEMA_NAME)
@ -225,15 +234,15 @@ def validate_schema_id(schema_id, registries):
def validate_schema_params(
registry,
schema_name,
data_format,
compatibility,
schema_definition,
num_schemas,
description=None,
tags=None,
):
registry: Any,
schema_name: str,
data_format: str,
compatibility: str,
schema_definition: str,
num_schemas: int,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> None:
validate_schema_name_pattern_and_length(schema_name)
if data_format not in ["AVRO", "JSON", "PROTOBUF"]:
@ -271,14 +280,14 @@ def validate_schema_params(
def validate_register_schema_version_params(
registry_name,
schema_name,
schema_arn,
num_schema_versions,
schema_definition,
compatibility,
data_format,
):
registry_name: str,
schema_name: str,
schema_arn: Optional[str],
num_schema_versions: int,
schema_definition: str,
compatibility: str,
data_format: str,
) -> None:
if compatibility == "DISABLED":
raise DisabledCompatibilityVersioningException(
schema_name, registry_name, schema_arn
@ -290,9 +299,19 @@ def validate_register_schema_version_params(
raise GeneralResourceNumberLimitExceededException(resource="schema versions")
def validate_schema_version_params(
registries, schema_id, schema_version_id, schema_version_number
):
def validate_schema_version_params( # type: ignore[return]
registries: Dict[str, Any],
schema_id: Optional[Dict[str, Any]],
schema_version_id: Optional[str],
schema_version_number: Optional[Dict[str, Any]],
) -> Tuple[
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
]:
if not schema_version_id and not schema_id and not schema_version_number:
raise InvalidSchemaIdNotProvidedException()
@ -329,8 +348,11 @@ def validate_schema_version_params(
def validate_schema_version_number(
registries, registry_name, schema_name, schema_version_number
):
registries: Dict[str, Any],
registry_name: str,
schema_name: str,
schema_version_number: Dict[str, str],
) -> Tuple[str, str]:
latest_version = schema_version_number.get(LATEST_VERSION)
version_number = schema_version_number.get(VERSION_NUMBER)
schema = registries[registry_name].schemas[schema_name]
@ -339,20 +361,24 @@ def validate_schema_version_number(
raise InvalidSchemaVersionNumberBothParamsProvidedException()
return schema.latest_schema_version, latest_version
return version_number, latest_version
return version_number, latest_version # type: ignore
def validate_schema_version_metadata_pattern_and_length(metadata_key_value):
def validate_schema_version_metadata_pattern_and_length(
metadata_key_value: Dict[str, str]
) -> Tuple[str, str]:
metadata_key = metadata_key_value.get(METADATA_KEY)
metadata_value = metadata_key_value.get(METADATA_VALUE)
validate_schema_version_metadata_key_pattern_and_length(metadata_key)
validate_schema_version_metadata_value_pattern_and_length(metadata_value)
validate_schema_version_metadata_key_pattern_and_length(metadata_key) # type: ignore
validate_schema_version_metadata_value_pattern_and_length(metadata_value) # type: ignore
return metadata_key, metadata_value
return metadata_key, metadata_value # type: ignore[return-value]
def validate_number_of_schema_version_metadata_allowed(metadata):
def validate_number_of_schema_version_metadata_allowed(
metadata: Dict[str, Any]
) -> None:
num_metadata_key_value_pairs = 0
for m in metadata.values():
num_metadata_key_value_pairs += len(m)
@ -362,8 +388,8 @@ def validate_number_of_schema_version_metadata_allowed(metadata):
def get_schema_version_if_definition_exists(
schema_versions, data_format, schema_definition
):
schema_versions: Any, data_format: str, schema_definition: str
) -> Optional[Dict[str, Any]]:
if data_format in ["AVRO", "JSON"]:
for schema_version in schema_versions:
if json.loads(schema_definition) == json.loads(
@ -378,9 +404,12 @@ def get_schema_version_if_definition_exists(
def get_put_schema_version_metadata_response(
schema_id, schema_version_number, schema_version_id, metadata_key_value
):
put_schema_version_metadata_response_dict = {}
schema_id: Dict[str, Any],
schema_version_number: Optional[Dict[str, str]],
schema_version_id: str,
metadata_key_value: Dict[str, str],
) -> Dict[str, Any]:
put_schema_version_metadata_response_dict: Dict[str, Any] = {}
if schema_version_id:
put_schema_version_metadata_response_dict[SCHEMA_VERSION_ID] = schema_version_id
if schema_id:
@ -416,7 +445,9 @@ def get_put_schema_version_metadata_response(
return put_schema_version_metadata_response_dict
def delete_schema_response(schema_name, schema_arn, status):
def delete_schema_response(
schema_name: str, schema_arn: str, status: str
) -> Dict[str, Any]:
return {
"SchemaName": schema_name,
"SchemaArn": schema_arn,

View File

@ -3,7 +3,7 @@ import time
from collections import OrderedDict
from datetime import datetime
import re
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api import state_manager
@ -77,12 +77,12 @@ class GlueBackend(BaseBackend):
},
}
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.databases = OrderedDict()
self.crawlers = OrderedDict()
self.jobs = OrderedDict()
self.job_runs = OrderedDict()
self.databases: Dict[str, FakeDatabase] = OrderedDict()
self.crawlers: Dict[str, FakeCrawler] = OrderedDict()
self.jobs: Dict[str, FakeJob] = OrderedDict()
self.job_runs: Dict[str, FakeJobRun] = OrderedDict()
self.tagger = TaggingService()
self.registries: Dict[str, FakeRegistry] = OrderedDict()
self.num_schemas = 0
@ -93,13 +93,17 @@ class GlueBackend(BaseBackend):
)
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "glue"
)
def create_database(self, database_name, database_input):
def create_database(
self, database_name: str, database_input: Dict[str, Any]
) -> "FakeDatabase":
if database_name in self.databases:
raise DatabaseAlreadyExistsException()
@ -107,27 +111,31 @@ class GlueBackend(BaseBackend):
self.databases[database_name] = database
return database
def get_database(self, database_name):
def get_database(self, database_name: str) -> "FakeDatabase":
try:
return self.databases[database_name]
except KeyError:
raise DatabaseNotFoundException(database_name)
def update_database(self, database_name, database_input):
def update_database(
self, database_name: str, database_input: Dict[str, Any]
) -> None:
if database_name not in self.databases:
raise DatabaseNotFoundException(database_name)
self.databases[database_name].input = database_input
def get_databases(self):
def get_databases(self) -> List["FakeDatabase"]:
return [self.databases[key] for key in self.databases] if self.databases else []
def delete_database(self, database_name):
def delete_database(self, database_name: str) -> None:
if database_name not in self.databases:
raise DatabaseNotFoundException(database_name)
del self.databases[database_name]
def create_table(self, database_name, table_name, table_input):
def create_table(
self, database_name: str, table_name: str, table_input: Dict[str, Any]
) -> "FakeTable":
database = self.get_database(database_name)
if table_name in database.tables:
@ -144,7 +152,9 @@ class GlueBackend(BaseBackend):
except KeyError:
raise TableNotFoundException(table_name)
def get_tables(self, database_name, expression):
def get_tables(
self, database_name: str, expression: Optional[str]
) -> List["FakeTable"]:
database = self.get_database(database_name)
if expression:
# sanitise expression, * is treated as a glob-like wildcard
@ -164,15 +174,16 @@ class GlueBackend(BaseBackend):
else:
return [table for table_name, table in database.tables.items()]
def delete_table(self, database_name, table_name):
def delete_table(self, database_name: str, table_name: str) -> None:
database = self.get_database(database_name)
try:
del database.tables[table_name]
except KeyError:
raise TableNotFoundException(table_name)
return {}
def update_table(self, database_name, table_name: str, table_input) -> None:
def update_table(
self, database_name: str, table_name: str, table_input: Dict[str, Any]
) -> None:
table = self.get_table(database_name, table_name)
table.update(table_input)
@ -202,15 +213,21 @@ class GlueBackend(BaseBackend):
table = self.get_table(database_name, table_name)
table.delete_version(version_id)
def create_partition(self, database_name: str, table_name: str, part_input) -> None:
def create_partition(
self, database_name: str, table_name: str, part_input: Dict[str, Any]
) -> None:
table = self.get_table(database_name, table_name)
table.create_partition(part_input)
def get_partition(self, database_name: str, table_name: str, values):
def get_partition(
self, database_name: str, table_name: str, values: str
) -> "FakePartition":
table = self.get_table(database_name, table_name)
return table.get_partition(values)
def get_partitions(self, database_name, table_name, expression):
def get_partitions(
self, database_name: str, table_name: str, expression: str
) -> List["FakePartition"]:
"""
See https://docs.aws.amazon.com/glue/latest/webapi/API_GetPartitions.html
for supported expressions.
@ -226,34 +243,38 @@ class GlueBackend(BaseBackend):
return table.get_partitions(expression)
def update_partition(
self, database_name, table_name, part_input, part_to_update
self,
database_name: str,
table_name: str,
part_input: Dict[str, Any],
part_to_update: str,
) -> None:
table = self.get_table(database_name, table_name)
table.update_partition(part_to_update, part_input)
def delete_partition(
self, database_name: str, table_name: str, part_to_delete
self, database_name: str, table_name: str, part_to_delete: int
) -> None:
table = self.get_table(database_name, table_name)
table.delete_partition(part_to_delete)
def create_crawler(
self,
name,
role,
database_name,
description,
targets,
schedule,
classifiers,
table_prefix,
schema_change_policy,
recrawl_policy,
lineage_configuration,
configuration,
crawler_security_configuration,
tags,
):
name: str,
role: str,
database_name: str,
description: str,
targets: Dict[str, Any],
schedule: str,
classifiers: List[str],
table_prefix: str,
schema_change_policy: Dict[str, str],
recrawl_policy: Dict[str, str],
lineage_configuration: Dict[str, str],
configuration: str,
crawler_security_configuration: str,
tags: Dict[str, str],
) -> None:
if name in self.crawlers:
raise CrawlerAlreadyExistsException()
@ -276,28 +297,28 @@ class GlueBackend(BaseBackend):
)
self.crawlers[name] = crawler
def get_crawler(self, name):
def get_crawler(self, name: str) -> "FakeCrawler":
try:
return self.crawlers[name]
except KeyError:
raise CrawlerNotFoundException(name)
def get_crawlers(self):
def get_crawlers(self) -> List["FakeCrawler"]:
return [self.crawlers[key] for key in self.crawlers] if self.crawlers else []
@paginate(pagination_model=PAGINATION_MODEL)
def list_crawlers(self):
def list_crawlers(self) -> List["FakeCrawler"]: # type: ignore[misc]
return [crawler for _, crawler in self.crawlers.items()]
def start_crawler(self, name):
def start_crawler(self, name: str) -> None:
crawler = self.get_crawler(name)
crawler.start_crawler()
def stop_crawler(self, name):
def stop_crawler(self, name: str) -> None:
crawler = self.get_crawler(name)
crawler.stop_crawler()
def delete_crawler(self, name):
def delete_crawler(self, name: str) -> None:
try:
del self.crawlers[name]
except KeyError:
@ -305,29 +326,29 @@ class GlueBackend(BaseBackend):
def create_job(
self,
name,
role,
command,
description,
log_uri,
execution_property,
default_arguments,
non_overridable_arguments,
connections,
max_retries,
allocated_capacity,
timeout,
max_capacity,
security_configuration,
tags,
notification_property,
glue_version,
number_of_workers,
worker_type,
code_gen_configuration_nodes,
execution_class,
source_control_details,
):
name: str,
role: str,
command: str,
description: str,
log_uri: str,
execution_property: Dict[str, int],
default_arguments: Dict[str, str],
non_overridable_arguments: Dict[str, str],
connections: Dict[str, List[str]],
max_retries: int,
allocated_capacity: int,
timeout: int,
max_capacity: float,
security_configuration: str,
tags: Dict[str, str],
notification_property: Dict[str, int],
glue_version: str,
number_of_workers: int,
worker_type: str,
code_gen_configuration_nodes: Dict[str, Any],
execution_class: str,
source_control_details: Dict[str, str],
) -> None:
self.jobs[name] = FakeJob(
name,
role,
@ -353,46 +374,50 @@ class GlueBackend(BaseBackend):
source_control_details,
backend=self,
)
return name
def get_job(self, name):
def get_job(self, name: str) -> "FakeJob":
try:
return self.jobs[name]
except KeyError:
raise JobNotFoundException(name)
@paginate(pagination_model=PAGINATION_MODEL)
def get_jobs(self):
def get_jobs(self) -> List["FakeJob"]: # type: ignore
return [job for _, job in self.jobs.items()]
def start_job_run(self, name):
def start_job_run(self, name: str) -> str:
job = self.get_job(name)
return job.start_job_run()
def get_job_run(self, name, run_id):
def get_job_run(self, name: str, run_id: str) -> "FakeJobRun":
job = self.get_job(name)
return job.get_job_run(run_id)
@paginate(pagination_model=PAGINATION_MODEL)
def list_jobs(self):
def list_jobs(self) -> List["FakeJob"]: # type: ignore
return [job for _, job in self.jobs.items()]
def get_tags(self, resource_id):
def get_tags(self, resource_id: str) -> Dict[str, str]:
return self.tagger.get_tag_dict_for_resource(resource_id)
def tag_resource(self, resource_arn, tags):
tags = TaggingService.convert_dict_to_tags_input(tags or {})
self.tagger.tag_resource(resource_arn, tags)
def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None:
tag_list = TaggingService.convert_dict_to_tags_input(tags or {})
self.tagger.tag_resource(resource_arn, tag_list)
def untag_resource(self, resource_arn, tag_keys):
def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys)
def create_registry(self, registry_name, description=None, tags=None):
def create_registry(
self,
registry_name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
# If registry name id default-registry, create default-registry
if registry_name == DEFAULT_REGISTRY_NAME:
registry = FakeRegistry(self, registry_name, description, tags)
self.registries[registry_name] = registry
return registry
return registry # type: ignore
# Validate Registry Parameters
validate_registry_params(self.registries, registry_name, description, tags)
@ -401,27 +426,27 @@ class GlueBackend(BaseBackend):
self.registries[registry_name] = registry
return registry.as_dict()
def delete_registry(self, registry_id):
def delete_registry(self, registry_id: Dict[str, Any]) -> Dict[str, Any]:
registry_name = validate_registry_id(registry_id, self.registries)
return self.registries.pop(registry_name).as_dict()
def get_registry(self, registry_id):
def get_registry(self, registry_id: Dict[str, Any]) -> Dict[str, Any]:
registry_name = validate_registry_id(registry_id, self.registries)
return self.registries[registry_name].as_dict()
def list_registries(self):
def list_registries(self) -> List[Dict[str, Any]]:
return [reg.as_dict() for reg in self.registries.values()]
def create_schema(
self,
registry_id,
schema_name,
data_format,
compatibility,
schema_definition,
description=None,
tags=None,
):
registry_id: Dict[str, Any],
schema_name: str,
data_format: str,
compatibility: str,
schema_definition: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
"""
The following parameters/features are not yet implemented: Glue Schema Registry: compatibility checks NONE | BACKWARD | BACKWARD_ALL | FORWARD | FORWARD_ALL | FULL | FULL_ALL and Data format parsing and syntax validation.
"""
@ -479,7 +504,9 @@ class GlueBackend(BaseBackend):
resp.update({"Tags": tags})
return resp
def register_schema_version(self, schema_id, schema_definition):
def register_schema_version(
self, schema_id: Dict[str, Any], schema_definition: str
) -> Dict[str, Any]:
# Validate Schema Id
registry_name, schema_name, schema_arn = validate_schema_id(
schema_id, self.registries
@ -538,8 +565,11 @@ class GlueBackend(BaseBackend):
return schema_version.as_dict()
def get_schema_version(
self, schema_id=None, schema_version_id=None, schema_version_number=None
):
self,
schema_id: Optional[Dict[str, str]] = None,
schema_version_id: Optional[str] = None,
schema_version_number: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# Validate Schema Parameters
(
schema_version_id,
@ -571,10 +601,10 @@ class GlueBackend(BaseBackend):
raise SchemaVersionNotFoundFromSchemaVersionIdException(schema_version_id)
# GetSchemaVersion using VersionNumber
schema = self.registries[registry_name].schemas[schema_name]
schema = self.registries[registry_name].schemas[schema_name] # type: ignore
for schema_version in schema.schema_versions.values():
if (
version_number == schema_version.version_number
version_number == schema_version.version_number # type: ignore
and schema_version.schema_version_status != DELETING_STATUS
):
get_schema_version_dict = schema_version.get_schema_version_as_dict()
@ -584,7 +614,9 @@ class GlueBackend(BaseBackend):
registry_name, schema_name, schema_arn, version_number, latest_version
)
def get_schema_by_definition(self, schema_id, schema_definition):
def get_schema_by_definition(
self, schema_id: Dict[str, str], schema_definition: str
) -> Dict[str, Any]:
# Validate SchemaId
validate_schema_definition_length(schema_definition)
registry_name, schema_name, schema_arn = validate_schema_id(
@ -606,8 +638,12 @@ class GlueBackend(BaseBackend):
raise SchemaNotFoundException(schema_name, registry_name, schema_arn)
def put_schema_version_metadata(
self, schema_id, schema_version_number, schema_version_id, metadata_key_value
):
self,
schema_id: Dict[str, Any],
schema_version_number: Dict[str, str],
schema_version_id: str,
metadata_key_value: Dict[str, str],
) -> Dict[str, Any]:
# Validate metadata_key_value and schema version params
(
metadata_key,
@ -620,7 +656,7 @@ class GlueBackend(BaseBackend):
schema_arn,
version_number,
latest_version,
) = validate_schema_version_params(
) = validate_schema_version_params( # type: ignore
self.registries, schema_id, schema_version_id, schema_version_number
)
@ -650,9 +686,9 @@ class GlueBackend(BaseBackend):
raise SchemaVersionNotFoundFromSchemaVersionIdException(schema_version_id)
# PutSchemaVersionMetadata using VersionNumber
schema = self.registries[registry_name].schemas[schema_name]
schema = self.registries[registry_name].schemas[schema_name] # type: ignore
for schema_version in schema.schema_versions.values():
if version_number == schema_version.version_number:
if version_number == schema_version.version_number: # type: ignore
validate_number_of_schema_version_metadata_allowed(
schema_version.metadata
)
@ -677,12 +713,12 @@ class GlueBackend(BaseBackend):
registry_name, schema_name, schema_arn, version_number, latest_version
)
def get_schema(self, schema_id):
def get_schema(self, schema_id: Dict[str, str]) -> Dict[str, Any]:
registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries)
schema = self.registries[registry_name].schemas[schema_name]
return schema.as_dict()
def delete_schema(self, schema_id):
def delete_schema(self, schema_id: Dict[str, str]) -> Dict[str, Any]:
# Validate schema_id
registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries)
@ -701,7 +737,9 @@ class GlueBackend(BaseBackend):
return response
def update_schema(self, schema_id, compatibility, description):
def update_schema(
self, schema_id: Dict[str, str], compatibility: str, description: str
) -> Dict[str, Any]:
"""
The SchemaVersionNumber-argument is not yet implemented
"""
@ -715,7 +753,9 @@ class GlueBackend(BaseBackend):
return schema.as_dict()
def batch_delete_table(self, database_name, tables):
def batch_delete_table(
self, database_name: str, tables: List[str]
) -> List[Dict[str, Any]]:
errors = []
for table_name in tables:
try:
@ -732,7 +772,12 @@ class GlueBackend(BaseBackend):
)
return errors
def batch_get_partition(self, database_name, table_name, partitions_to_get):
def batch_get_partition(
self,
database_name: str,
table_name: str,
partitions_to_get: List[Dict[str, str]],
) -> List[Dict[str, Any]]:
table = self.get_table(database_name, table_name)
partitions = []
@ -744,7 +789,9 @@ class GlueBackend(BaseBackend):
continue
return partitions
def batch_create_partition(self, database_name, table_name, partition_input):
def batch_create_partition(
self, database_name: str, table_name: str, partition_input: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
table = self.get_table(database_name, table_name)
errors_output = []
@ -763,7 +810,9 @@ class GlueBackend(BaseBackend):
)
return errors_output
def batch_update_partition(self, database_name, table_name, entries):
def batch_update_partition(
self, database_name: str, table_name: str, entries: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
table = self.get_table(database_name, table_name)
errors_output = []
@ -785,14 +834,16 @@ class GlueBackend(BaseBackend):
)
return errors_output
def batch_delete_partition(self, database_name, table_name, parts):
def batch_delete_partition(
self, database_name: str, table_name: str, parts: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
table = self.get_table(database_name, table_name)
errors_output = []
for part_input in parts:
values = part_input.get("Values")
try:
table.delete_partition(values)
table.delete_partition(values) # type: ignore
except PartitionNotFoundException:
errors_output.append(
{
@ -805,7 +856,7 @@ class GlueBackend(BaseBackend):
)
return errors_output
def batch_get_crawlers(self, crawler_names):
def batch_get_crawlers(self, crawler_names: List[str]) -> List[Dict[str, Any]]:
crawlers = []
for crawler in self.get_crawlers():
if crawler.as_dict()["Name"] in crawler_names:
@ -814,13 +865,13 @@ class GlueBackend(BaseBackend):
class FakeDatabase(BaseModel):
def __init__(self, database_name, database_input):
def __init__(self, database_name: str, database_input: Dict[str, Any]):
self.name = database_name
self.input = database_input
self.created_time = datetime.utcnow()
self.tables = OrderedDict()
self.tables: Dict[str, FakeTable] = OrderedDict()
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
return {
"Name": self.name,
"Description": self.input.get("Description"),
@ -836,23 +887,25 @@ class FakeDatabase(BaseModel):
class FakeTable(BaseModel):
def __init__(self, database_name: str, table_name: str, table_input):
def __init__(
self, database_name: str, table_name: str, table_input: Dict[str, Any]
):
self.database_name = database_name
self.name = table_name
self.partitions = OrderedDict()
self.partitions: Dict[str, FakePartition] = OrderedDict()
self.created_time = datetime.utcnow()
self.updated_time = None
self.updated_time: Optional[datetime] = None
self._current_version = 1
self.versions: Dict[str, Dict[str, Any]] = {
str(self._current_version): table_input
}
def update(self, table_input):
def update(self, table_input: Dict[str, Any]) -> None:
self.versions[str(self._current_version + 1)] = table_input
self._current_version += 1
self.updated_time = datetime.utcnow()
def get_version(self, ver):
def get_version(self, ver: str) -> Dict[str, Any]:
try:
int(ver)
except ValueError as e:
@ -863,11 +916,11 @@ class FakeTable(BaseModel):
except KeyError:
raise VersionNotFoundException()
def delete_version(self, version_id):
def delete_version(self, version_id: str) -> None:
self.versions.pop(version_id)
def as_dict(self, version=None):
version = version or self._current_version
def as_dict(self, version: Optional[str] = None) -> Dict[str, Any]:
version = version or self._current_version # type: ignore
obj = {
"DatabaseName": self.database_name,
"Name": self.name,
@ -880,23 +933,23 @@ class FakeTable(BaseModel):
obj["UpdateTime"] = unix_time(self.updated_time)
return obj
def create_partition(self, partiton_input):
def create_partition(self, partiton_input: Dict[str, Any]) -> None:
partition = FakePartition(self.database_name, self.name, partiton_input)
key = str(partition.values)
if key in self.partitions:
raise PartitionAlreadyExistsException()
self.partitions[str(partition.values)] = partition
def get_partitions(self, expression):
def get_partitions(self, expression: str) -> List["FakePartition"]:
return list(filter(PartitionFilter(expression, self), self.partitions.values()))
def get_partition(self, values):
def get_partition(self, values: str) -> "FakePartition":
try:
return self.partitions[str(values)]
except KeyError:
raise PartitionNotFoundException()
def update_partition(self, old_values, partiton_input):
def update_partition(self, old_values: str, partiton_input: Dict[str, Any]) -> None:
partition = FakePartition(self.database_name, self.name, partiton_input)
key = str(partition.values)
if old_values == partiton_input["Values"]:
@ -913,7 +966,7 @@ class FakeTable(BaseModel):
raise PartitionAlreadyExistsException()
self.partitions[key] = partition
def delete_partition(self, values):
def delete_partition(self, values: int) -> None:
try:
del self.partitions[str(values)]
except KeyError:
@ -921,14 +974,16 @@ class FakeTable(BaseModel):
class FakePartition(BaseModel):
def __init__(self, database_name, table_name, partiton_input):
def __init__(
self, database_name: str, table_name: str, partiton_input: Dict[str, Any]
):
self.creation_time = time.time()
self.database_name = database_name
self.table_name = table_name
self.partition_input = partiton_input
self.values = self.partition_input.get("Values", [])
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
obj = {
"DatabaseName": self.database_name,
"TableName": self.table_name,
@ -941,21 +996,21 @@ class FakePartition(BaseModel):
class FakeCrawler(BaseModel):
def __init__(
self,
name,
role,
database_name,
description,
targets,
schedule,
classifiers,
table_prefix,
schema_change_policy,
recrawl_policy,
lineage_configuration,
configuration,
crawler_security_configuration,
tags,
backend,
name: str,
role: str,
database_name: str,
description: str,
targets: Dict[str, Any],
schedule: str,
classifiers: List[str],
table_prefix: str,
schema_change_policy: Dict[str, str],
recrawl_policy: Dict[str, str],
lineage_configuration: Dict[str, str],
configuration: str,
crawler_security_configuration: str,
tags: Dict[str, str],
backend: GlueBackend,
):
self.name = name
self.role = role
@ -980,11 +1035,11 @@ class FakeCrawler(BaseModel):
self.backend = backend
self.backend.tag_resource(self.arn, tags)
def get_name(self):
def get_name(self) -> str:
return self.name
def as_dict(self):
last_crawl = self.last_crawl_info.as_dict() if self.last_crawl_info else None
def as_dict(self) -> Dict[str, Any]:
last_crawl = self.last_crawl_info.as_dict() if self.last_crawl_info else None # type: ignore
data = {
"Name": self.name,
"Role": self.role,
@ -1017,14 +1072,14 @@ class FakeCrawler(BaseModel):
return data
def start_crawler(self):
def start_crawler(self) -> None:
if self.state == "RUNNING":
raise CrawlerRunningException(
f"Crawler with name {self.name} has already started"
)
self.state = "RUNNING"
def stop_crawler(self):
def stop_crawler(self) -> None:
if self.state != "RUNNING":
raise CrawlerNotRunningException(
f"Crawler with name {self.name} isn't running"
@ -1034,7 +1089,13 @@ class FakeCrawler(BaseModel):
class LastCrawlInfo(BaseModel):
def __init__(
self, error_message, log_group, log_stream, message_prefix, start_time, status
self,
error_message: str,
log_group: str,
log_stream: str,
message_prefix: str,
start_time: str,
status: str,
):
self.error_message = error_message
self.log_group = log_group
@ -1043,7 +1104,7 @@ class LastCrawlInfo(BaseModel):
self.start_time = start_time
self.status = status
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
return {
"ErrorMessage": self.error_message,
"LogGroup": self.log_group,
@ -1057,29 +1118,29 @@ class LastCrawlInfo(BaseModel):
class FakeJob:
def __init__(
self,
name,
role,
command,
description=None,
log_uri=None,
execution_property=None,
default_arguments=None,
non_overridable_arguments=None,
connections=None,
max_retries=None,
allocated_capacity=None,
timeout=None,
max_capacity=None,
security_configuration=None,
tags=None,
notification_property=None,
glue_version=None,
number_of_workers=None,
worker_type=None,
code_gen_configuration_nodes=None,
execution_class=None,
source_control_details=None,
backend=None,
name: str,
role: str,
command: str,
description: str,
log_uri: str,
execution_property: Dict[str, int],
default_arguments: Dict[str, str],
non_overridable_arguments: Dict[str, str],
connections: Dict[str, List[str]],
max_retries: int,
allocated_capacity: int,
timeout: int,
max_capacity: float,
security_configuration: str,
tags: Dict[str, str],
notification_property: Dict[str, int],
glue_version: str,
number_of_workers: int,
worker_type: str,
code_gen_configuration_nodes: Dict[str, Any],
execution_class: str,
source_control_details: Dict[str, str],
backend: GlueBackend,
):
self.name = name
self.description = description
@ -1112,10 +1173,10 @@ class FakeJob:
self.job_runs: List[FakeJobRun] = []
def get_name(self):
def get_name(self) -> str:
return self.name
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
return {
"Name": self.name,
"Description": self.description,
@ -1142,7 +1203,7 @@ class FakeJob:
"SourceControlDetails": self.source_control_details,
}
def start_job_run(self):
def start_job_run(self) -> str:
running_jobs = len(
[jr for jr in self.job_runs if jr.status in ["STARTING", "RUNNING"]]
)
@ -1154,7 +1215,7 @@ class FakeJob:
self.job_runs.append(fake_job_run)
return fake_job_run.job_run_id
def get_job_run(self, run_id):
def get_job_run(self, run_id: str) -> "FakeJobRun":
for job_run in self.job_runs:
if job_run.job_run_id == run_id:
job_run.advance()
@ -1165,11 +1226,11 @@ class FakeJob:
class FakeJobRun(ManagedState):
def __init__(
self,
job_name: int,
job_name: str,
job_run_id: str = "01",
arguments: dict = None,
allocated_capacity: int = None,
timeout: int = None,
arguments: Optional[Dict[str, Any]] = None,
allocated_capacity: Optional[int] = None,
timeout: Optional[int] = None,
worker_type: str = "Standard",
):
ManagedState.__init__(
@ -1187,10 +1248,10 @@ class FakeJobRun(ManagedState):
self.modified_on = datetime.utcnow()
self.completed_on = datetime.utcnow()
def get_name(self):
def get_name(self) -> str:
return self.job_name
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
return {
"Id": self.job_run_id,
"Attempt": 1,
@ -1220,7 +1281,13 @@ class FakeJobRun(ManagedState):
class FakeRegistry(BaseModel):
def __init__(self, backend, registry_name, description=None, tags=None):
def __init__(
self,
backend: GlueBackend,
registry_name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
):
self.name = registry_name
self.description = description
self.tags = tags
@ -1230,7 +1297,7 @@ class FakeRegistry(BaseModel):
self.registry_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:registry/{self.name}"
self.schemas: Dict[str, FakeSchema] = OrderedDict()
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
return {
"RegistryArn": self.registry_arn,
"RegistryName": self.name,
@ -1243,12 +1310,12 @@ class FakeSchema(BaseModel):
def __init__(
self,
backend: GlueBackend,
registry_name,
schema_name,
data_format,
compatibility,
schema_version_id,
description=None,
registry_name: str,
schema_name: str,
data_format: str,
compatibility: str,
schema_version_id: str,
description: Optional[str] = None,
):
self.registry_name = registry_name
self.registry_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:registry/{self.registry_name}"
@ -1265,18 +1332,18 @@ class FakeSchema(BaseModel):
self.schema_version_status = AVAILABLE_STATUS
self.created_time = datetime.utcnow()
self.updated_time = datetime.utcnow()
self.schema_versions = OrderedDict()
self.schema_versions: Dict[str, FakeSchemaVersion] = OrderedDict()
def update_next_schema_version(self):
def update_next_schema_version(self) -> None:
self.next_schema_version += 1
def update_latest_schema_version(self):
def update_latest_schema_version(self) -> None:
self.latest_schema_version += 1
def get_next_schema_version(self):
def get_next_schema_version(self) -> int:
return self.next_schema_version
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
return {
"RegistryArn": self.registry_arn,
"RegistryName": self.registry_name,
@ -1298,10 +1365,10 @@ class FakeSchemaVersion(BaseModel):
def __init__(
self,
backend: GlueBackend,
registry_name,
schema_name,
schema_definition,
version_number,
registry_name: str,
schema_name: str,
schema_definition: str,
version_number: int,
):
self.registry_name = registry_name
self.schema_name = schema_name
@ -1312,19 +1379,19 @@ class FakeSchemaVersion(BaseModel):
self.schema_version_id = str(mock_random.uuid4())
self.created_time = datetime.utcnow()
self.updated_time = datetime.utcnow()
self.metadata = OrderedDict()
self.metadata: Dict[str, Any] = {}
def get_schema_version_id(self):
def get_schema_version_id(self) -> str:
return self.schema_version_id
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
return {
"SchemaVersionId": self.schema_version_id,
"VersionNumber": self.version_number,
"Status": self.schema_version_status,
}
def get_schema_version_as_dict(self):
def get_schema_version_as_dict(self) -> Dict[str, Any]:
# add data_format for full return dictionary of get_schema_version
return {
"SchemaVersionId": self.schema_version_id,
@ -1335,7 +1402,7 @@ class FakeSchemaVersion(BaseModel):
"CreatedTime": str(self.created_time),
}
def get_schema_by_definition_as_dict(self):
def get_schema_by_definition_as_dict(self) -> Dict[str, Any]:
# add data_format for full return dictionary of get_schema_by_definition
return {
"SchemaVersionId": self.schema_version_id,

View File

@ -1,11 +1,13 @@
import json
from typing import Any, Dict, List
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from .models import glue_backends, GlueBackend
from .models import glue_backends, GlueBackend, FakeJob, FakeCrawler
class GlueResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="glue")
@property
@ -13,66 +15,66 @@ class GlueResponse(BaseResponse):
return glue_backends[self.current_account][self.region]
@property
def parameters(self):
def parameters(self) -> Dict[str, Any]: # type: ignore[misc]
return json.loads(self.body)
def create_database(self):
def create_database(self) -> str:
database_input = self.parameters.get("DatabaseInput")
database_name = database_input.get("Name")
database_name = database_input.get("Name") # type: ignore
if "CatalogId" in self.parameters:
database_input["CatalogId"] = self.parameters.get("CatalogId")
self.glue_backend.create_database(database_name, database_input)
database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore
self.glue_backend.create_database(database_name, database_input) # type: ignore[arg-type]
return ""
def get_database(self):
def get_database(self) -> str:
database_name = self.parameters.get("Name")
database = self.glue_backend.get_database(database_name)
database = self.glue_backend.get_database(database_name) # type: ignore[arg-type]
return json.dumps({"Database": database.as_dict()})
def get_databases(self):
def get_databases(self) -> str:
database_list = self.glue_backend.get_databases()
return json.dumps(
{"DatabaseList": [database.as_dict() for database in database_list]}
)
def update_database(self):
def update_database(self) -> str:
database_input = self.parameters.get("DatabaseInput")
database_name = self.parameters.get("Name")
if "CatalogId" in self.parameters:
database_input["CatalogId"] = self.parameters.get("CatalogId")
self.glue_backend.update_database(database_name, database_input)
database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore
self.glue_backend.update_database(database_name, database_input) # type: ignore[arg-type]
return ""
def delete_database(self):
def delete_database(self) -> str:
name = self.parameters.get("Name")
self.glue_backend.delete_database(name)
self.glue_backend.delete_database(name) # type: ignore[arg-type]
return json.dumps({})
def create_table(self):
def create_table(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_input = self.parameters.get("TableInput")
table_name = table_input.get("Name")
self.glue_backend.create_table(database_name, table_name, table_input)
table_name = table_input.get("Name") # type: ignore
self.glue_backend.create_table(database_name, table_name, table_input) # type: ignore[arg-type]
return ""
def get_table(self):
def get_table(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("Name")
table = self.glue_backend.get_table(database_name, table_name)
table = self.glue_backend.get_table(database_name, table_name) # type: ignore[arg-type]
return json.dumps({"Table": table.as_dict()})
def update_table(self):
def update_table(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_input = self.parameters.get("TableInput")
table_name = table_input.get("Name")
self.glue_backend.update_table(database_name, table_name, table_input)
table_name = table_input.get("Name") # type: ignore
self.glue_backend.update_table(database_name, table_name, table_input) # type: ignore[arg-type]
return ""
def get_table_versions(self):
def get_table_versions(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
versions = self.glue_backend.get_table_versions(database_name, table_name)
versions = self.glue_backend.get_table_versions(database_name, table_name) # type: ignore[arg-type]
return json.dumps(
{
"TableVersions": [
@ -82,36 +84,36 @@ class GlueResponse(BaseResponse):
}
)
def get_table_version(self):
def get_table_version(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
ver_id = self.parameters.get("VersionId")
return self.glue_backend.get_table_version(database_name, table_name, ver_id)
return self.glue_backend.get_table_version(database_name, table_name, ver_id) # type: ignore[arg-type]
def delete_table_version(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
version_id = self.parameters.get("VersionId")
self.glue_backend.delete_table_version(database_name, table_name, version_id)
self.glue_backend.delete_table_version(database_name, table_name, version_id) # type: ignore[arg-type]
return "{}"
def get_tables(self):
def get_tables(self) -> str:
database_name = self.parameters.get("DatabaseName")
expression = self.parameters.get("Expression")
tables = self.glue_backend.get_tables(database_name, expression)
tables = self.glue_backend.get_tables(database_name, expression) # type: ignore[arg-type]
return json.dumps({"TableList": [table.as_dict() for table in tables]})
def delete_table(self):
def delete_table(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("Name")
resp = self.glue_backend.delete_table(database_name, table_name)
return json.dumps(resp)
self.glue_backend.delete_table(database_name, table_name) # type: ignore[arg-type]
return "{}"
def batch_delete_table(self):
def batch_delete_table(self) -> str:
database_name = self.parameters.get("DatabaseName")
tables = self.parameters.get("TablesToDelete")
errors = self.glue_backend.batch_delete_table(database_name, tables)
errors = self.glue_backend.batch_delete_table(database_name, tables) # type: ignore[arg-type]
out = {}
if errors:
@ -119,50 +121,50 @@ class GlueResponse(BaseResponse):
return json.dumps(out)
def get_partitions(self):
def get_partitions(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
expression = self.parameters.get("Expression")
partitions = self.glue_backend.get_partitions(
database_name, table_name, expression
database_name, table_name, expression # type: ignore[arg-type]
)
return json.dumps({"Partitions": [p.as_dict() for p in partitions]})
def get_partition(self):
def get_partition(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
values = self.parameters.get("PartitionValues")
p = self.glue_backend.get_partition(database_name, table_name, values)
p = self.glue_backend.get_partition(database_name, table_name, values) # type: ignore[arg-type]
return json.dumps({"Partition": p.as_dict()})
def batch_get_partition(self):
def batch_get_partition(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
partitions_to_get = self.parameters.get("PartitionsToGet")
partitions = self.glue_backend.batch_get_partition(
database_name, table_name, partitions_to_get
database_name, table_name, partitions_to_get # type: ignore[arg-type]
)
return json.dumps({"Partitions": partitions})
def create_partition(self):
def create_partition(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
part_input = self.parameters.get("PartitionInput")
self.glue_backend.create_partition(database_name, table_name, part_input)
self.glue_backend.create_partition(database_name, table_name, part_input) # type: ignore[arg-type]
return ""
def batch_create_partition(self):
def batch_create_partition(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
partition_input = self.parameters.get("PartitionInputList")
errors_output = self.glue_backend.batch_create_partition(
database_name, table_name, partition_input
database_name, table_name, partition_input # type: ignore[arg-type]
)
out = {}
@ -171,24 +173,24 @@ class GlueResponse(BaseResponse):
return json.dumps(out)
def update_partition(self):
def update_partition(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
part_input = self.parameters.get("PartitionInput")
part_to_update = self.parameters.get("PartitionValueList")
self.glue_backend.update_partition(
database_name, table_name, part_input, part_to_update
database_name, table_name, part_input, part_to_update # type: ignore[arg-type]
)
return ""
def batch_update_partition(self):
def batch_update_partition(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
entries = self.parameters.get("Entries")
errors_output = self.glue_backend.batch_update_partition(
database_name, table_name, entries
database_name, table_name, entries # type: ignore[arg-type]
)
out = {}
@ -197,21 +199,21 @@ class GlueResponse(BaseResponse):
return json.dumps(out)
def delete_partition(self):
def delete_partition(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
part_to_delete = self.parameters.get("PartitionValues")
self.glue_backend.delete_partition(database_name, table_name, part_to_delete)
self.glue_backend.delete_partition(database_name, table_name, part_to_delete) # type: ignore[arg-type]
return ""
def batch_delete_partition(self):
def batch_delete_partition(self) -> str:
database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName")
parts = self.parameters.get("PartitionsToDelete")
errors_output = self.glue_backend.batch_delete_partition(
database_name, table_name, parts
database_name, table_name, parts # type: ignore[arg-type]
)
out = {}
@ -220,37 +222,37 @@ class GlueResponse(BaseResponse):
return json.dumps(out)
def create_crawler(self):
def create_crawler(self) -> str:
self.glue_backend.create_crawler(
name=self.parameters.get("Name"),
role=self.parameters.get("Role"),
database_name=self.parameters.get("DatabaseName"),
description=self.parameters.get("Description"),
targets=self.parameters.get("Targets"),
schedule=self.parameters.get("Schedule"),
classifiers=self.parameters.get("Classifiers"),
table_prefix=self.parameters.get("TablePrefix"),
schema_change_policy=self.parameters.get("SchemaChangePolicy"),
recrawl_policy=self.parameters.get("RecrawlPolicy"),
lineage_configuration=self.parameters.get("LineageConfiguration"),
configuration=self.parameters.get("Configuration"),
crawler_security_configuration=self.parameters.get(
name=self.parameters.get("Name"), # type: ignore[arg-type]
role=self.parameters.get("Role"), # type: ignore[arg-type]
database_name=self.parameters.get("DatabaseName"), # type: ignore[arg-type]
description=self.parameters.get("Description"), # type: ignore[arg-type]
targets=self.parameters.get("Targets"), # type: ignore[arg-type]
schedule=self.parameters.get("Schedule"), # type: ignore[arg-type]
classifiers=self.parameters.get("Classifiers"), # type: ignore[arg-type]
table_prefix=self.parameters.get("TablePrefix"), # type: ignore[arg-type]
schema_change_policy=self.parameters.get("SchemaChangePolicy"), # type: ignore[arg-type]
recrawl_policy=self.parameters.get("RecrawlPolicy"), # type: ignore[arg-type]
lineage_configuration=self.parameters.get("LineageConfiguration"), # type: ignore[arg-type]
configuration=self.parameters.get("Configuration"), # type: ignore[arg-type]
crawler_security_configuration=self.parameters.get( # type: ignore[arg-type]
"CrawlerSecurityConfiguration"
),
tags=self.parameters.get("Tags"),
tags=self.parameters.get("Tags"), # type: ignore[arg-type]
)
return ""
def get_crawler(self):
def get_crawler(self) -> str:
name = self.parameters.get("Name")
crawler = self.glue_backend.get_crawler(name)
crawler = self.glue_backend.get_crawler(name) # type: ignore[arg-type]
return json.dumps({"Crawler": crawler.as_dict()})
def get_crawlers(self):
def get_crawlers(self) -> str:
crawlers = self.glue_backend.get_crawlers()
return json.dumps({"Crawlers": [crawler.as_dict() for crawler in crawlers]})
def list_crawlers(self):
def list_crawlers(self) -> str:
next_token = self._get_param("NextToken")
max_results = self._get_int_param("MaxResults")
tags = self._get_param("Tags")
@ -265,31 +267,33 @@ class GlueResponse(BaseResponse):
)
)
def filter_crawlers_by_tags(self, crawlers, tags):
def filter_crawlers_by_tags(
self, crawlers: List[FakeCrawler], tags: Dict[str, str]
) -> List[str]:
if not tags:
return [crawler.get_name() for crawler in crawlers]
return [
crawler.get_name()
for crawler in crawlers
if self.is_tags_match(self, crawler.arn, tags)
if self.is_tags_match(crawler.arn, tags)
]
def start_crawler(self):
def start_crawler(self) -> str:
name = self.parameters.get("Name")
self.glue_backend.start_crawler(name)
self.glue_backend.start_crawler(name) # type: ignore[arg-type]
return ""
def stop_crawler(self):
def stop_crawler(self) -> str:
name = self.parameters.get("Name")
self.glue_backend.stop_crawler(name)
self.glue_backend.stop_crawler(name) # type: ignore[arg-type]
return ""
def delete_crawler(self):
def delete_crawler(self) -> str:
name = self.parameters.get("Name")
self.glue_backend.delete_crawler(name)
self.glue_backend.delete_crawler(name) # type: ignore[arg-type]
return ""
def create_job(self):
def create_job(self) -> str:
name = self._get_param("Name")
description = self._get_param("Description")
log_uri = self._get_param("LogUri")
@ -312,7 +316,7 @@ class GlueResponse(BaseResponse):
code_gen_configuration_nodes = self._get_param("CodeGenConfigurationNodes")
execution_class = self._get_param("ExecutionClass")
source_control_details = self._get_param("SourceControlDetails")
name = self.glue_backend.create_job(
self.glue_backend.create_job(
name=name,
description=description,
log_uri=log_uri,
@ -338,12 +342,12 @@ class GlueResponse(BaseResponse):
)
return json.dumps(dict(Name=name))
def get_job(self):
def get_job(self) -> str:
name = self.parameters.get("JobName")
job = self.glue_backend.get_job(name)
job = self.glue_backend.get_job(name) # type: ignore[arg-type]
return json.dumps({"Job": job.as_dict()})
def get_jobs(self):
def get_jobs(self) -> str:
next_token = self._get_param("NextToken")
max_results = self._get_int_param("MaxResults")
jobs, next_token = self.glue_backend.get_jobs(
@ -356,18 +360,18 @@ class GlueResponse(BaseResponse):
)
)
def start_job_run(self):
def start_job_run(self) -> str:
name = self.parameters.get("JobName")
job_run_id = self.glue_backend.start_job_run(name)
job_run_id = self.glue_backend.start_job_run(name) # type: ignore[arg-type]
return json.dumps(dict(JobRunId=job_run_id))
def get_job_run(self):
def get_job_run(self) -> str:
name = self.parameters.get("JobName")
run_id = self.parameters.get("RunId")
job_run = self.glue_backend.get_job_run(name, run_id)
job_run = self.glue_backend.get_job_run(name, run_id) # type: ignore[arg-type]
return json.dumps({"JobRun": job_run.as_dict()})
def list_jobs(self):
def list_jobs(self) -> str:
next_token = self._get_param("NextToken")
max_results = self._get_int_param("MaxResults")
tags = self._get_param("Tags")
@ -382,32 +386,31 @@ class GlueResponse(BaseResponse):
)
)
def get_tags(self):
def get_tags(self) -> TYPE_RESPONSE:
resource_arn = self.parameters.get("ResourceArn")
tags = self.glue_backend.get_tags(resource_arn)
tags = self.glue_backend.get_tags(resource_arn) # type: ignore[arg-type]
return 200, {}, json.dumps({"Tags": tags})
def tag_resource(self):
def tag_resource(self) -> TYPE_RESPONSE:
resource_arn = self.parameters.get("ResourceArn")
tags = self.parameters.get("TagsToAdd", {})
self.glue_backend.tag_resource(resource_arn, tags)
self.glue_backend.tag_resource(resource_arn, tags) # type: ignore[arg-type]
return 201, {}, "{}"
def untag_resource(self):
def untag_resource(self) -> TYPE_RESPONSE:
resource_arn = self._get_param("ResourceArn")
tag_keys = self.parameters.get("TagsToRemove")
self.glue_backend.untag_resource(resource_arn, tag_keys)
self.glue_backend.untag_resource(resource_arn, tag_keys) # type: ignore[arg-type]
return 200, {}, "{}"
def filter_jobs_by_tags(self, jobs, tags):
def filter_jobs_by_tags(
self, jobs: List[FakeJob], tags: Dict[str, str]
) -> List[str]:
if not tags:
return [job.get_name() for job in jobs]
return [
job.get_name() for job in jobs if self.is_tags_match(self, job.arn, tags)
]
return [job.get_name() for job in jobs if self.is_tags_match(job.arn, tags)]
@staticmethod
def is_tags_match(self, resource_arn, tags):
def is_tags_match(self, resource_arn: str, tags: Dict[str, str]) -> bool:
glue_resource_tags = self.glue_backend.get_tags(resource_arn)
mutual_keys = set(glue_resource_tags).intersection(tags)
for key in mutual_keys:
@ -415,28 +418,28 @@ class GlueResponse(BaseResponse):
return True
return False
def create_registry(self):
def create_registry(self) -> str:
registry_name = self._get_param("RegistryName")
description = self._get_param("Description")
tags = self._get_param("Tags")
registry = self.glue_backend.create_registry(registry_name, description, tags)
return json.dumps(registry)
def delete_registry(self):
def delete_registry(self) -> str:
registry_id = self._get_param("RegistryId")
registry = self.glue_backend.delete_registry(registry_id)
return json.dumps(registry)
def get_registry(self):
def get_registry(self) -> str:
registry_id = self._get_param("RegistryId")
registry = self.glue_backend.get_registry(registry_id)
return json.dumps(registry)
def list_registries(self):
def list_registries(self) -> str:
registries = self.glue_backend.list_registries()
return json.dumps({"Registries": registries})
def create_schema(self):
def create_schema(self) -> str:
registry_id = self._get_param("RegistryId")
schema_name = self._get_param("SchemaName")
data_format = self._get_param("DataFormat")
@ -455,7 +458,7 @@ class GlueResponse(BaseResponse):
)
return json.dumps(schema)
def register_schema_version(self):
def register_schema_version(self) -> str:
schema_id = self._get_param("SchemaId")
schema_definition = self._get_param("SchemaDefinition")
schema_version = self.glue_backend.register_schema_version(
@ -463,7 +466,7 @@ class GlueResponse(BaseResponse):
)
return json.dumps(schema_version)
def get_schema_version(self):
def get_schema_version(self) -> str:
schema_id = self._get_param("SchemaId")
schema_version_id = self._get_param("SchemaVersionId")
schema_version_number = self._get_param("SchemaVersionNumber")
@ -473,7 +476,7 @@ class GlueResponse(BaseResponse):
)
return json.dumps(schema_version)
def get_schema_by_definition(self):
def get_schema_by_definition(self) -> str:
schema_id = self._get_param("SchemaId")
schema_definition = self._get_param("SchemaDefinition")
schema_version = self.glue_backend.get_schema_by_definition(
@ -481,7 +484,7 @@ class GlueResponse(BaseResponse):
)
return json.dumps(schema_version)
def put_schema_version_metadata(self):
def put_schema_version_metadata(self) -> str:
schema_id = self._get_param("SchemaId")
schema_version_number = self._get_param("SchemaVersionNumber")
schema_version_id = self._get_param("SchemaVersionId")
@ -491,24 +494,24 @@ class GlueResponse(BaseResponse):
)
return json.dumps(schema_version)
def get_schema(self):
def get_schema(self) -> str:
schema_id = self._get_param("SchemaId")
schema = self.glue_backend.get_schema(schema_id)
return json.dumps(schema)
def delete_schema(self):
def delete_schema(self) -> str:
schema_id = self._get_param("SchemaId")
schema = self.glue_backend.delete_schema(schema_id)
return json.dumps(schema)
def update_schema(self):
def update_schema(self) -> str:
schema_id = self._get_param("SchemaId")
compatibility = self._get_param("Compatibility")
description = self._get_param("Description")
schema = self.glue_backend.update_schema(schema_id, compatibility, description)
return json.dumps(schema)
def batch_get_crawlers(self):
def batch_get_crawlers(self) -> str:
crawler_names = self._get_param("CrawlerNames")
crawlers = self.glue_backend.batch_get_crawlers(crawler_names)
crawlers_not_found = list(
@ -521,5 +524,5 @@ class GlueResponse(BaseResponse):
}
)
def get_partition_indexes(self):
def get_partition_indexes(self) -> str:
return json.dumps({"PartitionIndexDescriptorList": []})

View File

@ -97,7 +97,7 @@ def _escape_regex(pattern: str) -> str:
class _Expr(abc.ABC):
@abc.abstractmethod
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any:
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any: # type: ignore[misc]
raise NotImplementedError()
@ -196,7 +196,7 @@ class _Like(_Expr):
pattern = _cast("string", self.literal)
# prepare SQL pattern for conversion to regex pattern
pattern = _escape_regex(pattern)
pattern = _escape_regex(pattern) # type: ignore
# NOTE convert SQL wildcards to regex, no literal matches possible
pattern = pattern.replace("_", ".").replace("%", ".*")
@ -265,19 +265,19 @@ class _BoolOr(_Expr):
class _PartitionFilterExpressionCache:
def __init__(self):
def __init__(self) -> None:
# build grammar according to Glue.Client.get_partitions(Expression)
lpar, rpar = map(Suppress, "()")
# NOTE these are AWS Athena column name best practices
ident = Forward().set_name("ident")
ident <<= Word(alphanums + "._").set_parse_action(_Ident) | lpar + ident + rpar
ident <<= Word(alphanums + "._").set_parse_action(_Ident) | lpar + ident + rpar # type: ignore
number = Forward().set_name("number")
number <<= pyparsing_common.number | lpar + number + rpar
number <<= pyparsing_common.number | lpar + number + rpar # type: ignore
string = Forward().set_name("string")
string <<= QuotedString(quote_char="'", esc_quote="''") | lpar + string + rpar
string <<= QuotedString(quote_char="'", esc_quote="''") | lpar + string + rpar # type: ignore
literal = (number | string).set_name("literal")
literal_list = delimited_list(literal, min=1).set_name("list")
@ -293,7 +293,7 @@ class _PartitionFilterExpressionCache:
in_, between, like, not_, is_, null = map(
CaselessKeyword, "in between like not is null".split()
)
not_ = Suppress(not_) # only needed for matching
not_ = Suppress(not_) # type: ignore # only needed for matching
cond = (
(ident + is_ + null).set_parse_action(_IsNull)
@ -343,11 +343,11 @@ _PARTITION_FILTER_EXPRESSION_CACHE = _PartitionFilterExpressionCache()
class PartitionFilter:
def __init__(self, expression: Optional[str], fake_table):
def __init__(self, expression: Optional[str], fake_table: Any):
self.expression = expression
self.fake_table = fake_table
def __call__(self, fake_partition) -> bool:
def __call__(self, fake_partition: Any) -> bool:
expression = _PARTITION_FILTER_EXPRESSION_CACHE.get(self.expression)
if expression is None:
return True

View File

@ -6,36 +6,36 @@ class GreengrassClientError(JsonRESTError):
class IdNotFoundException(GreengrassClientError):
def __init__(self, msg):
def __init__(self, msg: str):
self.code = 404
super().__init__("IdNotFoundException", msg)
class InvalidContainerDefinitionException(GreengrassClientError):
def __init__(self, msg):
def __init__(self, msg: str):
self.code = 400
super().__init__("InvalidContainerDefinitionException", msg)
class VersionNotFoundException(GreengrassClientError):
def __init__(self, msg):
def __init__(self, msg: str):
self.code = 404
super().__init__("VersionNotFoundException", msg)
class InvalidInputException(GreengrassClientError):
def __init__(self, msg):
def __init__(self, msg: str):
self.code = 400
super().__init__("InvalidInputException", msg)
class MissingCoreException(GreengrassClientError):
def __init__(self, msg):
def __init__(self, msg: str):
self.code = 400
super().__init__("MissingCoreException", msg)
class ResourceNotFoundException(GreengrassClientError):
def __init__(self, msg):
def __init__(self, msg: str):
self.code = 404
super().__init__("ResourceNotFoundException", msg)

View File

@ -1,6 +1,7 @@
import json
from collections import OrderedDict
from datetime import datetime
from typing import Any, Dict, List, Iterable, Optional
import re
from moto.core import BaseBackend, BackendDict, BaseModel
@ -18,7 +19,7 @@ from .exceptions import (
class FakeCoreDefinition(BaseModel):
def __init__(self, account_id, region_name, name):
def __init__(self, account_id: str, region_name: str, name: str):
self.region_name = region_name
self.name = name
self.id = str(mock_random.uuid4())
@ -27,7 +28,7 @@ class FakeCoreDefinition(BaseModel):
self.latest_version = ""
self.latest_version_arn = ""
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -44,7 +45,13 @@ class FakeCoreDefinition(BaseModel):
class FakeCoreDefinitionVersion(BaseModel):
def __init__(self, account_id, region_name, core_definition_id, definition):
def __init__(
self,
account_id: str,
region_name: str,
core_definition_id: str,
definition: Dict[str, Any],
):
self.region_name = region_name
self.core_definition_id = core_definition_id
self.definition = definition
@ -52,8 +59,8 @@ class FakeCoreDefinitionVersion(BaseModel):
self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/cores/{self.core_definition_id}/versions/{self.version}"
self.created_at_datetime = datetime.utcnow()
def to_dict(self, include_detail=False):
obj = {
def to_dict(self, include_detail: bool = False) -> Dict[str, Any]:
obj: Dict[str, Any] = {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
self.created_at_datetime
@ -69,7 +76,13 @@ class FakeCoreDefinitionVersion(BaseModel):
class FakeDeviceDefinition(BaseModel):
def __init__(self, account_id, region_name, name, initial_version):
def __init__(
self,
account_id: str,
region_name: str,
name: str,
initial_version: Dict[str, Any],
):
self.region_name = region_name
self.id = str(mock_random.uuid4())
self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/devices/{self.id}"
@ -80,7 +93,7 @@ class FakeDeviceDefinition(BaseModel):
self.name = name
self.initial_version = initial_version
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
res = {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -99,7 +112,13 @@ class FakeDeviceDefinition(BaseModel):
class FakeDeviceDefinitionVersion(BaseModel):
def __init__(self, account_id, region_name, device_definition_id, devices):
def __init__(
self,
account_id: str,
region_name: str,
device_definition_id: str,
devices: List[Dict[str, Any]],
):
self.region_name = region_name
self.device_definition_id = device_definition_id
self.devices = devices
@ -107,8 +126,8 @@ class FakeDeviceDefinitionVersion(BaseModel):
self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/devices/{self.device_definition_id}/versions/{self.version}"
self.created_at_datetime = datetime.utcnow()
def to_dict(self, include_detail=False):
obj = {
def to_dict(self, include_detail: bool = False) -> Dict[str, Any]:
obj: Dict[str, Any] = {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
self.created_at_datetime
@ -124,7 +143,13 @@ class FakeDeviceDefinitionVersion(BaseModel):
class FakeResourceDefinition(BaseModel):
def __init__(self, account_id, region_name, name, initial_version):
def __init__(
self,
account_id: str,
region_name: str,
name: str,
initial_version: Dict[str, Any],
):
self.region_name = region_name
self.id = str(mock_random.uuid4())
self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/resources/{self.id}"
@ -135,7 +160,7 @@ class FakeResourceDefinition(BaseModel):
self.name = name
self.initial_version = initial_version
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -152,7 +177,13 @@ class FakeResourceDefinition(BaseModel):
class FakeResourceDefinitionVersion(BaseModel):
def __init__(self, account_id, region_name, resource_definition_id, resources):
def __init__(
self,
account_id: str,
region_name: str,
resource_definition_id: str,
resources: List[Dict[str, Any]],
):
self.region_name = region_name
self.resource_definition_id = resource_definition_id
self.resources = resources
@ -160,7 +191,7 @@ class FakeResourceDefinitionVersion(BaseModel):
self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/resources/{self.resource_definition_id}/versions/{self.version}"
self.created_at_datetime = datetime.utcnow()
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -173,7 +204,13 @@ class FakeResourceDefinitionVersion(BaseModel):
class FakeFunctionDefinition(BaseModel):
def __init__(self, account_id, region_name, name, initial_version):
def __init__(
self,
account_id: str,
region_name: str,
name: str,
initial_version: Dict[str, Any],
):
self.region_name = region_name
self.id = str(mock_random.uuid4())
self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/functions/{self.id}"
@ -184,7 +221,7 @@ class FakeFunctionDefinition(BaseModel):
self.name = name
self.initial_version = initial_version
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
res = {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -204,7 +241,12 @@ class FakeFunctionDefinition(BaseModel):
class FakeFunctionDefinitionVersion(BaseModel):
def __init__(
self, account_id, region_name, function_definition_id, functions, default_config
self,
account_id: str,
region_name: str,
function_definition_id: str,
functions: List[Dict[str, Any]],
default_config: Dict[str, Any],
):
self.region_name = region_name
self.function_definition_id = function_definition_id
@ -214,7 +256,7 @@ class FakeFunctionDefinitionVersion(BaseModel):
self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/functions/{self.function_definition_id}/versions/{self.version}"
self.created_at_datetime = datetime.utcnow()
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -227,7 +269,13 @@ class FakeFunctionDefinitionVersion(BaseModel):
class FakeSubscriptionDefinition(BaseModel):
def __init__(self, account_id, region_name, name, initial_version):
def __init__(
self,
account_id: str,
region_name: str,
name: str,
initial_version: Dict[str, Any],
):
self.region_name = region_name
self.id = str(mock_random.uuid4())
self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/subscriptions/{self.id}"
@ -238,7 +286,7 @@ class FakeSubscriptionDefinition(BaseModel):
self.name = name
self.initial_version = initial_version
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -256,7 +304,11 @@ class FakeSubscriptionDefinition(BaseModel):
class FakeSubscriptionDefinitionVersion(BaseModel):
def __init__(
self, account_id, region_name, subscription_definition_id, subscriptions
self,
account_id: str,
region_name: str,
subscription_definition_id: str,
subscriptions: List[Dict[str, Any]],
):
self.region_name = region_name
self.subscription_definition_id = subscription_definition_id
@ -265,7 +317,7 @@ class FakeSubscriptionDefinitionVersion(BaseModel):
self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/subscriptions/{self.subscription_definition_id}/versions/{self.version}"
self.created_at_datetime = datetime.utcnow()
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -278,7 +330,7 @@ class FakeSubscriptionDefinitionVersion(BaseModel):
class FakeGroup(BaseModel):
def __init__(self, account_id, region_name, name):
def __init__(self, account_id: str, region_name: str, name: str):
self.region_name = region_name
self.group_id = str(mock_random.uuid4())
self.name = name
@ -288,7 +340,7 @@ class FakeGroup(BaseModel):
self.latest_version = ""
self.latest_version_arn = ""
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
obj = {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -308,14 +360,14 @@ class FakeGroup(BaseModel):
class FakeGroupVersion(BaseModel):
def __init__(
self,
account_id,
region_name,
group_id,
core_definition_version_arn,
device_definition_version_arn,
function_definition_version_arn,
resource_definition_version_arn,
subscription_definition_version_arn,
account_id: str,
region_name: str,
group_id: str,
core_definition_version_arn: Optional[str],
device_definition_version_arn: Optional[str],
function_definition_version_arn: Optional[str],
resource_definition_version_arn: Optional[str],
subscription_definition_version_arn: Optional[str],
):
self.region_name = region_name
self.group_id = group_id
@ -328,7 +380,7 @@ class FakeGroupVersion(BaseModel):
self.resource_definition_version_arn = resource_definition_version_arn
self.subscription_definition_version_arn = subscription_definition_version_arn
def to_dict(self, include_detail=False):
def to_dict(self, include_detail: bool = False) -> Dict[str, Any]:
definition = {}
if self.core_definition_version_arn:
@ -354,7 +406,7 @@ class FakeGroupVersion(BaseModel):
"SubscriptionDefinitionVersionArn"
] = self.subscription_definition_version_arn
obj = {
obj: Dict[str, Any] = {
"Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds(
self.created_at_datetime
@ -370,7 +422,14 @@ class FakeGroupVersion(BaseModel):
class FakeDeployment(BaseModel):
def __init__(self, account_id, region_name, group_id, group_arn, deployment_type):
def __init__(
self,
account_id: str,
region_name: str,
group_id: str,
group_arn: str,
deployment_type: str,
):
self.region_name = region_name
self.id = str(mock_random.uuid4())
self.group_id = group_id
@ -381,7 +440,7 @@ class FakeDeployment(BaseModel):
self.deployment_type = deployment_type
self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:/greengrass/groups/{self.group_id}/deployments/{self.id}"
def to_dict(self, include_detail=False):
def to_dict(self, include_detail: bool = False) -> Dict[str, Any]:
obj = {"DeploymentId": self.id, "DeploymentArn": self.arn}
if include_detail:
@ -395,11 +454,11 @@ class FakeDeployment(BaseModel):
class FakeAssociatedRole(BaseModel):
def __init__(self, role_arn):
def __init__(self, role_arn: str):
self.role_arn = role_arn
self.associated_at = datetime.utcnow()
def to_dict(self, include_detail=False):
def to_dict(self, include_detail: bool = False) -> Dict[str, Any]:
obj = {"AssociatedAt": iso_8601_datetime_with_milliseconds(self.associated_at)}
if include_detail:
@ -409,12 +468,17 @@ class FakeAssociatedRole(BaseModel):
class FakeDeploymentStatus(BaseModel):
def __init__(self, deployment_type, updated_at, deployment_status="InProgress"):
def __init__(
self,
deployment_type: str,
updated_at: datetime,
deployment_status: str = "InProgress",
):
self.deployment_type = deployment_type
self.update_at_datetime = updated_at
self.deployment_status = deployment_status
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"DeploymentStatus": self.deployment_status,
"DeploymentType": self.deployment_type,
@ -423,24 +487,38 @@ class FakeDeploymentStatus(BaseModel):
class GreengrassBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.groups = OrderedDict()
self.group_role_associations = OrderedDict()
self.group_versions = OrderedDict()
self.core_definitions = OrderedDict()
self.core_definition_versions = OrderedDict()
self.device_definitions = OrderedDict()
self.device_definition_versions = OrderedDict()
self.function_definitions = OrderedDict()
self.function_definition_versions = OrderedDict()
self.resource_definitions = OrderedDict()
self.resource_definition_versions = OrderedDict()
self.subscription_definitions = OrderedDict()
self.subscription_definition_versions = OrderedDict()
self.deployments = OrderedDict()
self.groups: Dict[str, FakeGroup] = OrderedDict()
self.group_role_associations: Dict[str, FakeAssociatedRole] = OrderedDict()
self.group_versions: Dict[str, Dict[str, FakeGroupVersion]] = OrderedDict()
self.core_definitions: Dict[str, FakeCoreDefinition] = OrderedDict()
self.core_definition_versions: Dict[
str, Dict[str, FakeCoreDefinitionVersion]
] = OrderedDict()
self.device_definitions: Dict[str, FakeDeviceDefinition] = OrderedDict()
self.device_definition_versions: Dict[
str, Dict[str, FakeDeviceDefinitionVersion]
] = OrderedDict()
self.function_definitions: Dict[str, FakeFunctionDefinition] = OrderedDict()
self.function_definition_versions: Dict[
str, Dict[str, FakeFunctionDefinitionVersion]
] = OrderedDict()
self.resource_definitions: Dict[str, FakeResourceDefinition] = OrderedDict()
self.resource_definition_versions: Dict[
str, Dict[str, FakeResourceDefinitionVersion]
] = OrderedDict()
self.subscription_definitions: Dict[
str, FakeSubscriptionDefinition
] = OrderedDict()
self.subscription_definition_versions: Dict[
str, Dict[str, FakeSubscriptionDefinitionVersion]
] = OrderedDict()
self.deployments: Dict[str, FakeDeployment] = OrderedDict()
def create_core_definition(self, name, initial_version):
def create_core_definition(
self, name: str, initial_version: Dict[str, Any]
) -> FakeCoreDefinition:
core_definition = FakeCoreDefinition(self.account_id, self.region_name, name)
self.core_definitions[core_definition.id] = core_definition
@ -449,22 +527,22 @@ class GreengrassBackend(BaseBackend):
)
return core_definition
def list_core_definitions(self):
def list_core_definitions(self) -> Iterable[FakeCoreDefinition]:
return self.core_definitions.values()
def get_core_definition(self, core_definition_id):
def get_core_definition(self, core_definition_id: str) -> FakeCoreDefinition:
if core_definition_id not in self.core_definitions:
raise IdNotFoundException("That Core List Definition does not exist")
return self.core_definitions[core_definition_id]
def delete_core_definition(self, core_definition_id):
def delete_core_definition(self, core_definition_id: str) -> None:
if core_definition_id not in self.core_definitions:
raise IdNotFoundException("That cores definition does not exist.")
del self.core_definitions[core_definition_id]
del self.core_definition_versions[core_definition_id]
def update_core_definition(self, core_definition_id, name):
def update_core_definition(self, core_definition_id: str, name: str) -> None:
if name == "":
raise InvalidContainerDefinitionException(
@ -474,7 +552,9 @@ class GreengrassBackend(BaseBackend):
raise IdNotFoundException("That cores definition does not exist.")
self.core_definitions[core_definition_id].name = name
def create_core_definition_version(self, core_definition_id, cores):
def create_core_definition_version(
self, core_definition_id: str, cores: List[Dict[str, Any]]
) -> FakeCoreDefinitionVersion:
definition = {"Cores": cores}
core_def_ver = FakeCoreDefinitionVersion(
@ -491,15 +571,17 @@ class GreengrassBackend(BaseBackend):
return core_def_ver
def list_core_definition_versions(self, core_definition_id):
def list_core_definition_versions(
self, core_definition_id: str
) -> Iterable[FakeCoreDefinitionVersion]:
if core_definition_id not in self.core_definitions:
raise IdNotFoundException("That cores definition does not exist.")
return self.core_definition_versions[core_definition_id].values()
def get_core_definition_version(
self, core_definition_id, core_definition_version_id
):
self, core_definition_id: str, core_definition_version_id: str
) -> FakeCoreDefinitionVersion:
if core_definition_id not in self.core_definitions:
raise IdNotFoundException("That cores definition does not exist.")
@ -516,7 +598,9 @@ class GreengrassBackend(BaseBackend):
core_definition_version_id
]
def create_device_definition(self, name, initial_version):
def create_device_definition(
self, name: str, initial_version: Dict[str, Any]
) -> FakeDeviceDefinition:
device_def = FakeDeviceDefinition(
self.account_id, self.region_name, name, initial_version
)
@ -527,10 +611,12 @@ class GreengrassBackend(BaseBackend):
return device_def
def list_device_definitions(self):
def list_device_definitions(self) -> Iterable[FakeDeviceDefinition]:
return self.device_definitions.values()
def create_device_definition_version(self, device_definition_id, devices):
def create_device_definition_version(
self, device_definition_id: str, devices: List[Dict[str, Any]]
) -> FakeDeviceDefinitionVersion:
if device_definition_id not in self.device_definitions:
raise IdNotFoundException("That devices definition does not exist.")
@ -552,25 +638,27 @@ class GreengrassBackend(BaseBackend):
return device_ver
def list_device_definition_versions(self, device_definition_id):
def list_device_definition_versions(
self, device_definition_id: str
) -> Iterable[FakeDeviceDefinitionVersion]:
if device_definition_id not in self.device_definitions:
raise IdNotFoundException("That devices definition does not exist.")
return self.device_definition_versions[device_definition_id].values()
def get_device_definition(self, device_definition_id):
def get_device_definition(self, device_definition_id: str) -> FakeDeviceDefinition:
if device_definition_id not in self.device_definitions:
raise IdNotFoundException("That Device List Definition does not exist.")
return self.device_definitions[device_definition_id]
def delete_device_definition(self, device_definition_id):
def delete_device_definition(self, device_definition_id: str) -> None:
if device_definition_id not in self.device_definitions:
raise IdNotFoundException("That devices definition does not exist.")
del self.device_definitions[device_definition_id]
del self.device_definition_versions[device_definition_id]
def update_device_definition(self, device_definition_id, name):
def update_device_definition(self, device_definition_id: str, name: str) -> None:
if name == "":
raise InvalidContainerDefinitionException(
@ -581,8 +669,8 @@ class GreengrassBackend(BaseBackend):
self.device_definitions[device_definition_id].name = name
def get_device_definition_version(
self, device_definition_id, device_definition_version_id
):
self, device_definition_id: str, device_definition_version_id: str
) -> FakeDeviceDefinitionVersion:
if device_definition_id not in self.device_definitions:
raise IdNotFoundException("That devices definition does not exist.")
@ -599,7 +687,9 @@ class GreengrassBackend(BaseBackend):
device_definition_version_id
]
def create_resource_definition(self, name, initial_version):
def create_resource_definition(
self, name: str, initial_version: Dict[str, Any]
) -> FakeResourceDefinition:
resources = initial_version.get("Resources", [])
GreengrassBackend._validate_resources(resources)
@ -614,22 +704,26 @@ class GreengrassBackend(BaseBackend):
return resource_def
def list_resource_definitions(self):
return self.resource_definitions
def list_resource_definitions(self) -> Iterable[FakeResourceDefinition]:
return self.resource_definitions.values()
def get_resource_definition(self, resource_definition_id):
def get_resource_definition(
self, resource_definition_id: str
) -> FakeResourceDefinition:
if resource_definition_id not in self.resource_definitions:
raise IdNotFoundException("That Resource List Definition does not exist.")
return self.resource_definitions[resource_definition_id]
def delete_resource_definition(self, resource_definition_id):
def delete_resource_definition(self, resource_definition_id: str) -> None:
if resource_definition_id not in self.resource_definitions:
raise IdNotFoundException("That resources definition does not exist.")
del self.resource_definitions[resource_definition_id]
del self.resource_definition_versions[resource_definition_id]
def update_resource_definition(self, resource_definition_id, name):
def update_resource_definition(
self, resource_definition_id: str, name: str
) -> None:
if name == "":
raise InvalidInputException("Invalid resource name.")
@ -637,7 +731,9 @@ class GreengrassBackend(BaseBackend):
raise IdNotFoundException("That resources definition does not exist.")
self.resource_definitions[resource_definition_id].name = name
def create_resource_definition_version(self, resource_definition_id, resources):
def create_resource_definition_version(
self, resource_definition_id: str, resources: List[Dict[str, Any]]
) -> FakeResourceDefinitionVersion:
if resource_definition_id not in self.resource_definitions:
raise IdNotFoundException("That resource definition does not exist.")
@ -666,7 +762,9 @@ class GreengrassBackend(BaseBackend):
return resource_def_ver
def list_resource_definition_versions(self, resource_definition_id):
def list_resource_definition_versions(
self, resource_definition_id: str
) -> Iterable[FakeResourceDefinitionVersion]:
if resource_definition_id not in self.resource_definition_versions:
raise IdNotFoundException("That resources definition does not exist.")
@ -674,8 +772,8 @@ class GreengrassBackend(BaseBackend):
return self.resource_definition_versions[resource_definition_id].values()
def get_resource_definition_version(
self, resource_definition_id, resource_definition_version_id
):
self, resource_definition_id: str, resource_definition_version_id: str
) -> FakeResourceDefinitionVersion:
if resource_definition_id not in self.resource_definition_versions:
raise IdNotFoundException("That resources definition does not exist.")
@ -693,7 +791,7 @@ class GreengrassBackend(BaseBackend):
]
@staticmethod
def _validate_resources(resources):
def _validate_resources(resources: List[Dict[str, Any]]) -> None: # type: ignore[misc]
for resource in resources:
volume_source_path = (
resource.get("ResourceDataContainer", {})
@ -719,7 +817,9 @@ class GreengrassBackend(BaseBackend):
f", but got: {device_source_path}])",
)
def create_function_definition(self, name, initial_version):
def create_function_definition(
self, name: str, initial_version: Dict[str, Any]
) -> FakeFunctionDefinition:
func_def = FakeFunctionDefinition(
self.account_id, self.region_name, name, initial_version
)
@ -731,22 +831,26 @@ class GreengrassBackend(BaseBackend):
return func_def
def list_function_definitions(self):
return self.function_definitions.values()
def list_function_definitions(self) -> List[FakeFunctionDefinition]:
return list(self.function_definitions.values())
def get_function_definition(self, function_definition_id):
def get_function_definition(
self, function_definition_id: str
) -> FakeFunctionDefinition:
if function_definition_id not in self.function_definitions:
raise IdNotFoundException("That Lambda List Definition does not exist.")
return self.function_definitions[function_definition_id]
def delete_function_definition(self, function_definition_id):
def delete_function_definition(self, function_definition_id: str) -> None:
if function_definition_id not in self.function_definitions:
raise IdNotFoundException("That lambdas definition does not exist.")
del self.function_definitions[function_definition_id]
del self.function_definition_versions[function_definition_id]
def update_function_definition(self, function_definition_id, name):
def update_function_definition(
self, function_definition_id: str, name: str
) -> None:
if name == "":
raise InvalidContainerDefinitionException(
@ -757,8 +861,11 @@ class GreengrassBackend(BaseBackend):
self.function_definitions[function_definition_id].name = name
def create_function_definition_version(
self, function_definition_id, functions, default_config
):
self,
function_definition_id: str,
functions: List[Dict[str, Any]],
default_config: Dict[str, Any],
) -> FakeFunctionDefinitionVersion:
if function_definition_id not in self.function_definitions:
raise IdNotFoundException("That lambdas does not exist.")
@ -784,14 +891,16 @@ class GreengrassBackend(BaseBackend):
return func_ver
def list_function_definition_versions(self, function_definition_id):
def list_function_definition_versions(
self, function_definition_id: str
) -> Dict[str, FakeFunctionDefinitionVersion]:
if function_definition_id not in self.function_definition_versions:
raise IdNotFoundException("That lambdas definition does not exist.")
return self.function_definition_versions[function_definition_id]
def get_function_definition_version(
self, function_definition_id, function_definition_version_id
):
self, function_definition_id: str, function_definition_version_id: str
) -> FakeFunctionDefinitionVersion:
if function_definition_id not in self.function_definition_versions:
raise IdNotFoundException("That lambdas definition does not exist.")
@ -809,7 +918,7 @@ class GreengrassBackend(BaseBackend):
]
@staticmethod
def _is_valid_subscription_target_or_source(target_or_source):
def _is_valid_subscription_target_or_source(target_or_source: str) -> bool:
if target_or_source in ["cloud", "GGShadowService"]:
return True
@ -829,10 +938,10 @@ class GreengrassBackend(BaseBackend):
return False
@staticmethod
def _validate_subscription_target_or_source(subscriptions):
def _validate_subscription_target_or_source(subscriptions: List[Dict[str, Any]]) -> None: # type: ignore[misc]
target_errors = []
source_errors = []
target_errors: List[str] = []
source_errors: List[str] = []
for subscription in subscriptions:
subscription_id = subscription["Id"]
@ -863,7 +972,9 @@ class GreengrassBackend(BaseBackend):
f"The subscriptions definition is invalid or corrupted. (ErrorDetails: [{error_msg}])",
)
def create_subscription_definition(self, name, initial_version):
def create_subscription_definition(
self, name: str, initial_version: Dict[str, Any]
) -> FakeSubscriptionDefinition:
GreengrassBackend._validate_subscription_target_or_source(
initial_version["Subscriptions"]
@ -883,10 +994,12 @@ class GreengrassBackend(BaseBackend):
sub_def.latest_version_arn = sub_def_ver.arn
return sub_def
def list_subscription_definitions(self):
return self.subscription_definitions.values()
def list_subscription_definitions(self) -> List[FakeSubscriptionDefinition]:
return list(self.subscription_definitions.values())
def get_subscription_definition(self, subscription_definition_id):
def get_subscription_definition(
self, subscription_definition_id: str
) -> FakeSubscriptionDefinition:
if subscription_definition_id not in self.subscription_definitions:
raise IdNotFoundException(
@ -894,13 +1007,15 @@ class GreengrassBackend(BaseBackend):
)
return self.subscription_definitions[subscription_definition_id]
def delete_subscription_definition(self, subscription_definition_id):
def delete_subscription_definition(self, subscription_definition_id: str) -> None:
if subscription_definition_id not in self.subscription_definitions:
raise IdNotFoundException("That subscriptions definition does not exist.")
del self.subscription_definitions[subscription_definition_id]
del self.subscription_definition_versions[subscription_definition_id]
def update_subscription_definition(self, subscription_definition_id, name):
def update_subscription_definition(
self, subscription_definition_id: str, name: str
) -> None:
if name == "":
raise InvalidContainerDefinitionException(
@ -911,8 +1026,8 @@ class GreengrassBackend(BaseBackend):
self.subscription_definitions[subscription_definition_id].name = name
def create_subscription_definition_version(
self, subscription_definition_id, subscriptions
):
self, subscription_definition_id: str, subscriptions: List[Dict[str, Any]]
) -> FakeSubscriptionDefinitionVersion:
GreengrassBackend._validate_subscription_target_or_source(subscriptions)
@ -931,14 +1046,16 @@ class GreengrassBackend(BaseBackend):
return sub_def_ver
def list_subscription_definition_versions(self, subscription_definition_id):
def list_subscription_definition_versions(
self, subscription_definition_id: str
) -> Dict[str, FakeSubscriptionDefinitionVersion]:
if subscription_definition_id not in self.subscription_definition_versions:
raise IdNotFoundException("That subscriptions definition does not exist.")
return self.subscription_definition_versions[subscription_definition_id]
def get_subscription_definition_version(
self, subscription_definition_id, subscription_definition_version_id
):
self, subscription_definition_id: str, subscription_definition_version_id: str
) -> FakeSubscriptionDefinitionVersion:
if subscription_definition_id not in self.subscription_definitions:
raise IdNotFoundException("That subscriptions definition does not exist.")
@ -955,7 +1072,7 @@ class GreengrassBackend(BaseBackend):
subscription_definition_version_id
]
def create_group(self, name, initial_version):
def create_group(self, name: str, initial_version: Dict[str, Any]) -> FakeGroup:
group = FakeGroup(self.account_id, self.region_name, name)
self.groups[group.group_id] = group
@ -983,22 +1100,22 @@ class GreengrassBackend(BaseBackend):
return group
def list_groups(self):
return self.groups.values()
def list_groups(self) -> List[FakeGroup]:
return list(self.groups.values())
def get_group(self, group_id):
def get_group(self, group_id: str) -> Optional[FakeGroup]:
if group_id not in self.groups:
raise IdNotFoundException("That Group Definition does not exist.")
return self.groups.get(group_id)
def delete_group(self, group_id):
def delete_group(self, group_id: str) -> None:
if group_id not in self.groups:
# I don't know why, the error message is different between get_group and delete_group
raise IdNotFoundException("That group definition does not exist.")
del self.groups[group_id]
del self.group_versions[group_id]
def update_group(self, group_id, name):
def update_group(self, group_id: str, name: str) -> None:
if name == "":
raise InvalidContainerDefinitionException(
@ -1010,13 +1127,13 @@ class GreengrassBackend(BaseBackend):
def create_group_version(
self,
group_id,
core_definition_version_arn,
device_definition_version_arn,
function_definition_version_arn,
resource_definition_version_arn,
subscription_definition_version_arn,
):
group_id: str,
core_definition_version_arn: Optional[str],
device_definition_version_arn: Optional[str],
function_definition_version_arn: Optional[str],
resource_definition_version_arn: Optional[str],
subscription_definition_version_arn: Optional[str],
) -> FakeGroupVersion:
if group_id not in self.groups:
raise IdNotFoundException("That group does not exist.")
@ -1048,19 +1165,21 @@ class GreengrassBackend(BaseBackend):
def _validate_group_version_definitions(
self,
core_definition_version_arn=None,
device_definition_version_arn=None,
function_definition_version_arn=None,
resource_definition_version_arn=None,
subscription_definition_version_arn=None,
):
def _is_valid_def_ver_arn(definition_version_arn, kind="cores"):
core_definition_version_arn: Optional[str] = None,
device_definition_version_arn: Optional[str] = None,
function_definition_version_arn: Optional[str] = None,
resource_definition_version_arn: Optional[str] = None,
subscription_definition_version_arn: Optional[str] = None,
) -> None:
def _is_valid_def_ver_arn(
definition_version_arn: Optional[str], kind: str = "cores"
) -> bool:
if definition_version_arn is None:
return True
if kind == "cores":
versions = self.core_definition_versions
versions: Any = self.core_definition_versions
elif kind == "devices":
versions = self.device_definition_versions
elif kind == "functions":
@ -1124,12 +1243,14 @@ class GreengrassBackend(BaseBackend):
f"The group is invalid or corrupted. (ErrorDetails: [{error_details}])",
)
def list_group_versions(self, group_id):
def list_group_versions(self, group_id: str) -> List[FakeGroupVersion]:
if group_id not in self.group_versions:
raise IdNotFoundException("That group definition does not exist.")
return self.group_versions[group_id].values()
return list(self.group_versions[group_id].values())
def get_group_version(self, group_id, group_version_id):
def get_group_version(
self, group_id: str, group_version_id: str
) -> FakeGroupVersion:
if group_id not in self.group_versions:
raise IdNotFoundException("That group definition does not exist.")
@ -1142,8 +1263,12 @@ class GreengrassBackend(BaseBackend):
return self.group_versions[group_id][group_version_id]
def create_deployment(
self, group_id, group_version_id, deployment_type, deployment_id=None
):
self,
group_id: str,
group_version_id: str,
deployment_type: str,
deployment_id: Optional[str] = None,
) -> FakeDeployment:
deployment_types = (
"NewDeployment",
@ -1199,7 +1324,7 @@ class GreengrassBackend(BaseBackend):
self.deployments[deployment.id] = deployment
return deployment
def list_deployments(self, group_id):
def list_deployments(self, group_id: str) -> List[FakeDeployment]:
# ListDeployments API does not check specified group is exists
return [
@ -1208,7 +1333,9 @@ class GreengrassBackend(BaseBackend):
if deployment.group_id == group_id
]
def get_deployment_status(self, group_id, deployment_id):
def get_deployment_status(
self, group_id: str, deployment_id: str
) -> FakeDeploymentStatus:
if deployment_id not in self.deployments:
raise InvalidInputException(f"Deployment '{deployment_id}' does not exist.")
@ -1224,7 +1351,7 @@ class GreengrassBackend(BaseBackend):
deployment.deployment_status,
)
def reset_deployments(self, group_id, force=False):
def reset_deployments(self, group_id: str, force: bool = False) -> FakeDeployment:
if group_id not in self.groups:
raise ResourceNotFoundException("That Group Definition does not exist.")
@ -1248,7 +1375,9 @@ class GreengrassBackend(BaseBackend):
self.deployments[deployment.id] = deployment
return deployment
def associate_role_to_group(self, group_id, role_arn):
def associate_role_to_group(
self, group_id: str, role_arn: str
) -> FakeAssociatedRole:
# I don't know why, AssociateRoleToGroup does not check specified group is exists
# So, this API allows any group id such as "a"
@ -1257,7 +1386,7 @@ class GreengrassBackend(BaseBackend):
self.group_role_associations[group_id] = associated_role
return associated_role
def get_associated_role(self, group_id):
def get_associated_role(self, group_id: str) -> FakeAssociatedRole:
if group_id not in self.group_role_associations:
raise GreengrassClientError(
@ -1266,7 +1395,7 @@ class GreengrassBackend(BaseBackend):
return self.group_role_associations[group_id]
def disassociate_role_from_group(self, group_id):
def disassociate_role_from_group(self, group_id: str) -> None:
if group_id not in self.group_role_associations:
return
del self.group_role_associations[group_id]

View File

@ -1,20 +1,22 @@
from datetime import datetime
from typing import Any
import json
from moto.core.common_types import TYPE_RESPONSE
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.core.responses import BaseResponse
from .models import greengrass_backends
from .models import greengrass_backends, GreengrassBackend
class GreengrassResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="greengrass")
@property
def greengrass_backend(self):
def greengrass_backend(self) -> GreengrassBackend:
return greengrass_backends[self.current_account][self.region]
def core_definitions(self, request, full_url, headers):
def core_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
@ -23,7 +25,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "POST":
return self.create_core_definition()
def list_core_definitions(self):
def list_core_definitions(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_core_definitions()
return (
200,
@ -33,7 +35,7 @@ class GreengrassResponse(BaseResponse):
),
)
def create_core_definition(self):
def create_core_definition(self) -> TYPE_RESPONSE:
name = self._get_param("Name")
initial_version = self._get_param("InitialVersion")
res = self.greengrass_backend.create_core_definition(
@ -41,7 +43,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def core_definition(self, request, full_url, headers):
def core_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
@ -53,21 +55,21 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT":
return self.update_core_definition()
def get_core_definition(self):
def get_core_definition(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_core_definition(
core_definition_id=core_definition_id
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_core_definition(self):
def delete_core_definition(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-1]
self.greengrass_backend.delete_core_definition(
core_definition_id=core_definition_id
)
return 200, {"status": 200}, json.dumps({})
def update_core_definition(self):
def update_core_definition(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-1]
name = self._get_param("Name")
self.greengrass_backend.update_core_definition(
@ -75,7 +77,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps({})
def core_definition_versions(self, request, full_url, headers):
def core_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
@ -84,7 +86,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "POST":
return self.create_core_definition_version()
def create_core_definition_version(self):
def create_core_definition_version(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-2]
cores = self._get_param("Cores")
@ -93,7 +95,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_core_definition_versions(self):
def list_core_definition_versions(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_core_definition_versions(core_definition_id)
return (
@ -102,13 +104,13 @@ class GreengrassResponse(BaseResponse):
json.dumps({"Versions": [core_def_ver.to_dict() for core_def_ver in res]}),
)
def core_definition_version(self, request, full_url, headers):
def core_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
return self.get_core_definition_version()
def get_core_definition_version(self):
def get_core_definition_version(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-3]
core_definition_version_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_core_definition_version(
@ -117,7 +119,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True))
def device_definitions(self, request, full_url, headers):
def device_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -126,7 +128,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_device_definition()
def create_device_definition(self):
def create_device_definition(self) -> TYPE_RESPONSE:
name = self._get_param("Name")
initial_version = self._get_param("InitialVersion")
@ -135,7 +137,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_device_definition(self):
def list_device_definition(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_device_definitions()
return (
200,
@ -149,7 +151,7 @@ class GreengrassResponse(BaseResponse):
),
)
def device_definition_versions(self, request, full_url, headers):
def device_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -158,7 +160,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_device_definition_versions()
def create_device_definition_version(self):
def create_device_definition_version(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-2]
devices = self._get_param("Devices")
@ -168,7 +170,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_device_definition_versions(self):
def list_device_definition_versions(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_device_definition_versions(
@ -182,7 +184,7 @@ class GreengrassResponse(BaseResponse):
),
)
def device_definition(self, request, full_url, headers):
def device_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
@ -194,14 +196,14 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT":
return self.update_device_definition()
def get_device_definition(self):
def get_device_definition(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_device_definition(
device_definition_id=device_definition_id
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_device_definition(self):
def delete_device_definition(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-1]
self.greengrass_backend.delete_device_definition(
@ -209,7 +211,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps({})
def update_device_definition(self):
def update_device_definition(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-1]
name = self._get_param("Name")
@ -218,13 +220,13 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps({})
def device_definition_version(self, request, full_url, headers):
def device_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
return self.get_device_definition_version()
def get_device_definition_version(self):
def get_device_definition_version(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-3]
device_definition_version_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_device_definition_version(
@ -233,7 +235,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True))
def resource_definitions(self, request, full_url, headers):
def resource_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -242,7 +244,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_resource_definitions()
def create_resource_definition(self):
def create_resource_definition(self) -> TYPE_RESPONSE:
initial_version = self._get_param("InitialVersion")
name = self._get_param("Name")
@ -251,16 +253,16 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_resource_definitions(self):
def list_resource_definitions(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_resource_definitions()
return (
200,
{"status": 200},
json.dumps({"Definitions": [i.to_dict() for i in res.values()]}),
json.dumps({"Definitions": [i.to_dict() for i in res]}),
)
def resource_definition(self, request, full_url, headers):
def resource_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
@ -272,14 +274,14 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT":
return self.update_resource_definition()
def get_resource_definition(self):
def get_resource_definition(self) -> TYPE_RESPONSE:
resource_definition_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_resource_definition(
resource_definition_id=resource_definition_id
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_resource_definition(self):
def delete_resource_definition(self) -> TYPE_RESPONSE:
resource_definition_id = self.path.split("/")[-1]
self.greengrass_backend.delete_resource_definition(
@ -287,7 +289,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps({})
def update_resource_definition(self):
def update_resource_definition(self) -> TYPE_RESPONSE:
resource_definition_id = self.path.split("/")[-1]
name = self._get_param("Name")
@ -296,7 +298,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps({})
def resource_definition_versions(self, request, full_url, headers):
def resource_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -305,7 +307,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_resource_definition_versions()
def create_resource_definition_version(self):
def create_resource_definition_version(self) -> TYPE_RESPONSE:
resource_definition_id = self.path.split("/")[-2]
resources = self._get_param("Resources")
@ -315,7 +317,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_resource_definition_versions(self):
def list_resource_definition_versions(self) -> TYPE_RESPONSE:
resource_device_definition_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_resource_definition_versions(
@ -330,13 +332,13 @@ class GreengrassResponse(BaseResponse):
),
)
def resource_definition_version(self, request, full_url, headers):
def resource_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
return self.get_resource_definition_version()
def get_resource_definition_version(self):
def get_resource_definition_version(self) -> TYPE_RESPONSE:
resource_definition_id = self.path.split("/")[-3]
resource_definition_version_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_resource_definition_version(
@ -345,7 +347,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def function_definitions(self, request, full_url, headers):
def function_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -354,7 +356,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_function_definitions()
def create_function_definition(self):
def create_function_definition(self) -> TYPE_RESPONSE:
initial_version = self._get_param("InitialVersion")
name = self._get_param("Name")
@ -363,7 +365,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_function_definitions(self):
def list_function_definitions(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_function_definitions()
return (
200,
@ -373,7 +375,7 @@ class GreengrassResponse(BaseResponse):
),
)
def function_definition(self, request, full_url, headers):
def function_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
@ -385,21 +387,21 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT":
return self.update_function_definition()
def get_function_definition(self):
def get_function_definition(self) -> TYPE_RESPONSE:
function_definition_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_function_definition(
function_definition_id=function_definition_id,
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_function_definition(self):
def delete_function_definition(self) -> TYPE_RESPONSE:
function_definition_id = self.path.split("/")[-1]
self.greengrass_backend.delete_function_definition(
function_definition_id=function_definition_id,
)
return 200, {"status": 200}, json.dumps({})
def update_function_definition(self):
def update_function_definition(self) -> TYPE_RESPONSE:
function_definition_id = self.path.split("/")[-1]
name = self._get_param("Name")
self.greengrass_backend.update_function_definition(
@ -407,7 +409,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps({})
def function_definition_versions(self, request, full_url, headers):
def function_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -416,7 +418,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_function_definition_versions()
def create_function_definition_version(self):
def create_function_definition_version(self) -> TYPE_RESPONSE:
default_config = self._get_param("DefaultConfig")
function_definition_id = self.path.split("/")[-2]
@ -429,7 +431,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_function_definition_versions(self):
def list_function_definition_versions(self) -> TYPE_RESPONSE:
function_definition_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_function_definition_versions(
function_definition_id=function_definition_id
@ -437,13 +439,13 @@ class GreengrassResponse(BaseResponse):
versions = [i.to_dict() for i in res.values()]
return 200, {"status": 200}, json.dumps({"Versions": versions})
def function_definition_version(self, request, full_url, headers):
def function_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
return self.get_function_definition_version()
def get_function_definition_version(self):
def get_function_definition_version(self) -> TYPE_RESPONSE:
function_definition_id = self.path.split("/")[-3]
function_definition_version_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_function_definition_version(
@ -452,7 +454,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def subscription_definitions(self, request, full_url, headers):
def subscription_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -461,7 +463,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_subscription_definitions()
def create_subscription_definition(self):
def create_subscription_definition(self) -> TYPE_RESPONSE:
initial_version = self._get_param("InitialVersion")
name = self._get_param("Name")
@ -470,7 +472,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_subscription_definitions(self):
def list_subscription_definitions(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_subscription_definitions()
return (
@ -486,7 +488,7 @@ class GreengrassResponse(BaseResponse):
),
)
def subscription_definition(self, request, full_url, headers):
def subscription_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
@ -498,21 +500,21 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT":
return self.update_subscription_definition()
def get_subscription_definition(self):
def get_subscription_definition(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_subscription_definition(
subscription_definition_id=subscription_definition_id
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_subscription_definition(self):
def delete_subscription_definition(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-1]
self.greengrass_backend.delete_subscription_definition(
subscription_definition_id=subscription_definition_id
)
return 200, {"status": 200}, json.dumps({})
def update_subscription_definition(self):
def update_subscription_definition(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-1]
name = self._get_param("Name")
self.greengrass_backend.update_subscription_definition(
@ -520,7 +522,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps({})
def subscription_definition_versions(self, request, full_url, headers):
def subscription_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -529,7 +531,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_subscription_definition_versions()
def create_subscription_definition_version(self):
def create_subscription_definition_version(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-2]
subscriptions = self._get_param("Subscriptions")
@ -539,7 +541,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_subscription_definition_versions(self):
def list_subscription_definition_versions(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_subscription_definition_versions(
subscription_definition_id=subscription_definition_id
@ -547,13 +549,13 @@ class GreengrassResponse(BaseResponse):
versions = [i.to_dict() for i in res.values()]
return 200, {"status": 200}, json.dumps({"Versions": versions})
def subscription_definition_version(self, request, full_url, headers):
def subscription_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
return self.get_subscription_definition_version()
def get_subscription_definition_version(self):
def get_subscription_definition_version(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-3]
subscription_definition_version_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_subscription_definition_version(
@ -562,7 +564,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def groups(self, request, full_url, headers):
def groups(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -571,7 +573,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_groups()
def create_group(self):
def create_group(self) -> TYPE_RESPONSE:
initial_version = self._get_param("InitialVersion")
name = self._get_param("Name")
@ -580,7 +582,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_groups(self):
def list_groups(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_groups()
return (
@ -589,7 +591,7 @@ class GreengrassResponse(BaseResponse):
json.dumps({"Groups": [group.to_dict() for group in res]}),
)
def group(self, request, full_url, headers):
def group(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
@ -601,27 +603,25 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT":
return self.update_group()
def get_group(self):
def get_group(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_group(
group_id=group_id,
)
return 200, {"status": 200}, json.dumps(res.to_dict())
res = self.greengrass_backend.get_group(group_id=group_id)
return 200, {"status": 200}, json.dumps(res.to_dict()) # type: ignore
def delete_group(self):
def delete_group(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-1]
self.greengrass_backend.delete_group(
group_id=group_id,
)
return 200, {"status": 200}, json.dumps({})
def update_group(self):
def update_group(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-1]
name = self._get_param("Name")
self.greengrass_backend.update_group(group_id=group_id, name=name)
return 200, {"status": 200}, json.dumps({})
def group_versions(self, request, full_url, headers):
def group_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -630,7 +630,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_group_versions()
def create_group_version(self):
def create_group_version(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2]
@ -656,7 +656,7 @@ class GreengrassResponse(BaseResponse):
)
return 201, {"status": 201}, json.dumps(res.to_dict())
def list_group_versions(self):
def list_group_versions(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_group_versions(group_id=group_id)
return (
@ -665,13 +665,13 @@ class GreengrassResponse(BaseResponse):
json.dumps({"Versions": [group_ver.to_dict() for group_ver in res]}),
)
def group_version(self, request, full_url, headers):
def group_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
return self.get_group_version()
def get_group_version(self):
def get_group_version(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-3]
group_version_id = self.path.split("/")[-1]
@ -681,7 +681,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True))
def deployments(self, request, full_url, headers):
def deployments(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
@ -690,7 +690,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET":
return self.list_deployments()
def create_deployment(self):
def create_deployment(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2]
group_version_id = self._get_param("GroupVersionId")
@ -705,7 +705,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def list_deployments(self):
def list_deployments(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_deployments(group_id=group_id)
@ -721,13 +721,13 @@ class GreengrassResponse(BaseResponse):
json.dumps({"Deployments": deployments}),
)
def deployment_satus(self, request, full_url, headers):
def deployment_satus(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "GET":
return self.get_deployment_status()
def get_deployment_status(self):
def get_deployment_status(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-4]
deployment_id = self.path.split("/")[-2]
@ -737,13 +737,13 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def deployments_reset(self, request, full_url, headers):
def deployments_reset(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "POST":
return self.reset_deployments()
def reset_deployments(self):
def reset_deployments(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-3]
res = self.greengrass_backend.reset_deployments(
@ -751,7 +751,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def role(self, request, full_url, headers):
def role(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if self.method == "PUT":
@ -763,7 +763,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "DELETE":
return self.disassociate_role_from_group()
def associate_role_to_group(self):
def associate_role_to_group(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2]
role_arn = self._get_param("RoleArn")
@ -773,7 +773,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def get_associated_role(self):
def get_associated_role(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2]
res = self.greengrass_backend.get_associated_role(
@ -781,7 +781,7 @@ class GreengrassResponse(BaseResponse):
)
return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True))
def disassociate_role_from_group(self):
def disassociate_role_from_group(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2]
self.greengrass_backend.disassociate_role_from_group(
group_id=group_id,

View File

@ -1,3 +1,4 @@
from typing import Any, List, Tuple
from moto.core.exceptions import JsonRESTError
@ -8,24 +9,28 @@ class GuardDutyException(JsonRESTError):
class DetectorNotFoundException(GuardDutyException):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"InvalidInputException",
"The request is rejected because the input detectorId is not owned by the current account.",
)
def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument
return {"X-Amzn-ErrorType": "BadRequestException"}
def get_headers(
self, *args: Any, **kwargs: Any
) -> List[Tuple[str, str]]: # pylint: disable=unused-argument
return [("X-Amzn-ErrorType", "BadRequestException")]
class FilterNotFoundException(GuardDutyException):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"InvalidInputException",
"The request is rejected since no such resource found.",
)
def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument
return {"X-Amzn-ErrorType": "BadRequestException"}
def get_headers(
self, *args: Any, **kwargs: Any
) -> List[Tuple[str, str]]: # pylint: disable=unused-argument
return [("X-Amzn-ErrorType", "BadRequestException")]

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random
from datetime import datetime
@ -6,12 +7,18 @@ from .exceptions import DetectorNotFoundException, FilterNotFoundException
class GuardDutyBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.admin_account_ids = []
self.detectors = {}
self.admin_account_ids: List[str] = []
self.detectors: Dict[str, Detector] = {}
def create_detector(self, enable, finding_publishing_frequency, data_sources, tags):
def create_detector(
self,
enable: bool,
finding_publishing_frequency: str,
data_sources: Dict[str, Any],
tags: Dict[str, str],
) -> str:
if finding_publishing_frequency not in [
"FIFTEEN_MINUTES",
"ONE_HOUR",
@ -31,29 +38,35 @@ class GuardDutyBackend(BaseBackend):
return detector.id
def create_filter(
self, detector_id, name, action, description, finding_criteria, rank
):
self,
detector_id: str,
name: str,
action: str,
description: str,
finding_criteria: Dict[str, Any],
rank: int,
) -> None:
detector = self.get_detector(detector_id)
_filter = Filter(name, action, description, finding_criteria, rank)
detector.add_filter(_filter)
def delete_detector(self, detector_id):
def delete_detector(self, detector_id: str) -> None:
self.detectors.pop(detector_id, None)
def delete_filter(self, detector_id, filter_name):
def delete_filter(self, detector_id: str, filter_name: str) -> None:
detector = self.get_detector(detector_id)
detector.delete_filter(filter_name)
def enable_organization_admin_account(self, admin_account_id):
def enable_organization_admin_account(self, admin_account_id: str) -> None:
self.admin_account_ids.append(admin_account_id)
def list_organization_admin_accounts(self):
def list_organization_admin_accounts(self) -> List[str]:
"""
Pagination is not yet implemented
"""
return self.admin_account_ids
def list_detectors(self):
def list_detectors(self) -> List[str]:
"""
The MaxResults and NextToken-parameter have not yet been implemented.
"""
@ -62,24 +75,34 @@ class GuardDutyBackend(BaseBackend):
detectorids.append(self.detectors[detector].id)
return detectorids
def get_detector(self, detector_id):
def get_detector(self, detector_id: str) -> "Detector":
if detector_id not in self.detectors:
raise DetectorNotFoundException
return self.detectors[detector_id]
def get_filter(self, detector_id, filter_name):
def get_filter(self, detector_id: str, filter_name: str) -> "Filter":
detector = self.get_detector(detector_id)
return detector.get_filter(filter_name)
def update_detector(
self, detector_id, enable, finding_publishing_frequency, data_sources
):
self,
detector_id: str,
enable: bool,
finding_publishing_frequency: str,
data_sources: Dict[str, Any],
) -> None:
detector = self.get_detector(detector_id)
detector.update(enable, finding_publishing_frequency, data_sources)
def update_filter(
self, detector_id, filter_name, action, description, finding_criteria, rank
):
self,
detector_id: str,
filter_name: str,
action: str,
description: str,
finding_criteria: Dict[str, Any],
rank: int,
) -> None:
detector = self.get_detector(detector_id)
detector.update_filter(
filter_name,
@ -91,14 +114,27 @@ class GuardDutyBackend(BaseBackend):
class Filter(BaseModel):
def __init__(self, name, action, description, finding_criteria, rank):
def __init__(
self,
name: str,
action: str,
description: str,
finding_criteria: Dict[str, Any],
rank: int,
):
self.name = name
self.action = action
self.description = description
self.finding_criteria = finding_criteria
self.rank = rank or 1
def update(self, action, description, finding_criteria, rank):
def update(
self,
action: Optional[str],
description: Optional[str],
finding_criteria: Optional[Dict[str, Any]],
rank: Optional[int],
) -> None:
if action is not None:
self.action = action
if description is not None:
@ -108,7 +144,7 @@ class Filter(BaseModel):
if rank is not None:
self.rank = rank
def to_json(self):
def to_json(self) -> Dict[str, Any]:
return {
"name": self.name,
"action": self.action,
@ -121,12 +157,12 @@ class Filter(BaseModel):
class Detector(BaseModel):
def __init__(
self,
account_id,
created_at,
finding_publish_freq,
enabled,
datasources,
tags,
account_id: str,
created_at: datetime,
finding_publish_freq: str,
enabled: bool,
datasources: Dict[str, Any],
tags: Dict[str, str],
):
self.id = mock_random.get_random_hex(length=32)
self.created_at = created_at
@ -137,20 +173,27 @@ class Detector(BaseModel):
self.datasources = datasources or {}
self.tags = tags or {}
self.filters = dict()
self.filters: Dict[str, Filter] = dict()
def add_filter(self, _filter: Filter):
def add_filter(self, _filter: Filter) -> None:
self.filters[_filter.name] = _filter
def delete_filter(self, filter_name):
def delete_filter(self, filter_name: str) -> None:
self.filters.pop(filter_name, None)
def get_filter(self, filter_name: str):
def get_filter(self, filter_name: str) -> Filter:
if filter_name not in self.filters:
raise FilterNotFoundException
return self.filters[filter_name]
def update_filter(self, filter_name, action, description, finding_criteria, rank):
def update_filter(
self,
filter_name: str,
action: str,
description: str,
finding_criteria: Dict[str, Any],
rank: int,
) -> None:
_filter = self.get_filter(filter_name)
_filter.update(
action=action,
@ -159,7 +202,12 @@ class Detector(BaseModel):
rank=rank,
)
def update(self, enable, finding_publishing_frequency, data_sources):
def update(
self,
enable: bool,
finding_publishing_frequency: str,
data_sources: Dict[str, Any],
) -> None:
if enable is not None:
self.enabled = enable
if finding_publishing_frequency is not None:
@ -167,7 +215,7 @@ class Detector(BaseModel):
if data_sources is not None:
self.datasources = data_sources
def to_json(self):
def to_json(self) -> Dict[str, Any]:
data_sources = {
"cloudTrail": {"status": "DISABLED"},
"dnsLogs": {"status": "DISABLED"},

View File

@ -1,18 +1,20 @@
from typing import Any
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from .models import guardduty_backends
from .models import guardduty_backends, GuardDutyBackend
import json
from urllib.parse import unquote
class GuardDutyResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="guardduty")
@property
def guardduty_backend(self):
def guardduty_backend(self) -> GuardDutyBackend:
return guardduty_backends[self.current_account][self.region]
def filter(self, request, full_url, headers):
def filter(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self.get_filter()
@ -21,12 +23,12 @@ class GuardDutyResponse(BaseResponse):
elif request.method == "POST":
return self.update_filter()
def filters(self, request, full_url, headers):
def filters(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "POST":
return self.create_filter()
def detectors(self, request, full_url, headers):
def detectors(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
if request.method == "POST":
return self.create_detector()
@ -35,7 +37,7 @@ class GuardDutyResponse(BaseResponse):
else:
return 404, {}, ""
def detector(self, request, full_url, headers):
def detector(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self.get_detector()
@ -44,7 +46,7 @@ class GuardDutyResponse(BaseResponse):
elif request.method == "POST":
return self.update_detector()
def create_filter(self):
def create_filter(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-2]
name = self._get_param("name")
action = self._get_param("action")
@ -57,7 +59,7 @@ class GuardDutyResponse(BaseResponse):
)
return 200, {}, json.dumps({"name": name})
def create_detector(self):
def create_detector(self) -> TYPE_RESPONSE:
enable = self._get_param("enable")
finding_publishing_frequency = self._get_param("findingPublishingFrequency")
data_sources = self._get_param("dataSources")
@ -69,20 +71,22 @@ class GuardDutyResponse(BaseResponse):
return 200, {}, json.dumps(dict(detectorId=detector_id))
def delete_detector(self):
def delete_detector(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-1]
self.guardduty_backend.delete_detector(detector_id)
return 200, {}, "{}"
def delete_filter(self):
def delete_filter(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-3]
filter_name = unquote(self.path.split("/")[-1])
self.guardduty_backend.delete_filter(detector_id, filter_name)
return 200, {}, "{}"
def enable_organization_admin_account(self, request, full_url, headers):
def enable_organization_admin_account(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
admin_account = self._get_param("adminAccountId")
@ -90,7 +94,9 @@ class GuardDutyResponse(BaseResponse):
return 200, {}, "{}"
def list_organization_admin_accounts(self, request, full_url, headers):
def list_organization_admin_accounts(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
account_ids = self.guardduty_backend.list_organization_admin_accounts()
@ -108,25 +114,25 @@ class GuardDutyResponse(BaseResponse):
),
)
def list_detectors(self):
def list_detectors(self) -> TYPE_RESPONSE:
detector_ids = self.guardduty_backend.list_detectors()
return 200, {}, json.dumps({"detectorIds": detector_ids})
def get_detector(self):
def get_detector(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-1]
detector = self.guardduty_backend.get_detector(detector_id)
return 200, {}, json.dumps(detector.to_json())
def get_filter(self):
def get_filter(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-3]
filter_name = unquote(self.path.split("/")[-1])
_filter = self.guardduty_backend.get_filter(detector_id, filter_name)
return 200, {}, json.dumps(_filter.to_json())
def update_detector(self):
def update_detector(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-1]
enable = self._get_param("enable")
finding_publishing_frequency = self._get_param("findingPublishingFrequency")
@ -137,7 +143,7 @@ class GuardDutyResponse(BaseResponse):
)
return 200, {}, "{}"
def update_filter(self):
def update_filter(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-3]
filter_name = unquote(self.path.split("/")[-1])
action = self._get_param("action")

View File

@ -171,7 +171,9 @@ class TaggingService:
)
@staticmethod
def convert_dict_to_tags_input(tags: Dict[str, str]) -> List[Dict[str, str]]:
def convert_dict_to_tags_input(
tags: Optional[Dict[str, str]]
) -> List[Dict[str, str]]:
"""Given a dictionary, return generic boto params for tags"""
if not tags:
return []

View File

@ -229,7 +229,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/moto_api,moto/neptune
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/moto_api,moto/neptune
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract