Techdebt: MyPy ECR (#5943)

This commit is contained in:
Bert Blommers 2023-02-18 09:48:26 -01:00 committed by GitHub
parent eb79d064e8
commit b241c16726
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 260 additions and 215 deletions

View File

@ -4,9 +4,9 @@ from moto.core.exceptions import JsonRESTError
class LifecyclePolicyNotFoundException(JsonRESTError):
code = 400
def __init__(self, repository_name, registry_id):
def __init__(self, repository_name: str, registry_id: str):
super().__init__(
error_type=__class__.__name__,
error_type="LifecyclePolicyNotFoundException",
message=(
"Lifecycle policy does not exist "
f"for the repository with name '{repository_name}' "
@ -18,9 +18,9 @@ class LifecyclePolicyNotFoundException(JsonRESTError):
class LimitExceededException(JsonRESTError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
error_type=__class__.__name__,
error_type="LimitExceededException",
message=("The scan quota per image has been exceeded. Wait and try again."),
)
@ -28,9 +28,9 @@ class LimitExceededException(JsonRESTError):
class RegistryPolicyNotFoundException(JsonRESTError):
code = 400
def __init__(self, registry_id):
def __init__(self, registry_id: str):
super().__init__(
error_type=__class__.__name__,
error_type="RegistryPolicyNotFoundException",
message=(
f"Registry policy does not exist in the registry with id '{registry_id}'"
),
@ -40,9 +40,9 @@ class RegistryPolicyNotFoundException(JsonRESTError):
class RepositoryAlreadyExistsException(JsonRESTError):
code = 400
def __init__(self, repository_name, registry_id):
def __init__(self, repository_name: str, registry_id: str):
super().__init__(
error_type=__class__.__name__,
error_type="RepositoryAlreadyExistsException",
message=(
f"The repository with name '{repository_name}' already exists "
f"in the registry with id '{registry_id}'"
@ -53,9 +53,9 @@ class RepositoryAlreadyExistsException(JsonRESTError):
class RepositoryNotEmptyException(JsonRESTError):
code = 400
def __init__(self, repository_name, registry_id):
def __init__(self, repository_name: str, registry_id: str):
super().__init__(
error_type=__class__.__name__,
error_type="RepositoryNotEmptyException",
message=(
f"The repository with name '{repository_name}' "
f"in registry with id '{registry_id}' "
@ -67,9 +67,9 @@ class RepositoryNotEmptyException(JsonRESTError):
class RepositoryNotFoundException(JsonRESTError):
code = 400
def __init__(self, repository_name, registry_id):
def __init__(self, repository_name: str, registry_id: str):
super().__init__(
error_type=__class__.__name__,
error_type="RepositoryNotFoundException",
message=(
f"The repository with name '{repository_name}' does not exist "
f"in the registry with id '{registry_id}'"
@ -80,9 +80,9 @@ class RepositoryNotFoundException(JsonRESTError):
class RepositoryPolicyNotFoundException(JsonRESTError):
code = 400
def __init__(self, repository_name, registry_id):
def __init__(self, repository_name: str, registry_id: str):
super().__init__(
error_type=__class__.__name__,
error_type="RepositoryPolicyNotFoundException",
message=(
"Repository policy does not exist "
f"for the repository with name '{repository_name}' "
@ -94,9 +94,9 @@ class RepositoryPolicyNotFoundException(JsonRESTError):
class ImageNotFoundException(JsonRESTError):
code = 400
def __init__(self, image_id, repository_name, registry_id):
def __init__(self, image_id: str, repository_name: str, registry_id: str):
super().__init__(
error_type=__class__.__name__,
error_type="ImageNotFoundException",
message=(
f"The image with imageId {image_id} does not exist "
f"within the repository with name '{repository_name}' "
@ -108,16 +108,16 @@ class ImageNotFoundException(JsonRESTError):
class InvalidParameterException(JsonRESTError):
code = 400
def __init__(self, message):
super().__init__(error_type=__class__.__name__, message=message)
def __init__(self, message: str):
super().__init__(error_type="InvalidParameterException", message=message)
class ScanNotFoundException(JsonRESTError):
code = 400
def __init__(self, image_id, repository_name, registry_id):
def __init__(self, image_id: str, repository_name: str, registry_id: str):
super().__init__(
error_type=__class__.__name__,
error_type="ScanNotFoundException",
message=(
f"Image scan does not exist for the image with '{image_id}' "
f"in the repository with name '{repository_name}' "
@ -129,5 +129,5 @@ class ScanNotFoundException(JsonRESTError):
class ValidationException(JsonRESTError):
code = 400
def __init__(self, message):
super().__init__(error_type=__class__.__name__, message=message)
def __init__(self, message: str):
super().__init__(error_type="ValidationException", message=message)

View File

@ -3,7 +3,7 @@ import json
import re
from collections import namedtuple
from datetime import datetime, timezone
from typing import Dict, List
from typing import Any, Dict, List, Iterable, Optional
from botocore.exceptions import ParamValidationError
@ -36,7 +36,7 @@ EcrRepositoryArn = namedtuple(
class BaseObject(BaseModel):
def camelCase(self, key):
def camelCase(self, key: str) -> str:
words = []
for i, word in enumerate(key.split("_")):
if i > 0:
@ -45,7 +45,7 @@ class BaseObject(BaseModel):
words.append(word)
return "".join(words)
def gen_response_object(self):
def gen_response_object(self) -> Dict[str, Any]:
response_object = dict()
for key, value in self.__dict__.items():
if "_" in key:
@ -55,20 +55,20 @@ class BaseObject(BaseModel):
return response_object
@property
def response_object(self):
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
return self.gen_response_object()
class Repository(BaseObject, CloudFormationModel):
def __init__(
self,
account_id,
region_name,
repository_name,
registry_id,
encryption_config,
image_scan_config,
image_tag_mutablility,
account_id: str,
region_name: str,
repository_name: str,
registry_id: str,
encryption_config: Optional[Dict[str, str]],
image_scan_config: str,
image_tag_mutablility: str,
):
self.account_id = account_id
self.region_name = region_name
@ -86,11 +86,13 @@ class Repository(BaseObject, CloudFormationModel):
self.encryption_configuration = self._determine_encryption_config(
encryption_config
)
self.policy = None
self.lifecycle_policy = None
self.policy: Optional[str] = None
self.lifecycle_policy: Optional[str] = None
self.images: List[Image] = []
def _determine_encryption_config(self, encryption_config):
def _determine_encryption_config(
self, encryption_config: Optional[Dict[str, str]]
) -> Dict[str, str]:
if not encryption_config:
return {"encryptionType": "AES256"}
if encryption_config == {"encryptionType": "KMS"}:
@ -99,7 +101,9 @@ class Repository(BaseObject, CloudFormationModel):
] = f"arn:aws:kms:{self.region_name}:{self.account_id}:key/{random.uuid4()}"
return encryption_config
def _get_image(self, image_tag, image_digest):
def _get_image(
self, image_tag: Optional[str], image_digest: Optional[str]
) -> "Image":
# you can either search for one or both
image = next(
(
@ -125,11 +129,11 @@ class Repository(BaseObject, CloudFormationModel):
return image
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.name
@property
def response_object(self):
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = self.gen_response_object()
response_object["registryId"] = self.registry_id
@ -142,21 +146,25 @@ class Repository(BaseObject, CloudFormationModel):
del response_object["arn"], response_object["name"], response_object["images"]
return response_object
def update(self, image_scan_config=None, image_tag_mutability=None):
def update(
self,
image_scan_config: Optional[Dict[str, Any]] = None,
image_tag_mutability: Optional[str] = None,
) -> None:
if image_scan_config:
self.image_scanning_configuration = image_scan_config
if image_tag_mutability:
self.image_tag_mutability = image_tag_mutability
def delete(self, account_id, region_name):
def delete(self, account_id: str, region_name: str) -> None:
ecr_backend = ecr_backends[account_id][region_name]
ecr_backend.delete_repository(self.name)
@classmethod
def has_cfn_attr(cls, attr):
def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["Arn", "RepositoryUri"]
def get_cfn_attribute(self, attribute_name):
def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn":
@ -167,18 +175,23 @@ class Repository(BaseObject, CloudFormationModel):
raise UnformattedGetAttTemplateException()
@staticmethod
def cloudformation_name_type():
def cloudformation_name_type() -> str:
return "RepositoryName"
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ecr-repository.html
return "AWS::ECR::Repository"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Repository":
ecr_backend = ecr_backends[account_id][region_name]
properties = cloudformation_json["Properties"]
@ -199,14 +212,14 @@ class Repository(BaseObject, CloudFormationModel):
)
@classmethod
def update_from_cloudformation_json(
def update_from_cloudformation_json( # type: ignore[misc]
cls,
original_resource,
new_resource_name,
cloudformation_json,
account_id,
region_name,
):
original_resource: Any,
new_resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> "Repository":
ecr_backend = ecr_backends[account_id][region_name]
properties = cloudformation_json["Properties"]
encryption_configuration = properties.get(
@ -237,13 +250,13 @@ class Repository(BaseObject, CloudFormationModel):
class Image(BaseObject):
def __init__(
self,
account_id,
tag,
manifest,
repository,
image_manifest_mediatype=None,
digest=None,
registry_id=None,
account_id: str,
tag: str,
manifest: str,
repository: str,
image_manifest_mediatype: Optional[str] = None,
digest: Optional[str] = None,
registry_id: Optional[str] = None,
):
self.image_tag = tag
self.image_tags = [tag] if tag is not None else []
@ -253,9 +266,9 @@ class Image(BaseObject):
self.registry_id = registry_id or account_id
self.image_digest = digest
self.image_pushed_at = str(datetime.now(timezone.utc).isoformat())
self.last_scan = None
self.last_scan: Optional[datetime] = None
def _create_digest(self):
def _create_digest(self) -> None:
image_manifest = json.loads(self.image_manifest)
if "layers" in image_manifest:
layer_digests = [layer["digest"] for layer in image_manifest["layers"]]
@ -269,12 +282,12 @@ class Image(BaseObject):
).hexdigest()
self.image_digest = f"sha256:{random_sha}"
def get_image_digest(self):
def get_image_digest(self) -> str:
if not self.image_digest:
self._create_digest()
return self.image_digest
return self.image_digest # type: ignore[return-value]
def get_image_size_in_bytes(self):
def get_image_size_in_bytes(self) -> Optional[int]:
image_manifest = json.loads(self.image_manifest)
if "layers" in image_manifest:
try:
@ -284,22 +297,22 @@ class Image(BaseObject):
else:
return None
def get_image_manifest(self):
def get_image_manifest(self) -> str:
return self.image_manifest
def remove_tag(self, tag):
def remove_tag(self, tag: str) -> None:
if tag is not None and tag in self.image_tags:
self.image_tags.remove(tag)
if self.image_tags:
self.image_tag = self.image_tags[-1]
def update_tag(self, tag):
def update_tag(self, tag: str) -> None:
self.image_tag = tag
if tag not in self.image_tags and tag is not None:
self.image_tags.append(tag)
@property
def response_object(self):
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = self.gen_response_object()
response_object["imageId"] = {}
response_object["imageId"]["imageTag"] = self.image_tag
@ -312,7 +325,7 @@ class Image(BaseObject):
}
@property
def response_list_object(self):
def response_list_object(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = self.gen_response_object()
response_object["imageTag"] = self.image_tag
response_object["imageDigest"] = self.get_image_digest()
@ -321,7 +334,7 @@ class Image(BaseObject):
}
@property
def response_describe_object(self):
def response_describe_object(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = self.gen_response_object()
response_object["imageTags"] = self.image_tags
response_object["imageDigest"] = self.get_image_digest()
@ -334,38 +347,40 @@ class Image(BaseObject):
return {k: v for k, v in response_object.items() if v is not None and v != []}
@property
def response_batch_get_image(self):
response_object = {}
response_object["imageId"] = {}
response_object["imageId"]["imageTag"] = self.image_tag
response_object["imageId"]["imageDigest"] = self.get_image_digest()
response_object["imageManifest"] = self.image_manifest
response_object["repositoryName"] = self.repository
response_object["registryId"] = self.registry_id
def response_batch_get_image(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = {
"imageId": {
"imageTag": self.image_tag,
"imageDigest": self.get_image_digest(),
},
"imageManifest": self.image_manifest,
"repositoryName": self.repository,
"registryId": self.registry_id,
}
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
k: v for k, v in response_object.items() if v is not None and v != [None] # type: ignore
}
@property
def response_batch_delete_image(self):
def response_batch_delete_image(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = {}
response_object["imageDigest"] = self.get_image_digest()
response_object["imageTag"] = self.image_tag
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
k: v for k, v in response_object.items() if v is not None and v != [None] # type: ignore
}
class ECRBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.registry_policy = None
self.replication_config = {"rules": []}
self.registry_policy: Optional[str] = None
self.replication_config: Dict[str, Any] = {"rules": []}
self.repositories: Dict[str, Repository] = {}
self.tagger = TaggingService(tag_name="tags")
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc]
"""Default VPC endpoint service."""
docker_endpoint = {
"AcceptanceRequired": False,
@ -388,7 +403,9 @@ class ECRBackend(BaseBackend):
service_region, zones, "api.ecr", special_service_name="ecr.api"
) + [docker_endpoint]
def _get_repository(self, name, registry_id=None) -> Repository:
def _get_repository(
self, name: str, registry_id: Optional[str] = None
) -> Repository:
repo = self.repositories.get(name)
reg_id = registry_id or self.account_id
@ -397,7 +414,7 @@ class ECRBackend(BaseBackend):
return repo
@staticmethod
def _parse_resource_arn(resource_arn) -> EcrRepositoryArn:
def _parse_resource_arn(resource_arn: str) -> EcrRepositoryArn: # type: ignore[misc]
match = re.match(ECR_REPOSITORY_ARN_PATTERN, resource_arn)
if not match:
raise InvalidParameterException(
@ -406,7 +423,11 @@ class ECRBackend(BaseBackend):
)
return EcrRepositoryArn(**match.groupdict())
def describe_repositories(self, registry_id=None, repository_names=None):
def describe_repositories(
self,
registry_id: Optional[str] = None,
repository_names: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
maxResults and nextToken not implemented
"""
@ -433,13 +454,13 @@ class ECRBackend(BaseBackend):
def create_repository(
self,
repository_name,
registry_id,
encryption_config,
image_scan_config,
image_tag_mutablility,
tags,
):
repository_name: str,
registry_id: str,
encryption_config: Dict[str, str],
image_scan_config: Any,
image_tag_mutablility: str,
tags: List[Dict[str, str]],
) -> Repository:
if self.repositories.get(repository_name):
raise RepositoryAlreadyExistsException(repository_name, self.account_id)
@ -457,7 +478,12 @@ class ECRBackend(BaseBackend):
return repository
def delete_repository(self, repository_name, registry_id=None, force=False):
def delete_repository(
self,
repository_name: str,
registry_id: Optional[str] = None,
force: bool = False,
) -> Repository:
repo = self._get_repository(repository_name, registry_id)
if repo.images and not force:
@ -468,7 +494,9 @@ class ECRBackend(BaseBackend):
self.tagger.delete_all_tags_for_resource(repo.arn)
return self.repositories.pop(repository_name)
def list_images(self, repository_name, registry_id=None):
def list_images(
self, repository_name: str, registry_id: Optional[str] = None
) -> List[Image]:
"""
maxResults and filtering not implemented
"""
@ -487,16 +515,18 @@ class ECRBackend(BaseBackend):
repository_name, registry_id or self.account_id
)
images = []
for image in repository.images:
images.append(image)
return images
return list(repository.images) # type: ignore[union-attr]
def describe_images(self, repository_name, registry_id=None, image_ids=None):
def describe_images(
self,
repository_name: str,
registry_id: Optional[str] = None,
image_ids: Optional[List[Dict[str, str]]] = None,
) -> Iterable[Image]:
repository = self._get_repository(repository_name, registry_id)
if image_ids:
response = set(
return set(
repository._get_image(
image_id.get("imageTag"), image_id.get("imageDigest")
)
@ -504,15 +534,15 @@ class ECRBackend(BaseBackend):
)
else:
response = []
for image in repository.images:
response.append(image)
return response
return list(repository.images)
def put_image(
self, repository_name, image_manifest, image_tag, image_manifest_mediatype=None
):
self,
repository_name: str,
image_manifest: str,
image_tag: str,
image_manifest_mediatype: Optional[str] = None,
) -> Image:
if repository_name in self.repositories:
repository = self.repositories[repository_name]
else:
@ -563,7 +593,12 @@ class ECRBackend(BaseBackend):
existing_images[0].update_tag(image_tag)
return existing_images[0]
def batch_get_image(self, repository_name, registry_id=None, image_ids=None):
def batch_get_image(
self,
repository_name: str,
registry_id: Optional[str] = None,
image_ids: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""
The parameter AcceptedMediaTypes has not yet been implemented
"""
@ -579,7 +614,7 @@ class ECRBackend(BaseBackend):
msg='Missing required parameter in input: "imageIds"'
)
response = {"images": [], "failures": []}
response: Dict[str, Any] = {"images": [], "failures": []}
for image_id in image_ids:
found = False
@ -604,7 +639,12 @@ class ECRBackend(BaseBackend):
return response
def batch_delete_image(self, repository_name, registry_id=None, image_ids=None):
def batch_delete_image(
self,
repository_name: str,
registry_id: Optional[str] = None,
image_ids: Optional[List[Dict[str, str]]] = None,
) -> Dict[str, Any]:
if repository_name in self.repositories:
repository = self.repositories[repository_name]
else:
@ -617,7 +657,7 @@ class ECRBackend(BaseBackend):
msg='Missing required parameter in input: "imageIds"'
)
response = {"imageIds": [], "failures": []}
response: Dict[str, Any] = {"imageIds": [], "failures": []}
for image_id in image_ids:
image_found = False
@ -636,12 +676,10 @@ class ECRBackend(BaseBackend):
# If we have a digest, is it valid?
if "imageDigest" in image_id:
pattern = re.compile(r"^[0-9a-zA-Z_+\.-]+:[0-9a-fA-F]{64}")
if not pattern.match(image_id.get("imageDigest")):
if not pattern.match(image_id["imageDigest"]):
response["failures"].append(
{
"imageId": {
"imageDigest": image_id.get("imageDigest", "null")
},
"imageId": {"imageDigest": image_id["imageDigest"]},
"failureCode": "InvalidImageDigest",
"failureReason": "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'",
}
@ -690,7 +728,7 @@ class ECRBackend(BaseBackend):
repository.images.remove(image)
if not image_found:
failure_response = {
failure_response: Dict[str, Any] = {
"imageId": {},
"failureCode": "ImageNotFound",
"failureReason": "Requested image not found",
@ -710,29 +748,25 @@ class ECRBackend(BaseBackend):
return response
def list_tags_for_resource(self, arn):
def list_tags_for_resource(self, arn: str) -> Dict[str, List[Dict[str, str]]]:
resource = self._parse_resource_arn(arn)
repo = self._get_repository(resource.repo_name, resource.account_id)
return self.tagger.list_tags_for_resource(repo.arn)
def tag_resource(self, arn, tags):
def tag_resource(self, arn: str, tags: List[Dict[str, str]]) -> None:
resource = self._parse_resource_arn(arn)
repo = self._get_repository(resource.repo_name, resource.account_id)
self.tagger.tag_resource(repo.arn, tags)
return {}
def untag_resource(self, arn, tag_keys):
def untag_resource(self, arn: str, tag_keys: List[str]) -> None:
resource = self._parse_resource_arn(arn)
repo = self._get_repository(resource.repo_name, resource.account_id)
self.tagger.untag_resource_using_names(repo.arn, tag_keys)
return {}
def put_image_tag_mutability(
self, registry_id, repository_name, image_tag_mutability
):
self, registry_id: str, repository_name: str, image_tag_mutability: str
) -> Dict[str, str]:
if image_tag_mutability not in ["IMMUTABLE", "MUTABLE"]:
raise InvalidParameterException(
"Invalid parameter at 'imageTagMutability' failed to satisfy constraint: "
@ -749,8 +783,8 @@ class ECRBackend(BaseBackend):
}
def put_image_scanning_configuration(
self, registry_id, repository_name, image_scan_config
):
self, registry_id: str, repository_name: str, image_scan_config: Dict[str, Any]
) -> Dict[str, Any]:
repo = self._get_repository(repository_name, registry_id)
repo.update(image_scan_config=image_scan_config)
@ -760,15 +794,17 @@ class ECRBackend(BaseBackend):
"imageScanningConfiguration": repo.image_scanning_configuration,
}
def set_repository_policy(self, registry_id, repository_name, policy_text):
def set_repository_policy(
self, registry_id: str, repository_name: str, policy_text: str
) -> Dict[str, Any]:
repo = self._get_repository(repository_name, registry_id)
try:
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_text)
# the repository policy can be defined without a resource field
iam_policy_document_validator._validate_resource_exist = lambda: None
iam_policy_document_validator._validate_resource_exist = lambda: None # type: ignore
# the repository policy can have the old version 2008-10-17
iam_policy_document_validator._validate_version = lambda: None
iam_policy_document_validator._validate_version = lambda: None # type: ignore
iam_policy_document_validator.validate()
except MalformedPolicyDocument:
raise InvalidParameterException(
@ -784,7 +820,9 @@ class ECRBackend(BaseBackend):
"policyText": repo.policy,
}
def get_repository_policy(self, registry_id, repository_name):
def get_repository_policy(
self, registry_id: str, repository_name: str
) -> Dict[str, Any]:
repo = self._get_repository(repository_name, registry_id)
if not repo.policy:
@ -796,7 +834,9 @@ class ECRBackend(BaseBackend):
"policyText": repo.policy,
}
def delete_repository_policy(self, registry_id, repository_name):
def delete_repository_policy(
self, registry_id: str, repository_name: str
) -> Dict[str, Any]:
repo = self._get_repository(repository_name, registry_id)
policy = repo.policy
@ -811,7 +851,9 @@ class ECRBackend(BaseBackend):
"policyText": policy,
}
def put_lifecycle_policy(self, registry_id, repository_name, lifecycle_policy_text):
def put_lifecycle_policy(
self, registry_id: str, repository_name: str, lifecycle_policy_text: str
) -> Dict[str, Any]:
repo = self._get_repository(repository_name, registry_id)
validator = EcrLifecyclePolicyValidator(lifecycle_policy_text)
@ -825,7 +867,9 @@ class ECRBackend(BaseBackend):
"lifecyclePolicyText": repo.lifecycle_policy,
}
def get_lifecycle_policy(self, registry_id, repository_name):
def get_lifecycle_policy(
self, registry_id: str, repository_name: str
) -> Dict[str, Any]:
repo = self._get_repository(repository_name, registry_id)
if not repo.lifecycle_policy:
@ -840,7 +884,9 @@ class ECRBackend(BaseBackend):
),
}
def delete_lifecycle_policy(self, registry_id, repository_name):
def delete_lifecycle_policy(
self, registry_id: str, repository_name: str
) -> Dict[str, Any]:
repo = self._get_repository(repository_name, registry_id)
policy = repo.lifecycle_policy
@ -858,7 +904,7 @@ class ECRBackend(BaseBackend):
),
}
def _validate_registry_policy_action(self, policy_text):
def _validate_registry_policy_action(self, policy_text: str) -> None:
# only CreateRepository & ReplicateImage actions are allowed
VALID_ACTIONS = {"ecr:CreateRepository", "ecr:ReplicateImage"}
@ -870,7 +916,7 @@ class ECRBackend(BaseBackend):
if set(action) - VALID_ACTIONS:
raise MalformedPolicyDocument()
def put_registry_policy(self, policy_text):
def put_registry_policy(self, policy_text: str) -> Dict[str, Any]:
try:
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_text)
iam_policy_document_validator.validate()
@ -889,7 +935,7 @@ class ECRBackend(BaseBackend):
"policyText": policy_text,
}
def get_registry_policy(self):
def get_registry_policy(self) -> Dict[str, Any]:
if not self.registry_policy:
raise RegistryPolicyNotFoundException(self.account_id)
@ -898,7 +944,7 @@ class ECRBackend(BaseBackend):
"policyText": self.registry_policy,
}
def delete_registry_policy(self):
def delete_registry_policy(self) -> Dict[str, Any]:
policy = self.registry_policy
if not policy:
raise RegistryPolicyNotFoundException(self.account_id)
@ -910,7 +956,9 @@ class ECRBackend(BaseBackend):
"policyText": policy,
}
def start_image_scan(self, registry_id, repository_name, image_id):
def start_image_scan(
self, registry_id: str, repository_name: str, image_id: Dict[str, str]
) -> Dict[str, Any]:
repo = self._get_repository(repository_name, registry_id)
image = repo._get_image(image_id.get("imageTag"), image_id.get("imageDigest"))
@ -931,7 +979,9 @@ class ECRBackend(BaseBackend):
"imageScanStatus": {"status": "IN_PROGRESS"},
}
def describe_image_scan_findings(self, registry_id, repository_name, image_id):
def describe_image_scan_findings(
self, registry_id: str, repository_name: str, image_id: Dict[str, Any]
) -> Dict[str, Any]:
repo = self._get_repository(repository_name, registry_id)
image = repo._get_image(image_id.get("imageTag"), image_id.get("imageDigest"))
@ -984,7 +1034,9 @@ class ECRBackend(BaseBackend):
},
}
def put_replication_configuration(self, replication_config):
def put_replication_configuration(
self, replication_config: Dict[str, Any]
) -> Dict[str, Any]:
rules = replication_config["rules"]
if len(rules) > 1:
raise ValidationException("This feature is disabled")
@ -1004,7 +1056,7 @@ class ECRBackend(BaseBackend):
return {"replicationConfiguration": replication_config}
def describe_registry(self):
def describe_registry(self) -> Dict[str, Any]:
return {
"registryId": self.account_id,
"replicationConfiguration": self.replication_config,

View File

@ -1,4 +1,5 @@
import json
from typing import Any, Dict, List
from moto.ecr.exceptions import InvalidParameterException
@ -28,12 +29,12 @@ class EcrLifecyclePolicyValidator:
"'Lifecycle policy validation failure: "
)
def __init__(self, policy_text):
def __init__(self, policy_text: str):
self._policy_text = policy_text
self._policy_json = {}
self._rules = []
self._policy_json: Dict[str, Any] = {}
self._rules: List[Any] = []
def validate(self):
def validate(self) -> None:
try:
self._parse_policy()
except Exception:
@ -61,17 +62,17 @@ class EcrLifecyclePolicyValidator:
self._validate_rule_type()
self._validate_rule_top_properties()
def _parse_policy(self):
def _parse_policy(self) -> None:
self._policy_json = json.loads(self._policy_text)
assert isinstance(self._policy_json, dict)
def _extract_rules(self):
def _extract_rules(self) -> None:
assert "rules" in self._policy_json
assert isinstance(self._policy_json["rules"], list)
self._rules = self._policy_json["rules"]
def _validate_rule_type(self):
def _validate_rule_type(self) -> None:
for rule in self._rules:
if not isinstance(rule, dict):
raise InvalidParameterException(
@ -83,7 +84,7 @@ class EcrLifecyclePolicyValidator:
)
)
def _validate_rule_top_properties(self):
def _validate_rule_top_properties(self) -> None:
for rule in self._rules:
rule_properties = set(rule.keys())
missing_properties = REQUIRED_RULE_PROPERTIES - rule_properties
@ -111,7 +112,7 @@ class EcrLifecyclePolicyValidator:
self._validate_action(rule["action"])
self._validate_selection(rule["selection"])
def _validate_action(self, action):
def _validate_action(self, action: Any) -> None:
given_properties = set(action.keys())
missing_properties = REQUIRED_ACTION_PROPERTIES - given_properties
@ -139,7 +140,7 @@ class EcrLifecyclePolicyValidator:
self._validate_action_type(action["type"])
def _validate_action_type(self, action_type):
def _validate_action_type(self, action_type: str) -> None:
if action_type not in VALID_ACTION_TYPE_VALUES:
raise InvalidParameterException(
"".join(
@ -151,7 +152,7 @@ class EcrLifecyclePolicyValidator:
)
)
def _validate_selection(self, selection):
def _validate_selection(self, selection: Any) -> None:
given_properties = set(selection.keys())
missing_properties = REQUIRED_SELECTION_PROPERTIES - given_properties
@ -182,7 +183,7 @@ class EcrLifecyclePolicyValidator:
self._validate_selection_count_unit(selection.get("countUnit"))
self._validate_selection_count_number(selection["countNumber"])
def _validate_selection_tag_status(self, tag_status):
def _validate_selection_tag_status(self, tag_status: Any) -> None:
if tag_status not in VALID_SELECTION_TAG_STATUS_VALUES:
raise InvalidParameterException(
"".join(
@ -194,7 +195,7 @@ class EcrLifecyclePolicyValidator:
)
)
def _validate_selection_count_type(self, count_type):
def _validate_selection_count_type(self, count_type: Any) -> None:
if count_type not in VALID_SELECTION_COUNT_TYPE_VALUES:
raise InvalidParameterException(
"".join(
@ -205,7 +206,7 @@ class EcrLifecyclePolicyValidator:
)
)
def _validate_selection_count_unit(self, count_unit):
def _validate_selection_count_unit(self, count_unit: Any) -> None:
if not count_unit:
return None
@ -220,7 +221,7 @@ class EcrLifecyclePolicyValidator:
)
)
def _validate_selection_count_number(self, count_number):
def _validate_selection_count_number(self, count_number: int) -> None:
if count_number < 1:
raise InvalidParameterException(
"".join(

View File

@ -4,28 +4,18 @@ from datetime import datetime
import time
from moto.core.responses import BaseResponse
from .models import ecr_backends
from .models import ecr_backends, ECRBackend
class ECRResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="ecr")
@property
def ecr_backend(self):
def ecr_backend(self) -> ECRBackend:
return ecr_backends[self.current_account][self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param_name, if_none=None):
return self.request_params.get(param_name, if_none)
def create_repository(self):
def create_repository(self) -> str:
repository_name = self._get_param("repositoryName")
registry_id = self._get_param("registryId")
encryption_config = self._get_param("encryptionConfiguration")
@ -43,7 +33,7 @@ class ECRResponse(BaseResponse):
)
return json.dumps({"repository": repository.response_object})
def describe_repositories(self):
def describe_repositories(self) -> str:
describe_repositories_name = self._get_param("repositoryNames")
registry_id = self._get_param("registryId")
@ -52,7 +42,7 @@ class ECRResponse(BaseResponse):
)
return json.dumps({"repositories": repositories, "failures": []})
def delete_repository(self):
def delete_repository(self) -> str:
repository_str = self._get_param("repositoryName")
registry_id = self._get_param("registryId")
force = self._get_param("force")
@ -62,7 +52,7 @@ class ECRResponse(BaseResponse):
)
return json.dumps({"repository": repository.response_object})
def put_image(self):
def put_image(self) -> str:
repository_str = self._get_param("repositoryName")
image_manifest = self._get_param("imageManifest")
image_tag = self._get_param("imageTag")
@ -70,7 +60,7 @@ class ECRResponse(BaseResponse):
return json.dumps({"image": image.response_object})
def list_images(self):
def list_images(self) -> str:
repository_str = self._get_param("repositoryName")
registry_id = self._get_param("registryId")
images = self.ecr_backend.list_images(repository_str, registry_id)
@ -78,7 +68,7 @@ class ECRResponse(BaseResponse):
{"imageIds": [image.response_list_object for image in images]}
)
def describe_images(self):
def describe_images(self) -> str:
repository_str = self._get_param("repositoryName")
registry_id = self._get_param("registryId")
image_ids = self._get_param("imageIds")
@ -89,13 +79,13 @@ class ECRResponse(BaseResponse):
{"imageDetails": [image.response_describe_object for image in images]}
)
def batch_check_layer_availability(self):
def batch_check_layer_availability(self) -> None:
self.error_on_dryrun()
raise NotImplementedError(
"ECR.batch_check_layer_availability is not yet implemented"
)
def batch_delete_image(self):
def batch_delete_image(self) -> str:
repository_str = self._get_param("repositoryName")
registry_id = self._get_param("registryId")
image_ids = self._get_param("imageIds")
@ -105,7 +95,7 @@ class ECRResponse(BaseResponse):
)
return json.dumps(response)
def batch_get_image(self):
def batch_get_image(self) -> str:
repository_str = self._get_param("repositoryName")
registry_id = self._get_param("registryId")
image_ids = self._get_param("imageIds")
@ -115,11 +105,11 @@ class ECRResponse(BaseResponse):
)
return json.dumps(response)
def complete_layer_upload(self):
def complete_layer_upload(self) -> None:
self.error_on_dryrun()
raise NotImplementedError("ECR.complete_layer_upload is not yet implemented")
def delete_repository_policy(self):
def delete_repository_policy(self) -> str:
registry_id = self._get_param("registryId")
repository_name = self._get_param("repositoryName")
@ -129,7 +119,7 @@ class ECRResponse(BaseResponse):
)
)
def get_authorization_token(self):
def get_authorization_token(self) -> str:
registry_ids = self._get_param("registryIds")
if not registry_ids:
registry_ids = [self.current_account]
@ -146,13 +136,13 @@ class ECRResponse(BaseResponse):
)
return json.dumps({"authorizationData": auth_data})
def get_download_url_for_layer(self):
def get_download_url_for_layer(self) -> None:
self.error_on_dryrun()
raise NotImplementedError(
"ECR.get_download_url_for_layer is not yet implemented"
)
def get_repository_policy(self):
def get_repository_policy(self) -> str:
registry_id = self._get_param("registryId")
repository_name = self._get_param("repositoryName")
@ -162,11 +152,11 @@ class ECRResponse(BaseResponse):
)
)
def initiate_layer_upload(self):
def initiate_layer_upload(self) -> None:
self.error_on_dryrun()
raise NotImplementedError("ECR.initiate_layer_upload is not yet implemented")
def set_repository_policy(self):
def set_repository_policy(self) -> str:
registry_id = self._get_param("registryId")
repository_name = self._get_param("repositoryName")
policy_text = self._get_param("policyText")
@ -182,28 +172,30 @@ class ECRResponse(BaseResponse):
)
)
def upload_layer_part(self):
def upload_layer_part(self) -> None:
self.error_on_dryrun()
raise NotImplementedError("ECR.upload_layer_part is not yet implemented")
def list_tags_for_resource(self):
def list_tags_for_resource(self) -> str:
arn = self._get_param("resourceArn")
return json.dumps(self.ecr_backend.list_tags_for_resource(arn))
def tag_resource(self):
def tag_resource(self) -> str:
arn = self._get_param("resourceArn")
tags = self._get_param("tags", [])
return json.dumps(self.ecr_backend.tag_resource(arn, tags))
self.ecr_backend.tag_resource(arn, tags)
return "{}"
def untag_resource(self):
def untag_resource(self) -> str:
arn = self._get_param("resourceArn")
tag_keys = self._get_param("tagKeys", [])
return json.dumps(self.ecr_backend.untag_resource(arn, tag_keys))
self.ecr_backend.untag_resource(arn, tag_keys)
return "{}"
def put_image_tag_mutability(self):
def put_image_tag_mutability(self) -> str:
registry_id = self._get_param("registryId")
repository_name = self._get_param("repositoryName")
image_tag_mutability = self._get_param("imageTagMutability")
@ -216,7 +208,7 @@ class ECRResponse(BaseResponse):
)
)
def put_image_scanning_configuration(self):
def put_image_scanning_configuration(self) -> str:
registry_id = self._get_param("registryId")
repository_name = self._get_param("repositoryName")
image_scan_config = self._get_param("imageScanningConfiguration")
@ -229,7 +221,7 @@ class ECRResponse(BaseResponse):
)
)
def put_lifecycle_policy(self):
def put_lifecycle_policy(self) -> str:
registry_id = self._get_param("registryId")
repository_name = self._get_param("repositoryName")
lifecycle_policy_text = self._get_param("lifecyclePolicyText")
@ -242,7 +234,7 @@ class ECRResponse(BaseResponse):
)
)
def get_lifecycle_policy(self):
def get_lifecycle_policy(self) -> str:
registry_id = self._get_param("registryId")
repository_name = self._get_param("repositoryName")
@ -252,7 +244,7 @@ class ECRResponse(BaseResponse):
)
)
def delete_lifecycle_policy(self):
def delete_lifecycle_policy(self) -> str:
registry_id = self._get_param("registryId")
repository_name = self._get_param("repositoryName")
@ -262,18 +254,18 @@ class ECRResponse(BaseResponse):
)
)
def put_registry_policy(self):
def put_registry_policy(self) -> str:
policy_text = self._get_param("policyText")
return json.dumps(self.ecr_backend.put_registry_policy(policy_text=policy_text))
def get_registry_policy(self):
def get_registry_policy(self) -> str:
return json.dumps(self.ecr_backend.get_registry_policy())
def delete_registry_policy(self):
def delete_registry_policy(self) -> str:
return json.dumps(self.ecr_backend.delete_registry_policy())
def start_image_scan(self):
def start_image_scan(self) -> str:
registry_id = self._get_param("registryId")
repository_name = self._get_param("repositoryName")
image_id = self._get_param("imageId")
@ -286,7 +278,7 @@ class ECRResponse(BaseResponse):
)
)
def describe_image_scan_findings(self):
def describe_image_scan_findings(self) -> str:
registry_id = self._get_param("registryId")
repository_name = self._get_param("repositoryName")
image_id = self._get_param("imageId")
@ -299,7 +291,7 @@ class ECRResponse(BaseResponse):
)
)
def put_replication_configuration(self):
def put_replication_configuration(self) -> str:
replication_config = self._get_param("replicationConfiguration")
return json.dumps(
@ -308,5 +300,5 @@ class ECRResponse(BaseResponse):
)
)
def describe_registry(self):
def describe_registry(self) -> str:
return json.dumps(self.ecr_backend.describe_registry())

View File

@ -43,7 +43,7 @@ class MalformedCertificate(RESTError):
class MalformedPolicyDocument(RESTError):
code = 400
def __init__(self, message=""):
def __init__(self, message: str = ""):
super().__init__(
"MalformedPolicyDocument",
message,

View File

@ -513,10 +513,10 @@ class BaseIAMPolicyValidator:
class IAMPolicyDocumentValidator(BaseIAMPolicyValidator):
def __init__(self, policy_document):
def __init__(self, policy_document: str):
super().__init__(policy_document)
def validate(self):
def validate(self) -> None:
super().validate()
try:
self._validate_resource_exist()

View File

@ -229,7 +229,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2,moto/ec2instanceconnect,moto/es,moto/moto_api
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2,moto/ec2instanceconnect,moto/ecr,moto/es,moto/moto_api
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract