diff --git a/AUTHORS.md b/AUTHORS.md
index 6b7c96291..0a152505a 100644
--- a/AUTHORS.md
+++ b/AUTHORS.md
@@ -53,3 +53,4 @@ Moto is written by Steve Pulec with contributions from:
* [Jim Shields](https://github.com/jimjshields)
* [William Richard](https://github.com/william-richard)
* [Alex Casalboni](https://github.com/alexcasalboni)
+* [Jon Beilke](https://github.com/jrbeilke)
diff --git a/moto/rds/models.py b/moto/rds/models.py
index 77deff09d..feecefe0c 100644
--- a/moto/rds/models.py
+++ b/moto/rds/models.py
@@ -48,6 +48,10 @@ class Database(BaseModel):
if self.publicly_accessible is None:
self.publicly_accessible = True
+ self.copy_tags_to_snapshot = kwargs.get("copy_tags_to_snapshot")
+ if self.copy_tags_to_snapshot is None:
+ self.copy_tags_to_snapshot = False
+
self.backup_retention_period = kwargs.get("backup_retention_period")
if self.backup_retention_period is None:
self.backup_retention_period = 1
@@ -137,6 +141,7 @@ class Database(BaseModel):
"multi_az": properties.get("MultiAZ"),
"port": properties.get('Port', 3306),
"publicly_accessible": properties.get("PubliclyAccessible"),
+ "copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"),
"region": region_name,
"security_groups": security_groups,
"storage_encrypted": properties.get("StorageEncrypted"),
@@ -217,6 +222,7 @@ class Database(BaseModel):
{% endif %}
{{ database.publicly_accessible }}
+ {{ database.copy_tags_to_snapshot }}
{{ database.auto_minor_version_upgrade }}
{{ database.allocated_storage }}
{{ database.storage_encrypted }}
diff --git a/moto/rds2/models.py b/moto/rds2/models.py
index 3fc4b6d65..fee004f76 100644
--- a/moto/rds2/models.py
+++ b/moto/rds2/models.py
@@ -73,6 +73,9 @@ class Database(BaseModel):
self.publicly_accessible = kwargs.get("publicly_accessible")
if self.publicly_accessible is None:
self.publicly_accessible = True
+ self.copy_tags_to_snapshot = kwargs.get("copy_tags_to_snapshot")
+ if self.copy_tags_to_snapshot is None:
+ self.copy_tags_to_snapshot = False
self.backup_retention_period = kwargs.get("backup_retention_period")
if self.backup_retention_period is None:
self.backup_retention_period = 1
@@ -208,6 +211,7 @@ class Database(BaseModel):
{% endif %}
{{ database.publicly_accessible }}
+ {{ database.copy_tags_to_snapshot }}
{{ database.auto_minor_version_upgrade }}
{{ database.allocated_storage }}
{{ database.storage_encrypted }}
@@ -304,6 +308,7 @@ class Database(BaseModel):
"db_parameter_group_name": properties.get('DBParameterGroupName'),
"port": properties.get('Port', 3306),
"publicly_accessible": properties.get("PubliclyAccessible"),
+ "copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"),
"region": region_name,
"security_groups": security_groups,
"storage_encrypted": properties.get("StorageEncrypted"),
@@ -362,6 +367,7 @@ class Database(BaseModel):
"PreferredBackupWindow": "{{ database.preferred_backup_window }}",
"PreferredMaintenanceWindow": "{{ database.preferred_maintenance_window }}",
"PubliclyAccessible": "{{ database.publicly_accessible }}",
+ "CopyTagsToSnapshot": "{{ database.copy_tags_to_snapshot }}",
"AllocatedStorage": "{{ database.allocated_storage }}",
"Endpoint": {
"Address": "{{ database.address }}",
@@ -411,10 +417,10 @@ class Database(BaseModel):
class Snapshot(BaseModel):
- def __init__(self, database, snapshot_id, tags=None):
+ def __init__(self, database, snapshot_id, tags):
self.database = database
self.snapshot_id = snapshot_id
- self.tags = tags or []
+ self.tags = tags
self.created_at = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
@property
@@ -456,6 +462,20 @@ class Snapshot(BaseModel):
""")
return template.render(snapshot=self, database=self.database)
+ def get_tags(self):
+ return self.tags
+
+ def add_tags(self, tags):
+ new_keys = [tag_set['Key'] for tag_set in tags]
+ self.tags = [tag_set for tag_set in self.tags if tag_set[
+ 'Key'] not in new_keys]
+ self.tags.extend(tags)
+ return self.tags
+
+ def remove_tags(self, tag_keys):
+ self.tags = [tag_set for tag_set in self.tags if tag_set[
+ 'Key'] not in tag_keys]
+
class SecurityGroup(BaseModel):
@@ -691,6 +711,10 @@ class RDS2Backend(BaseBackend):
raise DBSnapshotAlreadyExistsError(db_snapshot_identifier)
if len(self.snapshots) >= int(os.environ.get('MOTO_RDS_SNAPSHOT_LIMIT', '100')):
raise SnapshotQuotaExceededError()
+ if tags is None:
+ tags = list()
+ if database.copy_tags_to_snapshot and not tags:
+ tags = database.get_tags()
snapshot = Snapshot(database, db_snapshot_identifier, tags)
self.snapshots[db_snapshot_identifier] = snapshot
return snapshot
@@ -787,13 +811,13 @@ class RDS2Backend(BaseBackend):
def delete_database(self, db_instance_identifier, db_snapshot_name=None):
if db_instance_identifier in self.databases:
+ if db_snapshot_name:
+ self.create_snapshot(db_instance_identifier, db_snapshot_name)
database = self.databases.pop(db_instance_identifier)
if database.is_replica:
primary = self.find_db_from_id(database.source_db_identifier)
primary.remove_replica(database)
database.status = 'deleting'
- if db_snapshot_name:
- self.snapshots[db_snapshot_name] = Snapshot(database, db_snapshot_name)
return database
else:
raise DBInstanceNotFoundError(db_instance_identifier)
@@ -1028,8 +1052,8 @@ class RDS2Backend(BaseBackend):
if resource_name in self.security_groups:
return self.security_groups[resource_name].get_tags()
elif resource_type == 'snapshot': # DB Snapshot
- # TODO: Complete call to tags on resource type DB Snapshot
- return []
+ if resource_name in self.snapshots:
+ return self.snapshots[resource_name].get_tags()
elif resource_type == 'subgrp': # DB subnet group
if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].get_tags()
@@ -1059,7 +1083,8 @@ class RDS2Backend(BaseBackend):
if resource_name in self.security_groups:
return self.security_groups[resource_name].remove_tags(tag_keys)
elif resource_type == 'snapshot': # DB Snapshot
- return None
+ if resource_name in self.snapshots:
+ return self.snapshots[resource_name].remove_tags(tag_keys)
elif resource_type == 'subgrp': # DB subnet group
if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].remove_tags(tag_keys)
@@ -1088,7 +1113,8 @@ class RDS2Backend(BaseBackend):
if resource_name in self.security_groups:
return self.security_groups[resource_name].add_tags(tags)
elif resource_type == 'snapshot': # DB Snapshot
- return []
+ if resource_name in self.snapshots:
+ return self.snapshots[resource_name].add_tags(tags)
elif resource_type == 'subgrp': # DB subnet group
if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].add_tags(tags)
diff --git a/moto/rds2/responses.py b/moto/rds2/responses.py
index eddb0042b..66d4e0c52 100644
--- a/moto/rds2/responses.py
+++ b/moto/rds2/responses.py
@@ -19,6 +19,7 @@ class RDS2Response(BaseResponse):
"allocated_storage": self._get_int_param('AllocatedStorage'),
"availability_zone": self._get_param("AvailabilityZone"),
"backup_retention_period": self._get_param("BackupRetentionPeriod"),
+ "copy_tags_to_snapshot": self._get_param("CopyTagsToSnapshot"),
"db_instance_class": self._get_param('DBInstanceClass'),
"db_instance_identifier": self._get_param('DBInstanceIdentifier'),
"db_name": self._get_param("DBName"),
@@ -159,7 +160,7 @@ class RDS2Response(BaseResponse):
def create_db_snapshot(self):
db_instance_identifier = self._get_param('DBInstanceIdentifier')
db_snapshot_identifier = self._get_param('DBSnapshotIdentifier')
- tags = self._get_param('Tags', [])
+ tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
snapshot = self.backend.create_snapshot(db_instance_identifier, db_snapshot_identifier, tags)
template = self.response_template(CREATE_SNAPSHOT_TEMPLATE)
return template.render(snapshot=snapshot)
diff --git a/tests/test_rds2/test_rds2.py b/tests/test_rds2/test_rds2.py
index 80dcd4f53..cf9805444 100644
--- a/tests/test_rds2/test_rds2.py
+++ b/tests/test_rds2/test_rds2.py
@@ -33,6 +33,7 @@ def test_create_database():
db_instance['DBInstanceIdentifier'].should.equal("db-master-1")
db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False)
db_instance['DbiResourceId'].should.contain("db-")
+ db_instance['CopyTagsToSnapshot'].should.equal(False)
@mock_rds2
@@ -339,6 +340,49 @@ def test_create_db_snapshots():
snapshot.get('Engine').should.equal('postgres')
snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1')
snapshot.get('DBSnapshotIdentifier').should.equal('g-1')
+ result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshotArn'])
+ result['TagList'].should.equal([])
+
+
+@mock_rds2
+def test_create_db_snapshots_copy_tags():
+ conn = boto3.client('rds', region_name='us-west-2')
+ conn.create_db_snapshot.when.called_with(
+ DBInstanceIdentifier='db-primary-1',
+ DBSnapshotIdentifier='snapshot-1').should.throw(ClientError)
+
+ conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
+ AllocatedStorage=10,
+ Engine='postgres',
+ DBName='staging-postgres',
+ DBInstanceClass='db.m1.small',
+ MasterUsername='root',
+ MasterUserPassword='hunter2',
+ Port=1234,
+ DBSecurityGroups=["my_sg"],
+ CopyTagsToSnapshot=True,
+ Tags=[
+ {
+ 'Key': 'foo',
+ 'Value': 'bar',
+ },
+ {
+ 'Key': 'foo1',
+ 'Value': 'bar1',
+ },
+ ])
+
+ snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
+ DBSnapshotIdentifier='g-1').get('DBSnapshot')
+
+ snapshot.get('Engine').should.equal('postgres')
+ snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1')
+ snapshot.get('DBSnapshotIdentifier').should.equal('g-1')
+ result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshotArn'])
+ result['TagList'].should.equal([{'Value': 'bar',
+ 'Key': 'foo'},
+ {'Value': 'bar1',
+ 'Key': 'foo1'}])
@mock_rds2
@@ -656,6 +700,117 @@ def test_remove_tags_db():
len(result['TagList']).should.equal(1)
+@mock_rds2
+def test_list_tags_snapshot():
+ conn = boto3.client('rds', region_name='us-west-2')
+ result = conn.list_tags_for_resource(
+ ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:foo')
+ result['TagList'].should.equal([])
+ conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
+ AllocatedStorage=10,
+ Engine='postgres',
+ DBName='staging-postgres',
+ DBInstanceClass='db.m1.small',
+ MasterUsername='root',
+ MasterUserPassword='hunter2',
+ Port=1234,
+ DBSecurityGroups=["my_sg"])
+ snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
+ DBSnapshotIdentifier='snapshot-with-tags',
+ Tags=[
+ {
+ 'Key': 'foo',
+ 'Value': 'bar',
+ },
+ {
+ 'Key': 'foo1',
+ 'Value': 'bar1',
+ },
+ ])
+ result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshot']['DBSnapshotArn'])
+ result['TagList'].should.equal([{'Value': 'bar',
+ 'Key': 'foo'},
+ {'Value': 'bar1',
+ 'Key': 'foo1'}])
+
+
+@mock_rds2
+def test_add_tags_snapshot():
+ conn = boto3.client('rds', region_name='us-west-2')
+ conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
+ AllocatedStorage=10,
+ Engine='postgres',
+ DBName='staging-postgres',
+ DBInstanceClass='db.m1.small',
+ MasterUsername='root',
+ MasterUserPassword='hunter2',
+ Port=1234,
+ DBSecurityGroups=["my_sg"])
+ snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
+ DBSnapshotIdentifier='snapshot-without-tags',
+ Tags=[
+ {
+ 'Key': 'foo',
+ 'Value': 'bar',
+ },
+ {
+ 'Key': 'foo1',
+ 'Value': 'bar1',
+ },
+ ])
+ result = conn.list_tags_for_resource(
+ ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags')
+ list(result['TagList']).should.have.length_of(2)
+ conn.add_tags_to_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags',
+ Tags=[
+ {
+ 'Key': 'foo',
+ 'Value': 'fish',
+ },
+ {
+ 'Key': 'foo2',
+ 'Value': 'bar2',
+ },
+ ])
+ result = conn.list_tags_for_resource(
+ ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags')
+ list(result['TagList']).should.have.length_of(3)
+
+
+@mock_rds2
+def test_remove_tags_snapshot():
+ conn = boto3.client('rds', region_name='us-west-2')
+ conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
+ AllocatedStorage=10,
+ Engine='postgres',
+ DBName='staging-postgres',
+ DBInstanceClass='db.m1.small',
+ MasterUsername='root',
+ MasterUserPassword='hunter2',
+ Port=1234,
+ DBSecurityGroups=["my_sg"])
+ snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
+ DBSnapshotIdentifier='snapshot-with-tags',
+ Tags=[
+ {
+ 'Key': 'foo',
+ 'Value': 'bar',
+ },
+ {
+ 'Key': 'foo1',
+ 'Value': 'bar1',
+ },
+ ])
+ result = conn.list_tags_for_resource(
+ ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags')
+ list(result['TagList']).should.have.length_of(2)
+ conn.remove_tags_from_resource(
+ ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags', TagKeys=['foo'])
+ result = conn.list_tags_for_resource(
+ ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags')
+ len(result['TagList']).should.equal(1)
+
+
@mock_rds2
def test_add_tags_option_group():
conn = boto3.client('rds', region_name='us-west-2')