TechDebt - MyPy the ACM module (#5540)

This commit is contained in:
Bert Blommers 2022-10-07 14:41:31 +00:00 committed by GitHub
parent e98341fa89
commit 1a8f93dce3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 178 additions and 181 deletions

View File

@ -28,7 +28,7 @@ lint:
@echo "Running pylint..." @echo "Running pylint..."
pylint -j 0 moto tests pylint -j 0 moto tests
@echo "Running MyPy..." @echo "Running MyPy..."
mypy --install-types --non-interactive moto/applicationautoscaling/ mypy --install-types --non-interactive moto/acm moto/applicationautoscaling/
format: format:
black moto/ tests/ black moto/ tests/

20
moto/acm/exceptions.py Normal file
View File

@ -0,0 +1,20 @@
from moto.core.exceptions import AWSError
class AWSValidationException(AWSError):
TYPE = "ValidationException"
class AWSResourceNotFoundException(AWSError):
TYPE = "ResourceNotFoundException"
class CertificateNotFound(AWSResourceNotFoundException):
def __init__(self, arn: str, account_id: str):
super().__init__(
message=f"Certificate with arn {arn} not found in account {account_id}"
)
class AWSTooManyTagsException(AWSError):
TYPE = "TooManyTagsException"

View File

@ -2,13 +2,19 @@ import base64
import re import re
import datetime import datetime
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import AWSError
from moto.core.utils import BackendDict from moto.core.utils import BackendDict
from moto import settings from moto import settings
from typing import Any, Dict, List, Iterable, Optional, Tuple, Set
from .exceptions import (
AWSValidationException,
AWSTooManyTagsException,
CertificateNotFound,
)
from .utils import make_arn_for_certificate from .utils import make_arn_for_certificate
import cryptography.x509 import cryptography.x509
from cryptography.x509 import OID_COMMON_NAME, NameOID, DNSName
import cryptography.hazmat.primitives.asymmetric.rsa import cryptography.hazmat.primitives.asymmetric.rsa
from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
@ -43,28 +49,16 @@ RnGfN8j8KLDVmWyTYMk8V+6j0LI4+4zFh2upqGMQHL3VFVFWBek6vCDWhB/b
# so for now a cheap response is just give any old root CA # so for now a cheap response is just give any old root CA
def datetime_to_epoch(date): def datetime_to_epoch(date: datetime.datetime) -> float:
return date.timestamp() return date.timestamp()
class AWSValidationException(AWSError): class TagHolder(Dict[str, Optional[str]]):
TYPE = "ValidationException"
class AWSResourceNotFoundException(AWSError):
TYPE = "ResourceNotFoundException"
class AWSTooManyTagsException(AWSError):
TYPE = "TooManyTagsException"
class TagHolder(dict):
MAX_TAG_COUNT = 50 MAX_TAG_COUNT = 50
MAX_KEY_LENGTH = 128 MAX_KEY_LENGTH = 128
MAX_VALUE_LENGTH = 256 MAX_VALUE_LENGTH = 256
def _validate_kv(self, key, value, index): def _validate_kv(self, key: str, value: Optional[str], index: int) -> None:
if len(key) > self.MAX_KEY_LENGTH: if len(key) > self.MAX_KEY_LENGTH:
raise AWSValidationException( raise AWSValidationException(
"Value '%s' at 'tags.%d.member.key' failed to satisfy constraint: Member must have length less than or equal to %s" "Value '%s' at 'tags.%d.member.key' failed to satisfy constraint: Member must have length less than or equal to %s"
@ -81,11 +75,11 @@ class TagHolder(dict):
% key % key
) )
def add(self, tags): def add(self, tags: List[Dict[str, str]]) -> None:
tags_copy = self.copy() tags_copy = self.copy()
for i, tag in enumerate(tags): for i, tag in enumerate(tags):
key = tag["Key"] key = tag["Key"]
value = tag.get("Value", None) value = tag.get("Value")
self._validate_kv(key, value, i + 1) self._validate_kv(key, value, i + 1)
tags_copy[key] = value tags_copy[key] = value
@ -97,10 +91,10 @@ class TagHolder(dict):
self.update(tags_copy) self.update(tags_copy)
def remove(self, tags): def remove(self, tags: List[Dict[str, str]]) -> None:
for i, tag in enumerate(tags): for i, tag in enumerate(tags):
key = tag["Key"] key = tag["Key"]
value = tag.get("Value", None) value = tag.get("Value")
self._validate_kv(key, value, i + 1) self._validate_kv(key, value, i + 1)
try: try:
# If value isnt provided, just delete key # If value isnt provided, just delete key
@ -112,45 +106,41 @@ class TagHolder(dict):
except KeyError: except KeyError:
pass pass
def equals(self, tags): def equals(self, tags: List[Dict[str, str]]) -> bool:
tags = {t["Key"]: t.get("Value", None) for t in tags} if tags else {} flat_tags = {t["Key"]: t.get("Value") for t in tags} if tags else {}
return self == tags return self == flat_tags
class CertBundle(BaseModel): class CertBundle(BaseModel):
def __init__( def __init__(
self, self,
account_id, account_id: str,
certificate, certificate: bytes,
private_key, private_key: bytes,
chain=None, chain: Optional[bytes] = None,
region="us-east-1", region: str = "us-east-1",
arn=None, arn: Optional[str] = None,
cert_type="IMPORTED", cert_type: str = "IMPORTED",
cert_status="ISSUED", cert_status: str = "ISSUED",
): ):
self.created_at = datetime.datetime.utcnow() self.created_at = datetime.datetime.utcnow()
self.cert = certificate self.cert = certificate
self._cert = None
self.common_name = None
self.key = private_key self.key = private_key
self._key = None # AWS always returns your chain + root CA
self.chain = chain self.chain = chain + b"\n" + AWS_ROOT_CA if chain else AWS_ROOT_CA
self.tags = TagHolder() self.tags = TagHolder()
self._chain = None
self.type = cert_type # Should really be an enum self.type = cert_type # Should really be an enum
self.status = cert_status # Should really be an enum self.status = cert_status # Should really be an enum
self.in_use_by = [] self.in_use_by: List[str] = []
# AWS always returns your chain + root CA
if self.chain is None:
self.chain = AWS_ROOT_CA
else:
self.chain += b"\n" + AWS_ROOT_CA
# Takes care of PEM checking # Takes care of PEM checking
self.validate_pk() self._key = self.validate_pk()
self.validate_certificate() self._cert = self.validate_certificate()
# Extracting some common fields for ease of use
# Have to search through cert.subject for OIDs
self.common_name: Any = self._cert.subject.get_attributes_for_oid(
OID_COMMON_NAME
)[0].value
if chain is not None: if chain is not None:
self.validate_chain() self.validate_chain()
@ -158,57 +148,43 @@ class CertBundle(BaseModel):
# raise AWSValidationException('Provided certificate is not a valid self signed. Please provide either a valid self-signed certificate or certificate chain.') # raise AWSValidationException('Provided certificate is not a valid self signed. Please provide either a valid self-signed certificate or certificate chain.')
# Used for when one wants to overwrite an arn # Used for when one wants to overwrite an arn
if arn is None: self.arn = arn or make_arn_for_certificate(account_id, region)
self.arn = make_arn_for_certificate(account_id, region)
else:
self.arn = arn
@classmethod @classmethod
def generate_cert(cls, domain_name, account_id, region, sans=None): def generate_cert(
if sans is None: cls,
sans = set() domain_name: str,
else: account_id: str,
sans = set(sans) region: str,
sans: Optional[List[str]] = None,
) -> "CertBundle":
unique_sans: Set[str] = set(sans) if sans else set()
sans.add(domain_name) unique_sans.add(domain_name)
sans = [cryptography.x509.DNSName(item) for item in sans] unique_dns_names = [DNSName(item) for item in unique_sans]
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key( key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend() public_exponent=65537, key_size=2048, backend=default_backend()
) )
subject = cryptography.x509.Name( subject = cryptography.x509.Name(
[ [
cryptography.x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
cryptography.x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "CA"),
cryptography.x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
cryptography.x509.NameAttribute( cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COUNTRY_NAME, "US" NameOID.ORGANIZATION_NAME, "My Company"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, "CA"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.LOCALITY_NAME, "San Francisco"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATION_NAME, "My Company"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COMMON_NAME, domain_name
), ),
cryptography.x509.NameAttribute(NameOID.COMMON_NAME, domain_name),
] ]
) )
issuer = cryptography.x509.Name( issuer = cryptography.x509.Name(
[ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon [ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
cryptography.x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
cryptography.x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Amazon"),
cryptography.x509.NameAttribute( cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COUNTRY_NAME, "US" NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATION_NAME, "Amazon"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COMMON_NAME, "Amazon"
), ),
cryptography.x509.NameAttribute(NameOID.COMMON_NAME, "Amazon"),
] ]
) )
cert = ( cert = (
@ -220,7 +196,8 @@ class CertBundle(BaseModel):
.not_valid_before(datetime.datetime.utcnow()) .not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
.add_extension( .add_extension(
cryptography.x509.SubjectAlternativeName(sans), critical=False cryptography.x509.SubjectAlternativeName(unique_dns_names),
critical=False,
) )
.sign(key, hashes.SHA512(), default_backend()) .sign(key, hashes.SHA512(), default_backend())
) )
@ -241,17 +218,11 @@ class CertBundle(BaseModel):
region=region, region=region,
) )
def validate_pk(self): def validate_pk(self) -> Any:
try: try:
self._key = serialization.load_pem_private_key( return serialization.load_pem_private_key(
self.key, password=None, backend=default_backend() self.key, password=None, backend=default_backend()
) )
if self._key.key_size > 2048:
AWSValidationException(
"The private key length is not supported. Only 1024-bit and 2048-bit are allowed."
)
except Exception as err: except Exception as err:
if isinstance(err, AWSValidationException): if isinstance(err, AWSValidationException):
raise raise
@ -259,48 +230,40 @@ class CertBundle(BaseModel):
"The private key is not PEM-encoded or is not valid." "The private key is not PEM-encoded or is not valid."
) )
def validate_certificate(self): def validate_certificate(self) -> cryptography.x509.base.Certificate:
try: try:
self._cert = cryptography.x509.load_pem_x509_certificate( _cert = cryptography.x509.load_pem_x509_certificate(
self.cert, default_backend() self.cert, default_backend()
) )
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
if self._cert.not_valid_after < now: if _cert.not_valid_after < now:
raise AWSValidationException( raise AWSValidationException(
"The certificate has expired, is not valid." "The certificate has expired, is not valid."
) )
if self._cert.not_valid_before > now: if _cert.not_valid_before > now:
raise AWSValidationException( raise AWSValidationException(
"The certificate is not in effect yet, is not valid." "The certificate is not in effect yet, is not valid."
) )
# Extracting some common fields for ease of use
# Have to search through cert.subject for OIDs
self.common_name = self._cert.subject.get_attributes_for_oid(
cryptography.x509.OID_COMMON_NAME
)[0].value
except Exception as err: except Exception as err:
if isinstance(err, AWSValidationException): if isinstance(err, AWSValidationException):
raise raise
raise AWSValidationException( raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid." "The certificate is not PEM-encoded or is not valid."
) )
return _cert
def validate_chain(self): def validate_chain(self) -> None:
try: try:
self._chain = []
for cert_armored in self.chain.split(b"-\n-"): for cert_armored in self.chain.split(b"-\n-"):
# Fix missing -'s on split # Fix missing -'s on split
cert_armored = re.sub(b"^----B", b"-----B", cert_armored) cert_armored = re.sub(b"^----B", b"-----B", cert_armored)
cert_armored = re.sub(b"E----$", b"E-----", cert_armored) cert_armored = re.sub(b"E----$", b"E-----", cert_armored)
cert = cryptography.x509.load_pem_x509_certificate( cryptography.x509.load_pem_x509_certificate(
cert_armored, default_backend() cert_armored, default_backend()
) )
self._chain.append(cert)
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
if self._cert.not_valid_after < now: if self._cert.not_valid_after < now:
@ -320,7 +283,7 @@ class CertBundle(BaseModel):
"The certificate is not PEM-encoded or is not valid." "The certificate is not PEM-encoded or is not valid."
) )
def check(self): def check(self) -> None:
# Basically, if the certificate is pending, and then checked again after a # Basically, if the certificate is pending, and then checked again after a
# while, it will appear as if its been validated. The default wait time is 60 # while, it will appear as if its been validated. The default wait time is 60
# seconds but you can set an environment to change it. # seconds but you can set an environment to change it.
@ -332,7 +295,7 @@ class CertBundle(BaseModel):
): ):
self.status = "ISSUED" self.status = "ISSUED"
def describe(self): def describe(self) -> Dict[str, Any]:
# 'RenewalSummary': {}, # Only when cert is amazon issued # 'RenewalSummary': {}, # Only when cert is amazon issued
if self._key.key_size == 1024: if self._key.key_size == 1024:
key_algo = "RSA_1024" key_algo = "RSA_1024"
@ -343,7 +306,7 @@ class CertBundle(BaseModel):
# Look for SANs # Look for SANs
try: try:
san_obj = self._cert.extensions.get_extension_for_oid( san_obj: Any = self._cert.extensions.get_extension_for_oid(
cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME
) )
except cryptography.x509.ExtensionNotFound: except cryptography.x509.ExtensionNotFound:
@ -352,14 +315,14 @@ class CertBundle(BaseModel):
if san_obj is not None: if san_obj is not None:
sans = [item.value for item in san_obj.value] sans = [item.value for item in san_obj.value]
result = { result: Dict[str, Any] = {
"Certificate": { "Certificate": {
"CertificateArn": self.arn, "CertificateArn": self.arn,
"DomainName": self.common_name, "DomainName": self.common_name,
"InUseBy": self.in_use_by, "InUseBy": self.in_use_by,
"Issuer": self._cert.issuer.get_attributes_for_oid( "Issuer": self._cert.issuer.get_attributes_for_oid(OID_COMMON_NAME)[
cryptography.x509.OID_COMMON_NAME 0
)[0].value, ].value,
"KeyAlgorithm": key_algo, "KeyAlgorithm": key_algo,
"NotAfter": datetime_to_epoch(self._cert.not_valid_after), "NotAfter": datetime_to_epoch(self._cert.not_valid_after),
"NotBefore": datetime_to_epoch(self._cert.not_valid_before), "NotBefore": datetime_to_epoch(self._cert.not_valid_before),
@ -401,7 +364,7 @@ class CertBundle(BaseModel):
return result return result
def serialize_pk(self, passphrase_bytes): def serialize_pk(self, passphrase_bytes: bytes) -> str:
pk_bytes = self._key.private_bytes( pk_bytes = self._key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8, format=serialization.PrivateFormat.PKCS8,
@ -411,38 +374,36 @@ class CertBundle(BaseModel):
) )
return pk_bytes.decode("utf-8") return pk_bytes.decode("utf-8")
def __str__(self): def __str__(self) -> str:
return self.arn return self.arn
def __repr__(self): def __repr__(self) -> str:
return "<Certificate>" return "<Certificate>"
class AWSCertificateManagerBackend(BaseBackend): class AWSCertificateManagerBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self._certificates = {} self._certificates: Dict[str, CertBundle] = {}
self._idempotency_tokens = {} self._idempotency_tokens: Dict[str, Any] = {}
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "acm-pca" service_region, zones, "acm-pca"
) )
def _arn_not_found(self, arn): def set_certificate_in_use_by(self, arn: str, load_balancer_name: str) -> None:
msg = f"Certificate with arn {arn} not found in account {self.account_id}"
return AWSResourceNotFoundException(msg)
def set_certificate_in_use_by(self, arn, load_balancer_name):
if arn not in self._certificates: if arn not in self._certificates:
raise self._arn_not_found(arn) raise CertificateNotFound(arn=arn, account_id=self.account_id)
cert_bundle = self._certificates[arn] cert_bundle = self._certificates[arn]
cert_bundle.in_use_by.append(load_balancer_name) cert_bundle.in_use_by.append(load_balancer_name)
def _get_arn_from_idempotency_token(self, token): def _get_arn_from_idempotency_token(self, token: str) -> Optional[str]:
""" """
If token doesnt exist, return None, later it will be If token doesnt exist, return None, later it will be
set with an expiry and arn. set with an expiry and arn.
@ -465,16 +426,23 @@ class AWSCertificateManagerBackend(BaseBackend):
return None return None
def _set_idempotency_token_arn(self, token, arn): def _set_idempotency_token_arn(self, token: str, arn: str) -> None:
self._idempotency_tokens[token] = { self._idempotency_tokens[token] = {
"arn": arn, "arn": arn,
"expires": datetime.datetime.utcnow() + datetime.timedelta(hours=1), "expires": datetime.datetime.utcnow() + datetime.timedelta(hours=1),
} }
def import_cert(self, certificate, private_key, chain=None, arn=None, tags=None): def import_cert(
self,
certificate: bytes,
private_key: bytes,
chain: Optional[bytes],
arn: Optional[str],
tags: List[Dict[str, str]],
) -> str:
if arn is not None: if arn is not None:
if arn not in self._certificates: if arn not in self._certificates:
raise self._arn_not_found(arn) raise CertificateNotFound(arn=arn, account_id=self.account_id)
else: else:
# Will reuse provided ARN # Will reuse provided ARN
bundle = CertBundle( bundle = CertBundle(
@ -502,7 +470,7 @@ class AWSCertificateManagerBackend(BaseBackend):
return bundle.arn return bundle.arn
def get_certificates_list(self, statuses): def get_certificates_list(self, statuses: List[str]) -> Iterable[CertBundle]:
""" """
Get list of certificates Get list of certificates
@ -514,27 +482,27 @@ class AWSCertificateManagerBackend(BaseBackend):
if not statuses or cert.status in statuses: if not statuses or cert.status in statuses:
yield cert yield cert
def get_certificate(self, arn): def get_certificate(self, arn: str) -> CertBundle:
if arn not in self._certificates: if arn not in self._certificates:
raise self._arn_not_found(arn) raise CertificateNotFound(arn=arn, account_id=self.account_id)
cert_bundle = self._certificates[arn] cert_bundle = self._certificates[arn]
cert_bundle.check() cert_bundle.check()
return cert_bundle return cert_bundle
def delete_certificate(self, arn): def delete_certificate(self, arn: str) -> None:
if arn not in self._certificates: if arn not in self._certificates:
raise self._arn_not_found(arn) raise CertificateNotFound(arn=arn, account_id=self.account_id)
del self._certificates[arn] del self._certificates[arn]
def request_certificate( def request_certificate(
self, self,
domain_name, domain_name: str,
idempotency_token, idempotency_token: str,
subject_alt_names, subject_alt_names: List[str],
tags=None, tags: List[Dict[str, str]],
): ) -> str:
""" """
The parameter DomainValidationOptions has not yet been implemented The parameter DomainValidationOptions has not yet been implemented
""" """
@ -558,17 +526,21 @@ class AWSCertificateManagerBackend(BaseBackend):
return cert.arn return cert.arn
def add_tags_to_certificate(self, arn, tags): def add_tags_to_certificate(self, arn: str, tags: List[Dict[str, str]]) -> None:
# get_cert does arn check # get_cert does arn check
cert_bundle = self.get_certificate(arn) cert_bundle = self.get_certificate(arn)
cert_bundle.tags.add(tags) cert_bundle.tags.add(tags)
def remove_tags_from_certificate(self, arn, tags): def remove_tags_from_certificate(
self, arn: str, tags: List[Dict[str, str]]
) -> None:
# get_cert does arn check # get_cert does arn check
cert_bundle = self.get_certificate(arn) cert_bundle = self.get_certificate(arn)
cert_bundle.tags.remove(tags) cert_bundle.tags.remove(tags)
def export_certificate(self, certificate_arn, passphrase): def export_certificate(
self, certificate_arn: str, passphrase: str
) -> Tuple[str, str, str]:
passphrase_bytes = base64.standard_b64decode(passphrase) passphrase_bytes = base64.standard_b64decode(passphrase)
cert_bundle = self.get_certificate(certificate_arn) cert_bundle = self.get_certificate(certificate_arn)

View File

@ -2,34 +2,23 @@ import json
import base64 import base64
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import acm_backends, AWSValidationException from typing import Dict, List, Tuple, Union
from .models import acm_backends, AWSCertificateManagerBackend
from .exceptions import AWSValidationException
GENERIC_RESPONSE_TYPE = Union[str, Tuple[str, Dict[str, int]]]
class AWSCertificateManagerResponse(BaseResponse): class AWSCertificateManagerResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="acm") super().__init__(service_name="acm")
@property @property
def acm_backend(self): def acm_backend(self) -> AWSCertificateManagerBackend:
"""
ACM Backend
:return: ACM Backend object
:rtype: moto.acm.models.AWSCertificateManagerBackend
"""
return acm_backends[self.current_account][self.region] return acm_backends[self.current_account][self.region]
@property def add_tags_to_certificate(self) -> GENERIC_RESPONSE_TYPE:
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param_name, if_none=None):
return self.request_params.get(param_name, if_none)
def add_tags_to_certificate(self):
arn = self._get_param("CertificateArn") arn = self._get_param("CertificateArn")
tags = self._get_param("Tags") tags = self._get_param("Tags")
@ -44,7 +33,7 @@ class AWSCertificateManagerResponse(BaseResponse):
return "" return ""
def delete_certificate(self): def delete_certificate(self) -> GENERIC_RESPONSE_TYPE:
arn = self._get_param("CertificateArn") arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
@ -58,7 +47,7 @@ class AWSCertificateManagerResponse(BaseResponse):
return "" return ""
def describe_certificate(self): def describe_certificate(self) -> GENERIC_RESPONSE_TYPE:
arn = self._get_param("CertificateArn") arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
@ -72,7 +61,7 @@ class AWSCertificateManagerResponse(BaseResponse):
return json.dumps(cert_bundle.describe()) return json.dumps(cert_bundle.describe())
def get_certificate(self): def get_certificate(self) -> GENERIC_RESPONSE_TYPE:
arn = self._get_param("CertificateArn") arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
@ -90,7 +79,7 @@ class AWSCertificateManagerResponse(BaseResponse):
} }
return json.dumps(result) return json.dumps(result)
def import_certificate(self): def import_certificate(self) -> str:
""" """
Returns errors on: Returns errors on:
Certificate, PrivateKey or Chain not being properly formatted Certificate, PrivateKey or Chain not being properly formatted
@ -137,7 +126,7 @@ class AWSCertificateManagerResponse(BaseResponse):
return json.dumps({"CertificateArn": arn}) return json.dumps({"CertificateArn": arn})
def list_certificates(self): def list_certificates(self) -> str:
certs = [] certs = []
statuses = self._get_param("CertificateStatuses") statuses = self._get_param("CertificateStatuses")
for cert_bundle in self.acm_backend.get_certificates_list(statuses): for cert_bundle in self.acm_backend.get_certificates_list(statuses):
@ -151,16 +140,18 @@ class AWSCertificateManagerResponse(BaseResponse):
result = {"CertificateSummaryList": certs} result = {"CertificateSummaryList": certs}
return json.dumps(result) return json.dumps(result)
def list_tags_for_certificate(self): def list_tags_for_certificate(self) -> GENERIC_RESPONSE_TYPE:
arn = self._get_param("CertificateArn") arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
msg = "A required parameter for the specified action is not supplied." msg = "A required parameter for the specified action is not supplied."
return {"__type": "MissingParameter", "message": msg}, dict(status=400) return json.dumps({"__type": "MissingParameter", "message": msg}), dict(
status=400
)
cert_bundle = self.acm_backend.get_certificate(arn) cert_bundle = self.acm_backend.get_certificate(arn)
result = {"Tags": []} result: Dict[str, List[Dict[str, str]]] = {"Tags": []}
# Tag "objects" can not contain the Value part # Tag "objects" can not contain the Value part
for key, value in cert_bundle.tags.items(): for key, value in cert_bundle.tags.items():
tag_dict = {"Key": key} tag_dict = {"Key": key}
@ -170,7 +161,7 @@ class AWSCertificateManagerResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
def remove_tags_from_certificate(self): def remove_tags_from_certificate(self) -> GENERIC_RESPONSE_TYPE:
arn = self._get_param("CertificateArn") arn = self._get_param("CertificateArn")
tags = self._get_param("Tags") tags = self._get_param("Tags")
@ -185,7 +176,7 @@ class AWSCertificateManagerResponse(BaseResponse):
return "" return ""
def request_certificate(self): def request_certificate(self) -> GENERIC_RESPONSE_TYPE:
domain_name = self._get_param("DomainName") domain_name = self._get_param("DomainName")
idempotency_token = self._get_param("IdempotencyToken") idempotency_token = self._get_param("IdempotencyToken")
subject_alt_names = self._get_param("SubjectAlternativeNames") subject_alt_names = self._get_param("SubjectAlternativeNames")
@ -210,7 +201,7 @@ class AWSCertificateManagerResponse(BaseResponse):
return json.dumps({"CertificateArn": arn}) return json.dumps({"CertificateArn": arn})
def resend_validation_email(self): def resend_validation_email(self) -> GENERIC_RESPONSE_TYPE:
arn = self._get_param("CertificateArn") arn = self._get_param("CertificateArn")
domain = self._get_param("Domain") domain = self._get_param("Domain")
# ValidationDomain not used yet. # ValidationDomain not used yet.
@ -235,7 +226,7 @@ class AWSCertificateManagerResponse(BaseResponse):
return "" return ""
def export_certificate(self): def export_certificate(self) -> GENERIC_RESPONSE_TYPE:
certificate_arn = self._get_param("CertificateArn") certificate_arn = self._get_param("CertificateArn")
passphrase = self._get_param("Passphrase") passphrase = self._get_param("Passphrase")

View File

@ -1,7 +1,7 @@
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
def make_arn_for_certificate(account_id, region_name): def make_arn_for_certificate(account_id: str, region_name: str) -> str:
# Example # Example
# arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b # arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b
return "arn:aws:acm:{0}:{1}:certificate/{2}".format( return "arn:aws:acm:{0}:{1}:certificate/{2}".format(

View File

@ -1,5 +1,6 @@
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment from jinja2 import DictLoader, Environment
from typing import Optional
import json import json
# TODO: add "<Type>Sender</Type>" to error responses below? # TODO: add "<Type>Sender</Type>" to error responses below?
@ -133,10 +134,12 @@ class AuthFailureError(RESTError):
class AWSError(JsonRESTError): class AWSError(JsonRESTError):
TYPE = None TYPE: Optional[str] = None
STATUS = 400 STATUS = 400
def __init__(self, message, exception_type=None, status=None): def __init__(
self, message: str, exception_type: str = None, status: Optional[int] = None
):
super().__init__(exception_type or self.TYPE, message) super().__init__(exception_type or self.TYPE, message)
self.code = status or self.STATUS self.code = status or self.STATUS

View File

@ -575,12 +575,12 @@ class ELBBackend(BaseBackend):
return load_balancer return load_balancer
def _register_certificate(self, ssl_certificate_id, dns_name): def _register_certificate(self, ssl_certificate_id, dns_name):
from moto.acm.models import acm_backends, AWSResourceNotFoundException from moto.acm.models import acm_backends, CertificateNotFound
acm_backend = acm_backends[self.account_id][self.region_name] acm_backend = acm_backends[self.account_id][self.region_name]
try: try:
acm_backend.set_certificate_in_use_by(ssl_certificate_id, dns_name) acm_backend.set_certificate_in_use_by(ssl_certificate_id, dns_name)
except AWSResourceNotFoundException: except CertificateNotFound:
raise CertificateNotFoundException() raise CertificateNotFoundException()
def enable_availability_zones_for_load_balancer( def enable_availability_zones_for_load_balancer(

View File

@ -1548,14 +1548,13 @@ Member must satisfy regular expression pattern: {}".format(
""" """
Verify the provided certificate exists in either ACM or IAM Verify the provided certificate exists in either ACM or IAM
""" """
from moto.acm import acm_backends from moto.acm.models import acm_backends, CertificateNotFound
from moto.acm.models import AWSResourceNotFoundException
try: try:
acm_backend = acm_backends[self.account_id][self.region_name] acm_backend = acm_backends[self.account_id][self.region_name]
acm_backend.get_certificate(certificate_arn) acm_backend.get_certificate(certificate_arn)
return True return True
except AWSResourceNotFoundException: except CertificateNotFound:
pass pass
from moto.iam import iam_backends from moto.iam import iam_backends

View File

@ -433,6 +433,18 @@ def test_request_certificate_with_tags():
{"Key": "WithEmptyStr", "Value": ""}, {"Key": "WithEmptyStr", "Value": ""},
], ],
) )
arn_3 = resp["CertificateArn"]
assert arn_1 != arn_3 # if tags are matched, ACM would have returned same arn
resp = client.request_certificate(
DomainName="google.com",
IdempotencyToken=token,
SubjectAlternativeNames=["google.com", "www.google.com", "mail.google.com"],
)
arn_4 = resp["CertificateArn"]
assert arn_1 != arn_4 # if tags are matched, ACM would have returned same arn
@mock_acm @mock_acm