TechDebt - MyPy the ACM module (#5540)
This commit is contained in:
parent
e98341fa89
commit
1a8f93dce3
2
Makefile
2
Makefile
@ -28,7 +28,7 @@ lint:
|
||||
@echo "Running pylint..."
|
||||
pylint -j 0 moto tests
|
||||
@echo "Running MyPy..."
|
||||
mypy --install-types --non-interactive moto/applicationautoscaling/
|
||||
mypy --install-types --non-interactive moto/acm moto/applicationautoscaling/
|
||||
|
||||
format:
|
||||
black moto/ tests/
|
||||
|
20
moto/acm/exceptions.py
Normal file
20
moto/acm/exceptions.py
Normal file
@ -0,0 +1,20 @@
|
||||
from moto.core.exceptions import AWSError
|
||||
|
||||
|
||||
class AWSValidationException(AWSError):
|
||||
TYPE = "ValidationException"
|
||||
|
||||
|
||||
class AWSResourceNotFoundException(AWSError):
|
||||
TYPE = "ResourceNotFoundException"
|
||||
|
||||
|
||||
class CertificateNotFound(AWSResourceNotFoundException):
|
||||
def __init__(self, arn: str, account_id: str):
|
||||
super().__init__(
|
||||
message=f"Certificate with arn {arn} not found in account {account_id}"
|
||||
)
|
||||
|
||||
|
||||
class AWSTooManyTagsException(AWSError):
|
||||
TYPE = "TooManyTagsException"
|
@ -2,13 +2,19 @@ import base64
|
||||
import re
|
||||
import datetime
|
||||
from moto.core import BaseBackend, BaseModel
|
||||
from moto.core.exceptions import AWSError
|
||||
from moto.core.utils import BackendDict
|
||||
from moto import settings
|
||||
from typing import Any, Dict, List, Iterable, Optional, Tuple, Set
|
||||
|
||||
from .exceptions import (
|
||||
AWSValidationException,
|
||||
AWSTooManyTagsException,
|
||||
CertificateNotFound,
|
||||
)
|
||||
from .utils import make_arn_for_certificate
|
||||
|
||||
import cryptography.x509
|
||||
from cryptography.x509 import OID_COMMON_NAME, NameOID, DNSName
|
||||
import cryptography.hazmat.primitives.asymmetric.rsa
|
||||
from cryptography.hazmat.primitives import serialization, hashes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
@ -43,28 +49,16 @@ RnGfN8j8KLDVmWyTYMk8V+6j0LI4+4zFh2upqGMQHL3VFVFWBek6vCDWhB/b
|
||||
# so for now a cheap response is just give any old root CA
|
||||
|
||||
|
||||
def datetime_to_epoch(date):
|
||||
def datetime_to_epoch(date: datetime.datetime) -> float:
|
||||
return date.timestamp()
|
||||
|
||||
|
||||
class AWSValidationException(AWSError):
|
||||
TYPE = "ValidationException"
|
||||
|
||||
|
||||
class AWSResourceNotFoundException(AWSError):
|
||||
TYPE = "ResourceNotFoundException"
|
||||
|
||||
|
||||
class AWSTooManyTagsException(AWSError):
|
||||
TYPE = "TooManyTagsException"
|
||||
|
||||
|
||||
class TagHolder(dict):
|
||||
class TagHolder(Dict[str, Optional[str]]):
|
||||
MAX_TAG_COUNT = 50
|
||||
MAX_KEY_LENGTH = 128
|
||||
MAX_VALUE_LENGTH = 256
|
||||
|
||||
def _validate_kv(self, key, value, index):
|
||||
def _validate_kv(self, key: str, value: Optional[str], index: int) -> None:
|
||||
if len(key) > self.MAX_KEY_LENGTH:
|
||||
raise AWSValidationException(
|
||||
"Value '%s' at 'tags.%d.member.key' failed to satisfy constraint: Member must have length less than or equal to %s"
|
||||
@ -81,11 +75,11 @@ class TagHolder(dict):
|
||||
% key
|
||||
)
|
||||
|
||||
def add(self, tags):
|
||||
def add(self, tags: List[Dict[str, str]]) -> None:
|
||||
tags_copy = self.copy()
|
||||
for i, tag in enumerate(tags):
|
||||
key = tag["Key"]
|
||||
value = tag.get("Value", None)
|
||||
value = tag.get("Value")
|
||||
self._validate_kv(key, value, i + 1)
|
||||
|
||||
tags_copy[key] = value
|
||||
@ -97,10 +91,10 @@ class TagHolder(dict):
|
||||
|
||||
self.update(tags_copy)
|
||||
|
||||
def remove(self, tags):
|
||||
def remove(self, tags: List[Dict[str, str]]) -> None:
|
||||
for i, tag in enumerate(tags):
|
||||
key = tag["Key"]
|
||||
value = tag.get("Value", None)
|
||||
value = tag.get("Value")
|
||||
self._validate_kv(key, value, i + 1)
|
||||
try:
|
||||
# If value isnt provided, just delete key
|
||||
@ -112,45 +106,41 @@ class TagHolder(dict):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def equals(self, tags):
|
||||
tags = {t["Key"]: t.get("Value", None) for t in tags} if tags else {}
|
||||
return self == tags
|
||||
def equals(self, tags: List[Dict[str, str]]) -> bool:
|
||||
flat_tags = {t["Key"]: t.get("Value") for t in tags} if tags else {}
|
||||
return self == flat_tags
|
||||
|
||||
|
||||
class CertBundle(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
account_id,
|
||||
certificate,
|
||||
private_key,
|
||||
chain=None,
|
||||
region="us-east-1",
|
||||
arn=None,
|
||||
cert_type="IMPORTED",
|
||||
cert_status="ISSUED",
|
||||
account_id: str,
|
||||
certificate: bytes,
|
||||
private_key: bytes,
|
||||
chain: Optional[bytes] = None,
|
||||
region: str = "us-east-1",
|
||||
arn: Optional[str] = None,
|
||||
cert_type: str = "IMPORTED",
|
||||
cert_status: str = "ISSUED",
|
||||
):
|
||||
self.created_at = datetime.datetime.utcnow()
|
||||
self.cert = certificate
|
||||
self._cert = None
|
||||
self.common_name = None
|
||||
self.key = private_key
|
||||
self._key = None
|
||||
self.chain = chain
|
||||
# AWS always returns your chain + root CA
|
||||
self.chain = chain + b"\n" + AWS_ROOT_CA if chain else AWS_ROOT_CA
|
||||
self.tags = TagHolder()
|
||||
self._chain = None
|
||||
self.type = cert_type # Should really be an enum
|
||||
self.status = cert_status # Should really be an enum
|
||||
self.in_use_by = []
|
||||
|
||||
# AWS always returns your chain + root CA
|
||||
if self.chain is None:
|
||||
self.chain = AWS_ROOT_CA
|
||||
else:
|
||||
self.chain += b"\n" + AWS_ROOT_CA
|
||||
self.in_use_by: List[str] = []
|
||||
|
||||
# Takes care of PEM checking
|
||||
self.validate_pk()
|
||||
self.validate_certificate()
|
||||
self._key = self.validate_pk()
|
||||
self._cert = self.validate_certificate()
|
||||
# Extracting some common fields for ease of use
|
||||
# Have to search through cert.subject for OIDs
|
||||
self.common_name: Any = self._cert.subject.get_attributes_for_oid(
|
||||
OID_COMMON_NAME
|
||||
)[0].value
|
||||
if chain is not None:
|
||||
self.validate_chain()
|
||||
|
||||
@ -158,57 +148,43 @@ class CertBundle(BaseModel):
|
||||
# raise AWSValidationException('Provided certificate is not a valid self signed. Please provide either a valid self-signed certificate or certificate chain.')
|
||||
|
||||
# Used for when one wants to overwrite an arn
|
||||
if arn is None:
|
||||
self.arn = make_arn_for_certificate(account_id, region)
|
||||
else:
|
||||
self.arn = arn
|
||||
self.arn = arn or make_arn_for_certificate(account_id, region)
|
||||
|
||||
@classmethod
|
||||
def generate_cert(cls, domain_name, account_id, region, sans=None):
|
||||
if sans is None:
|
||||
sans = set()
|
||||
else:
|
||||
sans = set(sans)
|
||||
def generate_cert(
|
||||
cls,
|
||||
domain_name: str,
|
||||
account_id: str,
|
||||
region: str,
|
||||
sans: Optional[List[str]] = None,
|
||||
) -> "CertBundle":
|
||||
unique_sans: Set[str] = set(sans) if sans else set()
|
||||
|
||||
sans.add(domain_name)
|
||||
sans = [cryptography.x509.DNSName(item) for item in sans]
|
||||
unique_sans.add(domain_name)
|
||||
unique_dns_names = [DNSName(item) for item in unique_sans]
|
||||
|
||||
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
)
|
||||
subject = cryptography.x509.Name(
|
||||
[
|
||||
cryptography.x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
|
||||
cryptography.x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "CA"),
|
||||
cryptography.x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
|
||||
cryptography.x509.NameAttribute(
|
||||
cryptography.x509.NameOID.COUNTRY_NAME, "US"
|
||||
),
|
||||
cryptography.x509.NameAttribute(
|
||||
cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, "CA"
|
||||
),
|
||||
cryptography.x509.NameAttribute(
|
||||
cryptography.x509.NameOID.LOCALITY_NAME, "San Francisco"
|
||||
),
|
||||
cryptography.x509.NameAttribute(
|
||||
cryptography.x509.NameOID.ORGANIZATION_NAME, "My Company"
|
||||
),
|
||||
cryptography.x509.NameAttribute(
|
||||
cryptography.x509.NameOID.COMMON_NAME, domain_name
|
||||
NameOID.ORGANIZATION_NAME, "My Company"
|
||||
),
|
||||
cryptography.x509.NameAttribute(NameOID.COMMON_NAME, domain_name),
|
||||
]
|
||||
)
|
||||
issuer = cryptography.x509.Name(
|
||||
[ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
|
||||
cryptography.x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
|
||||
cryptography.x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Amazon"),
|
||||
cryptography.x509.NameAttribute(
|
||||
cryptography.x509.NameOID.COUNTRY_NAME, "US"
|
||||
),
|
||||
cryptography.x509.NameAttribute(
|
||||
cryptography.x509.NameOID.ORGANIZATION_NAME, "Amazon"
|
||||
),
|
||||
cryptography.x509.NameAttribute(
|
||||
cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
|
||||
),
|
||||
cryptography.x509.NameAttribute(
|
||||
cryptography.x509.NameOID.COMMON_NAME, "Amazon"
|
||||
NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
|
||||
),
|
||||
cryptography.x509.NameAttribute(NameOID.COMMON_NAME, "Amazon"),
|
||||
]
|
||||
)
|
||||
cert = (
|
||||
@ -220,7 +196,8 @@ class CertBundle(BaseModel):
|
||||
.not_valid_before(datetime.datetime.utcnow())
|
||||
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
|
||||
.add_extension(
|
||||
cryptography.x509.SubjectAlternativeName(sans), critical=False
|
||||
cryptography.x509.SubjectAlternativeName(unique_dns_names),
|
||||
critical=False,
|
||||
)
|
||||
.sign(key, hashes.SHA512(), default_backend())
|
||||
)
|
||||
@ -241,17 +218,11 @@ class CertBundle(BaseModel):
|
||||
region=region,
|
||||
)
|
||||
|
||||
def validate_pk(self):
|
||||
def validate_pk(self) -> Any:
|
||||
try:
|
||||
self._key = serialization.load_pem_private_key(
|
||||
return serialization.load_pem_private_key(
|
||||
self.key, password=None, backend=default_backend()
|
||||
)
|
||||
|
||||
if self._key.key_size > 2048:
|
||||
AWSValidationException(
|
||||
"The private key length is not supported. Only 1024-bit and 2048-bit are allowed."
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
if isinstance(err, AWSValidationException):
|
||||
raise
|
||||
@ -259,48 +230,40 @@ class CertBundle(BaseModel):
|
||||
"The private key is not PEM-encoded or is not valid."
|
||||
)
|
||||
|
||||
def validate_certificate(self):
|
||||
def validate_certificate(self) -> cryptography.x509.base.Certificate:
|
||||
try:
|
||||
self._cert = cryptography.x509.load_pem_x509_certificate(
|
||||
_cert = cryptography.x509.load_pem_x509_certificate(
|
||||
self.cert, default_backend()
|
||||
)
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
if self._cert.not_valid_after < now:
|
||||
if _cert.not_valid_after < now:
|
||||
raise AWSValidationException(
|
||||
"The certificate has expired, is not valid."
|
||||
)
|
||||
|
||||
if self._cert.not_valid_before > now:
|
||||
if _cert.not_valid_before > now:
|
||||
raise AWSValidationException(
|
||||
"The certificate is not in effect yet, is not valid."
|
||||
)
|
||||
|
||||
# Extracting some common fields for ease of use
|
||||
# Have to search through cert.subject for OIDs
|
||||
self.common_name = self._cert.subject.get_attributes_for_oid(
|
||||
cryptography.x509.OID_COMMON_NAME
|
||||
)[0].value
|
||||
|
||||
except Exception as err:
|
||||
if isinstance(err, AWSValidationException):
|
||||
raise
|
||||
raise AWSValidationException(
|
||||
"The certificate is not PEM-encoded or is not valid."
|
||||
)
|
||||
return _cert
|
||||
|
||||
def validate_chain(self):
|
||||
def validate_chain(self) -> None:
|
||||
try:
|
||||
self._chain = []
|
||||
|
||||
for cert_armored in self.chain.split(b"-\n-"):
|
||||
# Fix missing -'s on split
|
||||
cert_armored = re.sub(b"^----B", b"-----B", cert_armored)
|
||||
cert_armored = re.sub(b"E----$", b"E-----", cert_armored)
|
||||
cert = cryptography.x509.load_pem_x509_certificate(
|
||||
cryptography.x509.load_pem_x509_certificate(
|
||||
cert_armored, default_backend()
|
||||
)
|
||||
self._chain.append(cert)
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
if self._cert.not_valid_after < now:
|
||||
@ -320,7 +283,7 @@ class CertBundle(BaseModel):
|
||||
"The certificate is not PEM-encoded or is not valid."
|
||||
)
|
||||
|
||||
def check(self):
|
||||
def check(self) -> None:
|
||||
# Basically, if the certificate is pending, and then checked again after a
|
||||
# while, it will appear as if its been validated. The default wait time is 60
|
||||
# seconds but you can set an environment to change it.
|
||||
@ -332,7 +295,7 @@ class CertBundle(BaseModel):
|
||||
):
|
||||
self.status = "ISSUED"
|
||||
|
||||
def describe(self):
|
||||
def describe(self) -> Dict[str, Any]:
|
||||
# 'RenewalSummary': {}, # Only when cert is amazon issued
|
||||
if self._key.key_size == 1024:
|
||||
key_algo = "RSA_1024"
|
||||
@ -343,7 +306,7 @@ class CertBundle(BaseModel):
|
||||
|
||||
# Look for SANs
|
||||
try:
|
||||
san_obj = self._cert.extensions.get_extension_for_oid(
|
||||
san_obj: Any = self._cert.extensions.get_extension_for_oid(
|
||||
cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME
|
||||
)
|
||||
except cryptography.x509.ExtensionNotFound:
|
||||
@ -352,14 +315,14 @@ class CertBundle(BaseModel):
|
||||
if san_obj is not None:
|
||||
sans = [item.value for item in san_obj.value]
|
||||
|
||||
result = {
|
||||
result: Dict[str, Any] = {
|
||||
"Certificate": {
|
||||
"CertificateArn": self.arn,
|
||||
"DomainName": self.common_name,
|
||||
"InUseBy": self.in_use_by,
|
||||
"Issuer": self._cert.issuer.get_attributes_for_oid(
|
||||
cryptography.x509.OID_COMMON_NAME
|
||||
)[0].value,
|
||||
"Issuer": self._cert.issuer.get_attributes_for_oid(OID_COMMON_NAME)[
|
||||
0
|
||||
].value,
|
||||
"KeyAlgorithm": key_algo,
|
||||
"NotAfter": datetime_to_epoch(self._cert.not_valid_after),
|
||||
"NotBefore": datetime_to_epoch(self._cert.not_valid_before),
|
||||
@ -401,7 +364,7 @@ class CertBundle(BaseModel):
|
||||
|
||||
return result
|
||||
|
||||
def serialize_pk(self, passphrase_bytes):
|
||||
def serialize_pk(self, passphrase_bytes: bytes) -> str:
|
||||
pk_bytes = self._key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
@ -411,38 +374,36 @@ class CertBundle(BaseModel):
|
||||
)
|
||||
return pk_bytes.decode("utf-8")
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.arn
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "<Certificate>"
|
||||
|
||||
|
||||
class AWSCertificateManagerBackend(BaseBackend):
|
||||
def __init__(self, region_name, account_id):
|
||||
def __init__(self, region_name: str, account_id: str):
|
||||
super().__init__(region_name, account_id)
|
||||
self._certificates = {}
|
||||
self._idempotency_tokens = {}
|
||||
self._certificates: Dict[str, CertBundle] = {}
|
||||
self._idempotency_tokens: Dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service(service_region, zones):
|
||||
def default_vpc_endpoint_service(
|
||||
service_region: str, zones: List[str]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Default VPC endpoint service."""
|
||||
return BaseBackend.default_vpc_endpoint_service_factory(
|
||||
service_region, zones, "acm-pca"
|
||||
)
|
||||
|
||||
def _arn_not_found(self, arn):
|
||||
msg = f"Certificate with arn {arn} not found in account {self.account_id}"
|
||||
return AWSResourceNotFoundException(msg)
|
||||
|
||||
def set_certificate_in_use_by(self, arn, load_balancer_name):
|
||||
def set_certificate_in_use_by(self, arn: str, load_balancer_name: str) -> None:
|
||||
if arn not in self._certificates:
|
||||
raise self._arn_not_found(arn)
|
||||
raise CertificateNotFound(arn=arn, account_id=self.account_id)
|
||||
|
||||
cert_bundle = self._certificates[arn]
|
||||
cert_bundle.in_use_by.append(load_balancer_name)
|
||||
|
||||
def _get_arn_from_idempotency_token(self, token):
|
||||
def _get_arn_from_idempotency_token(self, token: str) -> Optional[str]:
|
||||
"""
|
||||
If token doesnt exist, return None, later it will be
|
||||
set with an expiry and arn.
|
||||
@ -465,16 +426,23 @@ class AWSCertificateManagerBackend(BaseBackend):
|
||||
|
||||
return None
|
||||
|
||||
def _set_idempotency_token_arn(self, token, arn):
|
||||
def _set_idempotency_token_arn(self, token: str, arn: str) -> None:
|
||||
self._idempotency_tokens[token] = {
|
||||
"arn": arn,
|
||||
"expires": datetime.datetime.utcnow() + datetime.timedelta(hours=1),
|
||||
}
|
||||
|
||||
def import_cert(self, certificate, private_key, chain=None, arn=None, tags=None):
|
||||
def import_cert(
|
||||
self,
|
||||
certificate: bytes,
|
||||
private_key: bytes,
|
||||
chain: Optional[bytes],
|
||||
arn: Optional[str],
|
||||
tags: List[Dict[str, str]],
|
||||
) -> str:
|
||||
if arn is not None:
|
||||
if arn not in self._certificates:
|
||||
raise self._arn_not_found(arn)
|
||||
raise CertificateNotFound(arn=arn, account_id=self.account_id)
|
||||
else:
|
||||
# Will reuse provided ARN
|
||||
bundle = CertBundle(
|
||||
@ -502,7 +470,7 @@ class AWSCertificateManagerBackend(BaseBackend):
|
||||
|
||||
return bundle.arn
|
||||
|
||||
def get_certificates_list(self, statuses):
|
||||
def get_certificates_list(self, statuses: List[str]) -> Iterable[CertBundle]:
|
||||
"""
|
||||
Get list of certificates
|
||||
|
||||
@ -514,27 +482,27 @@ class AWSCertificateManagerBackend(BaseBackend):
|
||||
if not statuses or cert.status in statuses:
|
||||
yield cert
|
||||
|
||||
def get_certificate(self, arn):
|
||||
def get_certificate(self, arn: str) -> CertBundle:
|
||||
if arn not in self._certificates:
|
||||
raise self._arn_not_found(arn)
|
||||
raise CertificateNotFound(arn=arn, account_id=self.account_id)
|
||||
|
||||
cert_bundle = self._certificates[arn]
|
||||
cert_bundle.check()
|
||||
return cert_bundle
|
||||
|
||||
def delete_certificate(self, arn):
|
||||
def delete_certificate(self, arn: str) -> None:
|
||||
if arn not in self._certificates:
|
||||
raise self._arn_not_found(arn)
|
||||
raise CertificateNotFound(arn=arn, account_id=self.account_id)
|
||||
|
||||
del self._certificates[arn]
|
||||
|
||||
def request_certificate(
|
||||
self,
|
||||
domain_name,
|
||||
idempotency_token,
|
||||
subject_alt_names,
|
||||
tags=None,
|
||||
):
|
||||
domain_name: str,
|
||||
idempotency_token: str,
|
||||
subject_alt_names: List[str],
|
||||
tags: List[Dict[str, str]],
|
||||
) -> str:
|
||||
"""
|
||||
The parameter DomainValidationOptions has not yet been implemented
|
||||
"""
|
||||
@ -558,17 +526,21 @@ class AWSCertificateManagerBackend(BaseBackend):
|
||||
|
||||
return cert.arn
|
||||
|
||||
def add_tags_to_certificate(self, arn, tags):
|
||||
def add_tags_to_certificate(self, arn: str, tags: List[Dict[str, str]]) -> None:
|
||||
# get_cert does arn check
|
||||
cert_bundle = self.get_certificate(arn)
|
||||
cert_bundle.tags.add(tags)
|
||||
|
||||
def remove_tags_from_certificate(self, arn, tags):
|
||||
def remove_tags_from_certificate(
|
||||
self, arn: str, tags: List[Dict[str, str]]
|
||||
) -> None:
|
||||
# get_cert does arn check
|
||||
cert_bundle = self.get_certificate(arn)
|
||||
cert_bundle.tags.remove(tags)
|
||||
|
||||
def export_certificate(self, certificate_arn, passphrase):
|
||||
def export_certificate(
|
||||
self, certificate_arn: str, passphrase: str
|
||||
) -> Tuple[str, str, str]:
|
||||
passphrase_bytes = base64.standard_b64decode(passphrase)
|
||||
cert_bundle = self.get_certificate(certificate_arn)
|
||||
|
||||
|
@ -2,34 +2,23 @@ import json
|
||||
import base64
|
||||
|
||||
from moto.core.responses import BaseResponse
|
||||
from .models import acm_backends, AWSValidationException
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from .models import acm_backends, AWSCertificateManagerBackend
|
||||
from .exceptions import AWSValidationException
|
||||
|
||||
|
||||
GENERIC_RESPONSE_TYPE = Union[str, Tuple[str, Dict[str, int]]]
|
||||
|
||||
|
||||
class AWSCertificateManagerResponse(BaseResponse):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(service_name="acm")
|
||||
|
||||
@property
|
||||
def acm_backend(self):
|
||||
"""
|
||||
ACM Backend
|
||||
|
||||
:return: ACM Backend object
|
||||
:rtype: moto.acm.models.AWSCertificateManagerBackend
|
||||
"""
|
||||
def acm_backend(self) -> AWSCertificateManagerBackend:
|
||||
return acm_backends[self.current_account][self.region]
|
||||
|
||||
@property
|
||||
def request_params(self):
|
||||
try:
|
||||
return json.loads(self.body)
|
||||
except ValueError:
|
||||
return {}
|
||||
|
||||
def _get_param(self, param_name, if_none=None):
|
||||
return self.request_params.get(param_name, if_none)
|
||||
|
||||
def add_tags_to_certificate(self):
|
||||
def add_tags_to_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||
arn = self._get_param("CertificateArn")
|
||||
tags = self._get_param("Tags")
|
||||
|
||||
@ -44,7 +33,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
||||
|
||||
return ""
|
||||
|
||||
def delete_certificate(self):
|
||||
def delete_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||
arn = self._get_param("CertificateArn")
|
||||
|
||||
if arn is None:
|
||||
@ -58,7 +47,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
||||
|
||||
return ""
|
||||
|
||||
def describe_certificate(self):
|
||||
def describe_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||
arn = self._get_param("CertificateArn")
|
||||
|
||||
if arn is None:
|
||||
@ -72,7 +61,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
||||
|
||||
return json.dumps(cert_bundle.describe())
|
||||
|
||||
def get_certificate(self):
|
||||
def get_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||
arn = self._get_param("CertificateArn")
|
||||
|
||||
if arn is None:
|
||||
@ -90,7 +79,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
||||
}
|
||||
return json.dumps(result)
|
||||
|
||||
def import_certificate(self):
|
||||
def import_certificate(self) -> str:
|
||||
"""
|
||||
Returns errors on:
|
||||
Certificate, PrivateKey or Chain not being properly formatted
|
||||
@ -137,7 +126,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"CertificateArn": arn})
|
||||
|
||||
def list_certificates(self):
|
||||
def list_certificates(self) -> str:
|
||||
certs = []
|
||||
statuses = self._get_param("CertificateStatuses")
|
||||
for cert_bundle in self.acm_backend.get_certificates_list(statuses):
|
||||
@ -151,16 +140,18 @@ class AWSCertificateManagerResponse(BaseResponse):
|
||||
result = {"CertificateSummaryList": certs}
|
||||
return json.dumps(result)
|
||||
|
||||
def list_tags_for_certificate(self):
|
||||
def list_tags_for_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||
arn = self._get_param("CertificateArn")
|
||||
|
||||
if arn is None:
|
||||
msg = "A required parameter for the specified action is not supplied."
|
||||
return {"__type": "MissingParameter", "message": msg}, dict(status=400)
|
||||
return json.dumps({"__type": "MissingParameter", "message": msg}), dict(
|
||||
status=400
|
||||
)
|
||||
|
||||
cert_bundle = self.acm_backend.get_certificate(arn)
|
||||
|
||||
result = {"Tags": []}
|
||||
result: Dict[str, List[Dict[str, str]]] = {"Tags": []}
|
||||
# Tag "objects" can not contain the Value part
|
||||
for key, value in cert_bundle.tags.items():
|
||||
tag_dict = {"Key": key}
|
||||
@ -170,7 +161,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
||||
|
||||
return json.dumps(result)
|
||||
|
||||
def remove_tags_from_certificate(self):
|
||||
def remove_tags_from_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||
arn = self._get_param("CertificateArn")
|
||||
tags = self._get_param("Tags")
|
||||
|
||||
@ -185,7 +176,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
||||
|
||||
return ""
|
||||
|
||||
def request_certificate(self):
|
||||
def request_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||
domain_name = self._get_param("DomainName")
|
||||
idempotency_token = self._get_param("IdempotencyToken")
|
||||
subject_alt_names = self._get_param("SubjectAlternativeNames")
|
||||
@ -210,7 +201,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"CertificateArn": arn})
|
||||
|
||||
def resend_validation_email(self):
|
||||
def resend_validation_email(self) -> GENERIC_RESPONSE_TYPE:
|
||||
arn = self._get_param("CertificateArn")
|
||||
domain = self._get_param("Domain")
|
||||
# ValidationDomain not used yet.
|
||||
@ -235,7 +226,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
||||
|
||||
return ""
|
||||
|
||||
def export_certificate(self):
|
||||
def export_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||
certificate_arn = self._get_param("CertificateArn")
|
||||
passphrase = self._get_param("Passphrase")
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from moto.moto_api._internal import mock_random
|
||||
|
||||
|
||||
def make_arn_for_certificate(account_id, region_name):
|
||||
def make_arn_for_certificate(account_id: str, region_name: str) -> str:
|
||||
# Example
|
||||
# arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b
|
||||
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(
|
||||
|
@ -1,5 +1,6 @@
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from jinja2 import DictLoader, Environment
|
||||
from typing import Optional
|
||||
import json
|
||||
|
||||
# TODO: add "<Type>Sender</Type>" to error responses below?
|
||||
@ -133,10 +134,12 @@ class AuthFailureError(RESTError):
|
||||
|
||||
|
||||
class AWSError(JsonRESTError):
|
||||
TYPE = None
|
||||
TYPE: Optional[str] = None
|
||||
STATUS = 400
|
||||
|
||||
def __init__(self, message, exception_type=None, status=None):
|
||||
def __init__(
|
||||
self, message: str, exception_type: str = None, status: Optional[int] = None
|
||||
):
|
||||
super().__init__(exception_type or self.TYPE, message)
|
||||
self.code = status or self.STATUS
|
||||
|
||||
|
@ -575,12 +575,12 @@ class ELBBackend(BaseBackend):
|
||||
return load_balancer
|
||||
|
||||
def _register_certificate(self, ssl_certificate_id, dns_name):
|
||||
from moto.acm.models import acm_backends, AWSResourceNotFoundException
|
||||
from moto.acm.models import acm_backends, CertificateNotFound
|
||||
|
||||
acm_backend = acm_backends[self.account_id][self.region_name]
|
||||
try:
|
||||
acm_backend.set_certificate_in_use_by(ssl_certificate_id, dns_name)
|
||||
except AWSResourceNotFoundException:
|
||||
except CertificateNotFound:
|
||||
raise CertificateNotFoundException()
|
||||
|
||||
def enable_availability_zones_for_load_balancer(
|
||||
|
@ -1548,14 +1548,13 @@ Member must satisfy regular expression pattern: {}".format(
|
||||
"""
|
||||
Verify the provided certificate exists in either ACM or IAM
|
||||
"""
|
||||
from moto.acm import acm_backends
|
||||
from moto.acm.models import AWSResourceNotFoundException
|
||||
from moto.acm.models import acm_backends, CertificateNotFound
|
||||
|
||||
try:
|
||||
acm_backend = acm_backends[self.account_id][self.region_name]
|
||||
acm_backend.get_certificate(certificate_arn)
|
||||
return True
|
||||
except AWSResourceNotFoundException:
|
||||
except CertificateNotFound:
|
||||
pass
|
||||
|
||||
from moto.iam import iam_backends
|
||||
|
@ -433,6 +433,18 @@ def test_request_certificate_with_tags():
|
||||
{"Key": "WithEmptyStr", "Value": ""},
|
||||
],
|
||||
)
|
||||
arn_3 = resp["CertificateArn"]
|
||||
|
||||
assert arn_1 != arn_3 # if tags are matched, ACM would have returned same arn
|
||||
|
||||
resp = client.request_certificate(
|
||||
DomainName="google.com",
|
||||
IdempotencyToken=token,
|
||||
SubjectAlternativeNames=["google.com", "www.google.com", "mail.google.com"],
|
||||
)
|
||||
arn_4 = resp["CertificateArn"]
|
||||
|
||||
assert arn_1 != arn_4 # if tags are matched, ACM would have returned same arn
|
||||
|
||||
|
||||
@mock_acm
|
||||
|
Loading…
Reference in New Issue
Block a user