From 1a8f93dce33f1c1c7983a8052ef54de30ff64407 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Fri, 7 Oct 2022 14:41:31 +0000 Subject: [PATCH] TechDebt - MyPy the ACM module (#5540) --- Makefile | 2 +- moto/acm/exceptions.py | 20 +++ moto/acm/models.py | 252 +++++++++++++++++-------------------- moto/acm/responses.py | 55 ++++---- moto/acm/utils.py | 2 +- moto/core/exceptions.py | 7 +- moto/elb/models.py | 4 +- moto/elbv2/models.py | 5 +- tests/test_acm/test_acm.py | 12 ++ 9 files changed, 178 insertions(+), 181 deletions(-) create mode 100644 moto/acm/exceptions.py diff --git a/Makefile b/Makefile index 0b87251c0..d1636d048 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ lint: @echo "Running pylint..." pylint -j 0 moto tests @echo "Running MyPy..." - mypy --install-types --non-interactive moto/applicationautoscaling/ + mypy --install-types --non-interactive moto/acm moto/applicationautoscaling/ format: black moto/ tests/ diff --git a/moto/acm/exceptions.py b/moto/acm/exceptions.py new file mode 100644 index 000000000..70233a1cb --- /dev/null +++ b/moto/acm/exceptions.py @@ -0,0 +1,20 @@ +from moto.core.exceptions import AWSError + + +class AWSValidationException(AWSError): + TYPE = "ValidationException" + + +class AWSResourceNotFoundException(AWSError): + TYPE = "ResourceNotFoundException" + + +class CertificateNotFound(AWSResourceNotFoundException): + def __init__(self, arn: str, account_id: str): + super().__init__( + message=f"Certificate with arn {arn} not found in account {account_id}" + ) + + +class AWSTooManyTagsException(AWSError): + TYPE = "TooManyTagsException" diff --git a/moto/acm/models.py b/moto/acm/models.py index 3627f8a12..a1cd531b2 100644 --- a/moto/acm/models.py +++ b/moto/acm/models.py @@ -2,13 +2,19 @@ import base64 import re import datetime from moto.core import BaseBackend, BaseModel -from moto.core.exceptions import AWSError from moto.core.utils import BackendDict from moto import settings +from typing import Any, Dict, List, Iterable, Optional, Tuple, Set +from .exceptions import ( + AWSValidationException, + AWSTooManyTagsException, + CertificateNotFound, +) from .utils import make_arn_for_certificate import cryptography.x509 +from cryptography.x509 import OID_COMMON_NAME, NameOID, DNSName import cryptography.hazmat.primitives.asymmetric.rsa from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.backends import default_backend @@ -43,28 +49,16 @@ RnGfN8j8KLDVmWyTYMk8V+6j0LI4+4zFh2upqGMQHL3VFVFWBek6vCDWhB/b # so for now a cheap response is just give any old root CA -def datetime_to_epoch(date): +def datetime_to_epoch(date: datetime.datetime) -> float: return date.timestamp() -class AWSValidationException(AWSError): - TYPE = "ValidationException" - - -class AWSResourceNotFoundException(AWSError): - TYPE = "ResourceNotFoundException" - - -class AWSTooManyTagsException(AWSError): - TYPE = "TooManyTagsException" - - -class TagHolder(dict): +class TagHolder(Dict[str, Optional[str]]): MAX_TAG_COUNT = 50 MAX_KEY_LENGTH = 128 MAX_VALUE_LENGTH = 256 - def _validate_kv(self, key, value, index): + def _validate_kv(self, key: str, value: Optional[str], index: int) -> None: if len(key) > self.MAX_KEY_LENGTH: raise AWSValidationException( "Value '%s' at 'tags.%d.member.key' failed to satisfy constraint: Member must have length less than or equal to %s" @@ -81,11 +75,11 @@ class TagHolder(dict): % key ) - def add(self, tags): + def add(self, tags: List[Dict[str, str]]) -> None: tags_copy = self.copy() for i, tag in enumerate(tags): key = tag["Key"] - value = tag.get("Value", None) + value = tag.get("Value") self._validate_kv(key, value, i + 1) tags_copy[key] = value @@ -97,10 +91,10 @@ class TagHolder(dict): self.update(tags_copy) - def remove(self, tags): + def remove(self, tags: List[Dict[str, str]]) -> None: for i, tag in enumerate(tags): key = tag["Key"] - value = tag.get("Value", None) + value = tag.get("Value") self._validate_kv(key, value, i + 1) try: # If value isnt provided, just delete key @@ -112,45 +106,41 @@ class TagHolder(dict): except KeyError: pass - def equals(self, tags): - tags = {t["Key"]: t.get("Value", None) for t in tags} if tags else {} - return self == tags + def equals(self, tags: List[Dict[str, str]]) -> bool: + flat_tags = {t["Key"]: t.get("Value") for t in tags} if tags else {} + return self == flat_tags class CertBundle(BaseModel): def __init__( self, - account_id, - certificate, - private_key, - chain=None, - region="us-east-1", - arn=None, - cert_type="IMPORTED", - cert_status="ISSUED", + account_id: str, + certificate: bytes, + private_key: bytes, + chain: Optional[bytes] = None, + region: str = "us-east-1", + arn: Optional[str] = None, + cert_type: str = "IMPORTED", + cert_status: str = "ISSUED", ): self.created_at = datetime.datetime.utcnow() self.cert = certificate - self._cert = None - self.common_name = None self.key = private_key - self._key = None - self.chain = chain + # AWS always returns your chain + root CA + self.chain = chain + b"\n" + AWS_ROOT_CA if chain else AWS_ROOT_CA self.tags = TagHolder() - self._chain = None self.type = cert_type # Should really be an enum self.status = cert_status # Should really be an enum - self.in_use_by = [] - - # AWS always returns your chain + root CA - if self.chain is None: - self.chain = AWS_ROOT_CA - else: - self.chain += b"\n" + AWS_ROOT_CA + self.in_use_by: List[str] = [] # Takes care of PEM checking - self.validate_pk() - self.validate_certificate() + self._key = self.validate_pk() + self._cert = self.validate_certificate() + # Extracting some common fields for ease of use + # Have to search through cert.subject for OIDs + self.common_name: Any = self._cert.subject.get_attributes_for_oid( + OID_COMMON_NAME + )[0].value if chain is not None: self.validate_chain() @@ -158,57 +148,43 @@ class CertBundle(BaseModel): # raise AWSValidationException('Provided certificate is not a valid self signed. Please provide either a valid self-signed certificate or certificate chain.') # Used for when one wants to overwrite an arn - if arn is None: - self.arn = make_arn_for_certificate(account_id, region) - else: - self.arn = arn + self.arn = arn or make_arn_for_certificate(account_id, region) @classmethod - def generate_cert(cls, domain_name, account_id, region, sans=None): - if sans is None: - sans = set() - else: - sans = set(sans) + def generate_cert( + cls, + domain_name: str, + account_id: str, + region: str, + sans: Optional[List[str]] = None, + ) -> "CertBundle": + unique_sans: Set[str] = set(sans) if sans else set() - sans.add(domain_name) - sans = [cryptography.x509.DNSName(item) for item in sans] + unique_sans.add(domain_name) + unique_dns_names = [DNSName(item) for item in unique_sans] key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() ) subject = cryptography.x509.Name( [ + cryptography.x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + cryptography.x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "CA"), + cryptography.x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"), cryptography.x509.NameAttribute( - cryptography.x509.NameOID.COUNTRY_NAME, "US" - ), - cryptography.x509.NameAttribute( - cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, "CA" - ), - cryptography.x509.NameAttribute( - cryptography.x509.NameOID.LOCALITY_NAME, "San Francisco" - ), - cryptography.x509.NameAttribute( - cryptography.x509.NameOID.ORGANIZATION_NAME, "My Company" - ), - cryptography.x509.NameAttribute( - cryptography.x509.NameOID.COMMON_NAME, domain_name + NameOID.ORGANIZATION_NAME, "My Company" ), + cryptography.x509.NameAttribute(NameOID.COMMON_NAME, domain_name), ] ) issuer = cryptography.x509.Name( [ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon + cryptography.x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + cryptography.x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Amazon"), cryptography.x509.NameAttribute( - cryptography.x509.NameOID.COUNTRY_NAME, "US" - ), - cryptography.x509.NameAttribute( - cryptography.x509.NameOID.ORGANIZATION_NAME, "Amazon" - ), - cryptography.x509.NameAttribute( - cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B" - ), - cryptography.x509.NameAttribute( - cryptography.x509.NameOID.COMMON_NAME, "Amazon" + NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B" ), + cryptography.x509.NameAttribute(NameOID.COMMON_NAME, "Amazon"), ] ) cert = ( @@ -220,7 +196,8 @@ class CertBundle(BaseModel): .not_valid_before(datetime.datetime.utcnow()) .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) .add_extension( - cryptography.x509.SubjectAlternativeName(sans), critical=False + cryptography.x509.SubjectAlternativeName(unique_dns_names), + critical=False, ) .sign(key, hashes.SHA512(), default_backend()) ) @@ -241,17 +218,11 @@ class CertBundle(BaseModel): region=region, ) - def validate_pk(self): + def validate_pk(self) -> Any: try: - self._key = serialization.load_pem_private_key( + return serialization.load_pem_private_key( self.key, password=None, backend=default_backend() ) - - if self._key.key_size > 2048: - AWSValidationException( - "The private key length is not supported. Only 1024-bit and 2048-bit are allowed." - ) - except Exception as err: if isinstance(err, AWSValidationException): raise @@ -259,48 +230,40 @@ class CertBundle(BaseModel): "The private key is not PEM-encoded or is not valid." ) - def validate_certificate(self): + def validate_certificate(self) -> cryptography.x509.base.Certificate: try: - self._cert = cryptography.x509.load_pem_x509_certificate( + _cert = cryptography.x509.load_pem_x509_certificate( self.cert, default_backend() ) now = datetime.datetime.utcnow() - if self._cert.not_valid_after < now: + if _cert.not_valid_after < now: raise AWSValidationException( "The certificate has expired, is not valid." ) - if self._cert.not_valid_before > now: + if _cert.not_valid_before > now: raise AWSValidationException( "The certificate is not in effect yet, is not valid." ) - # Extracting some common fields for ease of use - # Have to search through cert.subject for OIDs - self.common_name = self._cert.subject.get_attributes_for_oid( - cryptography.x509.OID_COMMON_NAME - )[0].value - except Exception as err: if isinstance(err, AWSValidationException): raise raise AWSValidationException( "The certificate is not PEM-encoded or is not valid." ) + return _cert - def validate_chain(self): + def validate_chain(self) -> None: try: - self._chain = [] - for cert_armored in self.chain.split(b"-\n-"): # Fix missing -'s on split cert_armored = re.sub(b"^----B", b"-----B", cert_armored) cert_armored = re.sub(b"E----$", b"E-----", cert_armored) - cert = cryptography.x509.load_pem_x509_certificate( + cryptography.x509.load_pem_x509_certificate( cert_armored, default_backend() ) - self._chain.append(cert) now = datetime.datetime.utcnow() if self._cert.not_valid_after < now: @@ -320,7 +283,7 @@ class CertBundle(BaseModel): "The certificate is not PEM-encoded or is not valid." ) - def check(self): + def check(self) -> None: # Basically, if the certificate is pending, and then checked again after a # while, it will appear as if its been validated. The default wait time is 60 # seconds but you can set an environment to change it. @@ -332,7 +295,7 @@ class CertBundle(BaseModel): ): self.status = "ISSUED" - def describe(self): + def describe(self) -> Dict[str, Any]: # 'RenewalSummary': {}, # Only when cert is amazon issued if self._key.key_size == 1024: key_algo = "RSA_1024" @@ -343,7 +306,7 @@ class CertBundle(BaseModel): # Look for SANs try: - san_obj = self._cert.extensions.get_extension_for_oid( + san_obj: Any = self._cert.extensions.get_extension_for_oid( cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME ) except cryptography.x509.ExtensionNotFound: @@ -352,14 +315,14 @@ class CertBundle(BaseModel): if san_obj is not None: sans = [item.value for item in san_obj.value] - result = { + result: Dict[str, Any] = { "Certificate": { "CertificateArn": self.arn, "DomainName": self.common_name, "InUseBy": self.in_use_by, - "Issuer": self._cert.issuer.get_attributes_for_oid( - cryptography.x509.OID_COMMON_NAME - )[0].value, + "Issuer": self._cert.issuer.get_attributes_for_oid(OID_COMMON_NAME)[ + 0 + ].value, "KeyAlgorithm": key_algo, "NotAfter": datetime_to_epoch(self._cert.not_valid_after), "NotBefore": datetime_to_epoch(self._cert.not_valid_before), @@ -401,7 +364,7 @@ class CertBundle(BaseModel): return result - def serialize_pk(self, passphrase_bytes): + def serialize_pk(self, passphrase_bytes: bytes) -> str: pk_bytes = self._key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, @@ -411,38 +374,36 @@ class CertBundle(BaseModel): ) return pk_bytes.decode("utf-8") - def __str__(self): + def __str__(self) -> str: return self.arn - def __repr__(self): + def __repr__(self) -> str: return "" class AWSCertificateManagerBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self._certificates = {} - self._idempotency_tokens = {} + self._certificates: Dict[str, CertBundle] = {} + self._idempotency_tokens: Dict[str, Any] = {} @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "acm-pca" ) - def _arn_not_found(self, arn): - msg = f"Certificate with arn {arn} not found in account {self.account_id}" - return AWSResourceNotFoundException(msg) - - def set_certificate_in_use_by(self, arn, load_balancer_name): + def set_certificate_in_use_by(self, arn: str, load_balancer_name: str) -> None: if arn not in self._certificates: - raise self._arn_not_found(arn) + raise CertificateNotFound(arn=arn, account_id=self.account_id) cert_bundle = self._certificates[arn] cert_bundle.in_use_by.append(load_balancer_name) - def _get_arn_from_idempotency_token(self, token): + def _get_arn_from_idempotency_token(self, token: str) -> Optional[str]: """ If token doesnt exist, return None, later it will be set with an expiry and arn. @@ -465,16 +426,23 @@ class AWSCertificateManagerBackend(BaseBackend): return None - def _set_idempotency_token_arn(self, token, arn): + def _set_idempotency_token_arn(self, token: str, arn: str) -> None: self._idempotency_tokens[token] = { "arn": arn, "expires": datetime.datetime.utcnow() + datetime.timedelta(hours=1), } - def import_cert(self, certificate, private_key, chain=None, arn=None, tags=None): + def import_cert( + self, + certificate: bytes, + private_key: bytes, + chain: Optional[bytes], + arn: Optional[str], + tags: List[Dict[str, str]], + ) -> str: if arn is not None: if arn not in self._certificates: - raise self._arn_not_found(arn) + raise CertificateNotFound(arn=arn, account_id=self.account_id) else: # Will reuse provided ARN bundle = CertBundle( @@ -502,7 +470,7 @@ class AWSCertificateManagerBackend(BaseBackend): return bundle.arn - def get_certificates_list(self, statuses): + def get_certificates_list(self, statuses: List[str]) -> Iterable[CertBundle]: """ Get list of certificates @@ -514,27 +482,27 @@ class AWSCertificateManagerBackend(BaseBackend): if not statuses or cert.status in statuses: yield cert - def get_certificate(self, arn): + def get_certificate(self, arn: str) -> CertBundle: if arn not in self._certificates: - raise self._arn_not_found(arn) + raise CertificateNotFound(arn=arn, account_id=self.account_id) cert_bundle = self._certificates[arn] cert_bundle.check() return cert_bundle - def delete_certificate(self, arn): + def delete_certificate(self, arn: str) -> None: if arn not in self._certificates: - raise self._arn_not_found(arn) + raise CertificateNotFound(arn=arn, account_id=self.account_id) del self._certificates[arn] def request_certificate( self, - domain_name, - idempotency_token, - subject_alt_names, - tags=None, - ): + domain_name: str, + idempotency_token: str, + subject_alt_names: List[str], + tags: List[Dict[str, str]], + ) -> str: """ The parameter DomainValidationOptions has not yet been implemented """ @@ -558,17 +526,21 @@ class AWSCertificateManagerBackend(BaseBackend): return cert.arn - def add_tags_to_certificate(self, arn, tags): + def add_tags_to_certificate(self, arn: str, tags: List[Dict[str, str]]) -> None: # get_cert does arn check cert_bundle = self.get_certificate(arn) cert_bundle.tags.add(tags) - def remove_tags_from_certificate(self, arn, tags): + def remove_tags_from_certificate( + self, arn: str, tags: List[Dict[str, str]] + ) -> None: # get_cert does arn check cert_bundle = self.get_certificate(arn) cert_bundle.tags.remove(tags) - def export_certificate(self, certificate_arn, passphrase): + def export_certificate( + self, certificate_arn: str, passphrase: str + ) -> Tuple[str, str, str]: passphrase_bytes = base64.standard_b64decode(passphrase) cert_bundle = self.get_certificate(certificate_arn) diff --git a/moto/acm/responses.py b/moto/acm/responses.py index 8e6e534ef..8194cd5a5 100644 --- a/moto/acm/responses.py +++ b/moto/acm/responses.py @@ -2,34 +2,23 @@ import json import base64 from moto.core.responses import BaseResponse -from .models import acm_backends, AWSValidationException +from typing import Dict, List, Tuple, Union +from .models import acm_backends, AWSCertificateManagerBackend +from .exceptions import AWSValidationException + + +GENERIC_RESPONSE_TYPE = Union[str, Tuple[str, Dict[str, int]]] class AWSCertificateManagerResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="acm") @property - def acm_backend(self): - """ - ACM Backend - - :return: ACM Backend object - :rtype: moto.acm.models.AWSCertificateManagerBackend - """ + def acm_backend(self) -> AWSCertificateManagerBackend: return acm_backends[self.current_account][self.region] - @property - def request_params(self): - try: - return json.loads(self.body) - except ValueError: - return {} - - def _get_param(self, param_name, if_none=None): - return self.request_params.get(param_name, if_none) - - def add_tags_to_certificate(self): + def add_tags_to_certificate(self) -> GENERIC_RESPONSE_TYPE: arn = self._get_param("CertificateArn") tags = self._get_param("Tags") @@ -44,7 +33,7 @@ class AWSCertificateManagerResponse(BaseResponse): return "" - def delete_certificate(self): + def delete_certificate(self) -> GENERIC_RESPONSE_TYPE: arn = self._get_param("CertificateArn") if arn is None: @@ -58,7 +47,7 @@ class AWSCertificateManagerResponse(BaseResponse): return "" - def describe_certificate(self): + def describe_certificate(self) -> GENERIC_RESPONSE_TYPE: arn = self._get_param("CertificateArn") if arn is None: @@ -72,7 +61,7 @@ class AWSCertificateManagerResponse(BaseResponse): return json.dumps(cert_bundle.describe()) - def get_certificate(self): + def get_certificate(self) -> GENERIC_RESPONSE_TYPE: arn = self._get_param("CertificateArn") if arn is None: @@ -90,7 +79,7 @@ class AWSCertificateManagerResponse(BaseResponse): } return json.dumps(result) - def import_certificate(self): + def import_certificate(self) -> str: """ Returns errors on: Certificate, PrivateKey or Chain not being properly formatted @@ -137,7 +126,7 @@ class AWSCertificateManagerResponse(BaseResponse): return json.dumps({"CertificateArn": arn}) - def list_certificates(self): + def list_certificates(self) -> str: certs = [] statuses = self._get_param("CertificateStatuses") for cert_bundle in self.acm_backend.get_certificates_list(statuses): @@ -151,16 +140,18 @@ class AWSCertificateManagerResponse(BaseResponse): result = {"CertificateSummaryList": certs} return json.dumps(result) - def list_tags_for_certificate(self): + def list_tags_for_certificate(self) -> GENERIC_RESPONSE_TYPE: arn = self._get_param("CertificateArn") if arn is None: msg = "A required parameter for the specified action is not supplied." - return {"__type": "MissingParameter", "message": msg}, dict(status=400) + return json.dumps({"__type": "MissingParameter", "message": msg}), dict( + status=400 + ) cert_bundle = self.acm_backend.get_certificate(arn) - result = {"Tags": []} + result: Dict[str, List[Dict[str, str]]] = {"Tags": []} # Tag "objects" can not contain the Value part for key, value in cert_bundle.tags.items(): tag_dict = {"Key": key} @@ -170,7 +161,7 @@ class AWSCertificateManagerResponse(BaseResponse): return json.dumps(result) - def remove_tags_from_certificate(self): + def remove_tags_from_certificate(self) -> GENERIC_RESPONSE_TYPE: arn = self._get_param("CertificateArn") tags = self._get_param("Tags") @@ -185,7 +176,7 @@ class AWSCertificateManagerResponse(BaseResponse): return "" - def request_certificate(self): + def request_certificate(self) -> GENERIC_RESPONSE_TYPE: domain_name = self._get_param("DomainName") idempotency_token = self._get_param("IdempotencyToken") subject_alt_names = self._get_param("SubjectAlternativeNames") @@ -210,7 +201,7 @@ class AWSCertificateManagerResponse(BaseResponse): return json.dumps({"CertificateArn": arn}) - def resend_validation_email(self): + def resend_validation_email(self) -> GENERIC_RESPONSE_TYPE: arn = self._get_param("CertificateArn") domain = self._get_param("Domain") # ValidationDomain not used yet. @@ -235,7 +226,7 @@ class AWSCertificateManagerResponse(BaseResponse): return "" - def export_certificate(self): + def export_certificate(self) -> GENERIC_RESPONSE_TYPE: certificate_arn = self._get_param("CertificateArn") passphrase = self._get_param("Passphrase") diff --git a/moto/acm/utils.py b/moto/acm/utils.py index c25e17213..a109f9d06 100644 --- a/moto/acm/utils.py +++ b/moto/acm/utils.py @@ -1,7 +1,7 @@ from moto.moto_api._internal import mock_random -def make_arn_for_certificate(account_id, region_name): +def make_arn_for_certificate(account_id: str, region_name: str) -> str: # Example # arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b return "arn:aws:acm:{0}:{1}:certificate/{2}".format( diff --git a/moto/core/exceptions.py b/moto/core/exceptions.py index 33ec9d6f8..4f71c54d2 100644 --- a/moto/core/exceptions.py +++ b/moto/core/exceptions.py @@ -1,5 +1,6 @@ from werkzeug.exceptions import HTTPException from jinja2 import DictLoader, Environment +from typing import Optional import json # TODO: add "Sender" to error responses below? @@ -133,10 +134,12 @@ class AuthFailureError(RESTError): class AWSError(JsonRESTError): - TYPE = None + TYPE: Optional[str] = None STATUS = 400 - def __init__(self, message, exception_type=None, status=None): + def __init__( + self, message: str, exception_type: str = None, status: Optional[int] = None + ): super().__init__(exception_type or self.TYPE, message) self.code = status or self.STATUS diff --git a/moto/elb/models.py b/moto/elb/models.py index 1b9810c29..aa287fcd1 100644 --- a/moto/elb/models.py +++ b/moto/elb/models.py @@ -575,12 +575,12 @@ class ELBBackend(BaseBackend): return load_balancer def _register_certificate(self, ssl_certificate_id, dns_name): - from moto.acm.models import acm_backends, AWSResourceNotFoundException + from moto.acm.models import acm_backends, CertificateNotFound acm_backend = acm_backends[self.account_id][self.region_name] try: acm_backend.set_certificate_in_use_by(ssl_certificate_id, dns_name) - except AWSResourceNotFoundException: + except CertificateNotFound: raise CertificateNotFoundException() def enable_availability_zones_for_load_balancer( diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 39c9f445a..4b8702645 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -1548,14 +1548,13 @@ Member must satisfy regular expression pattern: {}".format( """ Verify the provided certificate exists in either ACM or IAM """ - from moto.acm import acm_backends - from moto.acm.models import AWSResourceNotFoundException + from moto.acm.models import acm_backends, CertificateNotFound try: acm_backend = acm_backends[self.account_id][self.region_name] acm_backend.get_certificate(certificate_arn) return True - except AWSResourceNotFoundException: + except CertificateNotFound: pass from moto.iam import iam_backends diff --git a/tests/test_acm/test_acm.py b/tests/test_acm/test_acm.py index cdbcea379..ff362cfdb 100644 --- a/tests/test_acm/test_acm.py +++ b/tests/test_acm/test_acm.py @@ -433,6 +433,18 @@ def test_request_certificate_with_tags(): {"Key": "WithEmptyStr", "Value": ""}, ], ) + arn_3 = resp["CertificateArn"] + + assert arn_1 != arn_3 # if tags are matched, ACM would have returned same arn + + resp = client.request_certificate( + DomainName="google.com", + IdempotencyToken=token, + SubjectAlternativeNames=["google.com", "www.google.com", "mail.google.com"], + ) + arn_4 = resp["CertificateArn"] + + assert arn_1 != arn_4 # if tags are matched, ACM would have returned same arn @mock_acm