Techdebt: MyPy g-models (#6048)

This commit is contained in:
Bert Blommers 2023-03-11 16:00:52 -01:00 committed by GitHub
parent bb39b02098
commit f1f4454b0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1130 additions and 797 deletions

View File

@ -1,6 +1,6 @@
import hashlib import hashlib
import datetime import datetime
from typing import Any, Dict, List, Optional, Union
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.utilities.utils import md5_hash from moto.utilities.utils import md5_hash
@ -9,7 +9,7 @@ from .utils import get_job_id
class Job(BaseModel): class Job(BaseModel):
def __init__(self, tier): def __init__(self, tier: str):
self.st = datetime.datetime.now() self.st = datetime.datetime.now()
if tier.lower() == "expedited": if tier.lower() == "expedited":
@ -20,16 +20,19 @@ class Job(BaseModel):
# Standard # Standard
self.et = self.st + datetime.timedelta(seconds=5) self.et = self.st + datetime.timedelta(seconds=5)
def to_dict(self) -> Dict[str, Any]:
return {}
class ArchiveJob(Job): class ArchiveJob(Job):
def __init__(self, job_id, tier, arn, archive_id): def __init__(self, job_id: str, tier: str, arn: str, archive_id: Optional[str]):
self.job_id = job_id self.job_id = job_id
self.tier = tier self.tier = tier
self.arn = arn self.arn = arn
self.archive_id = archive_id self.archive_id = archive_id
Job.__init__(self, tier) Job.__init__(self, tier)
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
d = { d = {
"Action": "ArchiveRetrieval", "Action": "ArchiveRetrieval",
"ArchiveId": self.archive_id, "ArchiveId": self.archive_id,
@ -57,13 +60,13 @@ class ArchiveJob(Job):
class InventoryJob(Job): class InventoryJob(Job):
def __init__(self, job_id, tier, arn): def __init__(self, job_id: str, tier: str, arn: str):
self.job_id = job_id self.job_id = job_id
self.tier = tier self.tier = tier
self.arn = arn self.arn = arn
Job.__init__(self, tier) Job.__init__(self, tier)
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
d = { d = {
"Action": "InventoryRetrieval", "Action": "InventoryRetrieval",
"ArchiveSHA256TreeHash": None, "ArchiveSHA256TreeHash": None,
@ -89,15 +92,15 @@ class InventoryJob(Job):
class Vault(BaseModel): class Vault(BaseModel):
def __init__(self, vault_name, account_id, region): def __init__(self, vault_name: str, account_id: str, region: str):
self.st = datetime.datetime.now() self.st = datetime.datetime.now()
self.vault_name = vault_name self.vault_name = vault_name
self.region = region self.region = region
self.archives = {} self.archives: Dict[str, Dict[str, Any]] = {}
self.jobs = {} self.jobs: Dict[str, Job] = {}
self.arn = f"arn:aws:glacier:{region}:{account_id}:vaults/{vault_name}" self.arn = f"arn:aws:glacier:{region}:{account_id}:vaults/{vault_name}"
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
archives_size = 0 archives_size = 0
for k in self.archives: for k in self.archives:
archives_size += self.archives[k]["size"] archives_size += self.archives[k]["size"]
@ -111,7 +114,7 @@ class Vault(BaseModel):
} }
return d return d
def create_archive(self, body, description): def create_archive(self, body: bytes, description: str) -> Dict[str, Any]:
archive_id = md5_hash(body).hexdigest() archive_id = md5_hash(body).hexdigest()
self.archives[archive_id] = {} self.archives[archive_id] = {}
self.archives[archive_id]["archive_id"] = archive_id self.archives[archive_id]["archive_id"] = archive_id
@ -124,10 +127,10 @@ class Vault(BaseModel):
self.archives[archive_id]["description"] = description self.archives[archive_id]["description"] = description
return self.archives[archive_id] return self.archives[archive_id]
def get_archive_body(self, archive_id): def get_archive_body(self, archive_id: str) -> str:
return self.archives[archive_id]["body"] return self.archives[archive_id]["body"]
def get_archive_list(self): def get_archive_list(self) -> List[Dict[str, Any]]:
archive_list = [] archive_list = []
for a in self.archives: for a in self.archives:
archive = self.archives[a] archive = self.archives[a]
@ -141,34 +144,33 @@ class Vault(BaseModel):
archive_list.append(aobj) archive_list.append(aobj)
return archive_list return archive_list
def delete_archive(self, archive_id): def delete_archive(self, archive_id: str) -> Dict[str, Any]:
return self.archives.pop(archive_id) return self.archives.pop(archive_id)
def initiate_job(self, job_type, tier, archive_id): def initiate_job(self, job_type: str, tier: str, archive_id: Optional[str]) -> str:
job_id = get_job_id() job_id = get_job_id()
if job_type == "inventory-retrieval": if job_type == "inventory-retrieval":
job = InventoryJob(job_id, tier, self.arn) self.jobs[job_id] = InventoryJob(job_id, tier, self.arn)
elif job_type == "archive-retrieval": elif job_type == "archive-retrieval":
job = ArchiveJob(job_id, tier, self.arn, archive_id) self.jobs[job_id] = ArchiveJob(job_id, tier, self.arn, archive_id)
self.jobs[job_id] = job
return job_id return job_id
def list_jobs(self): def list_jobs(self) -> List[Job]:
return self.jobs.values() return list(self.jobs.values())
def describe_job(self, job_id): def describe_job(self, job_id: str) -> Optional[Job]:
return self.jobs.get(job_id) return self.jobs.get(job_id)
def job_ready(self, job_id): def job_ready(self, job_id: str) -> str:
job = self.describe_job(job_id) job = self.describe_job(job_id)
jobj = job.to_dict() jobj = job.to_dict() # type: ignore
return jobj["Completed"] return jobj["Completed"]
def get_job_output(self, job_id): def get_job_output(self, job_id: str) -> Union[str, Dict[str, Any]]:
job = self.describe_job(job_id) job = self.describe_job(job_id)
jobj = job.to_dict() jobj = job.to_dict() # type: ignore
if jobj["Action"] == "InventoryRetrieval": if jobj["Action"] == "InventoryRetrieval":
archives = self.get_archive_list() archives = self.get_archive_list()
return { return {
@ -177,48 +179,54 @@ class Vault(BaseModel):
"ArchiveList": archives, "ArchiveList": archives,
} }
else: else:
archive_body = self.get_archive_body(job.archive_id) archive_body = self.get_archive_body(job.archive_id) # type: ignore
return archive_body return archive_body
class GlacierBackend(BaseBackend): class GlacierBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.vaults = {} self.vaults: Dict[str, Vault] = {}
def get_vault(self, vault_name): def get_vault(self, vault_name: str) -> Vault:
return self.vaults[vault_name] return self.vaults[vault_name]
def create_vault(self, vault_name): def create_vault(self, vault_name: str) -> None:
self.vaults[vault_name] = Vault(vault_name, self.account_id, self.region_name) self.vaults[vault_name] = Vault(vault_name, self.account_id, self.region_name)
def list_vaults(self): def list_vaults(self) -> List[Vault]:
return self.vaults.values() return list(self.vaults.values())
def delete_vault(self, vault_name): def delete_vault(self, vault_name: str) -> None:
self.vaults.pop(vault_name) self.vaults.pop(vault_name)
def initiate_job(self, vault_name, job_type, tier, archive_id): def initiate_job(
self, vault_name: str, job_type: str, tier: str, archive_id: Optional[str]
) -> str:
vault = self.get_vault(vault_name) vault = self.get_vault(vault_name)
job_id = vault.initiate_job(job_type, tier, archive_id) job_id = vault.initiate_job(job_type, tier, archive_id)
return job_id return job_id
def describe_job(self, vault_name, archive_id): def describe_job(self, vault_name: str, archive_id: str) -> Optional[Job]:
vault = self.get_vault(vault_name) vault = self.get_vault(vault_name)
return vault.describe_job(archive_id) return vault.describe_job(archive_id)
def list_jobs(self, vault_name): def list_jobs(self, vault_name: str) -> List[Job]:
vault = self.get_vault(vault_name) vault = self.get_vault(vault_name)
return vault.list_jobs() return vault.list_jobs()
def get_job_output(self, vault_name, job_id): def get_job_output(
self, vault_name: str, job_id: str
) -> Union[str, Dict[str, Any], None]:
vault = self.get_vault(vault_name) vault = self.get_vault(vault_name)
if vault.job_ready(job_id): if vault.job_ready(job_id):
return vault.get_job_output(job_id) return vault.get_job_output(job_id)
else: else:
return None return None
def upload_archive(self, vault_name, body, description): def upload_archive(
self, vault_name: str, body: bytes, description: str
) -> Dict[str, Any]:
vault = self.get_vault(vault_name) vault = self.get_vault(vault_name)
return vault.create_archive(body, description) return vault.create_archive(body, description)

View File

@ -1,23 +1,26 @@
import json import json
from typing import Any, Dict
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import glacier_backends from .models import glacier_backends, GlacierBackend
from .utils import vault_from_glacier_url from .utils import vault_from_glacier_url
class GlacierResponse(BaseResponse): class GlacierResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="glacier") super().__init__(service_name="glacier")
@property @property
def glacier_backend(self): def glacier_backend(self) -> GlacierBackend:
return glacier_backends[self.current_account][self.region] return glacier_backends[self.current_account][self.region]
def all_vault_response(self, request, full_url, headers): def all_vault_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
return self._all_vault_response(headers) return self._all_vault_response(headers)
def _all_vault_response(self, headers): def _all_vault_response(self, headers: Any) -> TYPE_RESPONSE:
vaults = self.glacier_backend.list_vaults() vaults = self.glacier_backend.list_vaults()
response = json.dumps( response = json.dumps(
{"Marker": None, "VaultList": [vault.to_dict() for vault in vaults]} {"Marker": None, "VaultList": [vault.to_dict() for vault in vaults]}
@ -26,11 +29,13 @@ class GlacierResponse(BaseResponse):
headers["content-type"] = "application/json" headers["content-type"] = "application/json"
return 200, headers, response return 200, headers, response
def vault_response(self, request, full_url, headers): def vault_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
return self._vault_response(request, full_url, headers) return self._vault_response(request, full_url, headers)
def _vault_response(self, request, full_url, headers): def _vault_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
method = request.method method = request.method
vault_name = vault_from_glacier_url(full_url) vault_name = vault_from_glacier_url(full_url)
@ -41,23 +46,27 @@ class GlacierResponse(BaseResponse):
elif method == "DELETE": elif method == "DELETE":
return self._vault_response_delete(vault_name, headers) return self._vault_response_delete(vault_name, headers)
def _vault_response_get(self, vault_name, headers): def _vault_response_get(self, vault_name: str, headers: Any) -> TYPE_RESPONSE:
vault = self.glacier_backend.get_vault(vault_name) vault = self.glacier_backend.get_vault(vault_name)
headers["content-type"] = "application/json" headers["content-type"] = "application/json"
return 200, headers, json.dumps(vault.to_dict()) return 200, headers, json.dumps(vault.to_dict())
def _vault_response_put(self, vault_name, headers): def _vault_response_put(self, vault_name: str, headers: Any) -> TYPE_RESPONSE:
self.glacier_backend.create_vault(vault_name) self.glacier_backend.create_vault(vault_name)
return 201, headers, "" return 201, headers, ""
def _vault_response_delete(self, vault_name, headers): def _vault_response_delete(self, vault_name: str, headers: Any) -> TYPE_RESPONSE:
self.glacier_backend.delete_vault(vault_name) self.glacier_backend.delete_vault(vault_name)
return 204, headers, "" return 204, headers, ""
def vault_archive_response(self, request, full_url, headers): def vault_archive_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
return self._vault_archive_response(request, full_url, headers) return self._vault_archive_response(request, full_url, headers)
def _vault_archive_response(self, request, full_url, headers): def _vault_archive_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
method = request.method method = request.method
if hasattr(request, "body"): if hasattr(request, "body"):
body = request.body body = request.body
@ -75,17 +84,21 @@ class GlacierResponse(BaseResponse):
else: else:
return 400, headers, "400 Bad Request" return 400, headers, "400 Bad Request"
def _vault_archive_response_post(self, vault_name, body, description, headers): def _vault_archive_response_post(
self, vault_name: str, body: bytes, description: str, headers: Dict[str, Any]
) -> TYPE_RESPONSE:
vault = self.glacier_backend.upload_archive(vault_name, body, description) vault = self.glacier_backend.upload_archive(vault_name, body, description)
headers["x-amz-archive-id"] = vault["archive_id"] headers["x-amz-archive-id"] = vault["archive_id"]
headers["x-amz-sha256-tree-hash"] = vault["sha256"] headers["x-amz-sha256-tree-hash"] = vault["sha256"]
return 201, headers, "" return 201, headers, ""
def vault_archive_individual_response(self, request, full_url, headers): def vault_archive_individual_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
return self._vault_archive_individual_response(request, full_url, headers) return self._vault_archive_individual_response(request, full_url, headers)
def _vault_archive_individual_response(self, request, full_url, headers): def _vault_archive_individual_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
method = request.method method = request.method
vault_name = full_url.split("/")[-3] vault_name = full_url.split("/")[-3]
archive_id = full_url.split("/")[-1] archive_id = full_url.split("/")[-1]
@ -95,11 +108,13 @@ class GlacierResponse(BaseResponse):
vault.delete_archive(archive_id) vault.delete_archive(archive_id)
return 204, headers, "" return 204, headers, ""
def vault_jobs_response(self, request, full_url, headers): def vault_jobs_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
return self._vault_jobs_response(request, full_url, headers) return self._vault_jobs_response(request, full_url, headers)
def _vault_jobs_response(self, request, full_url, headers): def _vault_jobs_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
method = request.method method = request.method
if hasattr(request, "body"): if hasattr(request, "body"):
body = request.body body = request.body
@ -135,22 +150,28 @@ class GlacierResponse(BaseResponse):
headers["Location"] = f"/{account_id}/vaults/{vault_name}/jobs/{job_id}" headers["Location"] = f"/{account_id}/vaults/{vault_name}/jobs/{job_id}"
return 202, headers, "" return 202, headers, ""
def vault_jobs_individual_response(self, request, full_url, headers): def vault_jobs_individual_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
return self._vault_jobs_individual_response(full_url, headers) return self._vault_jobs_individual_response(full_url, headers)
def _vault_jobs_individual_response(self, full_url, headers): def _vault_jobs_individual_response(
self, full_url: str, headers: Any
) -> TYPE_RESPONSE:
vault_name = full_url.split("/")[-3] vault_name = full_url.split("/")[-3]
archive_id = full_url.split("/")[-1] archive_id = full_url.split("/")[-1]
job = self.glacier_backend.describe_job(vault_name, archive_id) job = self.glacier_backend.describe_job(vault_name, archive_id)
return 200, headers, json.dumps(job.to_dict()) return 200, headers, json.dumps(job.to_dict()) # type: ignore
def vault_jobs_output_response(self, request, full_url, headers): def vault_jobs_output_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
return self._vault_jobs_output_response(full_url, headers) return self._vault_jobs_output_response(full_url, headers)
def _vault_jobs_output_response(self, full_url, headers): def _vault_jobs_output_response(self, full_url: str, headers: Any) -> TYPE_RESPONSE:
vault_name = full_url.split("/")[-4] vault_name = full_url.split("/")[-4]
job_id = full_url.split("/")[-2] job_id = full_url.split("/")[-2]
output = self.glacier_backend.get_job_output(vault_name, job_id) output = self.glacier_backend.get_job_output(vault_name, job_id)

View File

@ -2,11 +2,11 @@ from moto.moto_api._internal import mock_random as random
import string import string
def vault_from_glacier_url(full_url): def vault_from_glacier_url(full_url: str) -> str:
return full_url.split("/")[-1] return full_url.split("/")[-1]
def get_job_id(): def get_job_id() -> str:
return "".join( return "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(92) random.choice(string.ascii_uppercase + string.digits) for _ in range(92)
) )

View File

@ -1,3 +1,4 @@
from typing import Optional
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
@ -6,72 +7,78 @@ class GlueClientError(JsonRESTError):
class AlreadyExistsException(GlueClientError): class AlreadyExistsException(GlueClientError):
def __init__(self, typ): def __init__(self, typ: str):
super().__init__("AlreadyExistsException", f"{typ} already exists.") super().__init__("AlreadyExistsException", f"{typ} already exists.")
class DatabaseAlreadyExistsException(AlreadyExistsException): class DatabaseAlreadyExistsException(AlreadyExistsException):
def __init__(self): def __init__(self) -> None:
super().__init__("Database") super().__init__("Database")
class TableAlreadyExistsException(AlreadyExistsException): class TableAlreadyExistsException(AlreadyExistsException):
def __init__(self): def __init__(self) -> None:
super().__init__("Table") super().__init__("Table")
class PartitionAlreadyExistsException(AlreadyExistsException): class PartitionAlreadyExistsException(AlreadyExistsException):
def __init__(self): def __init__(self) -> None:
super().__init__("Partition") super().__init__("Partition")
class CrawlerAlreadyExistsException(AlreadyExistsException): class CrawlerAlreadyExistsException(AlreadyExistsException):
def __init__(self): def __init__(self) -> None:
super().__init__("Crawler") super().__init__("Crawler")
class EntityNotFoundException(GlueClientError): class EntityNotFoundException(GlueClientError):
def __init__(self, msg): def __init__(self, msg: str):
super().__init__("EntityNotFoundException", msg) super().__init__("EntityNotFoundException", msg)
class DatabaseNotFoundException(EntityNotFoundException): class DatabaseNotFoundException(EntityNotFoundException):
def __init__(self, db): def __init__(self, db: str):
super().__init__(f"Database {db} not found.") super().__init__(f"Database {db} not found.")
class TableNotFoundException(EntityNotFoundException): class TableNotFoundException(EntityNotFoundException):
def __init__(self, tbl): def __init__(self, tbl: str):
super().__init__(f"Table {tbl} not found.") super().__init__(f"Table {tbl} not found.")
class PartitionNotFoundException(EntityNotFoundException): class PartitionNotFoundException(EntityNotFoundException):
def __init__(self): def __init__(self) -> None:
super().__init__("Cannot find partition.") super().__init__("Cannot find partition.")
class CrawlerNotFoundException(EntityNotFoundException): class CrawlerNotFoundException(EntityNotFoundException):
def __init__(self, crawler): def __init__(self, crawler: str):
super().__init__(f"Crawler {crawler} not found.") super().__init__(f"Crawler {crawler} not found.")
class JobNotFoundException(EntityNotFoundException): class JobNotFoundException(EntityNotFoundException):
def __init__(self, job): def __init__(self, job: str):
super().__init__(f"Job {job} not found.") super().__init__(f"Job {job} not found.")
class JobRunNotFoundException(EntityNotFoundException): class JobRunNotFoundException(EntityNotFoundException):
def __init__(self, job_run): def __init__(self, job_run: str):
super().__init__(f"Job run {job_run} not found.") super().__init__(f"Job run {job_run} not found.")
class VersionNotFoundException(EntityNotFoundException): class VersionNotFoundException(EntityNotFoundException):
def __init__(self): def __init__(self) -> None:
super().__init__("Version not found.") super().__init__("Version not found.")
class SchemaNotFoundException(EntityNotFoundException): class SchemaNotFoundException(EntityNotFoundException):
def __init__(self, schema_name, registry_name, schema_arn, null="null"): def __init__(
self,
schema_name: str,
registry_name: str,
schema_arn: Optional[str],
null: str = "null",
):
super().__init__( super().__init__(
f"Schema is not found. RegistryName: {registry_name if registry_name else null}, SchemaName: {schema_name if schema_name else null}, SchemaArn: {schema_arn if schema_arn else null}", f"Schema is not found. RegistryName: {registry_name if registry_name else null}, SchemaName: {schema_name if schema_name else null}, SchemaArn: {schema_arn if schema_arn else null}",
) )
@ -80,13 +87,13 @@ class SchemaNotFoundException(EntityNotFoundException):
class SchemaVersionNotFoundFromSchemaIdException(EntityNotFoundException): class SchemaVersionNotFoundFromSchemaIdException(EntityNotFoundException):
def __init__( def __init__(
self, self,
registry_name, registry_name: Optional[str],
schema_name, schema_name: Optional[str],
schema_arn, schema_arn: Optional[str],
version_number, version_number: Optional[str],
latest_version, latest_version: Optional[str],
null="null", null: str = "null",
false="false", false: str = "false",
): ):
super().__init__( super().__init__(
f"Schema version is not found. RegistryName: {registry_name if registry_name else null}, SchemaName: {schema_name if schema_name else null}, SchemaArn: {schema_arn if schema_arn else null}, VersionNumber: {version_number if version_number else null}, isLatestVersion: {latest_version if latest_version else false}", f"Schema version is not found. RegistryName: {registry_name if registry_name else null}, SchemaName: {schema_name if schema_name else null}, SchemaArn: {schema_arn if schema_arn else null}, VersionNumber: {version_number if version_number else null}, isLatestVersion: {latest_version if latest_version else false}",
@ -94,36 +101,36 @@ class SchemaVersionNotFoundFromSchemaIdException(EntityNotFoundException):
class SchemaVersionNotFoundFromSchemaVersionIdException(EntityNotFoundException): class SchemaVersionNotFoundFromSchemaVersionIdException(EntityNotFoundException):
def __init__(self, schema_version_id): def __init__(self, schema_version_id: str):
super().__init__( super().__init__(
f"Schema version is not found. SchemaVersionId: {schema_version_id}", f"Schema version is not found. SchemaVersionId: {schema_version_id}",
) )
class RegistryNotFoundException(EntityNotFoundException): class RegistryNotFoundException(EntityNotFoundException):
def __init__(self, resource, param_name, param_value): def __init__(self, resource: str, param_name: str, param_value: Optional[str]):
super().__init__( super().__init__(
resource + " is not found. " + param_name + ": " + param_value, resource + " is not found. " + param_name + ": " + param_value, # type: ignore
) )
class CrawlerRunningException(GlueClientError): class CrawlerRunningException(GlueClientError):
def __init__(self, msg): def __init__(self, msg: str):
super().__init__("CrawlerRunningException", msg) super().__init__("CrawlerRunningException", msg)
class CrawlerNotRunningException(GlueClientError): class CrawlerNotRunningException(GlueClientError):
def __init__(self, msg): def __init__(self, msg: str):
super().__init__("CrawlerNotRunningException", msg) super().__init__("CrawlerNotRunningException", msg)
class ConcurrentRunsExceededException(GlueClientError): class ConcurrentRunsExceededException(GlueClientError):
def __init__(self, msg): def __init__(self, msg: str):
super().__init__("ConcurrentRunsExceededException", msg) super().__init__("ConcurrentRunsExceededException", msg)
class ResourceNumberLimitExceededException(GlueClientError): class ResourceNumberLimitExceededException(GlueClientError):
def __init__(self, msg): def __init__(self, msg: str):
super().__init__( super().__init__(
"ResourceNumberLimitExceededException", "ResourceNumberLimitExceededException",
msg, msg,
@ -131,7 +138,7 @@ class ResourceNumberLimitExceededException(GlueClientError):
class GeneralResourceNumberLimitExceededException(ResourceNumberLimitExceededException): class GeneralResourceNumberLimitExceededException(ResourceNumberLimitExceededException):
def __init__(self, resource): def __init__(self, resource: str):
super().__init__( super().__init__(
"More " "More "
+ resource + resource
@ -140,14 +147,14 @@ class GeneralResourceNumberLimitExceededException(ResourceNumberLimitExceededExc
class SchemaVersionMetadataLimitExceededException(ResourceNumberLimitExceededException): class SchemaVersionMetadataLimitExceededException(ResourceNumberLimitExceededException):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"Your resource limits for Schema Version Metadata have been exceeded.", "Your resource limits for Schema Version Metadata have been exceeded.",
) )
class GSRAlreadyExistsException(GlueClientError): class GSRAlreadyExistsException(GlueClientError):
def __init__(self, msg): def __init__(self, msg: str):
super().__init__( super().__init__(
"AlreadyExistsException", "AlreadyExistsException",
msg, msg,
@ -155,21 +162,21 @@ class GSRAlreadyExistsException(GlueClientError):
class SchemaVersionMetadataAlreadyExistsException(GSRAlreadyExistsException): class SchemaVersionMetadataAlreadyExistsException(GSRAlreadyExistsException):
def __init__(self, schema_version_id, metadata_key, metadata_value): def __init__(self, schema_version_id: str, metadata_key: str, metadata_value: str):
super().__init__( super().__init__(
f"Resource already exist for schema version id: {schema_version_id}, metadata key: {metadata_key}, metadata value: {metadata_value}", f"Resource already exist for schema version id: {schema_version_id}, metadata key: {metadata_key}, metadata value: {metadata_value}",
) )
class GeneralGSRAlreadyExistsException(GSRAlreadyExistsException): class GeneralGSRAlreadyExistsException(GSRAlreadyExistsException):
def __init__(self, resource, param_name, param_value): def __init__(self, resource: str, param_name: str, param_value: str):
super().__init__( super().__init__(
resource + " already exists. " + param_name + ": " + param_value, resource + " already exists. " + param_name + ": " + param_value,
) )
class _InvalidOperationException(GlueClientError): class _InvalidOperationException(GlueClientError):
def __init__(self, error_type, op, msg): def __init__(self, error_type: str, op: str, msg: str):
super().__init__( super().__init__(
error_type, error_type,
"An error occurred (%s) when calling the %s operation: %s" "An error occurred (%s) when calling the %s operation: %s"
@ -178,22 +185,22 @@ class _InvalidOperationException(GlueClientError):
class InvalidStateException(_InvalidOperationException): class InvalidStateException(_InvalidOperationException):
def __init__(self, op, msg): def __init__(self, op: str, msg: str):
super().__init__("InvalidStateException", op, msg) super().__init__("InvalidStateException", op, msg)
class InvalidInputException(_InvalidOperationException): class InvalidInputException(_InvalidOperationException):
def __init__(self, op, msg): def __init__(self, op: str, msg: str):
super().__init__("InvalidInputException", op, msg) super().__init__("InvalidInputException", op, msg)
class GSRInvalidInputException(GlueClientError): class GSRInvalidInputException(GlueClientError):
def __init__(self, msg): def __init__(self, msg: str):
super().__init__("InvalidInputException", msg) super().__init__("InvalidInputException", msg)
class ResourceNameTooLongException(GSRInvalidInputException): class ResourceNameTooLongException(GSRInvalidInputException):
def __init__(self, param_name): def __init__(self, param_name: str):
super().__init__( super().__init__(
"The resource name contains too many or too few characters. Parameter Name: " "The resource name contains too many or too few characters. Parameter Name: "
+ param_name, + param_name,
@ -201,7 +208,7 @@ class ResourceNameTooLongException(GSRInvalidInputException):
class ParamValueContainsInvalidCharactersException(GSRInvalidInputException): class ParamValueContainsInvalidCharactersException(GSRInvalidInputException):
def __init__(self, param_name): def __init__(self, param_name: str):
super().__init__( super().__init__(
"The parameter value contains one or more characters that are not valid. Parameter Name: " "The parameter value contains one or more characters that are not valid. Parameter Name: "
+ param_name, + param_name,
@ -209,28 +216,28 @@ class ParamValueContainsInvalidCharactersException(GSRInvalidInputException):
class InvalidNumberOfTagsException(GSRInvalidInputException): class InvalidNumberOfTagsException(GSRInvalidInputException):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"New Tags cannot be empty or more than 50", "New Tags cannot be empty or more than 50",
) )
class InvalidDataFormatException(GSRInvalidInputException): class InvalidDataFormatException(GSRInvalidInputException):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"Data format is not valid.", "Data format is not valid.",
) )
class InvalidCompatibilityException(GSRInvalidInputException): class InvalidCompatibilityException(GSRInvalidInputException):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"Compatibility is not valid.", "Compatibility is not valid.",
) )
class InvalidSchemaDefinitionException(GSRInvalidInputException): class InvalidSchemaDefinitionException(GSRInvalidInputException):
def __init__(self, data_format_name, err): def __init__(self, data_format_name: str, err: ValueError):
super().__init__( super().__init__(
"Schema definition of " "Schema definition of "
+ data_format_name + data_format_name
@ -240,45 +247,51 @@ class InvalidSchemaDefinitionException(GSRInvalidInputException):
class InvalidRegistryIdBothParamsProvidedException(GSRInvalidInputException): class InvalidRegistryIdBothParamsProvidedException(GSRInvalidInputException):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"One of registryName or registryArn has to be provided, both cannot be provided.", "One of registryName or registryArn has to be provided, both cannot be provided.",
) )
class InvalidSchemaIdBothParamsProvidedException(GSRInvalidInputException): class InvalidSchemaIdBothParamsProvidedException(GSRInvalidInputException):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"One of (registryName and schemaName) or schemaArn has to be provided, both cannot be provided.", "One of (registryName and schemaName) or schemaArn has to be provided, both cannot be provided.",
) )
class InvalidSchemaIdNotProvidedException(GSRInvalidInputException): class InvalidSchemaIdNotProvidedException(GSRInvalidInputException):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"At least one of (registryName and schemaName) or schemaArn has to be provided.", "At least one of (registryName and schemaName) or schemaArn has to be provided.",
) )
class InvalidSchemaVersionNumberBothParamsProvidedException(GSRInvalidInputException): class InvalidSchemaVersionNumberBothParamsProvidedException(GSRInvalidInputException):
def __init__(self): def __init__(self) -> None:
super().__init__("Only one of VersionNumber or LatestVersion is required.") super().__init__("Only one of VersionNumber or LatestVersion is required.")
class InvalidSchemaVersionNumberNotProvidedException(GSRInvalidInputException): class InvalidSchemaVersionNumberNotProvidedException(GSRInvalidInputException):
def __init__(self): def __init__(self) -> None:
super().__init__("One of version number (or) latest version is required.") super().__init__("One of version number (or) latest version is required.")
class InvalidSchemaVersionIdProvidedWithOtherParamsException(GSRInvalidInputException): class InvalidSchemaVersionIdProvidedWithOtherParamsException(GSRInvalidInputException):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"No other input parameters can be specified when fetching by SchemaVersionId." "No other input parameters can be specified when fetching by SchemaVersionId."
) )
class DisabledCompatibilityVersioningException(GSRInvalidInputException): class DisabledCompatibilityVersioningException(GSRInvalidInputException):
def __init__(self, schema_name, registry_name, schema_arn, null="null"): def __init__(
self,
schema_name: str,
registry_name: str,
schema_arn: Optional[str],
null: str = "null",
):
super().__init__( super().__init__(
f"Compatibility DISABLED does not allow versioning. SchemaId: SchemaId(schemaArn={schema_arn if schema_arn else null}, schemaName={schema_name if schema_name else null}, registryName={registry_name if registry_name else null})" f"Compatibility DISABLED does not allow versioning. SchemaId: SchemaId(schemaArn={schema_arn if schema_arn else null}, schemaName={schema_name if schema_name else null}, registryName={registry_name if registry_name else null})"
) )

View File

@ -1,5 +1,6 @@
import re import re
import json import json
from typing import Any, Dict, Optional, Tuple, Pattern
from .glue_schema_registry_constants import ( from .glue_schema_registry_constants import (
MAX_REGISTRY_NAME_LENGTH, MAX_REGISTRY_NAME_LENGTH,
@ -53,7 +54,7 @@ from .exceptions import (
) )
def validate_registry_name_pattern_and_length(param_value): def validate_registry_name_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length( validate_param_pattern_and_length(
param_value, param_value,
param_name="registryName", param_name="registryName",
@ -62,7 +63,7 @@ def validate_registry_name_pattern_and_length(param_value):
) )
def validate_arn_pattern_and_length(param_value): def validate_arn_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length( validate_param_pattern_and_length(
param_value, param_value,
param_name="registryArn", param_name="registryArn",
@ -71,7 +72,7 @@ def validate_arn_pattern_and_length(param_value):
) )
def validate_description_pattern_and_length(param_value): def validate_description_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length( validate_param_pattern_and_length(
param_value, param_value,
param_name="description", param_name="description",
@ -80,7 +81,7 @@ def validate_description_pattern_and_length(param_value):
) )
def validate_schema_name_pattern_and_length(param_value): def validate_schema_name_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length( validate_param_pattern_and_length(
param_value, param_value,
param_name="schemaName", param_name="schemaName",
@ -89,7 +90,7 @@ def validate_schema_name_pattern_and_length(param_value):
) )
def validate_schema_version_metadata_key_pattern_and_length(param_value): def validate_schema_version_metadata_key_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length( validate_param_pattern_and_length(
param_value, param_value,
param_name="key", param_name="key",
@ -98,7 +99,7 @@ def validate_schema_version_metadata_key_pattern_and_length(param_value):
) )
def validate_schema_version_metadata_value_pattern_and_length(param_value): def validate_schema_version_metadata_value_pattern_and_length(param_value: str) -> None:
validate_param_pattern_and_length( validate_param_pattern_and_length(
param_value, param_value,
param_name="value", param_name="value",
@ -108,8 +109,8 @@ def validate_schema_version_metadata_value_pattern_and_length(param_value):
def validate_param_pattern_and_length( def validate_param_pattern_and_length(
param_value, param_name, max_name_length, pattern param_value: str, param_name: str, max_name_length: int, pattern: Pattern[str]
): ) -> None:
if len(param_value.encode("utf-8")) > max_name_length: if len(param_value.encode("utf-8")) > max_name_length:
raise ResourceNameTooLongException(param_name) raise ResourceNameTooLongException(param_name)
@ -117,7 +118,7 @@ def validate_param_pattern_and_length(
raise ParamValueContainsInvalidCharactersException(param_name) raise ParamValueContainsInvalidCharactersException(param_name)
def validate_schema_definition(schema_definition, data_format): def validate_schema_definition(schema_definition: str, data_format: str) -> None:
validate_schema_definition_length(schema_definition) validate_schema_definition_length(schema_definition)
if data_format in ["AVRO", "JSON"]: if data_format in ["AVRO", "JSON"]:
try: try:
@ -126,38 +127,39 @@ def validate_schema_definition(schema_definition, data_format):
raise InvalidSchemaDefinitionException(data_format, err) raise InvalidSchemaDefinitionException(data_format, err)
def validate_schema_definition_length(schema_definition): def validate_schema_definition_length(schema_definition: str) -> None:
if len(schema_definition) > MAX_SCHEMA_DEFINITION_LENGTH: if len(schema_definition) > MAX_SCHEMA_DEFINITION_LENGTH:
param_name = SCHEMA_DEFINITION param_name = SCHEMA_DEFINITION
raise ResourceNameTooLongException(param_name) raise ResourceNameTooLongException(param_name)
def validate_schema_version_id_pattern(schema_version_id): def validate_schema_version_id_pattern(schema_version_id: str) -> None:
if re.match(SCHEMA_VERSION_ID_PATTERN, schema_version_id) is None: if re.match(SCHEMA_VERSION_ID_PATTERN, schema_version_id) is None:
raise ParamValueContainsInvalidCharactersException(SCHEMA_VERSION_ID) raise ParamValueContainsInvalidCharactersException(SCHEMA_VERSION_ID)
def validate_number_of_tags(tags): def validate_number_of_tags(tags: Dict[str, str]) -> None:
if len(tags) > MAX_TAGS_ALLOWED: if len(tags) > MAX_TAGS_ALLOWED:
raise InvalidNumberOfTagsException() raise InvalidNumberOfTagsException()
def validate_registry_id(registry_id, registries): def validate_registry_id(
registry_id: Dict[str, Any], registries: Dict[str, Any]
) -> str:
if not registry_id: if not registry_id:
registry_name = DEFAULT_REGISTRY_NAME return DEFAULT_REGISTRY_NAME
return registry_name
if registry_id.get(REGISTRY_NAME) and registry_id.get(REGISTRY_ARN): if registry_id.get(REGISTRY_NAME) and registry_id.get(REGISTRY_ARN):
raise InvalidRegistryIdBothParamsProvidedException() raise InvalidRegistryIdBothParamsProvidedException()
if registry_id.get(REGISTRY_NAME): if registry_id.get(REGISTRY_NAME):
registry_name = registry_id.get(REGISTRY_NAME) registry_name = registry_id.get(REGISTRY_NAME)
validate_registry_name_pattern_and_length(registry_name) validate_registry_name_pattern_and_length(registry_name) # type: ignore
elif registry_id.get(REGISTRY_ARN): elif registry_id.get(REGISTRY_ARN):
registry_arn = registry_id.get(REGISTRY_ARN) registry_arn = registry_id.get(REGISTRY_ARN)
validate_arn_pattern_and_length(registry_arn) validate_arn_pattern_and_length(registry_arn) # type: ignore
registry_name = registry_arn.split("/")[-1] registry_name = registry_arn.split("/")[-1] # type: ignore
if registry_name != DEFAULT_REGISTRY_NAME and registry_name not in registries: if registry_name != DEFAULT_REGISTRY_NAME and registry_name not in registries:
if registry_id.get(REGISTRY_NAME): if registry_id.get(REGISTRY_NAME):
@ -174,10 +176,15 @@ def validate_registry_id(registry_id, registries):
param_value=registry_arn, param_value=registry_arn,
) )
return registry_name return registry_name # type: ignore
def validate_registry_params(registries, registry_name, description=None, tags=None): def validate_registry_params(
registries: Any,
registry_name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> None:
validate_registry_name_pattern_and_length(registry_name) validate_registry_name_pattern_and_length(registry_name)
if description: if description:
@ -197,7 +204,9 @@ def validate_registry_params(registries, registry_name, description=None, tags=N
) )
def validate_schema_id(schema_id, registries): def validate_schema_id(
schema_id: Dict[str, str], registries: Dict[str, Any]
) -> Tuple[str, str, Optional[str]]:
schema_arn = schema_id.get(SCHEMA_ARN) schema_arn = schema_id.get(SCHEMA_ARN)
registry_name = schema_id.get(REGISTRY_NAME) registry_name = schema_id.get(REGISTRY_NAME)
schema_name = schema_id.get(SCHEMA_NAME) schema_name = schema_id.get(SCHEMA_NAME)
@ -225,15 +234,15 @@ def validate_schema_id(schema_id, registries):
def validate_schema_params( def validate_schema_params(
registry, registry: Any,
schema_name, schema_name: str,
data_format, data_format: str,
compatibility, compatibility: str,
schema_definition, schema_definition: str,
num_schemas, num_schemas: int,
description=None, description: Optional[str] = None,
tags=None, tags: Optional[Dict[str, str]] = None,
): ) -> None:
validate_schema_name_pattern_and_length(schema_name) validate_schema_name_pattern_and_length(schema_name)
if data_format not in ["AVRO", "JSON", "PROTOBUF"]: if data_format not in ["AVRO", "JSON", "PROTOBUF"]:
@ -271,14 +280,14 @@ def validate_schema_params(
def validate_register_schema_version_params( def validate_register_schema_version_params(
registry_name, registry_name: str,
schema_name, schema_name: str,
schema_arn, schema_arn: Optional[str],
num_schema_versions, num_schema_versions: int,
schema_definition, schema_definition: str,
compatibility, compatibility: str,
data_format, data_format: str,
): ) -> None:
if compatibility == "DISABLED": if compatibility == "DISABLED":
raise DisabledCompatibilityVersioningException( raise DisabledCompatibilityVersioningException(
schema_name, registry_name, schema_arn schema_name, registry_name, schema_arn
@ -290,9 +299,19 @@ def validate_register_schema_version_params(
raise GeneralResourceNumberLimitExceededException(resource="schema versions") raise GeneralResourceNumberLimitExceededException(resource="schema versions")
def validate_schema_version_params( def validate_schema_version_params( # type: ignore[return]
registries, schema_id, schema_version_id, schema_version_number registries: Dict[str, Any],
): schema_id: Optional[Dict[str, Any]],
schema_version_id: Optional[str],
schema_version_number: Optional[Dict[str, Any]],
) -> Tuple[
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
]:
if not schema_version_id and not schema_id and not schema_version_number: if not schema_version_id and not schema_id and not schema_version_number:
raise InvalidSchemaIdNotProvidedException() raise InvalidSchemaIdNotProvidedException()
@ -329,8 +348,11 @@ def validate_schema_version_params(
def validate_schema_version_number( def validate_schema_version_number(
registries, registry_name, schema_name, schema_version_number registries: Dict[str, Any],
): registry_name: str,
schema_name: str,
schema_version_number: Dict[str, str],
) -> Tuple[str, str]:
latest_version = schema_version_number.get(LATEST_VERSION) latest_version = schema_version_number.get(LATEST_VERSION)
version_number = schema_version_number.get(VERSION_NUMBER) version_number = schema_version_number.get(VERSION_NUMBER)
schema = registries[registry_name].schemas[schema_name] schema = registries[registry_name].schemas[schema_name]
@ -339,20 +361,24 @@ def validate_schema_version_number(
raise InvalidSchemaVersionNumberBothParamsProvidedException() raise InvalidSchemaVersionNumberBothParamsProvidedException()
return schema.latest_schema_version, latest_version return schema.latest_schema_version, latest_version
return version_number, latest_version return version_number, latest_version # type: ignore
def validate_schema_version_metadata_pattern_and_length(metadata_key_value): def validate_schema_version_metadata_pattern_and_length(
metadata_key_value: Dict[str, str]
) -> Tuple[str, str]:
metadata_key = metadata_key_value.get(METADATA_KEY) metadata_key = metadata_key_value.get(METADATA_KEY)
metadata_value = metadata_key_value.get(METADATA_VALUE) metadata_value = metadata_key_value.get(METADATA_VALUE)
validate_schema_version_metadata_key_pattern_and_length(metadata_key) validate_schema_version_metadata_key_pattern_and_length(metadata_key) # type: ignore
validate_schema_version_metadata_value_pattern_and_length(metadata_value) validate_schema_version_metadata_value_pattern_and_length(metadata_value) # type: ignore
return metadata_key, metadata_value return metadata_key, metadata_value # type: ignore[return-value]
def validate_number_of_schema_version_metadata_allowed(metadata): def validate_number_of_schema_version_metadata_allowed(
metadata: Dict[str, Any]
) -> None:
num_metadata_key_value_pairs = 0 num_metadata_key_value_pairs = 0
for m in metadata.values(): for m in metadata.values():
num_metadata_key_value_pairs += len(m) num_metadata_key_value_pairs += len(m)
@ -362,8 +388,8 @@ def validate_number_of_schema_version_metadata_allowed(metadata):
def get_schema_version_if_definition_exists( def get_schema_version_if_definition_exists(
schema_versions, data_format, schema_definition schema_versions: Any, data_format: str, schema_definition: str
): ) -> Optional[Dict[str, Any]]:
if data_format in ["AVRO", "JSON"]: if data_format in ["AVRO", "JSON"]:
for schema_version in schema_versions: for schema_version in schema_versions:
if json.loads(schema_definition) == json.loads( if json.loads(schema_definition) == json.loads(
@ -378,9 +404,12 @@ def get_schema_version_if_definition_exists(
def get_put_schema_version_metadata_response( def get_put_schema_version_metadata_response(
schema_id, schema_version_number, schema_version_id, metadata_key_value schema_id: Dict[str, Any],
): schema_version_number: Optional[Dict[str, str]],
put_schema_version_metadata_response_dict = {} schema_version_id: str,
metadata_key_value: Dict[str, str],
) -> Dict[str, Any]:
put_schema_version_metadata_response_dict: Dict[str, Any] = {}
if schema_version_id: if schema_version_id:
put_schema_version_metadata_response_dict[SCHEMA_VERSION_ID] = schema_version_id put_schema_version_metadata_response_dict[SCHEMA_VERSION_ID] = schema_version_id
if schema_id: if schema_id:
@ -416,7 +445,9 @@ def get_put_schema_version_metadata_response(
return put_schema_version_metadata_response_dict return put_schema_version_metadata_response_dict
def delete_schema_response(schema_name, schema_arn, status): def delete_schema_response(
schema_name: str, schema_arn: str, status: str
) -> Dict[str, Any]:
return { return {
"SchemaName": schema_name, "SchemaName": schema_name,
"SchemaArn": schema_arn, "SchemaArn": schema_arn,

View File

@ -3,7 +3,7 @@ import time
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime from datetime import datetime
import re import re
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api import state_manager from moto.moto_api import state_manager
@ -77,12 +77,12 @@ class GlueBackend(BaseBackend):
}, },
} }
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.databases = OrderedDict() self.databases: Dict[str, FakeDatabase] = OrderedDict()
self.crawlers = OrderedDict() self.crawlers: Dict[str, FakeCrawler] = OrderedDict()
self.jobs = OrderedDict() self.jobs: Dict[str, FakeJob] = OrderedDict()
self.job_runs = OrderedDict() self.job_runs: Dict[str, FakeJobRun] = OrderedDict()
self.tagger = TaggingService() self.tagger = TaggingService()
self.registries: Dict[str, FakeRegistry] = OrderedDict() self.registries: Dict[str, FakeRegistry] = OrderedDict()
self.num_schemas = 0 self.num_schemas = 0
@ -93,13 +93,17 @@ class GlueBackend(BaseBackend):
) )
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "glue" service_region, zones, "glue"
) )
def create_database(self, database_name, database_input): def create_database(
self, database_name: str, database_input: Dict[str, Any]
) -> "FakeDatabase":
if database_name in self.databases: if database_name in self.databases:
raise DatabaseAlreadyExistsException() raise DatabaseAlreadyExistsException()
@ -107,27 +111,31 @@ class GlueBackend(BaseBackend):
self.databases[database_name] = database self.databases[database_name] = database
return database return database
def get_database(self, database_name): def get_database(self, database_name: str) -> "FakeDatabase":
try: try:
return self.databases[database_name] return self.databases[database_name]
except KeyError: except KeyError:
raise DatabaseNotFoundException(database_name) raise DatabaseNotFoundException(database_name)
def update_database(self, database_name, database_input): def update_database(
self, database_name: str, database_input: Dict[str, Any]
) -> None:
if database_name not in self.databases: if database_name not in self.databases:
raise DatabaseNotFoundException(database_name) raise DatabaseNotFoundException(database_name)
self.databases[database_name].input = database_input self.databases[database_name].input = database_input
def get_databases(self): def get_databases(self) -> List["FakeDatabase"]:
return [self.databases[key] for key in self.databases] if self.databases else [] return [self.databases[key] for key in self.databases] if self.databases else []
def delete_database(self, database_name): def delete_database(self, database_name: str) -> None:
if database_name not in self.databases: if database_name not in self.databases:
raise DatabaseNotFoundException(database_name) raise DatabaseNotFoundException(database_name)
del self.databases[database_name] del self.databases[database_name]
def create_table(self, database_name, table_name, table_input): def create_table(
self, database_name: str, table_name: str, table_input: Dict[str, Any]
) -> "FakeTable":
database = self.get_database(database_name) database = self.get_database(database_name)
if table_name in database.tables: if table_name in database.tables:
@ -144,7 +152,9 @@ class GlueBackend(BaseBackend):
except KeyError: except KeyError:
raise TableNotFoundException(table_name) raise TableNotFoundException(table_name)
def get_tables(self, database_name, expression): def get_tables(
self, database_name: str, expression: Optional[str]
) -> List["FakeTable"]:
database = self.get_database(database_name) database = self.get_database(database_name)
if expression: if expression:
# sanitise expression, * is treated as a glob-like wildcard # sanitise expression, * is treated as a glob-like wildcard
@ -164,15 +174,16 @@ class GlueBackend(BaseBackend):
else: else:
return [table for table_name, table in database.tables.items()] return [table for table_name, table in database.tables.items()]
def delete_table(self, database_name, table_name): def delete_table(self, database_name: str, table_name: str) -> None:
database = self.get_database(database_name) database = self.get_database(database_name)
try: try:
del database.tables[table_name] del database.tables[table_name]
except KeyError: except KeyError:
raise TableNotFoundException(table_name) raise TableNotFoundException(table_name)
return {}
def update_table(self, database_name, table_name: str, table_input) -> None: def update_table(
self, database_name: str, table_name: str, table_input: Dict[str, Any]
) -> None:
table = self.get_table(database_name, table_name) table = self.get_table(database_name, table_name)
table.update(table_input) table.update(table_input)
@ -202,15 +213,21 @@ class GlueBackend(BaseBackend):
table = self.get_table(database_name, table_name) table = self.get_table(database_name, table_name)
table.delete_version(version_id) table.delete_version(version_id)
def create_partition(self, database_name: str, table_name: str, part_input) -> None: def create_partition(
self, database_name: str, table_name: str, part_input: Dict[str, Any]
) -> None:
table = self.get_table(database_name, table_name) table = self.get_table(database_name, table_name)
table.create_partition(part_input) table.create_partition(part_input)
def get_partition(self, database_name: str, table_name: str, values): def get_partition(
self, database_name: str, table_name: str, values: str
) -> "FakePartition":
table = self.get_table(database_name, table_name) table = self.get_table(database_name, table_name)
return table.get_partition(values) return table.get_partition(values)
def get_partitions(self, database_name, table_name, expression): def get_partitions(
self, database_name: str, table_name: str, expression: str
) -> List["FakePartition"]:
""" """
See https://docs.aws.amazon.com/glue/latest/webapi/API_GetPartitions.html See https://docs.aws.amazon.com/glue/latest/webapi/API_GetPartitions.html
for supported expressions. for supported expressions.
@ -226,34 +243,38 @@ class GlueBackend(BaseBackend):
return table.get_partitions(expression) return table.get_partitions(expression)
def update_partition( def update_partition(
self, database_name, table_name, part_input, part_to_update self,
database_name: str,
table_name: str,
part_input: Dict[str, Any],
part_to_update: str,
) -> None: ) -> None:
table = self.get_table(database_name, table_name) table = self.get_table(database_name, table_name)
table.update_partition(part_to_update, part_input) table.update_partition(part_to_update, part_input)
def delete_partition( def delete_partition(
self, database_name: str, table_name: str, part_to_delete self, database_name: str, table_name: str, part_to_delete: int
) -> None: ) -> None:
table = self.get_table(database_name, table_name) table = self.get_table(database_name, table_name)
table.delete_partition(part_to_delete) table.delete_partition(part_to_delete)
def create_crawler( def create_crawler(
self, self,
name, name: str,
role, role: str,
database_name, database_name: str,
description, description: str,
targets, targets: Dict[str, Any],
schedule, schedule: str,
classifiers, classifiers: List[str],
table_prefix, table_prefix: str,
schema_change_policy, schema_change_policy: Dict[str, str],
recrawl_policy, recrawl_policy: Dict[str, str],
lineage_configuration, lineage_configuration: Dict[str, str],
configuration, configuration: str,
crawler_security_configuration, crawler_security_configuration: str,
tags, tags: Dict[str, str],
): ) -> None:
if name in self.crawlers: if name in self.crawlers:
raise CrawlerAlreadyExistsException() raise CrawlerAlreadyExistsException()
@ -276,28 +297,28 @@ class GlueBackend(BaseBackend):
) )
self.crawlers[name] = crawler self.crawlers[name] = crawler
def get_crawler(self, name): def get_crawler(self, name: str) -> "FakeCrawler":
try: try:
return self.crawlers[name] return self.crawlers[name]
except KeyError: except KeyError:
raise CrawlerNotFoundException(name) raise CrawlerNotFoundException(name)
def get_crawlers(self): def get_crawlers(self) -> List["FakeCrawler"]:
return [self.crawlers[key] for key in self.crawlers] if self.crawlers else [] return [self.crawlers[key] for key in self.crawlers] if self.crawlers else []
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_crawlers(self): def list_crawlers(self) -> List["FakeCrawler"]: # type: ignore[misc]
return [crawler for _, crawler in self.crawlers.items()] return [crawler for _, crawler in self.crawlers.items()]
def start_crawler(self, name): def start_crawler(self, name: str) -> None:
crawler = self.get_crawler(name) crawler = self.get_crawler(name)
crawler.start_crawler() crawler.start_crawler()
def stop_crawler(self, name): def stop_crawler(self, name: str) -> None:
crawler = self.get_crawler(name) crawler = self.get_crawler(name)
crawler.stop_crawler() crawler.stop_crawler()
def delete_crawler(self, name): def delete_crawler(self, name: str) -> None:
try: try:
del self.crawlers[name] del self.crawlers[name]
except KeyError: except KeyError:
@ -305,29 +326,29 @@ class GlueBackend(BaseBackend):
def create_job( def create_job(
self, self,
name, name: str,
role, role: str,
command, command: str,
description, description: str,
log_uri, log_uri: str,
execution_property, execution_property: Dict[str, int],
default_arguments, default_arguments: Dict[str, str],
non_overridable_arguments, non_overridable_arguments: Dict[str, str],
connections, connections: Dict[str, List[str]],
max_retries, max_retries: int,
allocated_capacity, allocated_capacity: int,
timeout, timeout: int,
max_capacity, max_capacity: float,
security_configuration, security_configuration: str,
tags, tags: Dict[str, str],
notification_property, notification_property: Dict[str, int],
glue_version, glue_version: str,
number_of_workers, number_of_workers: int,
worker_type, worker_type: str,
code_gen_configuration_nodes, code_gen_configuration_nodes: Dict[str, Any],
execution_class, execution_class: str,
source_control_details, source_control_details: Dict[str, str],
): ) -> None:
self.jobs[name] = FakeJob( self.jobs[name] = FakeJob(
name, name,
role, role,
@ -353,46 +374,50 @@ class GlueBackend(BaseBackend):
source_control_details, source_control_details,
backend=self, backend=self,
) )
return name
def get_job(self, name): def get_job(self, name: str) -> "FakeJob":
try: try:
return self.jobs[name] return self.jobs[name]
except KeyError: except KeyError:
raise JobNotFoundException(name) raise JobNotFoundException(name)
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def get_jobs(self): def get_jobs(self) -> List["FakeJob"]: # type: ignore
return [job for _, job in self.jobs.items()] return [job for _, job in self.jobs.items()]
def start_job_run(self, name): def start_job_run(self, name: str) -> str:
job = self.get_job(name) job = self.get_job(name)
return job.start_job_run() return job.start_job_run()
def get_job_run(self, name, run_id): def get_job_run(self, name: str, run_id: str) -> "FakeJobRun":
job = self.get_job(name) job = self.get_job(name)
return job.get_job_run(run_id) return job.get_job_run(run_id)
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_jobs(self): def list_jobs(self) -> List["FakeJob"]: # type: ignore
return [job for _, job in self.jobs.items()] return [job for _, job in self.jobs.items()]
def get_tags(self, resource_id): def get_tags(self, resource_id: str) -> Dict[str, str]:
return self.tagger.get_tag_dict_for_resource(resource_id) return self.tagger.get_tag_dict_for_resource(resource_id)
def tag_resource(self, resource_arn, tags): def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None:
tags = TaggingService.convert_dict_to_tags_input(tags or {}) tag_list = TaggingService.convert_dict_to_tags_input(tags or {})
self.tagger.tag_resource(resource_arn, tags) self.tagger.tag_resource(resource_arn, tag_list)
def untag_resource(self, resource_arn, tag_keys): def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys) self.tagger.untag_resource_using_names(resource_arn, tag_keys)
def create_registry(self, registry_name, description=None, tags=None): def create_registry(
self,
registry_name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
# If registry name id default-registry, create default-registry # If registry name id default-registry, create default-registry
if registry_name == DEFAULT_REGISTRY_NAME: if registry_name == DEFAULT_REGISTRY_NAME:
registry = FakeRegistry(self, registry_name, description, tags) registry = FakeRegistry(self, registry_name, description, tags)
self.registries[registry_name] = registry self.registries[registry_name] = registry
return registry return registry # type: ignore
# Validate Registry Parameters # Validate Registry Parameters
validate_registry_params(self.registries, registry_name, description, tags) validate_registry_params(self.registries, registry_name, description, tags)
@ -401,27 +426,27 @@ class GlueBackend(BaseBackend):
self.registries[registry_name] = registry self.registries[registry_name] = registry
return registry.as_dict() return registry.as_dict()
def delete_registry(self, registry_id): def delete_registry(self, registry_id: Dict[str, Any]) -> Dict[str, Any]:
registry_name = validate_registry_id(registry_id, self.registries) registry_name = validate_registry_id(registry_id, self.registries)
return self.registries.pop(registry_name).as_dict() return self.registries.pop(registry_name).as_dict()
def get_registry(self, registry_id): def get_registry(self, registry_id: Dict[str, Any]) -> Dict[str, Any]:
registry_name = validate_registry_id(registry_id, self.registries) registry_name = validate_registry_id(registry_id, self.registries)
return self.registries[registry_name].as_dict() return self.registries[registry_name].as_dict()
def list_registries(self): def list_registries(self) -> List[Dict[str, Any]]:
return [reg.as_dict() for reg in self.registries.values()] return [reg.as_dict() for reg in self.registries.values()]
def create_schema( def create_schema(
self, self,
registry_id, registry_id: Dict[str, Any],
schema_name, schema_name: str,
data_format, data_format: str,
compatibility, compatibility: str,
schema_definition, schema_definition: str,
description=None, description: Optional[str] = None,
tags=None, tags: Optional[Dict[str, str]] = None,
): ) -> Dict[str, Any]:
""" """
The following parameters/features are not yet implemented: Glue Schema Registry: compatibility checks NONE | BACKWARD | BACKWARD_ALL | FORWARD | FORWARD_ALL | FULL | FULL_ALL and Data format parsing and syntax validation. The following parameters/features are not yet implemented: Glue Schema Registry: compatibility checks NONE | BACKWARD | BACKWARD_ALL | FORWARD | FORWARD_ALL | FULL | FULL_ALL and Data format parsing and syntax validation.
""" """
@ -479,7 +504,9 @@ class GlueBackend(BaseBackend):
resp.update({"Tags": tags}) resp.update({"Tags": tags})
return resp return resp
def register_schema_version(self, schema_id, schema_definition): def register_schema_version(
self, schema_id: Dict[str, Any], schema_definition: str
) -> Dict[str, Any]:
# Validate Schema Id # Validate Schema Id
registry_name, schema_name, schema_arn = validate_schema_id( registry_name, schema_name, schema_arn = validate_schema_id(
schema_id, self.registries schema_id, self.registries
@ -538,8 +565,11 @@ class GlueBackend(BaseBackend):
return schema_version.as_dict() return schema_version.as_dict()
def get_schema_version( def get_schema_version(
self, schema_id=None, schema_version_id=None, schema_version_number=None self,
): schema_id: Optional[Dict[str, str]] = None,
schema_version_id: Optional[str] = None,
schema_version_number: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# Validate Schema Parameters # Validate Schema Parameters
( (
schema_version_id, schema_version_id,
@ -571,10 +601,10 @@ class GlueBackend(BaseBackend):
raise SchemaVersionNotFoundFromSchemaVersionIdException(schema_version_id) raise SchemaVersionNotFoundFromSchemaVersionIdException(schema_version_id)
# GetSchemaVersion using VersionNumber # GetSchemaVersion using VersionNumber
schema = self.registries[registry_name].schemas[schema_name] schema = self.registries[registry_name].schemas[schema_name] # type: ignore
for schema_version in schema.schema_versions.values(): for schema_version in schema.schema_versions.values():
if ( if (
version_number == schema_version.version_number version_number == schema_version.version_number # type: ignore
and schema_version.schema_version_status != DELETING_STATUS and schema_version.schema_version_status != DELETING_STATUS
): ):
get_schema_version_dict = schema_version.get_schema_version_as_dict() get_schema_version_dict = schema_version.get_schema_version_as_dict()
@ -584,7 +614,9 @@ class GlueBackend(BaseBackend):
registry_name, schema_name, schema_arn, version_number, latest_version registry_name, schema_name, schema_arn, version_number, latest_version
) )
def get_schema_by_definition(self, schema_id, schema_definition): def get_schema_by_definition(
self, schema_id: Dict[str, str], schema_definition: str
) -> Dict[str, Any]:
# Validate SchemaId # Validate SchemaId
validate_schema_definition_length(schema_definition) validate_schema_definition_length(schema_definition)
registry_name, schema_name, schema_arn = validate_schema_id( registry_name, schema_name, schema_arn = validate_schema_id(
@ -606,8 +638,12 @@ class GlueBackend(BaseBackend):
raise SchemaNotFoundException(schema_name, registry_name, schema_arn) raise SchemaNotFoundException(schema_name, registry_name, schema_arn)
def put_schema_version_metadata( def put_schema_version_metadata(
self, schema_id, schema_version_number, schema_version_id, metadata_key_value self,
): schema_id: Dict[str, Any],
schema_version_number: Dict[str, str],
schema_version_id: str,
metadata_key_value: Dict[str, str],
) -> Dict[str, Any]:
# Validate metadata_key_value and schema version params # Validate metadata_key_value and schema version params
( (
metadata_key, metadata_key,
@ -620,7 +656,7 @@ class GlueBackend(BaseBackend):
schema_arn, schema_arn,
version_number, version_number,
latest_version, latest_version,
) = validate_schema_version_params( ) = validate_schema_version_params( # type: ignore
self.registries, schema_id, schema_version_id, schema_version_number self.registries, schema_id, schema_version_id, schema_version_number
) )
@ -650,9 +686,9 @@ class GlueBackend(BaseBackend):
raise SchemaVersionNotFoundFromSchemaVersionIdException(schema_version_id) raise SchemaVersionNotFoundFromSchemaVersionIdException(schema_version_id)
# PutSchemaVersionMetadata using VersionNumber # PutSchemaVersionMetadata using VersionNumber
schema = self.registries[registry_name].schemas[schema_name] schema = self.registries[registry_name].schemas[schema_name] # type: ignore
for schema_version in schema.schema_versions.values(): for schema_version in schema.schema_versions.values():
if version_number == schema_version.version_number: if version_number == schema_version.version_number: # type: ignore
validate_number_of_schema_version_metadata_allowed( validate_number_of_schema_version_metadata_allowed(
schema_version.metadata schema_version.metadata
) )
@ -677,12 +713,12 @@ class GlueBackend(BaseBackend):
registry_name, schema_name, schema_arn, version_number, latest_version registry_name, schema_name, schema_arn, version_number, latest_version
) )
def get_schema(self, schema_id): def get_schema(self, schema_id: Dict[str, str]) -> Dict[str, Any]:
registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries) registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries)
schema = self.registries[registry_name].schemas[schema_name] schema = self.registries[registry_name].schemas[schema_name]
return schema.as_dict() return schema.as_dict()
def delete_schema(self, schema_id): def delete_schema(self, schema_id: Dict[str, str]) -> Dict[str, Any]:
# Validate schema_id # Validate schema_id
registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries) registry_name, schema_name, _ = validate_schema_id(schema_id, self.registries)
@ -701,7 +737,9 @@ class GlueBackend(BaseBackend):
return response return response
def update_schema(self, schema_id, compatibility, description): def update_schema(
self, schema_id: Dict[str, str], compatibility: str, description: str
) -> Dict[str, Any]:
""" """
The SchemaVersionNumber-argument is not yet implemented The SchemaVersionNumber-argument is not yet implemented
""" """
@ -715,7 +753,9 @@ class GlueBackend(BaseBackend):
return schema.as_dict() return schema.as_dict()
def batch_delete_table(self, database_name, tables): def batch_delete_table(
self, database_name: str, tables: List[str]
) -> List[Dict[str, Any]]:
errors = [] errors = []
for table_name in tables: for table_name in tables:
try: try:
@ -732,7 +772,12 @@ class GlueBackend(BaseBackend):
) )
return errors return errors
def batch_get_partition(self, database_name, table_name, partitions_to_get): def batch_get_partition(
self,
database_name: str,
table_name: str,
partitions_to_get: List[Dict[str, str]],
) -> List[Dict[str, Any]]:
table = self.get_table(database_name, table_name) table = self.get_table(database_name, table_name)
partitions = [] partitions = []
@ -744,7 +789,9 @@ class GlueBackend(BaseBackend):
continue continue
return partitions return partitions
def batch_create_partition(self, database_name, table_name, partition_input): def batch_create_partition(
self, database_name: str, table_name: str, partition_input: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
table = self.get_table(database_name, table_name) table = self.get_table(database_name, table_name)
errors_output = [] errors_output = []
@ -763,7 +810,9 @@ class GlueBackend(BaseBackend):
) )
return errors_output return errors_output
def batch_update_partition(self, database_name, table_name, entries): def batch_update_partition(
self, database_name: str, table_name: str, entries: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
table = self.get_table(database_name, table_name) table = self.get_table(database_name, table_name)
errors_output = [] errors_output = []
@ -785,14 +834,16 @@ class GlueBackend(BaseBackend):
) )
return errors_output return errors_output
def batch_delete_partition(self, database_name, table_name, parts): def batch_delete_partition(
self, database_name: str, table_name: str, parts: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
table = self.get_table(database_name, table_name) table = self.get_table(database_name, table_name)
errors_output = [] errors_output = []
for part_input in parts: for part_input in parts:
values = part_input.get("Values") values = part_input.get("Values")
try: try:
table.delete_partition(values) table.delete_partition(values) # type: ignore
except PartitionNotFoundException: except PartitionNotFoundException:
errors_output.append( errors_output.append(
{ {
@ -805,7 +856,7 @@ class GlueBackend(BaseBackend):
) )
return errors_output return errors_output
def batch_get_crawlers(self, crawler_names): def batch_get_crawlers(self, crawler_names: List[str]) -> List[Dict[str, Any]]:
crawlers = [] crawlers = []
for crawler in self.get_crawlers(): for crawler in self.get_crawlers():
if crawler.as_dict()["Name"] in crawler_names: if crawler.as_dict()["Name"] in crawler_names:
@ -814,13 +865,13 @@ class GlueBackend(BaseBackend):
class FakeDatabase(BaseModel): class FakeDatabase(BaseModel):
def __init__(self, database_name, database_input): def __init__(self, database_name: str, database_input: Dict[str, Any]):
self.name = database_name self.name = database_name
self.input = database_input self.input = database_input
self.created_time = datetime.utcnow() self.created_time = datetime.utcnow()
self.tables = OrderedDict() self.tables: Dict[str, FakeTable] = OrderedDict()
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
return { return {
"Name": self.name, "Name": self.name,
"Description": self.input.get("Description"), "Description": self.input.get("Description"),
@ -836,23 +887,25 @@ class FakeDatabase(BaseModel):
class FakeTable(BaseModel): class FakeTable(BaseModel):
def __init__(self, database_name: str, table_name: str, table_input): def __init__(
self, database_name: str, table_name: str, table_input: Dict[str, Any]
):
self.database_name = database_name self.database_name = database_name
self.name = table_name self.name = table_name
self.partitions = OrderedDict() self.partitions: Dict[str, FakePartition] = OrderedDict()
self.created_time = datetime.utcnow() self.created_time = datetime.utcnow()
self.updated_time = None self.updated_time: Optional[datetime] = None
self._current_version = 1 self._current_version = 1
self.versions: Dict[str, Dict[str, Any]] = { self.versions: Dict[str, Dict[str, Any]] = {
str(self._current_version): table_input str(self._current_version): table_input
} }
def update(self, table_input): def update(self, table_input: Dict[str, Any]) -> None:
self.versions[str(self._current_version + 1)] = table_input self.versions[str(self._current_version + 1)] = table_input
self._current_version += 1 self._current_version += 1
self.updated_time = datetime.utcnow() self.updated_time = datetime.utcnow()
def get_version(self, ver): def get_version(self, ver: str) -> Dict[str, Any]:
try: try:
int(ver) int(ver)
except ValueError as e: except ValueError as e:
@ -863,11 +916,11 @@ class FakeTable(BaseModel):
except KeyError: except KeyError:
raise VersionNotFoundException() raise VersionNotFoundException()
def delete_version(self, version_id): def delete_version(self, version_id: str) -> None:
self.versions.pop(version_id) self.versions.pop(version_id)
def as_dict(self, version=None): def as_dict(self, version: Optional[str] = None) -> Dict[str, Any]:
version = version or self._current_version version = version or self._current_version # type: ignore
obj = { obj = {
"DatabaseName": self.database_name, "DatabaseName": self.database_name,
"Name": self.name, "Name": self.name,
@ -880,23 +933,23 @@ class FakeTable(BaseModel):
obj["UpdateTime"] = unix_time(self.updated_time) obj["UpdateTime"] = unix_time(self.updated_time)
return obj return obj
def create_partition(self, partiton_input): def create_partition(self, partiton_input: Dict[str, Any]) -> None:
partition = FakePartition(self.database_name, self.name, partiton_input) partition = FakePartition(self.database_name, self.name, partiton_input)
key = str(partition.values) key = str(partition.values)
if key in self.partitions: if key in self.partitions:
raise PartitionAlreadyExistsException() raise PartitionAlreadyExistsException()
self.partitions[str(partition.values)] = partition self.partitions[str(partition.values)] = partition
def get_partitions(self, expression): def get_partitions(self, expression: str) -> List["FakePartition"]:
return list(filter(PartitionFilter(expression, self), self.partitions.values())) return list(filter(PartitionFilter(expression, self), self.partitions.values()))
def get_partition(self, values): def get_partition(self, values: str) -> "FakePartition":
try: try:
return self.partitions[str(values)] return self.partitions[str(values)]
except KeyError: except KeyError:
raise PartitionNotFoundException() raise PartitionNotFoundException()
def update_partition(self, old_values, partiton_input): def update_partition(self, old_values: str, partiton_input: Dict[str, Any]) -> None:
partition = FakePartition(self.database_name, self.name, partiton_input) partition = FakePartition(self.database_name, self.name, partiton_input)
key = str(partition.values) key = str(partition.values)
if old_values == partiton_input["Values"]: if old_values == partiton_input["Values"]:
@ -913,7 +966,7 @@ class FakeTable(BaseModel):
raise PartitionAlreadyExistsException() raise PartitionAlreadyExistsException()
self.partitions[key] = partition self.partitions[key] = partition
def delete_partition(self, values): def delete_partition(self, values: int) -> None:
try: try:
del self.partitions[str(values)] del self.partitions[str(values)]
except KeyError: except KeyError:
@ -921,14 +974,16 @@ class FakeTable(BaseModel):
class FakePartition(BaseModel): class FakePartition(BaseModel):
def __init__(self, database_name, table_name, partiton_input): def __init__(
self, database_name: str, table_name: str, partiton_input: Dict[str, Any]
):
self.creation_time = time.time() self.creation_time = time.time()
self.database_name = database_name self.database_name = database_name
self.table_name = table_name self.table_name = table_name
self.partition_input = partiton_input self.partition_input = partiton_input
self.values = self.partition_input.get("Values", []) self.values = self.partition_input.get("Values", [])
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
obj = { obj = {
"DatabaseName": self.database_name, "DatabaseName": self.database_name,
"TableName": self.table_name, "TableName": self.table_name,
@ -941,21 +996,21 @@ class FakePartition(BaseModel):
class FakeCrawler(BaseModel): class FakeCrawler(BaseModel):
def __init__( def __init__(
self, self,
name, name: str,
role, role: str,
database_name, database_name: str,
description, description: str,
targets, targets: Dict[str, Any],
schedule, schedule: str,
classifiers, classifiers: List[str],
table_prefix, table_prefix: str,
schema_change_policy, schema_change_policy: Dict[str, str],
recrawl_policy, recrawl_policy: Dict[str, str],
lineage_configuration, lineage_configuration: Dict[str, str],
configuration, configuration: str,
crawler_security_configuration, crawler_security_configuration: str,
tags, tags: Dict[str, str],
backend, backend: GlueBackend,
): ):
self.name = name self.name = name
self.role = role self.role = role
@ -980,11 +1035,11 @@ class FakeCrawler(BaseModel):
self.backend = backend self.backend = backend
self.backend.tag_resource(self.arn, tags) self.backend.tag_resource(self.arn, tags)
def get_name(self): def get_name(self) -> str:
return self.name return self.name
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
last_crawl = self.last_crawl_info.as_dict() if self.last_crawl_info else None last_crawl = self.last_crawl_info.as_dict() if self.last_crawl_info else None # type: ignore
data = { data = {
"Name": self.name, "Name": self.name,
"Role": self.role, "Role": self.role,
@ -1017,14 +1072,14 @@ class FakeCrawler(BaseModel):
return data return data
def start_crawler(self): def start_crawler(self) -> None:
if self.state == "RUNNING": if self.state == "RUNNING":
raise CrawlerRunningException( raise CrawlerRunningException(
f"Crawler with name {self.name} has already started" f"Crawler with name {self.name} has already started"
) )
self.state = "RUNNING" self.state = "RUNNING"
def stop_crawler(self): def stop_crawler(self) -> None:
if self.state != "RUNNING": if self.state != "RUNNING":
raise CrawlerNotRunningException( raise CrawlerNotRunningException(
f"Crawler with name {self.name} isn't running" f"Crawler with name {self.name} isn't running"
@ -1034,7 +1089,13 @@ class FakeCrawler(BaseModel):
class LastCrawlInfo(BaseModel): class LastCrawlInfo(BaseModel):
def __init__( def __init__(
self, error_message, log_group, log_stream, message_prefix, start_time, status self,
error_message: str,
log_group: str,
log_stream: str,
message_prefix: str,
start_time: str,
status: str,
): ):
self.error_message = error_message self.error_message = error_message
self.log_group = log_group self.log_group = log_group
@ -1043,7 +1104,7 @@ class LastCrawlInfo(BaseModel):
self.start_time = start_time self.start_time = start_time
self.status = status self.status = status
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
return { return {
"ErrorMessage": self.error_message, "ErrorMessage": self.error_message,
"LogGroup": self.log_group, "LogGroup": self.log_group,
@ -1057,29 +1118,29 @@ class LastCrawlInfo(BaseModel):
class FakeJob: class FakeJob:
def __init__( def __init__(
self, self,
name, name: str,
role, role: str,
command, command: str,
description=None, description: str,
log_uri=None, log_uri: str,
execution_property=None, execution_property: Dict[str, int],
default_arguments=None, default_arguments: Dict[str, str],
non_overridable_arguments=None, non_overridable_arguments: Dict[str, str],
connections=None, connections: Dict[str, List[str]],
max_retries=None, max_retries: int,
allocated_capacity=None, allocated_capacity: int,
timeout=None, timeout: int,
max_capacity=None, max_capacity: float,
security_configuration=None, security_configuration: str,
tags=None, tags: Dict[str, str],
notification_property=None, notification_property: Dict[str, int],
glue_version=None, glue_version: str,
number_of_workers=None, number_of_workers: int,
worker_type=None, worker_type: str,
code_gen_configuration_nodes=None, code_gen_configuration_nodes: Dict[str, Any],
execution_class=None, execution_class: str,
source_control_details=None, source_control_details: Dict[str, str],
backend=None, backend: GlueBackend,
): ):
self.name = name self.name = name
self.description = description self.description = description
@ -1112,10 +1173,10 @@ class FakeJob:
self.job_runs: List[FakeJobRun] = [] self.job_runs: List[FakeJobRun] = []
def get_name(self): def get_name(self) -> str:
return self.name return self.name
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
return { return {
"Name": self.name, "Name": self.name,
"Description": self.description, "Description": self.description,
@ -1142,7 +1203,7 @@ class FakeJob:
"SourceControlDetails": self.source_control_details, "SourceControlDetails": self.source_control_details,
} }
def start_job_run(self): def start_job_run(self) -> str:
running_jobs = len( running_jobs = len(
[jr for jr in self.job_runs if jr.status in ["STARTING", "RUNNING"]] [jr for jr in self.job_runs if jr.status in ["STARTING", "RUNNING"]]
) )
@ -1154,7 +1215,7 @@ class FakeJob:
self.job_runs.append(fake_job_run) self.job_runs.append(fake_job_run)
return fake_job_run.job_run_id return fake_job_run.job_run_id
def get_job_run(self, run_id): def get_job_run(self, run_id: str) -> "FakeJobRun":
for job_run in self.job_runs: for job_run in self.job_runs:
if job_run.job_run_id == run_id: if job_run.job_run_id == run_id:
job_run.advance() job_run.advance()
@ -1165,11 +1226,11 @@ class FakeJob:
class FakeJobRun(ManagedState): class FakeJobRun(ManagedState):
def __init__( def __init__(
self, self,
job_name: int, job_name: str,
job_run_id: str = "01", job_run_id: str = "01",
arguments: dict = None, arguments: Optional[Dict[str, Any]] = None,
allocated_capacity: int = None, allocated_capacity: Optional[int] = None,
timeout: int = None, timeout: Optional[int] = None,
worker_type: str = "Standard", worker_type: str = "Standard",
): ):
ManagedState.__init__( ManagedState.__init__(
@ -1187,10 +1248,10 @@ class FakeJobRun(ManagedState):
self.modified_on = datetime.utcnow() self.modified_on = datetime.utcnow()
self.completed_on = datetime.utcnow() self.completed_on = datetime.utcnow()
def get_name(self): def get_name(self) -> str:
return self.job_name return self.job_name
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
return { return {
"Id": self.job_run_id, "Id": self.job_run_id,
"Attempt": 1, "Attempt": 1,
@ -1220,7 +1281,13 @@ class FakeJobRun(ManagedState):
class FakeRegistry(BaseModel): class FakeRegistry(BaseModel):
def __init__(self, backend, registry_name, description=None, tags=None): def __init__(
self,
backend: GlueBackend,
registry_name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
):
self.name = registry_name self.name = registry_name
self.description = description self.description = description
self.tags = tags self.tags = tags
@ -1230,7 +1297,7 @@ class FakeRegistry(BaseModel):
self.registry_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:registry/{self.name}" self.registry_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:registry/{self.name}"
self.schemas: Dict[str, FakeSchema] = OrderedDict() self.schemas: Dict[str, FakeSchema] = OrderedDict()
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
return { return {
"RegistryArn": self.registry_arn, "RegistryArn": self.registry_arn,
"RegistryName": self.name, "RegistryName": self.name,
@ -1243,12 +1310,12 @@ class FakeSchema(BaseModel):
def __init__( def __init__(
self, self,
backend: GlueBackend, backend: GlueBackend,
registry_name, registry_name: str,
schema_name, schema_name: str,
data_format, data_format: str,
compatibility, compatibility: str,
schema_version_id, schema_version_id: str,
description=None, description: Optional[str] = None,
): ):
self.registry_name = registry_name self.registry_name = registry_name
self.registry_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:registry/{self.registry_name}" self.registry_arn = f"arn:aws:glue:{backend.region_name}:{backend.account_id}:registry/{self.registry_name}"
@ -1265,18 +1332,18 @@ class FakeSchema(BaseModel):
self.schema_version_status = AVAILABLE_STATUS self.schema_version_status = AVAILABLE_STATUS
self.created_time = datetime.utcnow() self.created_time = datetime.utcnow()
self.updated_time = datetime.utcnow() self.updated_time = datetime.utcnow()
self.schema_versions = OrderedDict() self.schema_versions: Dict[str, FakeSchemaVersion] = OrderedDict()
def update_next_schema_version(self): def update_next_schema_version(self) -> None:
self.next_schema_version += 1 self.next_schema_version += 1
def update_latest_schema_version(self): def update_latest_schema_version(self) -> None:
self.latest_schema_version += 1 self.latest_schema_version += 1
def get_next_schema_version(self): def get_next_schema_version(self) -> int:
return self.next_schema_version return self.next_schema_version
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
return { return {
"RegistryArn": self.registry_arn, "RegistryArn": self.registry_arn,
"RegistryName": self.registry_name, "RegistryName": self.registry_name,
@ -1298,10 +1365,10 @@ class FakeSchemaVersion(BaseModel):
def __init__( def __init__(
self, self,
backend: GlueBackend, backend: GlueBackend,
registry_name, registry_name: str,
schema_name, schema_name: str,
schema_definition, schema_definition: str,
version_number, version_number: int,
): ):
self.registry_name = registry_name self.registry_name = registry_name
self.schema_name = schema_name self.schema_name = schema_name
@ -1312,19 +1379,19 @@ class FakeSchemaVersion(BaseModel):
self.schema_version_id = str(mock_random.uuid4()) self.schema_version_id = str(mock_random.uuid4())
self.created_time = datetime.utcnow() self.created_time = datetime.utcnow()
self.updated_time = datetime.utcnow() self.updated_time = datetime.utcnow()
self.metadata = OrderedDict() self.metadata: Dict[str, Any] = {}
def get_schema_version_id(self): def get_schema_version_id(self) -> str:
return self.schema_version_id return self.schema_version_id
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
return { return {
"SchemaVersionId": self.schema_version_id, "SchemaVersionId": self.schema_version_id,
"VersionNumber": self.version_number, "VersionNumber": self.version_number,
"Status": self.schema_version_status, "Status": self.schema_version_status,
} }
def get_schema_version_as_dict(self): def get_schema_version_as_dict(self) -> Dict[str, Any]:
# add data_format for full return dictionary of get_schema_version # add data_format for full return dictionary of get_schema_version
return { return {
"SchemaVersionId": self.schema_version_id, "SchemaVersionId": self.schema_version_id,
@ -1335,7 +1402,7 @@ class FakeSchemaVersion(BaseModel):
"CreatedTime": str(self.created_time), "CreatedTime": str(self.created_time),
} }
def get_schema_by_definition_as_dict(self): def get_schema_by_definition_as_dict(self) -> Dict[str, Any]:
# add data_format for full return dictionary of get_schema_by_definition # add data_format for full return dictionary of get_schema_by_definition
return { return {
"SchemaVersionId": self.schema_version_id, "SchemaVersionId": self.schema_version_id,

View File

@ -1,11 +1,13 @@
import json import json
from typing import Any, Dict, List
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import glue_backends, GlueBackend from .models import glue_backends, GlueBackend, FakeJob, FakeCrawler
class GlueResponse(BaseResponse): class GlueResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="glue") super().__init__(service_name="glue")
@property @property
@ -13,66 +15,66 @@ class GlueResponse(BaseResponse):
return glue_backends[self.current_account][self.region] return glue_backends[self.current_account][self.region]
@property @property
def parameters(self): def parameters(self) -> Dict[str, Any]: # type: ignore[misc]
return json.loads(self.body) return json.loads(self.body)
def create_database(self): def create_database(self) -> str:
database_input = self.parameters.get("DatabaseInput") database_input = self.parameters.get("DatabaseInput")
database_name = database_input.get("Name") database_name = database_input.get("Name") # type: ignore
if "CatalogId" in self.parameters: if "CatalogId" in self.parameters:
database_input["CatalogId"] = self.parameters.get("CatalogId") database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore
self.glue_backend.create_database(database_name, database_input) self.glue_backend.create_database(database_name, database_input) # type: ignore[arg-type]
return "" return ""
def get_database(self): def get_database(self) -> str:
database_name = self.parameters.get("Name") database_name = self.parameters.get("Name")
database = self.glue_backend.get_database(database_name) database = self.glue_backend.get_database(database_name) # type: ignore[arg-type]
return json.dumps({"Database": database.as_dict()}) return json.dumps({"Database": database.as_dict()})
def get_databases(self): def get_databases(self) -> str:
database_list = self.glue_backend.get_databases() database_list = self.glue_backend.get_databases()
return json.dumps( return json.dumps(
{"DatabaseList": [database.as_dict() for database in database_list]} {"DatabaseList": [database.as_dict() for database in database_list]}
) )
def update_database(self): def update_database(self) -> str:
database_input = self.parameters.get("DatabaseInput") database_input = self.parameters.get("DatabaseInput")
database_name = self.parameters.get("Name") database_name = self.parameters.get("Name")
if "CatalogId" in self.parameters: if "CatalogId" in self.parameters:
database_input["CatalogId"] = self.parameters.get("CatalogId") database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore
self.glue_backend.update_database(database_name, database_input) self.glue_backend.update_database(database_name, database_input) # type: ignore[arg-type]
return "" return ""
def delete_database(self): def delete_database(self) -> str:
name = self.parameters.get("Name") name = self.parameters.get("Name")
self.glue_backend.delete_database(name) self.glue_backend.delete_database(name) # type: ignore[arg-type]
return json.dumps({}) return json.dumps({})
def create_table(self): def create_table(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_input = self.parameters.get("TableInput") table_input = self.parameters.get("TableInput")
table_name = table_input.get("Name") table_name = table_input.get("Name") # type: ignore
self.glue_backend.create_table(database_name, table_name, table_input) self.glue_backend.create_table(database_name, table_name, table_input) # type: ignore[arg-type]
return "" return ""
def get_table(self): def get_table(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("Name") table_name = self.parameters.get("Name")
table = self.glue_backend.get_table(database_name, table_name) table = self.glue_backend.get_table(database_name, table_name) # type: ignore[arg-type]
return json.dumps({"Table": table.as_dict()}) return json.dumps({"Table": table.as_dict()})
def update_table(self): def update_table(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_input = self.parameters.get("TableInput") table_input = self.parameters.get("TableInput")
table_name = table_input.get("Name") table_name = table_input.get("Name") # type: ignore
self.glue_backend.update_table(database_name, table_name, table_input) self.glue_backend.update_table(database_name, table_name, table_input) # type: ignore[arg-type]
return "" return ""
def get_table_versions(self): def get_table_versions(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
versions = self.glue_backend.get_table_versions(database_name, table_name) versions = self.glue_backend.get_table_versions(database_name, table_name) # type: ignore[arg-type]
return json.dumps( return json.dumps(
{ {
"TableVersions": [ "TableVersions": [
@ -82,36 +84,36 @@ class GlueResponse(BaseResponse):
} }
) )
def get_table_version(self): def get_table_version(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
ver_id = self.parameters.get("VersionId") ver_id = self.parameters.get("VersionId")
return self.glue_backend.get_table_version(database_name, table_name, ver_id) return self.glue_backend.get_table_version(database_name, table_name, ver_id) # type: ignore[arg-type]
def delete_table_version(self) -> str: def delete_table_version(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
version_id = self.parameters.get("VersionId") version_id = self.parameters.get("VersionId")
self.glue_backend.delete_table_version(database_name, table_name, version_id) self.glue_backend.delete_table_version(database_name, table_name, version_id) # type: ignore[arg-type]
return "{}" return "{}"
def get_tables(self): def get_tables(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
expression = self.parameters.get("Expression") expression = self.parameters.get("Expression")
tables = self.glue_backend.get_tables(database_name, expression) tables = self.glue_backend.get_tables(database_name, expression) # type: ignore[arg-type]
return json.dumps({"TableList": [table.as_dict() for table in tables]}) return json.dumps({"TableList": [table.as_dict() for table in tables]})
def delete_table(self): def delete_table(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("Name") table_name = self.parameters.get("Name")
resp = self.glue_backend.delete_table(database_name, table_name) self.glue_backend.delete_table(database_name, table_name) # type: ignore[arg-type]
return json.dumps(resp) return "{}"
def batch_delete_table(self): def batch_delete_table(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
tables = self.parameters.get("TablesToDelete") tables = self.parameters.get("TablesToDelete")
errors = self.glue_backend.batch_delete_table(database_name, tables) errors = self.glue_backend.batch_delete_table(database_name, tables) # type: ignore[arg-type]
out = {} out = {}
if errors: if errors:
@ -119,50 +121,50 @@ class GlueResponse(BaseResponse):
return json.dumps(out) return json.dumps(out)
def get_partitions(self): def get_partitions(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
expression = self.parameters.get("Expression") expression = self.parameters.get("Expression")
partitions = self.glue_backend.get_partitions( partitions = self.glue_backend.get_partitions(
database_name, table_name, expression database_name, table_name, expression # type: ignore[arg-type]
) )
return json.dumps({"Partitions": [p.as_dict() for p in partitions]}) return json.dumps({"Partitions": [p.as_dict() for p in partitions]})
def get_partition(self): def get_partition(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
values = self.parameters.get("PartitionValues") values = self.parameters.get("PartitionValues")
p = self.glue_backend.get_partition(database_name, table_name, values) p = self.glue_backend.get_partition(database_name, table_name, values) # type: ignore[arg-type]
return json.dumps({"Partition": p.as_dict()}) return json.dumps({"Partition": p.as_dict()})
def batch_get_partition(self): def batch_get_partition(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
partitions_to_get = self.parameters.get("PartitionsToGet") partitions_to_get = self.parameters.get("PartitionsToGet")
partitions = self.glue_backend.batch_get_partition( partitions = self.glue_backend.batch_get_partition(
database_name, table_name, partitions_to_get database_name, table_name, partitions_to_get # type: ignore[arg-type]
) )
return json.dumps({"Partitions": partitions}) return json.dumps({"Partitions": partitions})
def create_partition(self): def create_partition(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
part_input = self.parameters.get("PartitionInput") part_input = self.parameters.get("PartitionInput")
self.glue_backend.create_partition(database_name, table_name, part_input) self.glue_backend.create_partition(database_name, table_name, part_input) # type: ignore[arg-type]
return "" return ""
def batch_create_partition(self): def batch_create_partition(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
partition_input = self.parameters.get("PartitionInputList") partition_input = self.parameters.get("PartitionInputList")
errors_output = self.glue_backend.batch_create_partition( errors_output = self.glue_backend.batch_create_partition(
database_name, table_name, partition_input database_name, table_name, partition_input # type: ignore[arg-type]
) )
out = {} out = {}
@ -171,24 +173,24 @@ class GlueResponse(BaseResponse):
return json.dumps(out) return json.dumps(out)
def update_partition(self): def update_partition(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
part_input = self.parameters.get("PartitionInput") part_input = self.parameters.get("PartitionInput")
part_to_update = self.parameters.get("PartitionValueList") part_to_update = self.parameters.get("PartitionValueList")
self.glue_backend.update_partition( self.glue_backend.update_partition(
database_name, table_name, part_input, part_to_update database_name, table_name, part_input, part_to_update # type: ignore[arg-type]
) )
return "" return ""
def batch_update_partition(self): def batch_update_partition(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
entries = self.parameters.get("Entries") entries = self.parameters.get("Entries")
errors_output = self.glue_backend.batch_update_partition( errors_output = self.glue_backend.batch_update_partition(
database_name, table_name, entries database_name, table_name, entries # type: ignore[arg-type]
) )
out = {} out = {}
@ -197,21 +199,21 @@ class GlueResponse(BaseResponse):
return json.dumps(out) return json.dumps(out)
def delete_partition(self): def delete_partition(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
part_to_delete = self.parameters.get("PartitionValues") part_to_delete = self.parameters.get("PartitionValues")
self.glue_backend.delete_partition(database_name, table_name, part_to_delete) self.glue_backend.delete_partition(database_name, table_name, part_to_delete) # type: ignore[arg-type]
return "" return ""
def batch_delete_partition(self): def batch_delete_partition(self) -> str:
database_name = self.parameters.get("DatabaseName") database_name = self.parameters.get("DatabaseName")
table_name = self.parameters.get("TableName") table_name = self.parameters.get("TableName")
parts = self.parameters.get("PartitionsToDelete") parts = self.parameters.get("PartitionsToDelete")
errors_output = self.glue_backend.batch_delete_partition( errors_output = self.glue_backend.batch_delete_partition(
database_name, table_name, parts database_name, table_name, parts # type: ignore[arg-type]
) )
out = {} out = {}
@ -220,37 +222,37 @@ class GlueResponse(BaseResponse):
return json.dumps(out) return json.dumps(out)
def create_crawler(self): def create_crawler(self) -> str:
self.glue_backend.create_crawler( self.glue_backend.create_crawler(
name=self.parameters.get("Name"), name=self.parameters.get("Name"), # type: ignore[arg-type]
role=self.parameters.get("Role"), role=self.parameters.get("Role"), # type: ignore[arg-type]
database_name=self.parameters.get("DatabaseName"), database_name=self.parameters.get("DatabaseName"), # type: ignore[arg-type]
description=self.parameters.get("Description"), description=self.parameters.get("Description"), # type: ignore[arg-type]
targets=self.parameters.get("Targets"), targets=self.parameters.get("Targets"), # type: ignore[arg-type]
schedule=self.parameters.get("Schedule"), schedule=self.parameters.get("Schedule"), # type: ignore[arg-type]
classifiers=self.parameters.get("Classifiers"), classifiers=self.parameters.get("Classifiers"), # type: ignore[arg-type]
table_prefix=self.parameters.get("TablePrefix"), table_prefix=self.parameters.get("TablePrefix"), # type: ignore[arg-type]
schema_change_policy=self.parameters.get("SchemaChangePolicy"), schema_change_policy=self.parameters.get("SchemaChangePolicy"), # type: ignore[arg-type]
recrawl_policy=self.parameters.get("RecrawlPolicy"), recrawl_policy=self.parameters.get("RecrawlPolicy"), # type: ignore[arg-type]
lineage_configuration=self.parameters.get("LineageConfiguration"), lineage_configuration=self.parameters.get("LineageConfiguration"), # type: ignore[arg-type]
configuration=self.parameters.get("Configuration"), configuration=self.parameters.get("Configuration"), # type: ignore[arg-type]
crawler_security_configuration=self.parameters.get( crawler_security_configuration=self.parameters.get( # type: ignore[arg-type]
"CrawlerSecurityConfiguration" "CrawlerSecurityConfiguration"
), ),
tags=self.parameters.get("Tags"), tags=self.parameters.get("Tags"), # type: ignore[arg-type]
) )
return "" return ""
def get_crawler(self): def get_crawler(self) -> str:
name = self.parameters.get("Name") name = self.parameters.get("Name")
crawler = self.glue_backend.get_crawler(name) crawler = self.glue_backend.get_crawler(name) # type: ignore[arg-type]
return json.dumps({"Crawler": crawler.as_dict()}) return json.dumps({"Crawler": crawler.as_dict()})
def get_crawlers(self): def get_crawlers(self) -> str:
crawlers = self.glue_backend.get_crawlers() crawlers = self.glue_backend.get_crawlers()
return json.dumps({"Crawlers": [crawler.as_dict() for crawler in crawlers]}) return json.dumps({"Crawlers": [crawler.as_dict() for crawler in crawlers]})
def list_crawlers(self): def list_crawlers(self) -> str:
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
max_results = self._get_int_param("MaxResults") max_results = self._get_int_param("MaxResults")
tags = self._get_param("Tags") tags = self._get_param("Tags")
@ -265,31 +267,33 @@ class GlueResponse(BaseResponse):
) )
) )
def filter_crawlers_by_tags(self, crawlers, tags): def filter_crawlers_by_tags(
self, crawlers: List[FakeCrawler], tags: Dict[str, str]
) -> List[str]:
if not tags: if not tags:
return [crawler.get_name() for crawler in crawlers] return [crawler.get_name() for crawler in crawlers]
return [ return [
crawler.get_name() crawler.get_name()
for crawler in crawlers for crawler in crawlers
if self.is_tags_match(self, crawler.arn, tags) if self.is_tags_match(crawler.arn, tags)
] ]
def start_crawler(self): def start_crawler(self) -> str:
name = self.parameters.get("Name") name = self.parameters.get("Name")
self.glue_backend.start_crawler(name) self.glue_backend.start_crawler(name) # type: ignore[arg-type]
return "" return ""
def stop_crawler(self): def stop_crawler(self) -> str:
name = self.parameters.get("Name") name = self.parameters.get("Name")
self.glue_backend.stop_crawler(name) self.glue_backend.stop_crawler(name) # type: ignore[arg-type]
return "" return ""
def delete_crawler(self): def delete_crawler(self) -> str:
name = self.parameters.get("Name") name = self.parameters.get("Name")
self.glue_backend.delete_crawler(name) self.glue_backend.delete_crawler(name) # type: ignore[arg-type]
return "" return ""
def create_job(self): def create_job(self) -> str:
name = self._get_param("Name") name = self._get_param("Name")
description = self._get_param("Description") description = self._get_param("Description")
log_uri = self._get_param("LogUri") log_uri = self._get_param("LogUri")
@ -312,7 +316,7 @@ class GlueResponse(BaseResponse):
code_gen_configuration_nodes = self._get_param("CodeGenConfigurationNodes") code_gen_configuration_nodes = self._get_param("CodeGenConfigurationNodes")
execution_class = self._get_param("ExecutionClass") execution_class = self._get_param("ExecutionClass")
source_control_details = self._get_param("SourceControlDetails") source_control_details = self._get_param("SourceControlDetails")
name = self.glue_backend.create_job( self.glue_backend.create_job(
name=name, name=name,
description=description, description=description,
log_uri=log_uri, log_uri=log_uri,
@ -338,12 +342,12 @@ class GlueResponse(BaseResponse):
) )
return json.dumps(dict(Name=name)) return json.dumps(dict(Name=name))
def get_job(self): def get_job(self) -> str:
name = self.parameters.get("JobName") name = self.parameters.get("JobName")
job = self.glue_backend.get_job(name) job = self.glue_backend.get_job(name) # type: ignore[arg-type]
return json.dumps({"Job": job.as_dict()}) return json.dumps({"Job": job.as_dict()})
def get_jobs(self): def get_jobs(self) -> str:
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
max_results = self._get_int_param("MaxResults") max_results = self._get_int_param("MaxResults")
jobs, next_token = self.glue_backend.get_jobs( jobs, next_token = self.glue_backend.get_jobs(
@ -356,18 +360,18 @@ class GlueResponse(BaseResponse):
) )
) )
def start_job_run(self): def start_job_run(self) -> str:
name = self.parameters.get("JobName") name = self.parameters.get("JobName")
job_run_id = self.glue_backend.start_job_run(name) job_run_id = self.glue_backend.start_job_run(name) # type: ignore[arg-type]
return json.dumps(dict(JobRunId=job_run_id)) return json.dumps(dict(JobRunId=job_run_id))
def get_job_run(self): def get_job_run(self) -> str:
name = self.parameters.get("JobName") name = self.parameters.get("JobName")
run_id = self.parameters.get("RunId") run_id = self.parameters.get("RunId")
job_run = self.glue_backend.get_job_run(name, run_id) job_run = self.glue_backend.get_job_run(name, run_id) # type: ignore[arg-type]
return json.dumps({"JobRun": job_run.as_dict()}) return json.dumps({"JobRun": job_run.as_dict()})
def list_jobs(self): def list_jobs(self) -> str:
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
max_results = self._get_int_param("MaxResults") max_results = self._get_int_param("MaxResults")
tags = self._get_param("Tags") tags = self._get_param("Tags")
@ -382,32 +386,31 @@ class GlueResponse(BaseResponse):
) )
) )
def get_tags(self): def get_tags(self) -> TYPE_RESPONSE:
resource_arn = self.parameters.get("ResourceArn") resource_arn = self.parameters.get("ResourceArn")
tags = self.glue_backend.get_tags(resource_arn) tags = self.glue_backend.get_tags(resource_arn) # type: ignore[arg-type]
return 200, {}, json.dumps({"Tags": tags}) return 200, {}, json.dumps({"Tags": tags})
def tag_resource(self): def tag_resource(self) -> TYPE_RESPONSE:
resource_arn = self.parameters.get("ResourceArn") resource_arn = self.parameters.get("ResourceArn")
tags = self.parameters.get("TagsToAdd", {}) tags = self.parameters.get("TagsToAdd", {})
self.glue_backend.tag_resource(resource_arn, tags) self.glue_backend.tag_resource(resource_arn, tags) # type: ignore[arg-type]
return 201, {}, "{}" return 201, {}, "{}"
def untag_resource(self): def untag_resource(self) -> TYPE_RESPONSE:
resource_arn = self._get_param("ResourceArn") resource_arn = self._get_param("ResourceArn")
tag_keys = self.parameters.get("TagsToRemove") tag_keys = self.parameters.get("TagsToRemove")
self.glue_backend.untag_resource(resource_arn, tag_keys) self.glue_backend.untag_resource(resource_arn, tag_keys) # type: ignore[arg-type]
return 200, {}, "{}" return 200, {}, "{}"
def filter_jobs_by_tags(self, jobs, tags): def filter_jobs_by_tags(
self, jobs: List[FakeJob], tags: Dict[str, str]
) -> List[str]:
if not tags: if not tags:
return [job.get_name() for job in jobs] return [job.get_name() for job in jobs]
return [ return [job.get_name() for job in jobs if self.is_tags_match(job.arn, tags)]
job.get_name() for job in jobs if self.is_tags_match(self, job.arn, tags)
]
@staticmethod def is_tags_match(self, resource_arn: str, tags: Dict[str, str]) -> bool:
def is_tags_match(self, resource_arn, tags):
glue_resource_tags = self.glue_backend.get_tags(resource_arn) glue_resource_tags = self.glue_backend.get_tags(resource_arn)
mutual_keys = set(glue_resource_tags).intersection(tags) mutual_keys = set(glue_resource_tags).intersection(tags)
for key in mutual_keys: for key in mutual_keys:
@ -415,28 +418,28 @@ class GlueResponse(BaseResponse):
return True return True
return False return False
def create_registry(self): def create_registry(self) -> str:
registry_name = self._get_param("RegistryName") registry_name = self._get_param("RegistryName")
description = self._get_param("Description") description = self._get_param("Description")
tags = self._get_param("Tags") tags = self._get_param("Tags")
registry = self.glue_backend.create_registry(registry_name, description, tags) registry = self.glue_backend.create_registry(registry_name, description, tags)
return json.dumps(registry) return json.dumps(registry)
def delete_registry(self): def delete_registry(self) -> str:
registry_id = self._get_param("RegistryId") registry_id = self._get_param("RegistryId")
registry = self.glue_backend.delete_registry(registry_id) registry = self.glue_backend.delete_registry(registry_id)
return json.dumps(registry) return json.dumps(registry)
def get_registry(self): def get_registry(self) -> str:
registry_id = self._get_param("RegistryId") registry_id = self._get_param("RegistryId")
registry = self.glue_backend.get_registry(registry_id) registry = self.glue_backend.get_registry(registry_id)
return json.dumps(registry) return json.dumps(registry)
def list_registries(self): def list_registries(self) -> str:
registries = self.glue_backend.list_registries() registries = self.glue_backend.list_registries()
return json.dumps({"Registries": registries}) return json.dumps({"Registries": registries})
def create_schema(self): def create_schema(self) -> str:
registry_id = self._get_param("RegistryId") registry_id = self._get_param("RegistryId")
schema_name = self._get_param("SchemaName") schema_name = self._get_param("SchemaName")
data_format = self._get_param("DataFormat") data_format = self._get_param("DataFormat")
@ -455,7 +458,7 @@ class GlueResponse(BaseResponse):
) )
return json.dumps(schema) return json.dumps(schema)
def register_schema_version(self): def register_schema_version(self) -> str:
schema_id = self._get_param("SchemaId") schema_id = self._get_param("SchemaId")
schema_definition = self._get_param("SchemaDefinition") schema_definition = self._get_param("SchemaDefinition")
schema_version = self.glue_backend.register_schema_version( schema_version = self.glue_backend.register_schema_version(
@ -463,7 +466,7 @@ class GlueResponse(BaseResponse):
) )
return json.dumps(schema_version) return json.dumps(schema_version)
def get_schema_version(self): def get_schema_version(self) -> str:
schema_id = self._get_param("SchemaId") schema_id = self._get_param("SchemaId")
schema_version_id = self._get_param("SchemaVersionId") schema_version_id = self._get_param("SchemaVersionId")
schema_version_number = self._get_param("SchemaVersionNumber") schema_version_number = self._get_param("SchemaVersionNumber")
@ -473,7 +476,7 @@ class GlueResponse(BaseResponse):
) )
return json.dumps(schema_version) return json.dumps(schema_version)
def get_schema_by_definition(self): def get_schema_by_definition(self) -> str:
schema_id = self._get_param("SchemaId") schema_id = self._get_param("SchemaId")
schema_definition = self._get_param("SchemaDefinition") schema_definition = self._get_param("SchemaDefinition")
schema_version = self.glue_backend.get_schema_by_definition( schema_version = self.glue_backend.get_schema_by_definition(
@ -481,7 +484,7 @@ class GlueResponse(BaseResponse):
) )
return json.dumps(schema_version) return json.dumps(schema_version)
def put_schema_version_metadata(self): def put_schema_version_metadata(self) -> str:
schema_id = self._get_param("SchemaId") schema_id = self._get_param("SchemaId")
schema_version_number = self._get_param("SchemaVersionNumber") schema_version_number = self._get_param("SchemaVersionNumber")
schema_version_id = self._get_param("SchemaVersionId") schema_version_id = self._get_param("SchemaVersionId")
@ -491,24 +494,24 @@ class GlueResponse(BaseResponse):
) )
return json.dumps(schema_version) return json.dumps(schema_version)
def get_schema(self): def get_schema(self) -> str:
schema_id = self._get_param("SchemaId") schema_id = self._get_param("SchemaId")
schema = self.glue_backend.get_schema(schema_id) schema = self.glue_backend.get_schema(schema_id)
return json.dumps(schema) return json.dumps(schema)
def delete_schema(self): def delete_schema(self) -> str:
schema_id = self._get_param("SchemaId") schema_id = self._get_param("SchemaId")
schema = self.glue_backend.delete_schema(schema_id) schema = self.glue_backend.delete_schema(schema_id)
return json.dumps(schema) return json.dumps(schema)
def update_schema(self): def update_schema(self) -> str:
schema_id = self._get_param("SchemaId") schema_id = self._get_param("SchemaId")
compatibility = self._get_param("Compatibility") compatibility = self._get_param("Compatibility")
description = self._get_param("Description") description = self._get_param("Description")
schema = self.glue_backend.update_schema(schema_id, compatibility, description) schema = self.glue_backend.update_schema(schema_id, compatibility, description)
return json.dumps(schema) return json.dumps(schema)
def batch_get_crawlers(self): def batch_get_crawlers(self) -> str:
crawler_names = self._get_param("CrawlerNames") crawler_names = self._get_param("CrawlerNames")
crawlers = self.glue_backend.batch_get_crawlers(crawler_names) crawlers = self.glue_backend.batch_get_crawlers(crawler_names)
crawlers_not_found = list( crawlers_not_found = list(
@ -521,5 +524,5 @@ class GlueResponse(BaseResponse):
} }
) )
def get_partition_indexes(self): def get_partition_indexes(self) -> str:
return json.dumps({"PartitionIndexDescriptorList": []}) return json.dumps({"PartitionIndexDescriptorList": []})

View File

@ -97,7 +97,7 @@ def _escape_regex(pattern: str) -> str:
class _Expr(abc.ABC): class _Expr(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any: def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any: # type: ignore[misc]
raise NotImplementedError() raise NotImplementedError()
@ -196,7 +196,7 @@ class _Like(_Expr):
pattern = _cast("string", self.literal) pattern = _cast("string", self.literal)
# prepare SQL pattern for conversion to regex pattern # prepare SQL pattern for conversion to regex pattern
pattern = _escape_regex(pattern) pattern = _escape_regex(pattern) # type: ignore
# NOTE convert SQL wildcards to regex, no literal matches possible # NOTE convert SQL wildcards to regex, no literal matches possible
pattern = pattern.replace("_", ".").replace("%", ".*") pattern = pattern.replace("_", ".").replace("%", ".*")
@ -265,19 +265,19 @@ class _BoolOr(_Expr):
class _PartitionFilterExpressionCache: class _PartitionFilterExpressionCache:
def __init__(self): def __init__(self) -> None:
# build grammar according to Glue.Client.get_partitions(Expression) # build grammar according to Glue.Client.get_partitions(Expression)
lpar, rpar = map(Suppress, "()") lpar, rpar = map(Suppress, "()")
# NOTE these are AWS Athena column name best practices # NOTE these are AWS Athena column name best practices
ident = Forward().set_name("ident") ident = Forward().set_name("ident")
ident <<= Word(alphanums + "._").set_parse_action(_Ident) | lpar + ident + rpar ident <<= Word(alphanums + "._").set_parse_action(_Ident) | lpar + ident + rpar # type: ignore
number = Forward().set_name("number") number = Forward().set_name("number")
number <<= pyparsing_common.number | lpar + number + rpar number <<= pyparsing_common.number | lpar + number + rpar # type: ignore
string = Forward().set_name("string") string = Forward().set_name("string")
string <<= QuotedString(quote_char="'", esc_quote="''") | lpar + string + rpar string <<= QuotedString(quote_char="'", esc_quote="''") | lpar + string + rpar # type: ignore
literal = (number | string).set_name("literal") literal = (number | string).set_name("literal")
literal_list = delimited_list(literal, min=1).set_name("list") literal_list = delimited_list(literal, min=1).set_name("list")
@ -293,7 +293,7 @@ class _PartitionFilterExpressionCache:
in_, between, like, not_, is_, null = map( in_, between, like, not_, is_, null = map(
CaselessKeyword, "in between like not is null".split() CaselessKeyword, "in between like not is null".split()
) )
not_ = Suppress(not_) # only needed for matching not_ = Suppress(not_) # type: ignore # only needed for matching
cond = ( cond = (
(ident + is_ + null).set_parse_action(_IsNull) (ident + is_ + null).set_parse_action(_IsNull)
@ -343,11 +343,11 @@ _PARTITION_FILTER_EXPRESSION_CACHE = _PartitionFilterExpressionCache()
class PartitionFilter: class PartitionFilter:
def __init__(self, expression: Optional[str], fake_table): def __init__(self, expression: Optional[str], fake_table: Any):
self.expression = expression self.expression = expression
self.fake_table = fake_table self.fake_table = fake_table
def __call__(self, fake_partition) -> bool: def __call__(self, fake_partition: Any) -> bool:
expression = _PARTITION_FILTER_EXPRESSION_CACHE.get(self.expression) expression = _PARTITION_FILTER_EXPRESSION_CACHE.get(self.expression)
if expression is None: if expression is None:
return True return True

View File

@ -6,36 +6,36 @@ class GreengrassClientError(JsonRESTError):
class IdNotFoundException(GreengrassClientError): class IdNotFoundException(GreengrassClientError):
def __init__(self, msg): def __init__(self, msg: str):
self.code = 404 self.code = 404
super().__init__("IdNotFoundException", msg) super().__init__("IdNotFoundException", msg)
class InvalidContainerDefinitionException(GreengrassClientError): class InvalidContainerDefinitionException(GreengrassClientError):
def __init__(self, msg): def __init__(self, msg: str):
self.code = 400 self.code = 400
super().__init__("InvalidContainerDefinitionException", msg) super().__init__("InvalidContainerDefinitionException", msg)
class VersionNotFoundException(GreengrassClientError): class VersionNotFoundException(GreengrassClientError):
def __init__(self, msg): def __init__(self, msg: str):
self.code = 404 self.code = 404
super().__init__("VersionNotFoundException", msg) super().__init__("VersionNotFoundException", msg)
class InvalidInputException(GreengrassClientError): class InvalidInputException(GreengrassClientError):
def __init__(self, msg): def __init__(self, msg: str):
self.code = 400 self.code = 400
super().__init__("InvalidInputException", msg) super().__init__("InvalidInputException", msg)
class MissingCoreException(GreengrassClientError): class MissingCoreException(GreengrassClientError):
def __init__(self, msg): def __init__(self, msg: str):
self.code = 400 self.code = 400
super().__init__("MissingCoreException", msg) super().__init__("MissingCoreException", msg)
class ResourceNotFoundException(GreengrassClientError): class ResourceNotFoundException(GreengrassClientError):
def __init__(self, msg): def __init__(self, msg: str):
self.code = 404 self.code = 404
super().__init__("ResourceNotFoundException", msg) super().__init__("ResourceNotFoundException", msg)

View File

@ -1,6 +1,7 @@
import json import json
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Iterable, Optional
import re import re
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
@ -18,7 +19,7 @@ from .exceptions import (
class FakeCoreDefinition(BaseModel): class FakeCoreDefinition(BaseModel):
def __init__(self, account_id, region_name, name): def __init__(self, account_id: str, region_name: str, name: str):
self.region_name = region_name self.region_name = region_name
self.name = name self.name = name
self.id = str(mock_random.uuid4()) self.id = str(mock_random.uuid4())
@ -27,7 +28,7 @@ class FakeCoreDefinition(BaseModel):
self.latest_version = "" self.latest_version = ""
self.latest_version_arn = "" self.latest_version_arn = ""
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -44,7 +45,13 @@ class FakeCoreDefinition(BaseModel):
class FakeCoreDefinitionVersion(BaseModel): class FakeCoreDefinitionVersion(BaseModel):
def __init__(self, account_id, region_name, core_definition_id, definition): def __init__(
self,
account_id: str,
region_name: str,
core_definition_id: str,
definition: Dict[str, Any],
):
self.region_name = region_name self.region_name = region_name
self.core_definition_id = core_definition_id self.core_definition_id = core_definition_id
self.definition = definition self.definition = definition
@ -52,8 +59,8 @@ class FakeCoreDefinitionVersion(BaseModel):
self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/cores/{self.core_definition_id}/versions/{self.version}" self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/cores/{self.core_definition_id}/versions/{self.version}"
self.created_at_datetime = datetime.utcnow() self.created_at_datetime = datetime.utcnow()
def to_dict(self, include_detail=False): def to_dict(self, include_detail: bool = False) -> Dict[str, Any]:
obj = { obj: Dict[str, Any] = {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
self.created_at_datetime self.created_at_datetime
@ -69,7 +76,13 @@ class FakeCoreDefinitionVersion(BaseModel):
class FakeDeviceDefinition(BaseModel): class FakeDeviceDefinition(BaseModel):
def __init__(self, account_id, region_name, name, initial_version): def __init__(
self,
account_id: str,
region_name: str,
name: str,
initial_version: Dict[str, Any],
):
self.region_name = region_name self.region_name = region_name
self.id = str(mock_random.uuid4()) self.id = str(mock_random.uuid4())
self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/devices/{self.id}" self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/devices/{self.id}"
@ -80,7 +93,7 @@ class FakeDeviceDefinition(BaseModel):
self.name = name self.name = name
self.initial_version = initial_version self.initial_version = initial_version
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
res = { res = {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -99,7 +112,13 @@ class FakeDeviceDefinition(BaseModel):
class FakeDeviceDefinitionVersion(BaseModel): class FakeDeviceDefinitionVersion(BaseModel):
def __init__(self, account_id, region_name, device_definition_id, devices): def __init__(
self,
account_id: str,
region_name: str,
device_definition_id: str,
devices: List[Dict[str, Any]],
):
self.region_name = region_name self.region_name = region_name
self.device_definition_id = device_definition_id self.device_definition_id = device_definition_id
self.devices = devices self.devices = devices
@ -107,8 +126,8 @@ class FakeDeviceDefinitionVersion(BaseModel):
self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/devices/{self.device_definition_id}/versions/{self.version}" self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/devices/{self.device_definition_id}/versions/{self.version}"
self.created_at_datetime = datetime.utcnow() self.created_at_datetime = datetime.utcnow()
def to_dict(self, include_detail=False): def to_dict(self, include_detail: bool = False) -> Dict[str, Any]:
obj = { obj: Dict[str, Any] = {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
self.created_at_datetime self.created_at_datetime
@ -124,7 +143,13 @@ class FakeDeviceDefinitionVersion(BaseModel):
class FakeResourceDefinition(BaseModel): class FakeResourceDefinition(BaseModel):
def __init__(self, account_id, region_name, name, initial_version): def __init__(
self,
account_id: str,
region_name: str,
name: str,
initial_version: Dict[str, Any],
):
self.region_name = region_name self.region_name = region_name
self.id = str(mock_random.uuid4()) self.id = str(mock_random.uuid4())
self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/resources/{self.id}" self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/resources/{self.id}"
@ -135,7 +160,7 @@ class FakeResourceDefinition(BaseModel):
self.name = name self.name = name
self.initial_version = initial_version self.initial_version = initial_version
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -152,7 +177,13 @@ class FakeResourceDefinition(BaseModel):
class FakeResourceDefinitionVersion(BaseModel): class FakeResourceDefinitionVersion(BaseModel):
def __init__(self, account_id, region_name, resource_definition_id, resources): def __init__(
self,
account_id: str,
region_name: str,
resource_definition_id: str,
resources: List[Dict[str, Any]],
):
self.region_name = region_name self.region_name = region_name
self.resource_definition_id = resource_definition_id self.resource_definition_id = resource_definition_id
self.resources = resources self.resources = resources
@ -160,7 +191,7 @@ class FakeResourceDefinitionVersion(BaseModel):
self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/resources/{self.resource_definition_id}/versions/{self.version}" self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/resources/{self.resource_definition_id}/versions/{self.version}"
self.created_at_datetime = datetime.utcnow() self.created_at_datetime = datetime.utcnow()
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -173,7 +204,13 @@ class FakeResourceDefinitionVersion(BaseModel):
class FakeFunctionDefinition(BaseModel): class FakeFunctionDefinition(BaseModel):
def __init__(self, account_id, region_name, name, initial_version): def __init__(
self,
account_id: str,
region_name: str,
name: str,
initial_version: Dict[str, Any],
):
self.region_name = region_name self.region_name = region_name
self.id = str(mock_random.uuid4()) self.id = str(mock_random.uuid4())
self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/functions/{self.id}" self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/functions/{self.id}"
@ -184,7 +221,7 @@ class FakeFunctionDefinition(BaseModel):
self.name = name self.name = name
self.initial_version = initial_version self.initial_version = initial_version
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
res = { res = {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -204,7 +241,12 @@ class FakeFunctionDefinition(BaseModel):
class FakeFunctionDefinitionVersion(BaseModel): class FakeFunctionDefinitionVersion(BaseModel):
def __init__( def __init__(
self, account_id, region_name, function_definition_id, functions, default_config self,
account_id: str,
region_name: str,
function_definition_id: str,
functions: List[Dict[str, Any]],
default_config: Dict[str, Any],
): ):
self.region_name = region_name self.region_name = region_name
self.function_definition_id = function_definition_id self.function_definition_id = function_definition_id
@ -214,7 +256,7 @@ class FakeFunctionDefinitionVersion(BaseModel):
self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/functions/{self.function_definition_id}/versions/{self.version}" self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/functions/{self.function_definition_id}/versions/{self.version}"
self.created_at_datetime = datetime.utcnow() self.created_at_datetime = datetime.utcnow()
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -227,7 +269,13 @@ class FakeFunctionDefinitionVersion(BaseModel):
class FakeSubscriptionDefinition(BaseModel): class FakeSubscriptionDefinition(BaseModel):
def __init__(self, account_id, region_name, name, initial_version): def __init__(
self,
account_id: str,
region_name: str,
name: str,
initial_version: Dict[str, Any],
):
self.region_name = region_name self.region_name = region_name
self.id = str(mock_random.uuid4()) self.id = str(mock_random.uuid4())
self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/subscriptions/{self.id}" self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/subscriptions/{self.id}"
@ -238,7 +286,7 @@ class FakeSubscriptionDefinition(BaseModel):
self.name = name self.name = name
self.initial_version = initial_version self.initial_version = initial_version
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -256,7 +304,11 @@ class FakeSubscriptionDefinition(BaseModel):
class FakeSubscriptionDefinitionVersion(BaseModel): class FakeSubscriptionDefinitionVersion(BaseModel):
def __init__( def __init__(
self, account_id, region_name, subscription_definition_id, subscriptions self,
account_id: str,
region_name: str,
subscription_definition_id: str,
subscriptions: List[Dict[str, Any]],
): ):
self.region_name = region_name self.region_name = region_name
self.subscription_definition_id = subscription_definition_id self.subscription_definition_id = subscription_definition_id
@ -265,7 +317,7 @@ class FakeSubscriptionDefinitionVersion(BaseModel):
self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/subscriptions/{self.subscription_definition_id}/versions/{self.version}" self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/subscriptions/{self.subscription_definition_id}/versions/{self.version}"
self.created_at_datetime = datetime.utcnow() self.created_at_datetime = datetime.utcnow()
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -278,7 +330,7 @@ class FakeSubscriptionDefinitionVersion(BaseModel):
class FakeGroup(BaseModel): class FakeGroup(BaseModel):
def __init__(self, account_id, region_name, name): def __init__(self, account_id: str, region_name: str, name: str):
self.region_name = region_name self.region_name = region_name
self.group_id = str(mock_random.uuid4()) self.group_id = str(mock_random.uuid4())
self.name = name self.name = name
@ -288,7 +340,7 @@ class FakeGroup(BaseModel):
self.latest_version = "" self.latest_version = ""
self.latest_version_arn = "" self.latest_version_arn = ""
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
obj = { obj = {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
@ -308,14 +360,14 @@ class FakeGroup(BaseModel):
class FakeGroupVersion(BaseModel): class FakeGroupVersion(BaseModel):
def __init__( def __init__(
self, self,
account_id, account_id: str,
region_name, region_name: str,
group_id, group_id: str,
core_definition_version_arn, core_definition_version_arn: Optional[str],
device_definition_version_arn, device_definition_version_arn: Optional[str],
function_definition_version_arn, function_definition_version_arn: Optional[str],
resource_definition_version_arn, resource_definition_version_arn: Optional[str],
subscription_definition_version_arn, subscription_definition_version_arn: Optional[str],
): ):
self.region_name = region_name self.region_name = region_name
self.group_id = group_id self.group_id = group_id
@ -328,7 +380,7 @@ class FakeGroupVersion(BaseModel):
self.resource_definition_version_arn = resource_definition_version_arn self.resource_definition_version_arn = resource_definition_version_arn
self.subscription_definition_version_arn = subscription_definition_version_arn self.subscription_definition_version_arn = subscription_definition_version_arn
def to_dict(self, include_detail=False): def to_dict(self, include_detail: bool = False) -> Dict[str, Any]:
definition = {} definition = {}
if self.core_definition_version_arn: if self.core_definition_version_arn:
@ -354,7 +406,7 @@ class FakeGroupVersion(BaseModel):
"SubscriptionDefinitionVersionArn" "SubscriptionDefinitionVersionArn"
] = self.subscription_definition_version_arn ] = self.subscription_definition_version_arn
obj = { obj: Dict[str, Any] = {
"Arn": self.arn, "Arn": self.arn,
"CreationTimestamp": iso_8601_datetime_with_milliseconds( "CreationTimestamp": iso_8601_datetime_with_milliseconds(
self.created_at_datetime self.created_at_datetime
@ -370,7 +422,14 @@ class FakeGroupVersion(BaseModel):
class FakeDeployment(BaseModel): class FakeDeployment(BaseModel):
def __init__(self, account_id, region_name, group_id, group_arn, deployment_type): def __init__(
self,
account_id: str,
region_name: str,
group_id: str,
group_arn: str,
deployment_type: str,
):
self.region_name = region_name self.region_name = region_name
self.id = str(mock_random.uuid4()) self.id = str(mock_random.uuid4())
self.group_id = group_id self.group_id = group_id
@ -381,7 +440,7 @@ class FakeDeployment(BaseModel):
self.deployment_type = deployment_type self.deployment_type = deployment_type
self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:/greengrass/groups/{self.group_id}/deployments/{self.id}" self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:/greengrass/groups/{self.group_id}/deployments/{self.id}"
def to_dict(self, include_detail=False): def to_dict(self, include_detail: bool = False) -> Dict[str, Any]:
obj = {"DeploymentId": self.id, "DeploymentArn": self.arn} obj = {"DeploymentId": self.id, "DeploymentArn": self.arn}
if include_detail: if include_detail:
@ -395,11 +454,11 @@ class FakeDeployment(BaseModel):
class FakeAssociatedRole(BaseModel): class FakeAssociatedRole(BaseModel):
def __init__(self, role_arn): def __init__(self, role_arn: str):
self.role_arn = role_arn self.role_arn = role_arn
self.associated_at = datetime.utcnow() self.associated_at = datetime.utcnow()
def to_dict(self, include_detail=False): def to_dict(self, include_detail: bool = False) -> Dict[str, Any]:
obj = {"AssociatedAt": iso_8601_datetime_with_milliseconds(self.associated_at)} obj = {"AssociatedAt": iso_8601_datetime_with_milliseconds(self.associated_at)}
if include_detail: if include_detail:
@ -409,12 +468,17 @@ class FakeAssociatedRole(BaseModel):
class FakeDeploymentStatus(BaseModel): class FakeDeploymentStatus(BaseModel):
def __init__(self, deployment_type, updated_at, deployment_status="InProgress"): def __init__(
self,
deployment_type: str,
updated_at: datetime,
deployment_status: str = "InProgress",
):
self.deployment_type = deployment_type self.deployment_type = deployment_type
self.update_at_datetime = updated_at self.update_at_datetime = updated_at
self.deployment_status = deployment_status self.deployment_status = deployment_status
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
"DeploymentStatus": self.deployment_status, "DeploymentStatus": self.deployment_status,
"DeploymentType": self.deployment_type, "DeploymentType": self.deployment_type,
@ -423,24 +487,38 @@ class FakeDeploymentStatus(BaseModel):
class GreengrassBackend(BaseBackend): class GreengrassBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.groups = OrderedDict() self.groups: Dict[str, FakeGroup] = OrderedDict()
self.group_role_associations = OrderedDict() self.group_role_associations: Dict[str, FakeAssociatedRole] = OrderedDict()
self.group_versions = OrderedDict() self.group_versions: Dict[str, Dict[str, FakeGroupVersion]] = OrderedDict()
self.core_definitions = OrderedDict() self.core_definitions: Dict[str, FakeCoreDefinition] = OrderedDict()
self.core_definition_versions = OrderedDict() self.core_definition_versions: Dict[
self.device_definitions = OrderedDict() str, Dict[str, FakeCoreDefinitionVersion]
self.device_definition_versions = OrderedDict() ] = OrderedDict()
self.function_definitions = OrderedDict() self.device_definitions: Dict[str, FakeDeviceDefinition] = OrderedDict()
self.function_definition_versions = OrderedDict() self.device_definition_versions: Dict[
self.resource_definitions = OrderedDict() str, Dict[str, FakeDeviceDefinitionVersion]
self.resource_definition_versions = OrderedDict() ] = OrderedDict()
self.subscription_definitions = OrderedDict() self.function_definitions: Dict[str, FakeFunctionDefinition] = OrderedDict()
self.subscription_definition_versions = OrderedDict() self.function_definition_versions: Dict[
self.deployments = OrderedDict() str, Dict[str, FakeFunctionDefinitionVersion]
] = OrderedDict()
self.resource_definitions: Dict[str, FakeResourceDefinition] = OrderedDict()
self.resource_definition_versions: Dict[
str, Dict[str, FakeResourceDefinitionVersion]
] = OrderedDict()
self.subscription_definitions: Dict[
str, FakeSubscriptionDefinition
] = OrderedDict()
self.subscription_definition_versions: Dict[
str, Dict[str, FakeSubscriptionDefinitionVersion]
] = OrderedDict()
self.deployments: Dict[str, FakeDeployment] = OrderedDict()
def create_core_definition(self, name, initial_version): def create_core_definition(
self, name: str, initial_version: Dict[str, Any]
) -> FakeCoreDefinition:
core_definition = FakeCoreDefinition(self.account_id, self.region_name, name) core_definition = FakeCoreDefinition(self.account_id, self.region_name, name)
self.core_definitions[core_definition.id] = core_definition self.core_definitions[core_definition.id] = core_definition
@ -449,22 +527,22 @@ class GreengrassBackend(BaseBackend):
) )
return core_definition return core_definition
def list_core_definitions(self): def list_core_definitions(self) -> Iterable[FakeCoreDefinition]:
return self.core_definitions.values() return self.core_definitions.values()
def get_core_definition(self, core_definition_id): def get_core_definition(self, core_definition_id: str) -> FakeCoreDefinition:
if core_definition_id not in self.core_definitions: if core_definition_id not in self.core_definitions:
raise IdNotFoundException("That Core List Definition does not exist") raise IdNotFoundException("That Core List Definition does not exist")
return self.core_definitions[core_definition_id] return self.core_definitions[core_definition_id]
def delete_core_definition(self, core_definition_id): def delete_core_definition(self, core_definition_id: str) -> None:
if core_definition_id not in self.core_definitions: if core_definition_id not in self.core_definitions:
raise IdNotFoundException("That cores definition does not exist.") raise IdNotFoundException("That cores definition does not exist.")
del self.core_definitions[core_definition_id] del self.core_definitions[core_definition_id]
del self.core_definition_versions[core_definition_id] del self.core_definition_versions[core_definition_id]
def update_core_definition(self, core_definition_id, name): def update_core_definition(self, core_definition_id: str, name: str) -> None:
if name == "": if name == "":
raise InvalidContainerDefinitionException( raise InvalidContainerDefinitionException(
@ -474,7 +552,9 @@ class GreengrassBackend(BaseBackend):
raise IdNotFoundException("That cores definition does not exist.") raise IdNotFoundException("That cores definition does not exist.")
self.core_definitions[core_definition_id].name = name self.core_definitions[core_definition_id].name = name
def create_core_definition_version(self, core_definition_id, cores): def create_core_definition_version(
self, core_definition_id: str, cores: List[Dict[str, Any]]
) -> FakeCoreDefinitionVersion:
definition = {"Cores": cores} definition = {"Cores": cores}
core_def_ver = FakeCoreDefinitionVersion( core_def_ver = FakeCoreDefinitionVersion(
@ -491,15 +571,17 @@ class GreengrassBackend(BaseBackend):
return core_def_ver return core_def_ver
def list_core_definition_versions(self, core_definition_id): def list_core_definition_versions(
self, core_definition_id: str
) -> Iterable[FakeCoreDefinitionVersion]:
if core_definition_id not in self.core_definitions: if core_definition_id not in self.core_definitions:
raise IdNotFoundException("That cores definition does not exist.") raise IdNotFoundException("That cores definition does not exist.")
return self.core_definition_versions[core_definition_id].values() return self.core_definition_versions[core_definition_id].values()
def get_core_definition_version( def get_core_definition_version(
self, core_definition_id, core_definition_version_id self, core_definition_id: str, core_definition_version_id: str
): ) -> FakeCoreDefinitionVersion:
if core_definition_id not in self.core_definitions: if core_definition_id not in self.core_definitions:
raise IdNotFoundException("That cores definition does not exist.") raise IdNotFoundException("That cores definition does not exist.")
@ -516,7 +598,9 @@ class GreengrassBackend(BaseBackend):
core_definition_version_id core_definition_version_id
] ]
def create_device_definition(self, name, initial_version): def create_device_definition(
self, name: str, initial_version: Dict[str, Any]
) -> FakeDeviceDefinition:
device_def = FakeDeviceDefinition( device_def = FakeDeviceDefinition(
self.account_id, self.region_name, name, initial_version self.account_id, self.region_name, name, initial_version
) )
@ -527,10 +611,12 @@ class GreengrassBackend(BaseBackend):
return device_def return device_def
def list_device_definitions(self): def list_device_definitions(self) -> Iterable[FakeDeviceDefinition]:
return self.device_definitions.values() return self.device_definitions.values()
def create_device_definition_version(self, device_definition_id, devices): def create_device_definition_version(
self, device_definition_id: str, devices: List[Dict[str, Any]]
) -> FakeDeviceDefinitionVersion:
if device_definition_id not in self.device_definitions: if device_definition_id not in self.device_definitions:
raise IdNotFoundException("That devices definition does not exist.") raise IdNotFoundException("That devices definition does not exist.")
@ -552,25 +638,27 @@ class GreengrassBackend(BaseBackend):
return device_ver return device_ver
def list_device_definition_versions(self, device_definition_id): def list_device_definition_versions(
self, device_definition_id: str
) -> Iterable[FakeDeviceDefinitionVersion]:
if device_definition_id not in self.device_definitions: if device_definition_id not in self.device_definitions:
raise IdNotFoundException("That devices definition does not exist.") raise IdNotFoundException("That devices definition does not exist.")
return self.device_definition_versions[device_definition_id].values() return self.device_definition_versions[device_definition_id].values()
def get_device_definition(self, device_definition_id): def get_device_definition(self, device_definition_id: str) -> FakeDeviceDefinition:
if device_definition_id not in self.device_definitions: if device_definition_id not in self.device_definitions:
raise IdNotFoundException("That Device List Definition does not exist.") raise IdNotFoundException("That Device List Definition does not exist.")
return self.device_definitions[device_definition_id] return self.device_definitions[device_definition_id]
def delete_device_definition(self, device_definition_id): def delete_device_definition(self, device_definition_id: str) -> None:
if device_definition_id not in self.device_definitions: if device_definition_id not in self.device_definitions:
raise IdNotFoundException("That devices definition does not exist.") raise IdNotFoundException("That devices definition does not exist.")
del self.device_definitions[device_definition_id] del self.device_definitions[device_definition_id]
del self.device_definition_versions[device_definition_id] del self.device_definition_versions[device_definition_id]
def update_device_definition(self, device_definition_id, name): def update_device_definition(self, device_definition_id: str, name: str) -> None:
if name == "": if name == "":
raise InvalidContainerDefinitionException( raise InvalidContainerDefinitionException(
@ -581,8 +669,8 @@ class GreengrassBackend(BaseBackend):
self.device_definitions[device_definition_id].name = name self.device_definitions[device_definition_id].name = name
def get_device_definition_version( def get_device_definition_version(
self, device_definition_id, device_definition_version_id self, device_definition_id: str, device_definition_version_id: str
): ) -> FakeDeviceDefinitionVersion:
if device_definition_id not in self.device_definitions: if device_definition_id not in self.device_definitions:
raise IdNotFoundException("That devices definition does not exist.") raise IdNotFoundException("That devices definition does not exist.")
@ -599,7 +687,9 @@ class GreengrassBackend(BaseBackend):
device_definition_version_id device_definition_version_id
] ]
def create_resource_definition(self, name, initial_version): def create_resource_definition(
self, name: str, initial_version: Dict[str, Any]
) -> FakeResourceDefinition:
resources = initial_version.get("Resources", []) resources = initial_version.get("Resources", [])
GreengrassBackend._validate_resources(resources) GreengrassBackend._validate_resources(resources)
@ -614,22 +704,26 @@ class GreengrassBackend(BaseBackend):
return resource_def return resource_def
def list_resource_definitions(self): def list_resource_definitions(self) -> Iterable[FakeResourceDefinition]:
return self.resource_definitions return self.resource_definitions.values()
def get_resource_definition(self, resource_definition_id): def get_resource_definition(
self, resource_definition_id: str
) -> FakeResourceDefinition:
if resource_definition_id not in self.resource_definitions: if resource_definition_id not in self.resource_definitions:
raise IdNotFoundException("That Resource List Definition does not exist.") raise IdNotFoundException("That Resource List Definition does not exist.")
return self.resource_definitions[resource_definition_id] return self.resource_definitions[resource_definition_id]
def delete_resource_definition(self, resource_definition_id): def delete_resource_definition(self, resource_definition_id: str) -> None:
if resource_definition_id not in self.resource_definitions: if resource_definition_id not in self.resource_definitions:
raise IdNotFoundException("That resources definition does not exist.") raise IdNotFoundException("That resources definition does not exist.")
del self.resource_definitions[resource_definition_id] del self.resource_definitions[resource_definition_id]
del self.resource_definition_versions[resource_definition_id] del self.resource_definition_versions[resource_definition_id]
def update_resource_definition(self, resource_definition_id, name): def update_resource_definition(
self, resource_definition_id: str, name: str
) -> None:
if name == "": if name == "":
raise InvalidInputException("Invalid resource name.") raise InvalidInputException("Invalid resource name.")
@ -637,7 +731,9 @@ class GreengrassBackend(BaseBackend):
raise IdNotFoundException("That resources definition does not exist.") raise IdNotFoundException("That resources definition does not exist.")
self.resource_definitions[resource_definition_id].name = name self.resource_definitions[resource_definition_id].name = name
def create_resource_definition_version(self, resource_definition_id, resources): def create_resource_definition_version(
self, resource_definition_id: str, resources: List[Dict[str, Any]]
) -> FakeResourceDefinitionVersion:
if resource_definition_id not in self.resource_definitions: if resource_definition_id not in self.resource_definitions:
raise IdNotFoundException("That resource definition does not exist.") raise IdNotFoundException("That resource definition does not exist.")
@ -666,7 +762,9 @@ class GreengrassBackend(BaseBackend):
return resource_def_ver return resource_def_ver
def list_resource_definition_versions(self, resource_definition_id): def list_resource_definition_versions(
self, resource_definition_id: str
) -> Iterable[FakeResourceDefinitionVersion]:
if resource_definition_id not in self.resource_definition_versions: if resource_definition_id not in self.resource_definition_versions:
raise IdNotFoundException("That resources definition does not exist.") raise IdNotFoundException("That resources definition does not exist.")
@ -674,8 +772,8 @@ class GreengrassBackend(BaseBackend):
return self.resource_definition_versions[resource_definition_id].values() return self.resource_definition_versions[resource_definition_id].values()
def get_resource_definition_version( def get_resource_definition_version(
self, resource_definition_id, resource_definition_version_id self, resource_definition_id: str, resource_definition_version_id: str
): ) -> FakeResourceDefinitionVersion:
if resource_definition_id not in self.resource_definition_versions: if resource_definition_id not in self.resource_definition_versions:
raise IdNotFoundException("That resources definition does not exist.") raise IdNotFoundException("That resources definition does not exist.")
@ -693,7 +791,7 @@ class GreengrassBackend(BaseBackend):
] ]
@staticmethod @staticmethod
def _validate_resources(resources): def _validate_resources(resources: List[Dict[str, Any]]) -> None: # type: ignore[misc]
for resource in resources: for resource in resources:
volume_source_path = ( volume_source_path = (
resource.get("ResourceDataContainer", {}) resource.get("ResourceDataContainer", {})
@ -719,7 +817,9 @@ class GreengrassBackend(BaseBackend):
f", but got: {device_source_path}])", f", but got: {device_source_path}])",
) )
def create_function_definition(self, name, initial_version): def create_function_definition(
self, name: str, initial_version: Dict[str, Any]
) -> FakeFunctionDefinition:
func_def = FakeFunctionDefinition( func_def = FakeFunctionDefinition(
self.account_id, self.region_name, name, initial_version self.account_id, self.region_name, name, initial_version
) )
@ -731,22 +831,26 @@ class GreengrassBackend(BaseBackend):
return func_def return func_def
def list_function_definitions(self): def list_function_definitions(self) -> List[FakeFunctionDefinition]:
return self.function_definitions.values() return list(self.function_definitions.values())
def get_function_definition(self, function_definition_id): def get_function_definition(
self, function_definition_id: str
) -> FakeFunctionDefinition:
if function_definition_id not in self.function_definitions: if function_definition_id not in self.function_definitions:
raise IdNotFoundException("That Lambda List Definition does not exist.") raise IdNotFoundException("That Lambda List Definition does not exist.")
return self.function_definitions[function_definition_id] return self.function_definitions[function_definition_id]
def delete_function_definition(self, function_definition_id): def delete_function_definition(self, function_definition_id: str) -> None:
if function_definition_id not in self.function_definitions: if function_definition_id not in self.function_definitions:
raise IdNotFoundException("That lambdas definition does not exist.") raise IdNotFoundException("That lambdas definition does not exist.")
del self.function_definitions[function_definition_id] del self.function_definitions[function_definition_id]
del self.function_definition_versions[function_definition_id] del self.function_definition_versions[function_definition_id]
def update_function_definition(self, function_definition_id, name): def update_function_definition(
self, function_definition_id: str, name: str
) -> None:
if name == "": if name == "":
raise InvalidContainerDefinitionException( raise InvalidContainerDefinitionException(
@ -757,8 +861,11 @@ class GreengrassBackend(BaseBackend):
self.function_definitions[function_definition_id].name = name self.function_definitions[function_definition_id].name = name
def create_function_definition_version( def create_function_definition_version(
self, function_definition_id, functions, default_config self,
): function_definition_id: str,
functions: List[Dict[str, Any]],
default_config: Dict[str, Any],
) -> FakeFunctionDefinitionVersion:
if function_definition_id not in self.function_definitions: if function_definition_id not in self.function_definitions:
raise IdNotFoundException("That lambdas does not exist.") raise IdNotFoundException("That lambdas does not exist.")
@ -784,14 +891,16 @@ class GreengrassBackend(BaseBackend):
return func_ver return func_ver
def list_function_definition_versions(self, function_definition_id): def list_function_definition_versions(
self, function_definition_id: str
) -> Dict[str, FakeFunctionDefinitionVersion]:
if function_definition_id not in self.function_definition_versions: if function_definition_id not in self.function_definition_versions:
raise IdNotFoundException("That lambdas definition does not exist.") raise IdNotFoundException("That lambdas definition does not exist.")
return self.function_definition_versions[function_definition_id] return self.function_definition_versions[function_definition_id]
def get_function_definition_version( def get_function_definition_version(
self, function_definition_id, function_definition_version_id self, function_definition_id: str, function_definition_version_id: str
): ) -> FakeFunctionDefinitionVersion:
if function_definition_id not in self.function_definition_versions: if function_definition_id not in self.function_definition_versions:
raise IdNotFoundException("That lambdas definition does not exist.") raise IdNotFoundException("That lambdas definition does not exist.")
@ -809,7 +918,7 @@ class GreengrassBackend(BaseBackend):
] ]
@staticmethod @staticmethod
def _is_valid_subscription_target_or_source(target_or_source): def _is_valid_subscription_target_or_source(target_or_source: str) -> bool:
if target_or_source in ["cloud", "GGShadowService"]: if target_or_source in ["cloud", "GGShadowService"]:
return True return True
@ -829,10 +938,10 @@ class GreengrassBackend(BaseBackend):
return False return False
@staticmethod @staticmethod
def _validate_subscription_target_or_source(subscriptions): def _validate_subscription_target_or_source(subscriptions: List[Dict[str, Any]]) -> None: # type: ignore[misc]
target_errors = [] target_errors: List[str] = []
source_errors = [] source_errors: List[str] = []
for subscription in subscriptions: for subscription in subscriptions:
subscription_id = subscription["Id"] subscription_id = subscription["Id"]
@ -863,7 +972,9 @@ class GreengrassBackend(BaseBackend):
f"The subscriptions definition is invalid or corrupted. (ErrorDetails: [{error_msg}])", f"The subscriptions definition is invalid or corrupted. (ErrorDetails: [{error_msg}])",
) )
def create_subscription_definition(self, name, initial_version): def create_subscription_definition(
self, name: str, initial_version: Dict[str, Any]
) -> FakeSubscriptionDefinition:
GreengrassBackend._validate_subscription_target_or_source( GreengrassBackend._validate_subscription_target_or_source(
initial_version["Subscriptions"] initial_version["Subscriptions"]
@ -883,10 +994,12 @@ class GreengrassBackend(BaseBackend):
sub_def.latest_version_arn = sub_def_ver.arn sub_def.latest_version_arn = sub_def_ver.arn
return sub_def return sub_def
def list_subscription_definitions(self): def list_subscription_definitions(self) -> List[FakeSubscriptionDefinition]:
return self.subscription_definitions.values() return list(self.subscription_definitions.values())
def get_subscription_definition(self, subscription_definition_id): def get_subscription_definition(
self, subscription_definition_id: str
) -> FakeSubscriptionDefinition:
if subscription_definition_id not in self.subscription_definitions: if subscription_definition_id not in self.subscription_definitions:
raise IdNotFoundException( raise IdNotFoundException(
@ -894,13 +1007,15 @@ class GreengrassBackend(BaseBackend):
) )
return self.subscription_definitions[subscription_definition_id] return self.subscription_definitions[subscription_definition_id]
def delete_subscription_definition(self, subscription_definition_id): def delete_subscription_definition(self, subscription_definition_id: str) -> None:
if subscription_definition_id not in self.subscription_definitions: if subscription_definition_id not in self.subscription_definitions:
raise IdNotFoundException("That subscriptions definition does not exist.") raise IdNotFoundException("That subscriptions definition does not exist.")
del self.subscription_definitions[subscription_definition_id] del self.subscription_definitions[subscription_definition_id]
del self.subscription_definition_versions[subscription_definition_id] del self.subscription_definition_versions[subscription_definition_id]
def update_subscription_definition(self, subscription_definition_id, name): def update_subscription_definition(
self, subscription_definition_id: str, name: str
) -> None:
if name == "": if name == "":
raise InvalidContainerDefinitionException( raise InvalidContainerDefinitionException(
@ -911,8 +1026,8 @@ class GreengrassBackend(BaseBackend):
self.subscription_definitions[subscription_definition_id].name = name self.subscription_definitions[subscription_definition_id].name = name
def create_subscription_definition_version( def create_subscription_definition_version(
self, subscription_definition_id, subscriptions self, subscription_definition_id: str, subscriptions: List[Dict[str, Any]]
): ) -> FakeSubscriptionDefinitionVersion:
GreengrassBackend._validate_subscription_target_or_source(subscriptions) GreengrassBackend._validate_subscription_target_or_source(subscriptions)
@ -931,14 +1046,16 @@ class GreengrassBackend(BaseBackend):
return sub_def_ver return sub_def_ver
def list_subscription_definition_versions(self, subscription_definition_id): def list_subscription_definition_versions(
self, subscription_definition_id: str
) -> Dict[str, FakeSubscriptionDefinitionVersion]:
if subscription_definition_id not in self.subscription_definition_versions: if subscription_definition_id not in self.subscription_definition_versions:
raise IdNotFoundException("That subscriptions definition does not exist.") raise IdNotFoundException("That subscriptions definition does not exist.")
return self.subscription_definition_versions[subscription_definition_id] return self.subscription_definition_versions[subscription_definition_id]
def get_subscription_definition_version( def get_subscription_definition_version(
self, subscription_definition_id, subscription_definition_version_id self, subscription_definition_id: str, subscription_definition_version_id: str
): ) -> FakeSubscriptionDefinitionVersion:
if subscription_definition_id not in self.subscription_definitions: if subscription_definition_id not in self.subscription_definitions:
raise IdNotFoundException("That subscriptions definition does not exist.") raise IdNotFoundException("That subscriptions definition does not exist.")
@ -955,7 +1072,7 @@ class GreengrassBackend(BaseBackend):
subscription_definition_version_id subscription_definition_version_id
] ]
def create_group(self, name, initial_version): def create_group(self, name: str, initial_version: Dict[str, Any]) -> FakeGroup:
group = FakeGroup(self.account_id, self.region_name, name) group = FakeGroup(self.account_id, self.region_name, name)
self.groups[group.group_id] = group self.groups[group.group_id] = group
@ -983,22 +1100,22 @@ class GreengrassBackend(BaseBackend):
return group return group
def list_groups(self): def list_groups(self) -> List[FakeGroup]:
return self.groups.values() return list(self.groups.values())
def get_group(self, group_id): def get_group(self, group_id: str) -> Optional[FakeGroup]:
if group_id not in self.groups: if group_id not in self.groups:
raise IdNotFoundException("That Group Definition does not exist.") raise IdNotFoundException("That Group Definition does not exist.")
return self.groups.get(group_id) return self.groups.get(group_id)
def delete_group(self, group_id): def delete_group(self, group_id: str) -> None:
if group_id not in self.groups: if group_id not in self.groups:
# I don't know why, the error message is different between get_group and delete_group # I don't know why, the error message is different between get_group and delete_group
raise IdNotFoundException("That group definition does not exist.") raise IdNotFoundException("That group definition does not exist.")
del self.groups[group_id] del self.groups[group_id]
del self.group_versions[group_id] del self.group_versions[group_id]
def update_group(self, group_id, name): def update_group(self, group_id: str, name: str) -> None:
if name == "": if name == "":
raise InvalidContainerDefinitionException( raise InvalidContainerDefinitionException(
@ -1010,13 +1127,13 @@ class GreengrassBackend(BaseBackend):
def create_group_version( def create_group_version(
self, self,
group_id, group_id: str,
core_definition_version_arn, core_definition_version_arn: Optional[str],
device_definition_version_arn, device_definition_version_arn: Optional[str],
function_definition_version_arn, function_definition_version_arn: Optional[str],
resource_definition_version_arn, resource_definition_version_arn: Optional[str],
subscription_definition_version_arn, subscription_definition_version_arn: Optional[str],
): ) -> FakeGroupVersion:
if group_id not in self.groups: if group_id not in self.groups:
raise IdNotFoundException("That group does not exist.") raise IdNotFoundException("That group does not exist.")
@ -1048,19 +1165,21 @@ class GreengrassBackend(BaseBackend):
def _validate_group_version_definitions( def _validate_group_version_definitions(
self, self,
core_definition_version_arn=None, core_definition_version_arn: Optional[str] = None,
device_definition_version_arn=None, device_definition_version_arn: Optional[str] = None,
function_definition_version_arn=None, function_definition_version_arn: Optional[str] = None,
resource_definition_version_arn=None, resource_definition_version_arn: Optional[str] = None,
subscription_definition_version_arn=None, subscription_definition_version_arn: Optional[str] = None,
): ) -> None:
def _is_valid_def_ver_arn(definition_version_arn, kind="cores"): def _is_valid_def_ver_arn(
definition_version_arn: Optional[str], kind: str = "cores"
) -> bool:
if definition_version_arn is None: if definition_version_arn is None:
return True return True
if kind == "cores": if kind == "cores":
versions = self.core_definition_versions versions: Any = self.core_definition_versions
elif kind == "devices": elif kind == "devices":
versions = self.device_definition_versions versions = self.device_definition_versions
elif kind == "functions": elif kind == "functions":
@ -1124,12 +1243,14 @@ class GreengrassBackend(BaseBackend):
f"The group is invalid or corrupted. (ErrorDetails: [{error_details}])", f"The group is invalid or corrupted. (ErrorDetails: [{error_details}])",
) )
def list_group_versions(self, group_id): def list_group_versions(self, group_id: str) -> List[FakeGroupVersion]:
if group_id not in self.group_versions: if group_id not in self.group_versions:
raise IdNotFoundException("That group definition does not exist.") raise IdNotFoundException("That group definition does not exist.")
return self.group_versions[group_id].values() return list(self.group_versions[group_id].values())
def get_group_version(self, group_id, group_version_id): def get_group_version(
self, group_id: str, group_version_id: str
) -> FakeGroupVersion:
if group_id not in self.group_versions: if group_id not in self.group_versions:
raise IdNotFoundException("That group definition does not exist.") raise IdNotFoundException("That group definition does not exist.")
@ -1142,8 +1263,12 @@ class GreengrassBackend(BaseBackend):
return self.group_versions[group_id][group_version_id] return self.group_versions[group_id][group_version_id]
def create_deployment( def create_deployment(
self, group_id, group_version_id, deployment_type, deployment_id=None self,
): group_id: str,
group_version_id: str,
deployment_type: str,
deployment_id: Optional[str] = None,
) -> FakeDeployment:
deployment_types = ( deployment_types = (
"NewDeployment", "NewDeployment",
@ -1199,7 +1324,7 @@ class GreengrassBackend(BaseBackend):
self.deployments[deployment.id] = deployment self.deployments[deployment.id] = deployment
return deployment return deployment
def list_deployments(self, group_id): def list_deployments(self, group_id: str) -> List[FakeDeployment]:
# ListDeployments API does not check specified group is exists # ListDeployments API does not check specified group is exists
return [ return [
@ -1208,7 +1333,9 @@ class GreengrassBackend(BaseBackend):
if deployment.group_id == group_id if deployment.group_id == group_id
] ]
def get_deployment_status(self, group_id, deployment_id): def get_deployment_status(
self, group_id: str, deployment_id: str
) -> FakeDeploymentStatus:
if deployment_id not in self.deployments: if deployment_id not in self.deployments:
raise InvalidInputException(f"Deployment '{deployment_id}' does not exist.") raise InvalidInputException(f"Deployment '{deployment_id}' does not exist.")
@ -1224,7 +1351,7 @@ class GreengrassBackend(BaseBackend):
deployment.deployment_status, deployment.deployment_status,
) )
def reset_deployments(self, group_id, force=False): def reset_deployments(self, group_id: str, force: bool = False) -> FakeDeployment:
if group_id not in self.groups: if group_id not in self.groups:
raise ResourceNotFoundException("That Group Definition does not exist.") raise ResourceNotFoundException("That Group Definition does not exist.")
@ -1248,7 +1375,9 @@ class GreengrassBackend(BaseBackend):
self.deployments[deployment.id] = deployment self.deployments[deployment.id] = deployment
return deployment return deployment
def associate_role_to_group(self, group_id, role_arn): def associate_role_to_group(
self, group_id: str, role_arn: str
) -> FakeAssociatedRole:
# I don't know why, AssociateRoleToGroup does not check specified group is exists # I don't know why, AssociateRoleToGroup does not check specified group is exists
# So, this API allows any group id such as "a" # So, this API allows any group id such as "a"
@ -1257,7 +1386,7 @@ class GreengrassBackend(BaseBackend):
self.group_role_associations[group_id] = associated_role self.group_role_associations[group_id] = associated_role
return associated_role return associated_role
def get_associated_role(self, group_id): def get_associated_role(self, group_id: str) -> FakeAssociatedRole:
if group_id not in self.group_role_associations: if group_id not in self.group_role_associations:
raise GreengrassClientError( raise GreengrassClientError(
@ -1266,7 +1395,7 @@ class GreengrassBackend(BaseBackend):
return self.group_role_associations[group_id] return self.group_role_associations[group_id]
def disassociate_role_from_group(self, group_id): def disassociate_role_from_group(self, group_id: str) -> None:
if group_id not in self.group_role_associations: if group_id not in self.group_role_associations:
return return
del self.group_role_associations[group_id] del self.group_role_associations[group_id]

View File

@ -1,20 +1,22 @@
from datetime import datetime from datetime import datetime
from typing import Any
import json import json
from moto.core.common_types import TYPE_RESPONSE
from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import greengrass_backends from .models import greengrass_backends, GreengrassBackend
class GreengrassResponse(BaseResponse): class GreengrassResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="greengrass") super().__init__(service_name="greengrass")
@property @property
def greengrass_backend(self): def greengrass_backend(self) -> GreengrassBackend:
return greengrass_backends[self.current_account][self.region] return greengrass_backends[self.current_account][self.region]
def core_definitions(self, request, full_url, headers): def core_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
@ -23,7 +25,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "POST": if self.method == "POST":
return self.create_core_definition() return self.create_core_definition()
def list_core_definitions(self): def list_core_definitions(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_core_definitions() res = self.greengrass_backend.list_core_definitions()
return ( return (
200, 200,
@ -33,7 +35,7 @@ class GreengrassResponse(BaseResponse):
), ),
) )
def create_core_definition(self): def create_core_definition(self) -> TYPE_RESPONSE:
name = self._get_param("Name") name = self._get_param("Name")
initial_version = self._get_param("InitialVersion") initial_version = self._get_param("InitialVersion")
res = self.greengrass_backend.create_core_definition( res = self.greengrass_backend.create_core_definition(
@ -41,7 +43,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def core_definition(self, request, full_url, headers): def core_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
@ -53,21 +55,21 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT": if self.method == "PUT":
return self.update_core_definition() return self.update_core_definition()
def get_core_definition(self): def get_core_definition(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-1] core_definition_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_core_definition( res = self.greengrass_backend.get_core_definition(
core_definition_id=core_definition_id core_definition_id=core_definition_id
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_core_definition(self): def delete_core_definition(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-1] core_definition_id = self.path.split("/")[-1]
self.greengrass_backend.delete_core_definition( self.greengrass_backend.delete_core_definition(
core_definition_id=core_definition_id core_definition_id=core_definition_id
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def update_core_definition(self): def update_core_definition(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-1] core_definition_id = self.path.split("/")[-1]
name = self._get_param("Name") name = self._get_param("Name")
self.greengrass_backend.update_core_definition( self.greengrass_backend.update_core_definition(
@ -75,7 +77,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def core_definition_versions(self, request, full_url, headers): def core_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
@ -84,7 +86,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "POST": if self.method == "POST":
return self.create_core_definition_version() return self.create_core_definition_version()
def create_core_definition_version(self): def create_core_definition_version(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-2] core_definition_id = self.path.split("/")[-2]
cores = self._get_param("Cores") cores = self._get_param("Cores")
@ -93,7 +95,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_core_definition_versions(self): def list_core_definition_versions(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-2] core_definition_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_core_definition_versions(core_definition_id) res = self.greengrass_backend.list_core_definition_versions(core_definition_id)
return ( return (
@ -102,13 +104,13 @@ class GreengrassResponse(BaseResponse):
json.dumps({"Versions": [core_def_ver.to_dict() for core_def_ver in res]}), json.dumps({"Versions": [core_def_ver.to_dict() for core_def_ver in res]}),
) )
def core_definition_version(self, request, full_url, headers): def core_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
return self.get_core_definition_version() return self.get_core_definition_version()
def get_core_definition_version(self): def get_core_definition_version(self) -> TYPE_RESPONSE:
core_definition_id = self.path.split("/")[-3] core_definition_id = self.path.split("/")[-3]
core_definition_version_id = self.path.split("/")[-1] core_definition_version_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_core_definition_version( res = self.greengrass_backend.get_core_definition_version(
@ -117,7 +119,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True)) return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True))
def device_definitions(self, request, full_url, headers): def device_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -126,7 +128,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_device_definition() return self.list_device_definition()
def create_device_definition(self): def create_device_definition(self) -> TYPE_RESPONSE:
name = self._get_param("Name") name = self._get_param("Name")
initial_version = self._get_param("InitialVersion") initial_version = self._get_param("InitialVersion")
@ -135,7 +137,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_device_definition(self): def list_device_definition(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_device_definitions() res = self.greengrass_backend.list_device_definitions()
return ( return (
200, 200,
@ -149,7 +151,7 @@ class GreengrassResponse(BaseResponse):
), ),
) )
def device_definition_versions(self, request, full_url, headers): def device_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -158,7 +160,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_device_definition_versions() return self.list_device_definition_versions()
def create_device_definition_version(self): def create_device_definition_version(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-2] device_definition_id = self.path.split("/")[-2]
devices = self._get_param("Devices") devices = self._get_param("Devices")
@ -168,7 +170,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_device_definition_versions(self): def list_device_definition_versions(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-2] device_definition_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_device_definition_versions( res = self.greengrass_backend.list_device_definition_versions(
@ -182,7 +184,7 @@ class GreengrassResponse(BaseResponse):
), ),
) )
def device_definition(self, request, full_url, headers): def device_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
@ -194,14 +196,14 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT": if self.method == "PUT":
return self.update_device_definition() return self.update_device_definition()
def get_device_definition(self): def get_device_definition(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-1] device_definition_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_device_definition( res = self.greengrass_backend.get_device_definition(
device_definition_id=device_definition_id device_definition_id=device_definition_id
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_device_definition(self): def delete_device_definition(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-1] device_definition_id = self.path.split("/")[-1]
self.greengrass_backend.delete_device_definition( self.greengrass_backend.delete_device_definition(
@ -209,7 +211,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def update_device_definition(self): def update_device_definition(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-1] device_definition_id = self.path.split("/")[-1]
name = self._get_param("Name") name = self._get_param("Name")
@ -218,13 +220,13 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def device_definition_version(self, request, full_url, headers): def device_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
return self.get_device_definition_version() return self.get_device_definition_version()
def get_device_definition_version(self): def get_device_definition_version(self) -> TYPE_RESPONSE:
device_definition_id = self.path.split("/")[-3] device_definition_id = self.path.split("/")[-3]
device_definition_version_id = self.path.split("/")[-1] device_definition_version_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_device_definition_version( res = self.greengrass_backend.get_device_definition_version(
@ -233,7 +235,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True)) return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True))
def resource_definitions(self, request, full_url, headers): def resource_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -242,7 +244,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_resource_definitions() return self.list_resource_definitions()
def create_resource_definition(self): def create_resource_definition(self) -> TYPE_RESPONSE:
initial_version = self._get_param("InitialVersion") initial_version = self._get_param("InitialVersion")
name = self._get_param("Name") name = self._get_param("Name")
@ -251,16 +253,16 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_resource_definitions(self): def list_resource_definitions(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_resource_definitions() res = self.greengrass_backend.list_resource_definitions()
return ( return (
200, 200,
{"status": 200}, {"status": 200},
json.dumps({"Definitions": [i.to_dict() for i in res.values()]}), json.dumps({"Definitions": [i.to_dict() for i in res]}),
) )
def resource_definition(self, request, full_url, headers): def resource_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
@ -272,14 +274,14 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT": if self.method == "PUT":
return self.update_resource_definition() return self.update_resource_definition()
def get_resource_definition(self): def get_resource_definition(self) -> TYPE_RESPONSE:
resource_definition_id = self.path.split("/")[-1] resource_definition_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_resource_definition( res = self.greengrass_backend.get_resource_definition(
resource_definition_id=resource_definition_id resource_definition_id=resource_definition_id
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_resource_definition(self): def delete_resource_definition(self) -> TYPE_RESPONSE:
resource_definition_id = self.path.split("/")[-1] resource_definition_id = self.path.split("/")[-1]
self.greengrass_backend.delete_resource_definition( self.greengrass_backend.delete_resource_definition(
@ -287,7 +289,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def update_resource_definition(self): def update_resource_definition(self) -> TYPE_RESPONSE:
resource_definition_id = self.path.split("/")[-1] resource_definition_id = self.path.split("/")[-1]
name = self._get_param("Name") name = self._get_param("Name")
@ -296,7 +298,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def resource_definition_versions(self, request, full_url, headers): def resource_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -305,7 +307,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_resource_definition_versions() return self.list_resource_definition_versions()
def create_resource_definition_version(self): def create_resource_definition_version(self) -> TYPE_RESPONSE:
resource_definition_id = self.path.split("/")[-2] resource_definition_id = self.path.split("/")[-2]
resources = self._get_param("Resources") resources = self._get_param("Resources")
@ -315,7 +317,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_resource_definition_versions(self): def list_resource_definition_versions(self) -> TYPE_RESPONSE:
resource_device_definition_id = self.path.split("/")[-2] resource_device_definition_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_resource_definition_versions( res = self.greengrass_backend.list_resource_definition_versions(
@ -330,13 +332,13 @@ class GreengrassResponse(BaseResponse):
), ),
) )
def resource_definition_version(self, request, full_url, headers): def resource_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
return self.get_resource_definition_version() return self.get_resource_definition_version()
def get_resource_definition_version(self): def get_resource_definition_version(self) -> TYPE_RESPONSE:
resource_definition_id = self.path.split("/")[-3] resource_definition_id = self.path.split("/")[-3]
resource_definition_version_id = self.path.split("/")[-1] resource_definition_version_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_resource_definition_version( res = self.greengrass_backend.get_resource_definition_version(
@ -345,7 +347,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def function_definitions(self, request, full_url, headers): def function_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -354,7 +356,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_function_definitions() return self.list_function_definitions()
def create_function_definition(self): def create_function_definition(self) -> TYPE_RESPONSE:
initial_version = self._get_param("InitialVersion") initial_version = self._get_param("InitialVersion")
name = self._get_param("Name") name = self._get_param("Name")
@ -363,7 +365,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_function_definitions(self): def list_function_definitions(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_function_definitions() res = self.greengrass_backend.list_function_definitions()
return ( return (
200, 200,
@ -373,7 +375,7 @@ class GreengrassResponse(BaseResponse):
), ),
) )
def function_definition(self, request, full_url, headers): def function_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
@ -385,21 +387,21 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT": if self.method == "PUT":
return self.update_function_definition() return self.update_function_definition()
def get_function_definition(self): def get_function_definition(self) -> TYPE_RESPONSE:
function_definition_id = self.path.split("/")[-1] function_definition_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_function_definition( res = self.greengrass_backend.get_function_definition(
function_definition_id=function_definition_id, function_definition_id=function_definition_id,
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_function_definition(self): def delete_function_definition(self) -> TYPE_RESPONSE:
function_definition_id = self.path.split("/")[-1] function_definition_id = self.path.split("/")[-1]
self.greengrass_backend.delete_function_definition( self.greengrass_backend.delete_function_definition(
function_definition_id=function_definition_id, function_definition_id=function_definition_id,
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def update_function_definition(self): def update_function_definition(self) -> TYPE_RESPONSE:
function_definition_id = self.path.split("/")[-1] function_definition_id = self.path.split("/")[-1]
name = self._get_param("Name") name = self._get_param("Name")
self.greengrass_backend.update_function_definition( self.greengrass_backend.update_function_definition(
@ -407,7 +409,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def function_definition_versions(self, request, full_url, headers): def function_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -416,7 +418,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_function_definition_versions() return self.list_function_definition_versions()
def create_function_definition_version(self): def create_function_definition_version(self) -> TYPE_RESPONSE:
default_config = self._get_param("DefaultConfig") default_config = self._get_param("DefaultConfig")
function_definition_id = self.path.split("/")[-2] function_definition_id = self.path.split("/")[-2]
@ -429,7 +431,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_function_definition_versions(self): def list_function_definition_versions(self) -> TYPE_RESPONSE:
function_definition_id = self.path.split("/")[-2] function_definition_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_function_definition_versions( res = self.greengrass_backend.list_function_definition_versions(
function_definition_id=function_definition_id function_definition_id=function_definition_id
@ -437,13 +439,13 @@ class GreengrassResponse(BaseResponse):
versions = [i.to_dict() for i in res.values()] versions = [i.to_dict() for i in res.values()]
return 200, {"status": 200}, json.dumps({"Versions": versions}) return 200, {"status": 200}, json.dumps({"Versions": versions})
def function_definition_version(self, request, full_url, headers): def function_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
return self.get_function_definition_version() return self.get_function_definition_version()
def get_function_definition_version(self): def get_function_definition_version(self) -> TYPE_RESPONSE:
function_definition_id = self.path.split("/")[-3] function_definition_id = self.path.split("/")[-3]
function_definition_version_id = self.path.split("/")[-1] function_definition_version_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_function_definition_version( res = self.greengrass_backend.get_function_definition_version(
@ -452,7 +454,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def subscription_definitions(self, request, full_url, headers): def subscription_definitions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -461,7 +463,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_subscription_definitions() return self.list_subscription_definitions()
def create_subscription_definition(self): def create_subscription_definition(self) -> TYPE_RESPONSE:
initial_version = self._get_param("InitialVersion") initial_version = self._get_param("InitialVersion")
name = self._get_param("Name") name = self._get_param("Name")
@ -470,7 +472,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_subscription_definitions(self): def list_subscription_definitions(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_subscription_definitions() res = self.greengrass_backend.list_subscription_definitions()
return ( return (
@ -486,7 +488,7 @@ class GreengrassResponse(BaseResponse):
), ),
) )
def subscription_definition(self, request, full_url, headers): def subscription_definition(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
@ -498,21 +500,21 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT": if self.method == "PUT":
return self.update_subscription_definition() return self.update_subscription_definition()
def get_subscription_definition(self): def get_subscription_definition(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-1] subscription_definition_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_subscription_definition( res = self.greengrass_backend.get_subscription_definition(
subscription_definition_id=subscription_definition_id subscription_definition_id=subscription_definition_id
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_subscription_definition(self): def delete_subscription_definition(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-1] subscription_definition_id = self.path.split("/")[-1]
self.greengrass_backend.delete_subscription_definition( self.greengrass_backend.delete_subscription_definition(
subscription_definition_id=subscription_definition_id subscription_definition_id=subscription_definition_id
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def update_subscription_definition(self): def update_subscription_definition(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-1] subscription_definition_id = self.path.split("/")[-1]
name = self._get_param("Name") name = self._get_param("Name")
self.greengrass_backend.update_subscription_definition( self.greengrass_backend.update_subscription_definition(
@ -520,7 +522,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def subscription_definition_versions(self, request, full_url, headers): def subscription_definition_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -529,7 +531,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_subscription_definition_versions() return self.list_subscription_definition_versions()
def create_subscription_definition_version(self): def create_subscription_definition_version(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-2] subscription_definition_id = self.path.split("/")[-2]
subscriptions = self._get_param("Subscriptions") subscriptions = self._get_param("Subscriptions")
@ -539,7 +541,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_subscription_definition_versions(self): def list_subscription_definition_versions(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-2] subscription_definition_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_subscription_definition_versions( res = self.greengrass_backend.list_subscription_definition_versions(
subscription_definition_id=subscription_definition_id subscription_definition_id=subscription_definition_id
@ -547,13 +549,13 @@ class GreengrassResponse(BaseResponse):
versions = [i.to_dict() for i in res.values()] versions = [i.to_dict() for i in res.values()]
return 200, {"status": 200}, json.dumps({"Versions": versions}) return 200, {"status": 200}, json.dumps({"Versions": versions})
def subscription_definition_version(self, request, full_url, headers): def subscription_definition_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
return self.get_subscription_definition_version() return self.get_subscription_definition_version()
def get_subscription_definition_version(self): def get_subscription_definition_version(self) -> TYPE_RESPONSE:
subscription_definition_id = self.path.split("/")[-3] subscription_definition_id = self.path.split("/")[-3]
subscription_definition_version_id = self.path.split("/")[-1] subscription_definition_version_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_subscription_definition_version( res = self.greengrass_backend.get_subscription_definition_version(
@ -562,7 +564,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def groups(self, request, full_url, headers): def groups(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -571,7 +573,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_groups() return self.list_groups()
def create_group(self): def create_group(self) -> TYPE_RESPONSE:
initial_version = self._get_param("InitialVersion") initial_version = self._get_param("InitialVersion")
name = self._get_param("Name") name = self._get_param("Name")
@ -580,7 +582,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_groups(self): def list_groups(self) -> TYPE_RESPONSE:
res = self.greengrass_backend.list_groups() res = self.greengrass_backend.list_groups()
return ( return (
@ -589,7 +591,7 @@ class GreengrassResponse(BaseResponse):
json.dumps({"Groups": [group.to_dict() for group in res]}), json.dumps({"Groups": [group.to_dict() for group in res]}),
) )
def group(self, request, full_url, headers): def group(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
@ -601,27 +603,25 @@ class GreengrassResponse(BaseResponse):
if self.method == "PUT": if self.method == "PUT":
return self.update_group() return self.update_group()
def get_group(self): def get_group(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-1] group_id = self.path.split("/")[-1]
res = self.greengrass_backend.get_group( res = self.greengrass_backend.get_group(group_id=group_id)
group_id=group_id, return 200, {"status": 200}, json.dumps(res.to_dict()) # type: ignore
)
return 200, {"status": 200}, json.dumps(res.to_dict())
def delete_group(self): def delete_group(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-1] group_id = self.path.split("/")[-1]
self.greengrass_backend.delete_group( self.greengrass_backend.delete_group(
group_id=group_id, group_id=group_id,
) )
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def update_group(self): def update_group(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-1] group_id = self.path.split("/")[-1]
name = self._get_param("Name") name = self._get_param("Name")
self.greengrass_backend.update_group(group_id=group_id, name=name) self.greengrass_backend.update_group(group_id=group_id, name=name)
return 200, {"status": 200}, json.dumps({}) return 200, {"status": 200}, json.dumps({})
def group_versions(self, request, full_url, headers): def group_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -630,7 +630,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_group_versions() return self.list_group_versions()
def create_group_version(self): def create_group_version(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2] group_id = self.path.split("/")[-2]
@ -656,7 +656,7 @@ class GreengrassResponse(BaseResponse):
) )
return 201, {"status": 201}, json.dumps(res.to_dict()) return 201, {"status": 201}, json.dumps(res.to_dict())
def list_group_versions(self): def list_group_versions(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2] group_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_group_versions(group_id=group_id) res = self.greengrass_backend.list_group_versions(group_id=group_id)
return ( return (
@ -665,13 +665,13 @@ class GreengrassResponse(BaseResponse):
json.dumps({"Versions": [group_ver.to_dict() for group_ver in res]}), json.dumps({"Versions": [group_ver.to_dict() for group_ver in res]}),
) )
def group_version(self, request, full_url, headers): def group_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
return self.get_group_version() return self.get_group_version()
def get_group_version(self): def get_group_version(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-3] group_id = self.path.split("/")[-3]
group_version_id = self.path.split("/")[-1] group_version_id = self.path.split("/")[-1]
@ -681,7 +681,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True)) return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True))
def deployments(self, request, full_url, headers): def deployments(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
@ -690,7 +690,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.list_deployments() return self.list_deployments()
def create_deployment(self): def create_deployment(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2] group_id = self.path.split("/")[-2]
group_version_id = self._get_param("GroupVersionId") group_version_id = self._get_param("GroupVersionId")
@ -705,7 +705,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def list_deployments(self): def list_deployments(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2] group_id = self.path.split("/")[-2]
res = self.greengrass_backend.list_deployments(group_id=group_id) res = self.greengrass_backend.list_deployments(group_id=group_id)
@ -721,13 +721,13 @@ class GreengrassResponse(BaseResponse):
json.dumps({"Deployments": deployments}), json.dumps({"Deployments": deployments}),
) )
def deployment_satus(self, request, full_url, headers): def deployment_satus(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "GET": if self.method == "GET":
return self.get_deployment_status() return self.get_deployment_status()
def get_deployment_status(self): def get_deployment_status(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-4] group_id = self.path.split("/")[-4]
deployment_id = self.path.split("/")[-2] deployment_id = self.path.split("/")[-2]
@ -737,13 +737,13 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def deployments_reset(self, request, full_url, headers): def deployments_reset(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
return self.reset_deployments() return self.reset_deployments()
def reset_deployments(self): def reset_deployments(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-3] group_id = self.path.split("/")[-3]
res = self.greengrass_backend.reset_deployments( res = self.greengrass_backend.reset_deployments(
@ -751,7 +751,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def role(self, request, full_url, headers): def role(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "PUT": if self.method == "PUT":
@ -763,7 +763,7 @@ class GreengrassResponse(BaseResponse):
if self.method == "DELETE": if self.method == "DELETE":
return self.disassociate_role_from_group() return self.disassociate_role_from_group()
def associate_role_to_group(self): def associate_role_to_group(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2] group_id = self.path.split("/")[-2]
role_arn = self._get_param("RoleArn") role_arn = self._get_param("RoleArn")
@ -773,7 +773,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict()) return 200, {"status": 200}, json.dumps(res.to_dict())
def get_associated_role(self): def get_associated_role(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2] group_id = self.path.split("/")[-2]
res = self.greengrass_backend.get_associated_role( res = self.greengrass_backend.get_associated_role(
@ -781,7 +781,7 @@ class GreengrassResponse(BaseResponse):
) )
return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True)) return 200, {"status": 200}, json.dumps(res.to_dict(include_detail=True))
def disassociate_role_from_group(self): def disassociate_role_from_group(self) -> TYPE_RESPONSE:
group_id = self.path.split("/")[-2] group_id = self.path.split("/")[-2]
self.greengrass_backend.disassociate_role_from_group( self.greengrass_backend.disassociate_role_from_group(
group_id=group_id, group_id=group_id,

View File

@ -1,3 +1,4 @@
from typing import Any, List, Tuple
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
@ -8,24 +9,28 @@ class GuardDutyException(JsonRESTError):
class DetectorNotFoundException(GuardDutyException): class DetectorNotFoundException(GuardDutyException):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidInputException", "InvalidInputException",
"The request is rejected because the input detectorId is not owned by the current account.", "The request is rejected because the input detectorId is not owned by the current account.",
) )
def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument def get_headers(
return {"X-Amzn-ErrorType": "BadRequestException"} self, *args: Any, **kwargs: Any
) -> List[Tuple[str, str]]: # pylint: disable=unused-argument
return [("X-Amzn-ErrorType", "BadRequestException")]
class FilterNotFoundException(GuardDutyException): class FilterNotFoundException(GuardDutyException):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidInputException", "InvalidInputException",
"The request is rejected since no such resource found.", "The request is rejected since no such resource found.",
) )
def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument def get_headers(
return {"X-Amzn-ErrorType": "BadRequestException"} self, *args: Any, **kwargs: Any
) -> List[Tuple[str, str]]: # pylint: disable=unused-argument
return [("X-Amzn-ErrorType", "BadRequestException")]

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
from datetime import datetime from datetime import datetime
@ -6,12 +7,18 @@ from .exceptions import DetectorNotFoundException, FilterNotFoundException
class GuardDutyBackend(BaseBackend): class GuardDutyBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.admin_account_ids = [] self.admin_account_ids: List[str] = []
self.detectors = {} self.detectors: Dict[str, Detector] = {}
def create_detector(self, enable, finding_publishing_frequency, data_sources, tags): def create_detector(
self,
enable: bool,
finding_publishing_frequency: str,
data_sources: Dict[str, Any],
tags: Dict[str, str],
) -> str:
if finding_publishing_frequency not in [ if finding_publishing_frequency not in [
"FIFTEEN_MINUTES", "FIFTEEN_MINUTES",
"ONE_HOUR", "ONE_HOUR",
@ -31,29 +38,35 @@ class GuardDutyBackend(BaseBackend):
return detector.id return detector.id
def create_filter( def create_filter(
self, detector_id, name, action, description, finding_criteria, rank self,
): detector_id: str,
name: str,
action: str,
description: str,
finding_criteria: Dict[str, Any],
rank: int,
) -> None:
detector = self.get_detector(detector_id) detector = self.get_detector(detector_id)
_filter = Filter(name, action, description, finding_criteria, rank) _filter = Filter(name, action, description, finding_criteria, rank)
detector.add_filter(_filter) detector.add_filter(_filter)
def delete_detector(self, detector_id): def delete_detector(self, detector_id: str) -> None:
self.detectors.pop(detector_id, None) self.detectors.pop(detector_id, None)
def delete_filter(self, detector_id, filter_name): def delete_filter(self, detector_id: str, filter_name: str) -> None:
detector = self.get_detector(detector_id) detector = self.get_detector(detector_id)
detector.delete_filter(filter_name) detector.delete_filter(filter_name)
def enable_organization_admin_account(self, admin_account_id): def enable_organization_admin_account(self, admin_account_id: str) -> None:
self.admin_account_ids.append(admin_account_id) self.admin_account_ids.append(admin_account_id)
def list_organization_admin_accounts(self): def list_organization_admin_accounts(self) -> List[str]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
return self.admin_account_ids return self.admin_account_ids
def list_detectors(self): def list_detectors(self) -> List[str]:
""" """
The MaxResults and NextToken-parameter have not yet been implemented. The MaxResults and NextToken-parameter have not yet been implemented.
""" """
@ -62,24 +75,34 @@ class GuardDutyBackend(BaseBackend):
detectorids.append(self.detectors[detector].id) detectorids.append(self.detectors[detector].id)
return detectorids return detectorids
def get_detector(self, detector_id): def get_detector(self, detector_id: str) -> "Detector":
if detector_id not in self.detectors: if detector_id not in self.detectors:
raise DetectorNotFoundException raise DetectorNotFoundException
return self.detectors[detector_id] return self.detectors[detector_id]
def get_filter(self, detector_id, filter_name): def get_filter(self, detector_id: str, filter_name: str) -> "Filter":
detector = self.get_detector(detector_id) detector = self.get_detector(detector_id)
return detector.get_filter(filter_name) return detector.get_filter(filter_name)
def update_detector( def update_detector(
self, detector_id, enable, finding_publishing_frequency, data_sources self,
): detector_id: str,
enable: bool,
finding_publishing_frequency: str,
data_sources: Dict[str, Any],
) -> None:
detector = self.get_detector(detector_id) detector = self.get_detector(detector_id)
detector.update(enable, finding_publishing_frequency, data_sources) detector.update(enable, finding_publishing_frequency, data_sources)
def update_filter( def update_filter(
self, detector_id, filter_name, action, description, finding_criteria, rank self,
): detector_id: str,
filter_name: str,
action: str,
description: str,
finding_criteria: Dict[str, Any],
rank: int,
) -> None:
detector = self.get_detector(detector_id) detector = self.get_detector(detector_id)
detector.update_filter( detector.update_filter(
filter_name, filter_name,
@ -91,14 +114,27 @@ class GuardDutyBackend(BaseBackend):
class Filter(BaseModel): class Filter(BaseModel):
def __init__(self, name, action, description, finding_criteria, rank): def __init__(
self,
name: str,
action: str,
description: str,
finding_criteria: Dict[str, Any],
rank: int,
):
self.name = name self.name = name
self.action = action self.action = action
self.description = description self.description = description
self.finding_criteria = finding_criteria self.finding_criteria = finding_criteria
self.rank = rank or 1 self.rank = rank or 1
def update(self, action, description, finding_criteria, rank): def update(
self,
action: Optional[str],
description: Optional[str],
finding_criteria: Optional[Dict[str, Any]],
rank: Optional[int],
) -> None:
if action is not None: if action is not None:
self.action = action self.action = action
if description is not None: if description is not None:
@ -108,7 +144,7 @@ class Filter(BaseModel):
if rank is not None: if rank is not None:
self.rank = rank self.rank = rank
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"name": self.name, "name": self.name,
"action": self.action, "action": self.action,
@ -121,12 +157,12 @@ class Filter(BaseModel):
class Detector(BaseModel): class Detector(BaseModel):
def __init__( def __init__(
self, self,
account_id, account_id: str,
created_at, created_at: datetime,
finding_publish_freq, finding_publish_freq: str,
enabled, enabled: bool,
datasources, datasources: Dict[str, Any],
tags, tags: Dict[str, str],
): ):
self.id = mock_random.get_random_hex(length=32) self.id = mock_random.get_random_hex(length=32)
self.created_at = created_at self.created_at = created_at
@ -137,20 +173,27 @@ class Detector(BaseModel):
self.datasources = datasources or {} self.datasources = datasources or {}
self.tags = tags or {} self.tags = tags or {}
self.filters = dict() self.filters: Dict[str, Filter] = dict()
def add_filter(self, _filter: Filter): def add_filter(self, _filter: Filter) -> None:
self.filters[_filter.name] = _filter self.filters[_filter.name] = _filter
def delete_filter(self, filter_name): def delete_filter(self, filter_name: str) -> None:
self.filters.pop(filter_name, None) self.filters.pop(filter_name, None)
def get_filter(self, filter_name: str): def get_filter(self, filter_name: str) -> Filter:
if filter_name not in self.filters: if filter_name not in self.filters:
raise FilterNotFoundException raise FilterNotFoundException
return self.filters[filter_name] return self.filters[filter_name]
def update_filter(self, filter_name, action, description, finding_criteria, rank): def update_filter(
self,
filter_name: str,
action: str,
description: str,
finding_criteria: Dict[str, Any],
rank: int,
) -> None:
_filter = self.get_filter(filter_name) _filter = self.get_filter(filter_name)
_filter.update( _filter.update(
action=action, action=action,
@ -159,7 +202,12 @@ class Detector(BaseModel):
rank=rank, rank=rank,
) )
def update(self, enable, finding_publishing_frequency, data_sources): def update(
self,
enable: bool,
finding_publishing_frequency: str,
data_sources: Dict[str, Any],
) -> None:
if enable is not None: if enable is not None:
self.enabled = enable self.enabled = enable
if finding_publishing_frequency is not None: if finding_publishing_frequency is not None:
@ -167,7 +215,7 @@ class Detector(BaseModel):
if data_sources is not None: if data_sources is not None:
self.datasources = data_sources self.datasources = data_sources
def to_json(self): def to_json(self) -> Dict[str, Any]:
data_sources = { data_sources = {
"cloudTrail": {"status": "DISABLED"}, "cloudTrail": {"status": "DISABLED"},
"dnsLogs": {"status": "DISABLED"}, "dnsLogs": {"status": "DISABLED"},

View File

@ -1,18 +1,20 @@
from typing import Any
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import guardduty_backends from .models import guardduty_backends, GuardDutyBackend
import json import json
from urllib.parse import unquote from urllib.parse import unquote
class GuardDutyResponse(BaseResponse): class GuardDutyResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="guardduty") super().__init__(service_name="guardduty")
@property @property
def guardduty_backend(self): def guardduty_backend(self) -> GuardDutyBackend:
return guardduty_backends[self.current_account][self.region] return guardduty_backends[self.current_account][self.region]
def filter(self, request, full_url, headers): def filter(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self.get_filter() return self.get_filter()
@ -21,12 +23,12 @@ class GuardDutyResponse(BaseResponse):
elif request.method == "POST": elif request.method == "POST":
return self.update_filter() return self.update_filter()
def filters(self, request, full_url, headers): def filters(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
return self.create_filter() return self.create_filter()
def detectors(self, request, full_url, headers): def detectors(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
return self.create_detector() return self.create_detector()
@ -35,7 +37,7 @@ class GuardDutyResponse(BaseResponse):
else: else:
return 404, {}, "" return 404, {}, ""
def detector(self, request, full_url, headers): def detector(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self.get_detector() return self.get_detector()
@ -44,7 +46,7 @@ class GuardDutyResponse(BaseResponse):
elif request.method == "POST": elif request.method == "POST":
return self.update_detector() return self.update_detector()
def create_filter(self): def create_filter(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-2] detector_id = self.path.split("/")[-2]
name = self._get_param("name") name = self._get_param("name")
action = self._get_param("action") action = self._get_param("action")
@ -57,7 +59,7 @@ class GuardDutyResponse(BaseResponse):
) )
return 200, {}, json.dumps({"name": name}) return 200, {}, json.dumps({"name": name})
def create_detector(self): def create_detector(self) -> TYPE_RESPONSE:
enable = self._get_param("enable") enable = self._get_param("enable")
finding_publishing_frequency = self._get_param("findingPublishingFrequency") finding_publishing_frequency = self._get_param("findingPublishingFrequency")
data_sources = self._get_param("dataSources") data_sources = self._get_param("dataSources")
@ -69,20 +71,22 @@ class GuardDutyResponse(BaseResponse):
return 200, {}, json.dumps(dict(detectorId=detector_id)) return 200, {}, json.dumps(dict(detectorId=detector_id))
def delete_detector(self): def delete_detector(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-1] detector_id = self.path.split("/")[-1]
self.guardduty_backend.delete_detector(detector_id) self.guardduty_backend.delete_detector(detector_id)
return 200, {}, "{}" return 200, {}, "{}"
def delete_filter(self): def delete_filter(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-3] detector_id = self.path.split("/")[-3]
filter_name = unquote(self.path.split("/")[-1]) filter_name = unquote(self.path.split("/")[-1])
self.guardduty_backend.delete_filter(detector_id, filter_name) self.guardduty_backend.delete_filter(detector_id, filter_name)
return 200, {}, "{}" return 200, {}, "{}"
def enable_organization_admin_account(self, request, full_url, headers): def enable_organization_admin_account(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
admin_account = self._get_param("adminAccountId") admin_account = self._get_param("adminAccountId")
@ -90,7 +94,9 @@ class GuardDutyResponse(BaseResponse):
return 200, {}, "{}" return 200, {}, "{}"
def list_organization_admin_accounts(self, request, full_url, headers): def list_organization_admin_accounts(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
account_ids = self.guardduty_backend.list_organization_admin_accounts() account_ids = self.guardduty_backend.list_organization_admin_accounts()
@ -108,25 +114,25 @@ class GuardDutyResponse(BaseResponse):
), ),
) )
def list_detectors(self): def list_detectors(self) -> TYPE_RESPONSE:
detector_ids = self.guardduty_backend.list_detectors() detector_ids = self.guardduty_backend.list_detectors()
return 200, {}, json.dumps({"detectorIds": detector_ids}) return 200, {}, json.dumps({"detectorIds": detector_ids})
def get_detector(self): def get_detector(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-1] detector_id = self.path.split("/")[-1]
detector = self.guardduty_backend.get_detector(detector_id) detector = self.guardduty_backend.get_detector(detector_id)
return 200, {}, json.dumps(detector.to_json()) return 200, {}, json.dumps(detector.to_json())
def get_filter(self): def get_filter(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-3] detector_id = self.path.split("/")[-3]
filter_name = unquote(self.path.split("/")[-1]) filter_name = unquote(self.path.split("/")[-1])
_filter = self.guardduty_backend.get_filter(detector_id, filter_name) _filter = self.guardduty_backend.get_filter(detector_id, filter_name)
return 200, {}, json.dumps(_filter.to_json()) return 200, {}, json.dumps(_filter.to_json())
def update_detector(self): def update_detector(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-1] detector_id = self.path.split("/")[-1]
enable = self._get_param("enable") enable = self._get_param("enable")
finding_publishing_frequency = self._get_param("findingPublishingFrequency") finding_publishing_frequency = self._get_param("findingPublishingFrequency")
@ -137,7 +143,7 @@ class GuardDutyResponse(BaseResponse):
) )
return 200, {}, "{}" return 200, {}, "{}"
def update_filter(self): def update_filter(self) -> TYPE_RESPONSE:
detector_id = self.path.split("/")[-3] detector_id = self.path.split("/")[-3]
filter_name = unquote(self.path.split("/")[-1]) filter_name = unquote(self.path.split("/")[-1])
action = self._get_param("action") action = self._get_param("action")

View File

@ -171,7 +171,9 @@ class TaggingService:
) )
@staticmethod @staticmethod
def convert_dict_to_tags_input(tags: Dict[str, str]) -> List[Dict[str, str]]: def convert_dict_to_tags_input(
tags: Optional[Dict[str, str]]
) -> List[Dict[str, str]]:
"""Given a dictionary, return generic boto params for tags""" """Given a dictionary, return generic boto params for tags"""
if not tags: if not tags:
return [] return []

View File

@ -229,7 +229,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/moto_api,moto/neptune files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/moto_api,moto/neptune
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract