update RDS models to include CopyTagsToSnapshot

This commit is contained in:
Jon Beilke 2018-09-21 08:31:31 -05:00
parent e8c65d3d85
commit 881afc8f4a
4 changed files with 59 additions and 0 deletions

View File

@ -53,3 +53,4 @@ Moto is written by Steve Pulec with contributions from:
* [Jim Shields](https://github.com/jimjshields) * [Jim Shields](https://github.com/jimjshields)
* [William Richard](https://github.com/william-richard) * [William Richard](https://github.com/william-richard)
* [Alex Casalboni](https://github.com/alexcasalboni) * [Alex Casalboni](https://github.com/alexcasalboni)
* [Jon Beilke](https://github.com/jrbeilke)

View File

@ -48,6 +48,10 @@ class Database(BaseModel):
if self.publicly_accessible is None: if self.publicly_accessible is None:
self.publicly_accessible = True self.publicly_accessible = True
self.copy_tags_to_snapshot = kwargs.get("copy_tags_to_snapshot")
if self.copy_tags_to_snapshot is None:
self.copy_tags_to_snapshot = False
self.backup_retention_period = kwargs.get("backup_retention_period") self.backup_retention_period = kwargs.get("backup_retention_period")
if self.backup_retention_period is None: if self.backup_retention_period is None:
self.backup_retention_period = 1 self.backup_retention_period = 1
@ -137,6 +141,7 @@ class Database(BaseModel):
"multi_az": properties.get("MultiAZ"), "multi_az": properties.get("MultiAZ"),
"port": properties.get('Port', 3306), "port": properties.get('Port', 3306),
"publicly_accessible": properties.get("PubliclyAccessible"), "publicly_accessible": properties.get("PubliclyAccessible"),
"copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"),
"region": region_name, "region": region_name,
"security_groups": security_groups, "security_groups": security_groups,
"storage_encrypted": properties.get("StorageEncrypted"), "storage_encrypted": properties.get("StorageEncrypted"),
@ -217,6 +222,7 @@ class Database(BaseModel):
</DBSubnetGroup> </DBSubnetGroup>
{% endif %} {% endif %}
<PubliclyAccessible>{{ database.publicly_accessible }}</PubliclyAccessible> <PubliclyAccessible>{{ database.publicly_accessible }}</PubliclyAccessible>
<CopyTagsToSnapshot>{{ database.copy_tags_to_snapshot }}</CopyTagsToSnapshot>
<AutoMinorVersionUpgrade>{{ database.auto_minor_version_upgrade }}</AutoMinorVersionUpgrade> <AutoMinorVersionUpgrade>{{ database.auto_minor_version_upgrade }}</AutoMinorVersionUpgrade>
<AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage> <AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage>
<StorageEncrypted>{{ database.storage_encrypted }}</StorageEncrypted> <StorageEncrypted>{{ database.storage_encrypted }}</StorageEncrypted>

View File

@ -73,6 +73,9 @@ class Database(BaseModel):
self.publicly_accessible = kwargs.get("publicly_accessible") self.publicly_accessible = kwargs.get("publicly_accessible")
if self.publicly_accessible is None: if self.publicly_accessible is None:
self.publicly_accessible = True self.publicly_accessible = True
self.copy_tags_to_snapshot = kwargs.get("copy_tags_to_snapshot")
if self.copy_tags_to_snapshot is None:
self.copy_tags_to_snapshot = False
self.backup_retention_period = kwargs.get("backup_retention_period") self.backup_retention_period = kwargs.get("backup_retention_period")
if self.backup_retention_period is None: if self.backup_retention_period is None:
self.backup_retention_period = 1 self.backup_retention_period = 1
@ -208,6 +211,7 @@ class Database(BaseModel):
</DBSubnetGroup> </DBSubnetGroup>
{% endif %} {% endif %}
<PubliclyAccessible>{{ database.publicly_accessible }}</PubliclyAccessible> <PubliclyAccessible>{{ database.publicly_accessible }}</PubliclyAccessible>
<CopyTagsToSnapshot>{{ database.copy_tags_to_snapshot }}</CopyTagsToSnapshot>
<AutoMinorVersionUpgrade>{{ database.auto_minor_version_upgrade }}</AutoMinorVersionUpgrade> <AutoMinorVersionUpgrade>{{ database.auto_minor_version_upgrade }}</AutoMinorVersionUpgrade>
<AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage> <AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage>
<StorageEncrypted>{{ database.storage_encrypted }}</StorageEncrypted> <StorageEncrypted>{{ database.storage_encrypted }}</StorageEncrypted>
@ -304,6 +308,7 @@ class Database(BaseModel):
"db_parameter_group_name": properties.get('DBParameterGroupName'), "db_parameter_group_name": properties.get('DBParameterGroupName'),
"port": properties.get('Port', 3306), "port": properties.get('Port', 3306),
"publicly_accessible": properties.get("PubliclyAccessible"), "publicly_accessible": properties.get("PubliclyAccessible"),
"copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"),
"region": region_name, "region": region_name,
"security_groups": security_groups, "security_groups": security_groups,
"storage_encrypted": properties.get("StorageEncrypted"), "storage_encrypted": properties.get("StorageEncrypted"),
@ -362,6 +367,7 @@ class Database(BaseModel):
"PreferredBackupWindow": "{{ database.preferred_backup_window }}", "PreferredBackupWindow": "{{ database.preferred_backup_window }}",
"PreferredMaintenanceWindow": "{{ database.preferred_maintenance_window }}", "PreferredMaintenanceWindow": "{{ database.preferred_maintenance_window }}",
"PubliclyAccessible": "{{ database.publicly_accessible }}", "PubliclyAccessible": "{{ database.publicly_accessible }}",
"CopyTagsToSnapshot": "{{ database.copy_tags_to_snapshot }}",
"AllocatedStorage": "{{ database.allocated_storage }}", "AllocatedStorage": "{{ database.allocated_storage }}",
"Endpoint": { "Endpoint": {
"Address": "{{ database.address }}", "Address": "{{ database.address }}",
@ -691,6 +697,8 @@ class RDS2Backend(BaseBackend):
raise DBSnapshotAlreadyExistsError(db_snapshot_identifier) raise DBSnapshotAlreadyExistsError(db_snapshot_identifier)
if len(self.snapshots) >= int(os.environ.get('MOTO_RDS_SNAPSHOT_LIMIT', '100')): if len(self.snapshots) >= int(os.environ.get('MOTO_RDS_SNAPSHOT_LIMIT', '100')):
raise SnapshotQuotaExceededError() raise SnapshotQuotaExceededError()
if not database.copy_tags_to_snapshot:
tags = None
snapshot = Snapshot(database, db_snapshot_identifier, tags) snapshot = Snapshot(database, db_snapshot_identifier, tags)
self.snapshots[db_snapshot_identifier] = snapshot self.snapshots[db_snapshot_identifier] = snapshot
return snapshot return snapshot

View File

@ -33,6 +33,7 @@ def test_create_database():
db_instance['DBInstanceIdentifier'].should.equal("db-master-1") db_instance['DBInstanceIdentifier'].should.equal("db-master-1")
db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False) db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False)
db_instance['DbiResourceId'].should.contain("db-") db_instance['DbiResourceId'].should.contain("db-")
db_instance['CopyTagsToSnapshot'].should.equal(False)
@mock_rds2 @mock_rds2
@ -339,6 +340,49 @@ def test_create_db_snapshots():
snapshot.get('Engine').should.equal('postgres') snapshot.get('Engine').should.equal('postgres')
snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1') snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1')
snapshot.get('DBSnapshotIdentifier').should.equal('g-1') snapshot.get('DBSnapshotIdentifier').should.equal('g-1')
result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshot']['DBSnapshotArn'])
result['TagList'].should.equal([])
@mock_rds2
def test_create_db_snapshots_copy_tags():
conn = boto3.client('rds', region_name='us-west-2')
conn.create_db_snapshot.when.called_with(
DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='snapshot-1').should.throw(ClientError)
conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
AllocatedStorage=10,
Engine='postgres',
DBName='staging-postgres',
DBInstanceClass='db.m1.small',
MasterUsername='root',
MasterUserPassword='hunter2',
Port=1234,
DBSecurityGroups=["my_sg"],
CopyTagsToSnapshot=True,
Tags=[
{
'Key': 'foo',
'Value': 'bar',
},
{
'Key': 'foo1',
'Value': 'bar1',
},
])
snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='g-1').get('DBSnapshot')
snapshot.get('Engine').should.equal('postgres')
snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1')
snapshot.get('DBSnapshotIdentifier').should.equal('g-1')
result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshot']['DBSnapshotArn'])
result['TagList'].should.equal([{'Value': 'bar',
'Key': 'foo'},
{'Value': 'bar1',
'Key': 'foo1'}])
@mock_rds2 @mock_rds2