TechDebt - MyPy the ACM module (#5540)
This commit is contained in:
parent
e98341fa89
commit
1a8f93dce3
2
Makefile
2
Makefile
@ -28,7 +28,7 @@ lint:
|
|||||||
@echo "Running pylint..."
|
@echo "Running pylint..."
|
||||||
pylint -j 0 moto tests
|
pylint -j 0 moto tests
|
||||||
@echo "Running MyPy..."
|
@echo "Running MyPy..."
|
||||||
mypy --install-types --non-interactive moto/applicationautoscaling/
|
mypy --install-types --non-interactive moto/acm moto/applicationautoscaling/
|
||||||
|
|
||||||
format:
|
format:
|
||||||
black moto/ tests/
|
black moto/ tests/
|
||||||
|
20
moto/acm/exceptions.py
Normal file
20
moto/acm/exceptions.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from moto.core.exceptions import AWSError
|
||||||
|
|
||||||
|
|
||||||
|
class AWSValidationException(AWSError):
|
||||||
|
TYPE = "ValidationException"
|
||||||
|
|
||||||
|
|
||||||
|
class AWSResourceNotFoundException(AWSError):
|
||||||
|
TYPE = "ResourceNotFoundException"
|
||||||
|
|
||||||
|
|
||||||
|
class CertificateNotFound(AWSResourceNotFoundException):
|
||||||
|
def __init__(self, arn: str, account_id: str):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Certificate with arn {arn} not found in account {account_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AWSTooManyTagsException(AWSError):
|
||||||
|
TYPE = "TooManyTagsException"
|
@ -2,13 +2,19 @@ import base64
|
|||||||
import re
|
import re
|
||||||
import datetime
|
import datetime
|
||||||
from moto.core import BaseBackend, BaseModel
|
from moto.core import BaseBackend, BaseModel
|
||||||
from moto.core.exceptions import AWSError
|
|
||||||
from moto.core.utils import BackendDict
|
from moto.core.utils import BackendDict
|
||||||
from moto import settings
|
from moto import settings
|
||||||
|
from typing import Any, Dict, List, Iterable, Optional, Tuple, Set
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
AWSValidationException,
|
||||||
|
AWSTooManyTagsException,
|
||||||
|
CertificateNotFound,
|
||||||
|
)
|
||||||
from .utils import make_arn_for_certificate
|
from .utils import make_arn_for_certificate
|
||||||
|
|
||||||
import cryptography.x509
|
import cryptography.x509
|
||||||
|
from cryptography.x509 import OID_COMMON_NAME, NameOID, DNSName
|
||||||
import cryptography.hazmat.primitives.asymmetric.rsa
|
import cryptography.hazmat.primitives.asymmetric.rsa
|
||||||
from cryptography.hazmat.primitives import serialization, hashes
|
from cryptography.hazmat.primitives import serialization, hashes
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
@ -43,28 +49,16 @@ RnGfN8j8KLDVmWyTYMk8V+6j0LI4+4zFh2upqGMQHL3VFVFWBek6vCDWhB/b
|
|||||||
# so for now a cheap response is just give any old root CA
|
# so for now a cheap response is just give any old root CA
|
||||||
|
|
||||||
|
|
||||||
def datetime_to_epoch(date):
|
def datetime_to_epoch(date: datetime.datetime) -> float:
|
||||||
return date.timestamp()
|
return date.timestamp()
|
||||||
|
|
||||||
|
|
||||||
class AWSValidationException(AWSError):
|
class TagHolder(Dict[str, Optional[str]]):
|
||||||
TYPE = "ValidationException"
|
|
||||||
|
|
||||||
|
|
||||||
class AWSResourceNotFoundException(AWSError):
|
|
||||||
TYPE = "ResourceNotFoundException"
|
|
||||||
|
|
||||||
|
|
||||||
class AWSTooManyTagsException(AWSError):
|
|
||||||
TYPE = "TooManyTagsException"
|
|
||||||
|
|
||||||
|
|
||||||
class TagHolder(dict):
|
|
||||||
MAX_TAG_COUNT = 50
|
MAX_TAG_COUNT = 50
|
||||||
MAX_KEY_LENGTH = 128
|
MAX_KEY_LENGTH = 128
|
||||||
MAX_VALUE_LENGTH = 256
|
MAX_VALUE_LENGTH = 256
|
||||||
|
|
||||||
def _validate_kv(self, key, value, index):
|
def _validate_kv(self, key: str, value: Optional[str], index: int) -> None:
|
||||||
if len(key) > self.MAX_KEY_LENGTH:
|
if len(key) > self.MAX_KEY_LENGTH:
|
||||||
raise AWSValidationException(
|
raise AWSValidationException(
|
||||||
"Value '%s' at 'tags.%d.member.key' failed to satisfy constraint: Member must have length less than or equal to %s"
|
"Value '%s' at 'tags.%d.member.key' failed to satisfy constraint: Member must have length less than or equal to %s"
|
||||||
@ -81,11 +75,11 @@ class TagHolder(dict):
|
|||||||
% key
|
% key
|
||||||
)
|
)
|
||||||
|
|
||||||
def add(self, tags):
|
def add(self, tags: List[Dict[str, str]]) -> None:
|
||||||
tags_copy = self.copy()
|
tags_copy = self.copy()
|
||||||
for i, tag in enumerate(tags):
|
for i, tag in enumerate(tags):
|
||||||
key = tag["Key"]
|
key = tag["Key"]
|
||||||
value = tag.get("Value", None)
|
value = tag.get("Value")
|
||||||
self._validate_kv(key, value, i + 1)
|
self._validate_kv(key, value, i + 1)
|
||||||
|
|
||||||
tags_copy[key] = value
|
tags_copy[key] = value
|
||||||
@ -97,10 +91,10 @@ class TagHolder(dict):
|
|||||||
|
|
||||||
self.update(tags_copy)
|
self.update(tags_copy)
|
||||||
|
|
||||||
def remove(self, tags):
|
def remove(self, tags: List[Dict[str, str]]) -> None:
|
||||||
for i, tag in enumerate(tags):
|
for i, tag in enumerate(tags):
|
||||||
key = tag["Key"]
|
key = tag["Key"]
|
||||||
value = tag.get("Value", None)
|
value = tag.get("Value")
|
||||||
self._validate_kv(key, value, i + 1)
|
self._validate_kv(key, value, i + 1)
|
||||||
try:
|
try:
|
||||||
# If value isnt provided, just delete key
|
# If value isnt provided, just delete key
|
||||||
@ -112,45 +106,41 @@ class TagHolder(dict):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def equals(self, tags):
|
def equals(self, tags: List[Dict[str, str]]) -> bool:
|
||||||
tags = {t["Key"]: t.get("Value", None) for t in tags} if tags else {}
|
flat_tags = {t["Key"]: t.get("Value") for t in tags} if tags else {}
|
||||||
return self == tags
|
return self == flat_tags
|
||||||
|
|
||||||
|
|
||||||
class CertBundle(BaseModel):
|
class CertBundle(BaseModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
account_id,
|
account_id: str,
|
||||||
certificate,
|
certificate: bytes,
|
||||||
private_key,
|
private_key: bytes,
|
||||||
chain=None,
|
chain: Optional[bytes] = None,
|
||||||
region="us-east-1",
|
region: str = "us-east-1",
|
||||||
arn=None,
|
arn: Optional[str] = None,
|
||||||
cert_type="IMPORTED",
|
cert_type: str = "IMPORTED",
|
||||||
cert_status="ISSUED",
|
cert_status: str = "ISSUED",
|
||||||
):
|
):
|
||||||
self.created_at = datetime.datetime.utcnow()
|
self.created_at = datetime.datetime.utcnow()
|
||||||
self.cert = certificate
|
self.cert = certificate
|
||||||
self._cert = None
|
|
||||||
self.common_name = None
|
|
||||||
self.key = private_key
|
self.key = private_key
|
||||||
self._key = None
|
# AWS always returns your chain + root CA
|
||||||
self.chain = chain
|
self.chain = chain + b"\n" + AWS_ROOT_CA if chain else AWS_ROOT_CA
|
||||||
self.tags = TagHolder()
|
self.tags = TagHolder()
|
||||||
self._chain = None
|
|
||||||
self.type = cert_type # Should really be an enum
|
self.type = cert_type # Should really be an enum
|
||||||
self.status = cert_status # Should really be an enum
|
self.status = cert_status # Should really be an enum
|
||||||
self.in_use_by = []
|
self.in_use_by: List[str] = []
|
||||||
|
|
||||||
# AWS always returns your chain + root CA
|
|
||||||
if self.chain is None:
|
|
||||||
self.chain = AWS_ROOT_CA
|
|
||||||
else:
|
|
||||||
self.chain += b"\n" + AWS_ROOT_CA
|
|
||||||
|
|
||||||
# Takes care of PEM checking
|
# Takes care of PEM checking
|
||||||
self.validate_pk()
|
self._key = self.validate_pk()
|
||||||
self.validate_certificate()
|
self._cert = self.validate_certificate()
|
||||||
|
# Extracting some common fields for ease of use
|
||||||
|
# Have to search through cert.subject for OIDs
|
||||||
|
self.common_name: Any = self._cert.subject.get_attributes_for_oid(
|
||||||
|
OID_COMMON_NAME
|
||||||
|
)[0].value
|
||||||
if chain is not None:
|
if chain is not None:
|
||||||
self.validate_chain()
|
self.validate_chain()
|
||||||
|
|
||||||
@ -158,57 +148,43 @@ class CertBundle(BaseModel):
|
|||||||
# raise AWSValidationException('Provided certificate is not a valid self signed. Please provide either a valid self-signed certificate or certificate chain.')
|
# raise AWSValidationException('Provided certificate is not a valid self signed. Please provide either a valid self-signed certificate or certificate chain.')
|
||||||
|
|
||||||
# Used for when one wants to overwrite an arn
|
# Used for when one wants to overwrite an arn
|
||||||
if arn is None:
|
self.arn = arn or make_arn_for_certificate(account_id, region)
|
||||||
self.arn = make_arn_for_certificate(account_id, region)
|
|
||||||
else:
|
|
||||||
self.arn = arn
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_cert(cls, domain_name, account_id, region, sans=None):
|
def generate_cert(
|
||||||
if sans is None:
|
cls,
|
||||||
sans = set()
|
domain_name: str,
|
||||||
else:
|
account_id: str,
|
||||||
sans = set(sans)
|
region: str,
|
||||||
|
sans: Optional[List[str]] = None,
|
||||||
|
) -> "CertBundle":
|
||||||
|
unique_sans: Set[str] = set(sans) if sans else set()
|
||||||
|
|
||||||
sans.add(domain_name)
|
unique_sans.add(domain_name)
|
||||||
sans = [cryptography.x509.DNSName(item) for item in sans]
|
unique_dns_names = [DNSName(item) for item in unique_sans]
|
||||||
|
|
||||||
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
|
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
|
||||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||||
)
|
)
|
||||||
subject = cryptography.x509.Name(
|
subject = cryptography.x509.Name(
|
||||||
[
|
[
|
||||||
|
cryptography.x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
|
||||||
|
cryptography.x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "CA"),
|
||||||
|
cryptography.x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
|
||||||
cryptography.x509.NameAttribute(
|
cryptography.x509.NameAttribute(
|
||||||
cryptography.x509.NameOID.COUNTRY_NAME, "US"
|
NameOID.ORGANIZATION_NAME, "My Company"
|
||||||
),
|
|
||||||
cryptography.x509.NameAttribute(
|
|
||||||
cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, "CA"
|
|
||||||
),
|
|
||||||
cryptography.x509.NameAttribute(
|
|
||||||
cryptography.x509.NameOID.LOCALITY_NAME, "San Francisco"
|
|
||||||
),
|
|
||||||
cryptography.x509.NameAttribute(
|
|
||||||
cryptography.x509.NameOID.ORGANIZATION_NAME, "My Company"
|
|
||||||
),
|
|
||||||
cryptography.x509.NameAttribute(
|
|
||||||
cryptography.x509.NameOID.COMMON_NAME, domain_name
|
|
||||||
),
|
),
|
||||||
|
cryptography.x509.NameAttribute(NameOID.COMMON_NAME, domain_name),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
issuer = cryptography.x509.Name(
|
issuer = cryptography.x509.Name(
|
||||||
[ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
|
[ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
|
||||||
|
cryptography.x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
|
||||||
|
cryptography.x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Amazon"),
|
||||||
cryptography.x509.NameAttribute(
|
cryptography.x509.NameAttribute(
|
||||||
cryptography.x509.NameOID.COUNTRY_NAME, "US"
|
NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
|
||||||
),
|
|
||||||
cryptography.x509.NameAttribute(
|
|
||||||
cryptography.x509.NameOID.ORGANIZATION_NAME, "Amazon"
|
|
||||||
),
|
|
||||||
cryptography.x509.NameAttribute(
|
|
||||||
cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
|
|
||||||
),
|
|
||||||
cryptography.x509.NameAttribute(
|
|
||||||
cryptography.x509.NameOID.COMMON_NAME, "Amazon"
|
|
||||||
),
|
),
|
||||||
|
cryptography.x509.NameAttribute(NameOID.COMMON_NAME, "Amazon"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
cert = (
|
cert = (
|
||||||
@ -220,7 +196,8 @@ class CertBundle(BaseModel):
|
|||||||
.not_valid_before(datetime.datetime.utcnow())
|
.not_valid_before(datetime.datetime.utcnow())
|
||||||
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
|
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
|
||||||
.add_extension(
|
.add_extension(
|
||||||
cryptography.x509.SubjectAlternativeName(sans), critical=False
|
cryptography.x509.SubjectAlternativeName(unique_dns_names),
|
||||||
|
critical=False,
|
||||||
)
|
)
|
||||||
.sign(key, hashes.SHA512(), default_backend())
|
.sign(key, hashes.SHA512(), default_backend())
|
||||||
)
|
)
|
||||||
@ -241,17 +218,11 @@ class CertBundle(BaseModel):
|
|||||||
region=region,
|
region=region,
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_pk(self):
|
def validate_pk(self) -> Any:
|
||||||
try:
|
try:
|
||||||
self._key = serialization.load_pem_private_key(
|
return serialization.load_pem_private_key(
|
||||||
self.key, password=None, backend=default_backend()
|
self.key, password=None, backend=default_backend()
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._key.key_size > 2048:
|
|
||||||
AWSValidationException(
|
|
||||||
"The private key length is not supported. Only 1024-bit and 2048-bit are allowed."
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if isinstance(err, AWSValidationException):
|
if isinstance(err, AWSValidationException):
|
||||||
raise
|
raise
|
||||||
@ -259,48 +230,40 @@ class CertBundle(BaseModel):
|
|||||||
"The private key is not PEM-encoded or is not valid."
|
"The private key is not PEM-encoded or is not valid."
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_certificate(self):
|
def validate_certificate(self) -> cryptography.x509.base.Certificate:
|
||||||
try:
|
try:
|
||||||
self._cert = cryptography.x509.load_pem_x509_certificate(
|
_cert = cryptography.x509.load_pem_x509_certificate(
|
||||||
self.cert, default_backend()
|
self.cert, default_backend()
|
||||||
)
|
)
|
||||||
|
|
||||||
now = datetime.datetime.utcnow()
|
now = datetime.datetime.utcnow()
|
||||||
if self._cert.not_valid_after < now:
|
if _cert.not_valid_after < now:
|
||||||
raise AWSValidationException(
|
raise AWSValidationException(
|
||||||
"The certificate has expired, is not valid."
|
"The certificate has expired, is not valid."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._cert.not_valid_before > now:
|
if _cert.not_valid_before > now:
|
||||||
raise AWSValidationException(
|
raise AWSValidationException(
|
||||||
"The certificate is not in effect yet, is not valid."
|
"The certificate is not in effect yet, is not valid."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extracting some common fields for ease of use
|
|
||||||
# Have to search through cert.subject for OIDs
|
|
||||||
self.common_name = self._cert.subject.get_attributes_for_oid(
|
|
||||||
cryptography.x509.OID_COMMON_NAME
|
|
||||||
)[0].value
|
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if isinstance(err, AWSValidationException):
|
if isinstance(err, AWSValidationException):
|
||||||
raise
|
raise
|
||||||
raise AWSValidationException(
|
raise AWSValidationException(
|
||||||
"The certificate is not PEM-encoded or is not valid."
|
"The certificate is not PEM-encoded or is not valid."
|
||||||
)
|
)
|
||||||
|
return _cert
|
||||||
|
|
||||||
def validate_chain(self):
|
def validate_chain(self) -> None:
|
||||||
try:
|
try:
|
||||||
self._chain = []
|
|
||||||
|
|
||||||
for cert_armored in self.chain.split(b"-\n-"):
|
for cert_armored in self.chain.split(b"-\n-"):
|
||||||
# Fix missing -'s on split
|
# Fix missing -'s on split
|
||||||
cert_armored = re.sub(b"^----B", b"-----B", cert_armored)
|
cert_armored = re.sub(b"^----B", b"-----B", cert_armored)
|
||||||
cert_armored = re.sub(b"E----$", b"E-----", cert_armored)
|
cert_armored = re.sub(b"E----$", b"E-----", cert_armored)
|
||||||
cert = cryptography.x509.load_pem_x509_certificate(
|
cryptography.x509.load_pem_x509_certificate(
|
||||||
cert_armored, default_backend()
|
cert_armored, default_backend()
|
||||||
)
|
)
|
||||||
self._chain.append(cert)
|
|
||||||
|
|
||||||
now = datetime.datetime.utcnow()
|
now = datetime.datetime.utcnow()
|
||||||
if self._cert.not_valid_after < now:
|
if self._cert.not_valid_after < now:
|
||||||
@ -320,7 +283,7 @@ class CertBundle(BaseModel):
|
|||||||
"The certificate is not PEM-encoded or is not valid."
|
"The certificate is not PEM-encoded or is not valid."
|
||||||
)
|
)
|
||||||
|
|
||||||
def check(self):
|
def check(self) -> None:
|
||||||
# Basically, if the certificate is pending, and then checked again after a
|
# Basically, if the certificate is pending, and then checked again after a
|
||||||
# while, it will appear as if its been validated. The default wait time is 60
|
# while, it will appear as if its been validated. The default wait time is 60
|
||||||
# seconds but you can set an environment to change it.
|
# seconds but you can set an environment to change it.
|
||||||
@ -332,7 +295,7 @@ class CertBundle(BaseModel):
|
|||||||
):
|
):
|
||||||
self.status = "ISSUED"
|
self.status = "ISSUED"
|
||||||
|
|
||||||
def describe(self):
|
def describe(self) -> Dict[str, Any]:
|
||||||
# 'RenewalSummary': {}, # Only when cert is amazon issued
|
# 'RenewalSummary': {}, # Only when cert is amazon issued
|
||||||
if self._key.key_size == 1024:
|
if self._key.key_size == 1024:
|
||||||
key_algo = "RSA_1024"
|
key_algo = "RSA_1024"
|
||||||
@ -343,7 +306,7 @@ class CertBundle(BaseModel):
|
|||||||
|
|
||||||
# Look for SANs
|
# Look for SANs
|
||||||
try:
|
try:
|
||||||
san_obj = self._cert.extensions.get_extension_for_oid(
|
san_obj: Any = self._cert.extensions.get_extension_for_oid(
|
||||||
cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME
|
cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME
|
||||||
)
|
)
|
||||||
except cryptography.x509.ExtensionNotFound:
|
except cryptography.x509.ExtensionNotFound:
|
||||||
@ -352,14 +315,14 @@ class CertBundle(BaseModel):
|
|||||||
if san_obj is not None:
|
if san_obj is not None:
|
||||||
sans = [item.value for item in san_obj.value]
|
sans = [item.value for item in san_obj.value]
|
||||||
|
|
||||||
result = {
|
result: Dict[str, Any] = {
|
||||||
"Certificate": {
|
"Certificate": {
|
||||||
"CertificateArn": self.arn,
|
"CertificateArn": self.arn,
|
||||||
"DomainName": self.common_name,
|
"DomainName": self.common_name,
|
||||||
"InUseBy": self.in_use_by,
|
"InUseBy": self.in_use_by,
|
||||||
"Issuer": self._cert.issuer.get_attributes_for_oid(
|
"Issuer": self._cert.issuer.get_attributes_for_oid(OID_COMMON_NAME)[
|
||||||
cryptography.x509.OID_COMMON_NAME
|
0
|
||||||
)[0].value,
|
].value,
|
||||||
"KeyAlgorithm": key_algo,
|
"KeyAlgorithm": key_algo,
|
||||||
"NotAfter": datetime_to_epoch(self._cert.not_valid_after),
|
"NotAfter": datetime_to_epoch(self._cert.not_valid_after),
|
||||||
"NotBefore": datetime_to_epoch(self._cert.not_valid_before),
|
"NotBefore": datetime_to_epoch(self._cert.not_valid_before),
|
||||||
@ -401,7 +364,7 @@ class CertBundle(BaseModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def serialize_pk(self, passphrase_bytes):
|
def serialize_pk(self, passphrase_bytes: bytes) -> str:
|
||||||
pk_bytes = self._key.private_bytes(
|
pk_bytes = self._key.private_bytes(
|
||||||
encoding=serialization.Encoding.PEM,
|
encoding=serialization.Encoding.PEM,
|
||||||
format=serialization.PrivateFormat.PKCS8,
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
@ -411,38 +374,36 @@ class CertBundle(BaseModel):
|
|||||||
)
|
)
|
||||||
return pk_bytes.decode("utf-8")
|
return pk_bytes.decode("utf-8")
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return self.arn
|
return self.arn
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return "<Certificate>"
|
return "<Certificate>"
|
||||||
|
|
||||||
|
|
||||||
class AWSCertificateManagerBackend(BaseBackend):
|
class AWSCertificateManagerBackend(BaseBackend):
|
||||||
def __init__(self, region_name, account_id):
|
def __init__(self, region_name: str, account_id: str):
|
||||||
super().__init__(region_name, account_id)
|
super().__init__(region_name, account_id)
|
||||||
self._certificates = {}
|
self._certificates: Dict[str, CertBundle] = {}
|
||||||
self._idempotency_tokens = {}
|
self._idempotency_tokens: Dict[str, Any] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_vpc_endpoint_service(service_region, zones):
|
def default_vpc_endpoint_service(
|
||||||
|
service_region: str, zones: List[str]
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
"""Default VPC endpoint service."""
|
"""Default VPC endpoint service."""
|
||||||
return BaseBackend.default_vpc_endpoint_service_factory(
|
return BaseBackend.default_vpc_endpoint_service_factory(
|
||||||
service_region, zones, "acm-pca"
|
service_region, zones, "acm-pca"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _arn_not_found(self, arn):
|
def set_certificate_in_use_by(self, arn: str, load_balancer_name: str) -> None:
|
||||||
msg = f"Certificate with arn {arn} not found in account {self.account_id}"
|
|
||||||
return AWSResourceNotFoundException(msg)
|
|
||||||
|
|
||||||
def set_certificate_in_use_by(self, arn, load_balancer_name):
|
|
||||||
if arn not in self._certificates:
|
if arn not in self._certificates:
|
||||||
raise self._arn_not_found(arn)
|
raise CertificateNotFound(arn=arn, account_id=self.account_id)
|
||||||
|
|
||||||
cert_bundle = self._certificates[arn]
|
cert_bundle = self._certificates[arn]
|
||||||
cert_bundle.in_use_by.append(load_balancer_name)
|
cert_bundle.in_use_by.append(load_balancer_name)
|
||||||
|
|
||||||
def _get_arn_from_idempotency_token(self, token):
|
def _get_arn_from_idempotency_token(self, token: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
If token doesnt exist, return None, later it will be
|
If token doesnt exist, return None, later it will be
|
||||||
set with an expiry and arn.
|
set with an expiry and arn.
|
||||||
@ -465,16 +426,23 @@ class AWSCertificateManagerBackend(BaseBackend):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _set_idempotency_token_arn(self, token, arn):
|
def _set_idempotency_token_arn(self, token: str, arn: str) -> None:
|
||||||
self._idempotency_tokens[token] = {
|
self._idempotency_tokens[token] = {
|
||||||
"arn": arn,
|
"arn": arn,
|
||||||
"expires": datetime.datetime.utcnow() + datetime.timedelta(hours=1),
|
"expires": datetime.datetime.utcnow() + datetime.timedelta(hours=1),
|
||||||
}
|
}
|
||||||
|
|
||||||
def import_cert(self, certificate, private_key, chain=None, arn=None, tags=None):
|
def import_cert(
|
||||||
|
self,
|
||||||
|
certificate: bytes,
|
||||||
|
private_key: bytes,
|
||||||
|
chain: Optional[bytes],
|
||||||
|
arn: Optional[str],
|
||||||
|
tags: List[Dict[str, str]],
|
||||||
|
) -> str:
|
||||||
if arn is not None:
|
if arn is not None:
|
||||||
if arn not in self._certificates:
|
if arn not in self._certificates:
|
||||||
raise self._arn_not_found(arn)
|
raise CertificateNotFound(arn=arn, account_id=self.account_id)
|
||||||
else:
|
else:
|
||||||
# Will reuse provided ARN
|
# Will reuse provided ARN
|
||||||
bundle = CertBundle(
|
bundle = CertBundle(
|
||||||
@ -502,7 +470,7 @@ class AWSCertificateManagerBackend(BaseBackend):
|
|||||||
|
|
||||||
return bundle.arn
|
return bundle.arn
|
||||||
|
|
||||||
def get_certificates_list(self, statuses):
|
def get_certificates_list(self, statuses: List[str]) -> Iterable[CertBundle]:
|
||||||
"""
|
"""
|
||||||
Get list of certificates
|
Get list of certificates
|
||||||
|
|
||||||
@ -514,27 +482,27 @@ class AWSCertificateManagerBackend(BaseBackend):
|
|||||||
if not statuses or cert.status in statuses:
|
if not statuses or cert.status in statuses:
|
||||||
yield cert
|
yield cert
|
||||||
|
|
||||||
def get_certificate(self, arn):
|
def get_certificate(self, arn: str) -> CertBundle:
|
||||||
if arn not in self._certificates:
|
if arn not in self._certificates:
|
||||||
raise self._arn_not_found(arn)
|
raise CertificateNotFound(arn=arn, account_id=self.account_id)
|
||||||
|
|
||||||
cert_bundle = self._certificates[arn]
|
cert_bundle = self._certificates[arn]
|
||||||
cert_bundle.check()
|
cert_bundle.check()
|
||||||
return cert_bundle
|
return cert_bundle
|
||||||
|
|
||||||
def delete_certificate(self, arn):
|
def delete_certificate(self, arn: str) -> None:
|
||||||
if arn not in self._certificates:
|
if arn not in self._certificates:
|
||||||
raise self._arn_not_found(arn)
|
raise CertificateNotFound(arn=arn, account_id=self.account_id)
|
||||||
|
|
||||||
del self._certificates[arn]
|
del self._certificates[arn]
|
||||||
|
|
||||||
def request_certificate(
|
def request_certificate(
|
||||||
self,
|
self,
|
||||||
domain_name,
|
domain_name: str,
|
||||||
idempotency_token,
|
idempotency_token: str,
|
||||||
subject_alt_names,
|
subject_alt_names: List[str],
|
||||||
tags=None,
|
tags: List[Dict[str, str]],
|
||||||
):
|
) -> str:
|
||||||
"""
|
"""
|
||||||
The parameter DomainValidationOptions has not yet been implemented
|
The parameter DomainValidationOptions has not yet been implemented
|
||||||
"""
|
"""
|
||||||
@ -558,17 +526,21 @@ class AWSCertificateManagerBackend(BaseBackend):
|
|||||||
|
|
||||||
return cert.arn
|
return cert.arn
|
||||||
|
|
||||||
def add_tags_to_certificate(self, arn, tags):
|
def add_tags_to_certificate(self, arn: str, tags: List[Dict[str, str]]) -> None:
|
||||||
# get_cert does arn check
|
# get_cert does arn check
|
||||||
cert_bundle = self.get_certificate(arn)
|
cert_bundle = self.get_certificate(arn)
|
||||||
cert_bundle.tags.add(tags)
|
cert_bundle.tags.add(tags)
|
||||||
|
|
||||||
def remove_tags_from_certificate(self, arn, tags):
|
def remove_tags_from_certificate(
|
||||||
|
self, arn: str, tags: List[Dict[str, str]]
|
||||||
|
) -> None:
|
||||||
# get_cert does arn check
|
# get_cert does arn check
|
||||||
cert_bundle = self.get_certificate(arn)
|
cert_bundle = self.get_certificate(arn)
|
||||||
cert_bundle.tags.remove(tags)
|
cert_bundle.tags.remove(tags)
|
||||||
|
|
||||||
def export_certificate(self, certificate_arn, passphrase):
|
def export_certificate(
|
||||||
|
self, certificate_arn: str, passphrase: str
|
||||||
|
) -> Tuple[str, str, str]:
|
||||||
passphrase_bytes = base64.standard_b64decode(passphrase)
|
passphrase_bytes = base64.standard_b64decode(passphrase)
|
||||||
cert_bundle = self.get_certificate(certificate_arn)
|
cert_bundle = self.get_certificate(certificate_arn)
|
||||||
|
|
||||||
|
@ -2,34 +2,23 @@ import json
|
|||||||
import base64
|
import base64
|
||||||
|
|
||||||
from moto.core.responses import BaseResponse
|
from moto.core.responses import BaseResponse
|
||||||
from .models import acm_backends, AWSValidationException
|
from typing import Dict, List, Tuple, Union
|
||||||
|
from .models import acm_backends, AWSCertificateManagerBackend
|
||||||
|
from .exceptions import AWSValidationException
|
||||||
|
|
||||||
|
|
||||||
|
GENERIC_RESPONSE_TYPE = Union[str, Tuple[str, Dict[str, int]]]
|
||||||
|
|
||||||
|
|
||||||
class AWSCertificateManagerResponse(BaseResponse):
|
class AWSCertificateManagerResponse(BaseResponse):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(service_name="acm")
|
super().__init__(service_name="acm")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def acm_backend(self):
|
def acm_backend(self) -> AWSCertificateManagerBackend:
|
||||||
"""
|
|
||||||
ACM Backend
|
|
||||||
|
|
||||||
:return: ACM Backend object
|
|
||||||
:rtype: moto.acm.models.AWSCertificateManagerBackend
|
|
||||||
"""
|
|
||||||
return acm_backends[self.current_account][self.region]
|
return acm_backends[self.current_account][self.region]
|
||||||
|
|
||||||
@property
|
def add_tags_to_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||||
def request_params(self):
|
|
||||||
try:
|
|
||||||
return json.loads(self.body)
|
|
||||||
except ValueError:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _get_param(self, param_name, if_none=None):
|
|
||||||
return self.request_params.get(param_name, if_none)
|
|
||||||
|
|
||||||
def add_tags_to_certificate(self):
|
|
||||||
arn = self._get_param("CertificateArn")
|
arn = self._get_param("CertificateArn")
|
||||||
tags = self._get_param("Tags")
|
tags = self._get_param("Tags")
|
||||||
|
|
||||||
@ -44,7 +33,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def delete_certificate(self):
|
def delete_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||||
arn = self._get_param("CertificateArn")
|
arn = self._get_param("CertificateArn")
|
||||||
|
|
||||||
if arn is None:
|
if arn is None:
|
||||||
@ -58,7 +47,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def describe_certificate(self):
|
def describe_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||||
arn = self._get_param("CertificateArn")
|
arn = self._get_param("CertificateArn")
|
||||||
|
|
||||||
if arn is None:
|
if arn is None:
|
||||||
@ -72,7 +61,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
|||||||
|
|
||||||
return json.dumps(cert_bundle.describe())
|
return json.dumps(cert_bundle.describe())
|
||||||
|
|
||||||
def get_certificate(self):
|
def get_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||||
arn = self._get_param("CertificateArn")
|
arn = self._get_param("CertificateArn")
|
||||||
|
|
||||||
if arn is None:
|
if arn is None:
|
||||||
@ -90,7 +79,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
|||||||
}
|
}
|
||||||
return json.dumps(result)
|
return json.dumps(result)
|
||||||
|
|
||||||
def import_certificate(self):
|
def import_certificate(self) -> str:
|
||||||
"""
|
"""
|
||||||
Returns errors on:
|
Returns errors on:
|
||||||
Certificate, PrivateKey or Chain not being properly formatted
|
Certificate, PrivateKey or Chain not being properly formatted
|
||||||
@ -137,7 +126,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
|||||||
|
|
||||||
return json.dumps({"CertificateArn": arn})
|
return json.dumps({"CertificateArn": arn})
|
||||||
|
|
||||||
def list_certificates(self):
|
def list_certificates(self) -> str:
|
||||||
certs = []
|
certs = []
|
||||||
statuses = self._get_param("CertificateStatuses")
|
statuses = self._get_param("CertificateStatuses")
|
||||||
for cert_bundle in self.acm_backend.get_certificates_list(statuses):
|
for cert_bundle in self.acm_backend.get_certificates_list(statuses):
|
||||||
@ -151,16 +140,18 @@ class AWSCertificateManagerResponse(BaseResponse):
|
|||||||
result = {"CertificateSummaryList": certs}
|
result = {"CertificateSummaryList": certs}
|
||||||
return json.dumps(result)
|
return json.dumps(result)
|
||||||
|
|
||||||
def list_tags_for_certificate(self):
|
def list_tags_for_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||||
arn = self._get_param("CertificateArn")
|
arn = self._get_param("CertificateArn")
|
||||||
|
|
||||||
if arn is None:
|
if arn is None:
|
||||||
msg = "A required parameter for the specified action is not supplied."
|
msg = "A required parameter for the specified action is not supplied."
|
||||||
return {"__type": "MissingParameter", "message": msg}, dict(status=400)
|
return json.dumps({"__type": "MissingParameter", "message": msg}), dict(
|
||||||
|
status=400
|
||||||
|
)
|
||||||
|
|
||||||
cert_bundle = self.acm_backend.get_certificate(arn)
|
cert_bundle = self.acm_backend.get_certificate(arn)
|
||||||
|
|
||||||
result = {"Tags": []}
|
result: Dict[str, List[Dict[str, str]]] = {"Tags": []}
|
||||||
# Tag "objects" can not contain the Value part
|
# Tag "objects" can not contain the Value part
|
||||||
for key, value in cert_bundle.tags.items():
|
for key, value in cert_bundle.tags.items():
|
||||||
tag_dict = {"Key": key}
|
tag_dict = {"Key": key}
|
||||||
@ -170,7 +161,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
|||||||
|
|
||||||
return json.dumps(result)
|
return json.dumps(result)
|
||||||
|
|
||||||
def remove_tags_from_certificate(self):
|
def remove_tags_from_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||||
arn = self._get_param("CertificateArn")
|
arn = self._get_param("CertificateArn")
|
||||||
tags = self._get_param("Tags")
|
tags = self._get_param("Tags")
|
||||||
|
|
||||||
@ -185,7 +176,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def request_certificate(self):
|
def request_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||||
domain_name = self._get_param("DomainName")
|
domain_name = self._get_param("DomainName")
|
||||||
idempotency_token = self._get_param("IdempotencyToken")
|
idempotency_token = self._get_param("IdempotencyToken")
|
||||||
subject_alt_names = self._get_param("SubjectAlternativeNames")
|
subject_alt_names = self._get_param("SubjectAlternativeNames")
|
||||||
@ -210,7 +201,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
|||||||
|
|
||||||
return json.dumps({"CertificateArn": arn})
|
return json.dumps({"CertificateArn": arn})
|
||||||
|
|
||||||
def resend_validation_email(self):
|
def resend_validation_email(self) -> GENERIC_RESPONSE_TYPE:
|
||||||
arn = self._get_param("CertificateArn")
|
arn = self._get_param("CertificateArn")
|
||||||
domain = self._get_param("Domain")
|
domain = self._get_param("Domain")
|
||||||
# ValidationDomain not used yet.
|
# ValidationDomain not used yet.
|
||||||
@ -235,7 +226,7 @@ class AWSCertificateManagerResponse(BaseResponse):
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def export_certificate(self):
|
def export_certificate(self) -> GENERIC_RESPONSE_TYPE:
|
||||||
certificate_arn = self._get_param("CertificateArn")
|
certificate_arn = self._get_param("CertificateArn")
|
||||||
passphrase = self._get_param("Passphrase")
|
passphrase = self._get_param("Passphrase")
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from moto.moto_api._internal import mock_random
|
from moto.moto_api._internal import mock_random
|
||||||
|
|
||||||
|
|
||||||
def make_arn_for_certificate(account_id, region_name):
|
def make_arn_for_certificate(account_id: str, region_name: str) -> str:
|
||||||
# Example
|
# Example
|
||||||
# arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b
|
# arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b
|
||||||
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(
|
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from werkzeug.exceptions import HTTPException
|
from werkzeug.exceptions import HTTPException
|
||||||
from jinja2 import DictLoader, Environment
|
from jinja2 import DictLoader, Environment
|
||||||
|
from typing import Optional
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# TODO: add "<Type>Sender</Type>" to error responses below?
|
# TODO: add "<Type>Sender</Type>" to error responses below?
|
||||||
@ -133,10 +134,12 @@ class AuthFailureError(RESTError):
|
|||||||
|
|
||||||
|
|
||||||
class AWSError(JsonRESTError):
|
class AWSError(JsonRESTError):
|
||||||
TYPE = None
|
TYPE: Optional[str] = None
|
||||||
STATUS = 400
|
STATUS = 400
|
||||||
|
|
||||||
def __init__(self, message, exception_type=None, status=None):
|
def __init__(
|
||||||
|
self, message: str, exception_type: str = None, status: Optional[int] = None
|
||||||
|
):
|
||||||
super().__init__(exception_type or self.TYPE, message)
|
super().__init__(exception_type or self.TYPE, message)
|
||||||
self.code = status or self.STATUS
|
self.code = status or self.STATUS
|
||||||
|
|
||||||
|
@ -575,12 +575,12 @@ class ELBBackend(BaseBackend):
|
|||||||
return load_balancer
|
return load_balancer
|
||||||
|
|
||||||
def _register_certificate(self, ssl_certificate_id, dns_name):
|
def _register_certificate(self, ssl_certificate_id, dns_name):
|
||||||
from moto.acm.models import acm_backends, AWSResourceNotFoundException
|
from moto.acm.models import acm_backends, CertificateNotFound
|
||||||
|
|
||||||
acm_backend = acm_backends[self.account_id][self.region_name]
|
acm_backend = acm_backends[self.account_id][self.region_name]
|
||||||
try:
|
try:
|
||||||
acm_backend.set_certificate_in_use_by(ssl_certificate_id, dns_name)
|
acm_backend.set_certificate_in_use_by(ssl_certificate_id, dns_name)
|
||||||
except AWSResourceNotFoundException:
|
except CertificateNotFound:
|
||||||
raise CertificateNotFoundException()
|
raise CertificateNotFoundException()
|
||||||
|
|
||||||
def enable_availability_zones_for_load_balancer(
|
def enable_availability_zones_for_load_balancer(
|
||||||
|
@ -1548,14 +1548,13 @@ Member must satisfy regular expression pattern: {}".format(
|
|||||||
"""
|
"""
|
||||||
Verify the provided certificate exists in either ACM or IAM
|
Verify the provided certificate exists in either ACM or IAM
|
||||||
"""
|
"""
|
||||||
from moto.acm import acm_backends
|
from moto.acm.models import acm_backends, CertificateNotFound
|
||||||
from moto.acm.models import AWSResourceNotFoundException
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
acm_backend = acm_backends[self.account_id][self.region_name]
|
acm_backend = acm_backends[self.account_id][self.region_name]
|
||||||
acm_backend.get_certificate(certificate_arn)
|
acm_backend.get_certificate(certificate_arn)
|
||||||
return True
|
return True
|
||||||
except AWSResourceNotFoundException:
|
except CertificateNotFound:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from moto.iam import iam_backends
|
from moto.iam import iam_backends
|
||||||
|
@ -433,6 +433,18 @@ def test_request_certificate_with_tags():
|
|||||||
{"Key": "WithEmptyStr", "Value": ""},
|
{"Key": "WithEmptyStr", "Value": ""},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
arn_3 = resp["CertificateArn"]
|
||||||
|
|
||||||
|
assert arn_1 != arn_3 # if tags are matched, ACM would have returned same arn
|
||||||
|
|
||||||
|
resp = client.request_certificate(
|
||||||
|
DomainName="google.com",
|
||||||
|
IdempotencyToken=token,
|
||||||
|
SubjectAlternativeNames=["google.com", "www.google.com", "mail.google.com"],
|
||||||
|
)
|
||||||
|
arn_4 = resp["CertificateArn"]
|
||||||
|
|
||||||
|
assert arn_1 != arn_4 # if tags are matched, ACM would have returned same arn
|
||||||
|
|
||||||
|
|
||||||
@mock_acm
|
@mock_acm
|
||||||
|
Loading…
Reference in New Issue
Block a user