RDS: db instance identifier validation (#6519)

This commit is contained in:
rafcio19 2023-07-13 22:31:14 +01:00 committed by GitHub
parent 9b641e3c29
commit e0ceec9e48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 214 additions and 133 deletions

View File

@ -192,3 +192,12 @@ class InvalidGlobalClusterStateFault(RDSClientError):
super().__init__( super().__init__(
"InvalidGlobalClusterStateFault", f"Global Cluster {arn} is not empty" "InvalidGlobalClusterStateFault", f"Global Cluster {arn} is not empty"
) )
class InvalidDBInstanceIdentifier(InvalidParameterValue):
def __init__(self) -> None:
super().__init__(
"The parameter DBInstanceIdentifier is not a valid identifier. "
"Identifiers must begin with a letter; must contain only ASCII letters, digits, and hyphens; "
"and must not end with a hyphen or contain two consecutive hyphens."
)

View File

@ -1,6 +1,7 @@
import copy import copy
import datetime import datetime
import os import os
import re
import string import string
from collections import defaultdict from collections import defaultdict
@ -33,6 +34,7 @@ from .exceptions import (
InvalidParameterValue, InvalidParameterValue,
InvalidParameterCombination, InvalidParameterCombination,
InvalidDBClusterStateFault, InvalidDBClusterStateFault,
InvalidDBInstanceIdentifier,
InvalidGlobalClusterStateFault, InvalidGlobalClusterStateFault,
ExportTaskNotFoundError, ExportTaskNotFoundError,
ExportTaskAlreadyExistsError, ExportTaskAlreadyExistsError,
@ -921,7 +923,7 @@ class Database(CloudFormationModel):
"availability_zone": properties.get("AvailabilityZone"), "availability_zone": properties.get("AvailabilityZone"),
"backup_retention_period": properties.get("BackupRetentionPeriod"), "backup_retention_period": properties.get("BackupRetentionPeriod"),
"db_instance_class": properties.get("DBInstanceClass"), "db_instance_class": properties.get("DBInstanceClass"),
"db_instance_identifier": resource_name, "db_instance_identifier": resource_name.replace("_", "-"),
"db_name": properties.get("DBName"), "db_name": properties.get("DBName"),
"preferred_backup_window": properties.get( "preferred_backup_window": properties.get(
"PreferredBackupWindow", "13:14-13:44" "PreferredBackupWindow", "13:14-13:44"
@ -1562,6 +1564,7 @@ class RDSBackend(BaseBackend):
def create_db_instance(self, db_kwargs: Dict[str, Any]) -> Database: def create_db_instance(self, db_kwargs: Dict[str, Any]) -> Database:
database_id = db_kwargs["db_instance_identifier"] database_id = db_kwargs["db_instance_identifier"]
self._validate_db_identifier(database_id)
database = Database(**db_kwargs) database = Database(**db_kwargs)
cluster_id = database.db_cluster_identifier cluster_id = database.db_cluster_identifier
@ -1742,6 +1745,7 @@ class RDSBackend(BaseBackend):
def stop_db_instance( def stop_db_instance(
self, db_instance_identifier: str, db_snapshot_identifier: Optional[str] = None self, db_instance_identifier: str, db_snapshot_identifier: Optional[str] = None
) -> Database: ) -> Database:
self._validate_db_identifier(db_instance_identifier)
database = self.describe_db_instances(db_instance_identifier)[0] database = self.describe_db_instances(db_instance_identifier)[0]
# todo: certain rds types not allowed to be stopped at this time. # todo: certain rds types not allowed to be stopped at this time.
# https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_StopInstance.html#USER_StopInstance.Limitations # https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_StopInstance.html#USER_StopInstance.Limitations
@ -1758,6 +1762,7 @@ class RDSBackend(BaseBackend):
return database return database
def start_db_instance(self, db_instance_identifier: str) -> Database: def start_db_instance(self, db_instance_identifier: str) -> Database:
self._validate_db_identifier(db_instance_identifier)
database = self.describe_db_instances(db_instance_identifier)[0] database = self.describe_db_instances(db_instance_identifier)[0]
# todo: bunch of different error messages to be generated from this api call # todo: bunch of different error messages to be generated from this api call
if database.status != "stopped": if database.status != "stopped":
@ -1780,6 +1785,7 @@ class RDSBackend(BaseBackend):
def delete_db_instance( def delete_db_instance(
self, db_instance_identifier: str, db_snapshot_name: Optional[str] = None self, db_instance_identifier: str, db_snapshot_name: Optional[str] = None
) -> Database: ) -> Database:
self._validate_db_identifier(db_instance_identifier)
if db_instance_identifier in self.databases: if db_instance_identifier in self.databases:
if self.databases[db_instance_identifier].deletion_protection: if self.databases[db_instance_identifier].deletion_protection:
raise InvalidParameterValue( raise InvalidParameterValue(
@ -2588,6 +2594,19 @@ class RDSBackend(BaseBackend):
tags_dict.update({d["Key"]: d["Value"] for d in new_tags}) tags_dict.update({d["Key"]: d["Value"] for d in new_tags})
return [{"Key": k, "Value": v} for k, v in tags_dict.items()] return [{"Key": k, "Value": v} for k, v in tags_dict.items()]
@staticmethod
def _validate_db_identifier(db_identifier: str) -> None:
# https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_CreateDBInstance.html
# Constraints:
# # Must contain from 1 to 63 letters, numbers, or hyphens.
# # First character must be a letter.
# # Can't end with a hyphen or contain two consecutive hyphens.
if re.match(
"^(?!.*--)([a-zA-Z]?[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])$", db_identifier
):
return
raise InvalidDBInstanceIdentifier
def describe_orderable_db_instance_options( def describe_orderable_db_instance_options(
self, engine: str, engine_version: str self, engine: str, engine_version: str
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:

File diff suppressed because it is too large Load Diff

View File

@ -144,7 +144,7 @@ def test_rds_db_parameter_groups():
Parameters=[ Parameters=[
{"ParameterKey": key, "ParameterValue": value} {"ParameterKey": key, "ParameterValue": value}
for key, value in [ for key, value in [
("DBInstanceIdentifier", "master_db"), ("DBInstanceIdentifier", "master-db"),
("DBName", "my_db"), ("DBName", "my_db"),
("DBUser", "my_user"), ("DBUser", "my_user"),
("DBPassword", "my_password"), ("DBPassword", "my_password"),
@ -188,11 +188,12 @@ def test_rds_mysql_with_read_replica():
template_json = json.dumps(rds_mysql_with_read_replica.template) template_json = json.dumps(rds_mysql_with_read_replica.template)
cf = boto3.client("cloudformation", "us-west-1") cf = boto3.client("cloudformation", "us-west-1")
db_identifier = "master-db"
cf.create_stack( cf.create_stack(
StackName="test_stack", StackName="test_stack",
TemplateBody=template_json, TemplateBody=template_json,
Parameters=[ Parameters=[
{"ParameterKey": "DBInstanceIdentifier", "ParameterValue": "master_db"}, {"ParameterKey": "DBInstanceIdentifier", "ParameterValue": db_identifier},
{"ParameterKey": "DBName", "ParameterValue": "my_db"}, {"ParameterKey": "DBName", "ParameterValue": "my_db"},
{"ParameterKey": "DBUser", "ParameterValue": "my_user"}, {"ParameterKey": "DBUser", "ParameterValue": "my_user"},
{"ParameterKey": "DBPassword", "ParameterValue": "my_password"}, {"ParameterKey": "DBPassword", "ParameterValue": "my_password"},
@ -205,27 +206,27 @@ def test_rds_mysql_with_read_replica():
rds = boto3.client("rds", region_name="us-west-1") rds = boto3.client("rds", region_name="us-west-1")
primary = rds.describe_db_instances(DBInstanceIdentifier="master_db")[ primary = rds.describe_db_instances(DBInstanceIdentifier=db_identifier)[
"DBInstances" "DBInstances"
][0] ][0]
primary.should.have.key("MasterUsername").equal("my_user") assert primary["MasterUsername"] == "my_user"
primary.should.have.key("AllocatedStorage").equal(20) assert primary["AllocatedStorage"] == 20
primary.should.have.key("DBInstanceClass").equal("db.m1.medium") assert primary["DBInstanceClass"] == "db.m1.medium"
primary.should.have.key("MultiAZ").equal(True) assert primary["MultiAZ"]
primary.should.have.key("ReadReplicaDBInstanceIdentifiers").being.length_of(1) assert len(primary["ReadReplicaDBInstanceIdentifiers"]) == 1
replica_id = primary["ReadReplicaDBInstanceIdentifiers"][0] replica_id = primary["ReadReplicaDBInstanceIdentifiers"][0]
replica = rds.describe_db_instances(DBInstanceIdentifier=replica_id)["DBInstances"][ replica = rds.describe_db_instances(DBInstanceIdentifier=replica_id)["DBInstances"][
0 0
] ]
replica.should.have.key("DBInstanceClass").equal("db.m1.medium") assert replica["DBInstanceClass"] == "db.m1.medium"
security_group_name = primary["DBSecurityGroups"][0]["DBSecurityGroupName"] security_group_name = primary["DBSecurityGroups"][0]["DBSecurityGroupName"]
security_group = rds.describe_db_security_groups( security_group = rds.describe_db_security_groups(
DBSecurityGroupName=security_group_name DBSecurityGroupName=security_group_name
)["DBSecurityGroups"][0] )["DBSecurityGroups"][0]
security_group["EC2SecurityGroups"][0]["EC2SecurityGroupName"].should.equal( assert (
"application" security_group["EC2SecurityGroups"][0]["EC2SecurityGroupName"] == "application"
) )
@ -235,11 +236,12 @@ def test_rds_mysql_with_read_replica():
def test_rds_mysql_with_read_replica_in_vpc(): def test_rds_mysql_with_read_replica_in_vpc():
template_json = json.dumps(rds_mysql_with_read_replica.template) template_json = json.dumps(rds_mysql_with_read_replica.template)
cf = boto3.client("cloudformation", "eu-central-1") cf = boto3.client("cloudformation", "eu-central-1")
db_identifier = "master-db"
cf.create_stack( cf.create_stack(
StackName="test_stack", StackName="test_stack",
TemplateBody=template_json, TemplateBody=template_json,
Parameters=[ Parameters=[
{"ParameterKey": "DBInstanceIdentifier", "ParameterValue": "master_db"}, {"ParameterKey": "DBInstanceIdentifier", "ParameterValue": db_identifier},
{"ParameterKey": "DBName", "ParameterValue": "my_db"}, {"ParameterKey": "DBName", "ParameterValue": "my_db"},
{"ParameterKey": "DBUser", "ParameterValue": "my_user"}, {"ParameterKey": "DBUser", "ParameterValue": "my_user"},
{"ParameterKey": "DBPassword", "ParameterValue": "my_password"}, {"ParameterKey": "DBPassword", "ParameterValue": "my_password"},
@ -250,7 +252,7 @@ def test_rds_mysql_with_read_replica_in_vpc():
) )
rds = boto3.client("rds", region_name="eu-central-1") rds = boto3.client("rds", region_name="eu-central-1")
primary = rds.describe_db_instances(DBInstanceIdentifier="master_db")[ primary = rds.describe_db_instances(DBInstanceIdentifier=db_identifier)[
"DBInstances" "DBInstances"
][0] ][0]