diff --git a/moto/rds/exceptions.py b/moto/rds/exceptions.py index 487162a8a..518ff401d 100644 --- a/moto/rds/exceptions.py +++ b/moto/rds/exceptions.py @@ -22,3 +22,10 @@ class DBInstanceNotFoundError(RDSClientError): super(DBInstanceNotFoundError, self).__init__( 'DBInstanceNotFound', "Database {0} not found.".format(database_identifier)) + + +class DBSecurityGroupNotFoundError(RDSClientError): + def __init__(self, security_group_name): + super(DBSecurityGroupNotFoundError, self).__init__( + 'DBSecurityGroupNotFound', + "Security Group {0} not found.".format(security_group_name)) diff --git a/moto/rds/models.py b/moto/rds/models.py index 27f4d10aa..06ada38c4 100644 --- a/moto/rds/models.py +++ b/moto/rds/models.py @@ -4,7 +4,7 @@ import boto.rds from jinja2 import Template from moto.core import BaseBackend -from .exceptions import DBInstanceNotFoundError +from .exceptions import DBInstanceNotFoundError, DBSecurityGroupNotFoundError class Database(object): @@ -35,18 +35,24 @@ class Database(object): self.multi_az = kwargs.get("multi_az") self.db_subnet_group_name = kwargs.get("db_subnet_group_name") + self.security_groups = kwargs.get('security_groups', []) + # PreferredBackupWindow # PreferredMaintenanceWindow # backup_retention_period = self._get_param("BackupRetentionPeriod") # OptionGroupName # DBParameterGroupName - # DBSecurityGroups.member.N # VpcSecurityGroupIds.member.N @property def address(self): return "{}.aaaaaaaaaa.{}.rds.amazonaws.com".format(self.db_instance_identifier, self.region) + def update(self, db_kwargs): + for key, value in db_kwargs.items(): + if value is not None: + setattr(self, key, value) + def to_xml(self): template = Template(""" {{ database.backup_retention_period }} @@ -65,10 +71,12 @@ class Database(object): + {% for security_group in database.security_groups %} active - default + {{ security_group }} + {% endfor %} {{ database.publicly_accessible }} {{ database.auto_minor_version_upgrade }} @@ -83,10 +91,38 @@ class Database(object): return template.render(database=self) +class SecurityGroup(object): + def __init__(self, group_name, description): + self.group_name = group_name + self.description = description + self.ip_ranges = [] + + def to_xml(self): + template = Template(""" + + {{ security_group.description }} + + {% for ip_range in security_group.ip_ranges %} + + {{ ip_range }} + authorized + + {% endfor %} + + {{ security_group.ownder_id }} + {{ security_group.group_name }} + """) + return template.render(security_group=self) + + def authorize(self, cidr_ip): + self.ip_ranges.append(cidr_ip) + + class RDSBackend(BaseBackend): def __init__(self): self.databases = {} + self.security_groups = {} def create_database(self, db_kwargs): database_id = db_kwargs['db_instance_identifier'] @@ -102,12 +138,40 @@ class RDSBackend(BaseBackend): raise DBInstanceNotFoundError(db_instance_identifier) return self.databases.values() + def modify_database(self, db_instance_identifier, db_kwargs): + database = self.describe_databases(db_instance_identifier)[0] + database.update(db_kwargs) + return database + def delete_database(self, db_instance_identifier): if db_instance_identifier in self.databases: return self.databases.pop(db_instance_identifier) else: raise DBInstanceNotFoundError(db_instance_identifier) + def create_security_group(self, group_name, description): + security_group = SecurityGroup(group_name, description) + self.security_groups[group_name] = security_group + return security_group + + def describe_security_groups(self, security_group_name): + if security_group_name: + if security_group_name in self.security_groups: + return [self.security_groups[security_group_name]] + else: + raise DBSecurityGroupNotFoundError(security_group_name) + return self.security_groups.values() + + def delete_security_group(self, security_group_name): + if security_group_name in self.security_groups: + return self.security_groups.pop(security_group_name) + else: + raise DBSecurityGroupNotFoundError(security_group_name) + + def authorize_security_group(self, security_group_name, cidr_ip): + security_group = self.describe_security_groups(security_group_name)[0] + security_group.authorize(cidr_ip) + return security_group rds_backends = {} for region in boto.rds.regions(): diff --git a/moto/rds/responses.py b/moto/rds/responses.py index c6ed9707b..4631ff6b0 100644 --- a/moto/rds/responses.py +++ b/moto/rds/responses.py @@ -10,8 +10,8 @@ class RDSResponse(BaseResponse): def backend(self): return rds_backends[self.region] - def create_dbinstance(self): - db_kwargs = { + def _get_db_kwargs(self): + return { "engine": self._get_param("Engine"), "engine_version": self._get_param("EngineVersion"), "region": self.region, @@ -34,7 +34,7 @@ class RDSResponse(BaseResponse): # OptionGroupName # DBParameterGroupName - # DBSecurityGroups.member.N + "security_groups": self._get_multi_param('DBSecurityGroups.member'), # VpcSecurityGroupIds.member.N "availability_zone": self._get_param("AvailabilityZone"), @@ -42,6 +42,9 @@ class RDSResponse(BaseResponse): "db_subnet_group_name": self._get_param("DBSubnetGroupName"), } + def create_dbinstance(self): + db_kwargs = self._get_db_kwargs() + database = self.backend.create_database(db_kwargs) template = self.response_template(CREATE_DATABASE_TEMPLATE) return template.render(database=database) @@ -52,12 +55,45 @@ class RDSResponse(BaseResponse): template = self.response_template(DESCRIBE_DATABASES_TEMPLATE) return template.render(databases=databases) + def modify_dbinstance(self): + db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_kwargs = self._get_db_kwargs() + database = self.backend.modify_database(db_instance_identifier, db_kwargs) + template = self.response_template(MODIFY_DATABASE_TEMPLATE) + return template.render(database=database) + def delete_dbinstance(self): db_instance_identifier = self._get_param('DBInstanceIdentifier') database = self.backend.delete_database(db_instance_identifier) template = self.response_template(DELETE_DATABASE_TEMPLATE) return template.render(database=database) + def create_dbsecurity_group(self): + group_name = self._get_param('DBSecurityGroupName') + description = self._get_param('DBSecurityGroupDescription') + security_group = self.backend.create_security_group(group_name, description) + template = self.response_template(CREATE_SECURITY_GROUP_TEMPLATE) + return template.render(security_group=security_group) + + def describe_dbsecurity_groups(self): + security_group_name = self._get_param('DBSecurityGroupName') + security_groups = self.backend.describe_security_groups(security_group_name) + template = self.response_template(DESCRIBE_SECURITY_GROUPS_TEMPLATE) + return template.render(security_groups=security_groups) + + def delete_dbsecurity_group(self): + security_group_name = self._get_param('DBSecurityGroupName') + security_group = self.backend.delete_security_group(security_group_name) + template = self.response_template(DELETE_SECURITY_GROUP_TEMPLATE) + return template.render(security_group=security_group) + + def authorize_dbsecurity_group_ingress(self): + security_group_name = self._get_param('DBSecurityGroupName') + cidr_ip = self._get_param('CIDRIP') + security_group = self.backend.authorize_security_group(security_group_name, cidr_ip) + template = self.response_template(AUTHORIZE_SECURITY_GROUP_TEMPLATE) + return template.render(security_group=security_group) + CREATE_DATABASE_TEMPLATE = """ @@ -81,6 +117,15 @@ DESCRIBE_DATABASES_TEMPLATE = """ + + {{ database.to_xml() }} + + + f643f1ac-bbfe-11d3-f4c6-37db295f7674 + +""" + DELETE_DATABASE_TEMPLATE = """ {{ database.to_xml() }} @@ -89,3 +134,40 @@ DELETE_DATABASE_TEMPLATE = """ + + {{ security_group.to_xml() }} + + + e68ef6fa-afc1-11c3-845a-476777009d19 + +""" + +DESCRIBE_SECURITY_GROUPS_TEMPLATE = """ + + + {% for security_group in security_groups %} + {{ security_group.to_xml() }} + {% endfor %} + + + + b76e692c-b98c-11d3-a907-5a2c468b9cb0 + +""" + +DELETE_SECURITY_GROUP_TEMPLATE = """ + + 7aec7454-ba25-11d3-855b-576787000e19 + +""" + +AUTHORIZE_SECURITY_GROUP_TEMPLATE = """ + + {{ security_group.to_xml() }} + + + 6176b5f8-bfed-11d3-f92b-31fa5e8dbc99 + +""" diff --git a/tests/test_rds/test_rds.py b/tests/test_rds/test_rds.py index 518f698c4..1df5bf9ad 100644 --- a/tests/test_rds/test_rds.py +++ b/tests/test_rds/test_rds.py @@ -11,7 +11,8 @@ from moto import mock_rds def test_create_database(): conn = boto.rds.connect_to_region("us-west-2") - database = conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2') + database = conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2', + security_groups=["my_sg"]) database.status.should.equal('available') database.id.should.equal("db-master-1") @@ -19,6 +20,7 @@ def test_create_database(): database.instance_class.should.equal("db.m1.small") database.master_username.should.equal("root") database.endpoint.should.equal(('db-master-1.aaaaaaaaaa.us-west-2.rds.amazonaws.com', 3306)) + database.security_groups[0].name.should.equal('my_sg') @mock_rds @@ -60,3 +62,79 @@ def test_delete_database(): def test_delete_non_existant_database(): conn = boto.rds.connect_to_region("us-west-2") conn.delete_dbinstance.when.called_with("not-a-db").should.throw(BotoServerError) + + +@mock_rds +def test_create_database_security_group(): + conn = boto.rds.connect_to_region("us-west-2") + + security_group = conn.create_dbsecurity_group('db_sg', 'DB Security Group') + security_group.name.should.equal('db_sg') + security_group.description.should.equal("DB Security Group") + list(security_group.ip_ranges).should.equal([]) + + +@mock_rds +def test_get_security_groups(): + conn = boto.rds.connect_to_region("us-west-2") + + list(conn.get_all_dbsecurity_groups()).should.have.length_of(0) + + conn.create_dbsecurity_group('db_sg1', 'DB Security Group') + conn.create_dbsecurity_group('db_sg2', 'DB Security Group') + + list(conn.get_all_dbsecurity_groups()).should.have.length_of(2) + + databases = conn.get_all_dbsecurity_groups("db_sg1") + list(databases).should.have.length_of(1) + + databases[0].name.should.equal("db_sg1") + + +@mock_rds +def test_get_non_existant_security_group(): + conn = boto.rds.connect_to_region("us-west-2") + conn.get_all_dbsecurity_groups.when.called_with("not-a-sg").should.throw(BotoServerError) + + +@mock_rds +def test_delete_database_security_group(): + conn = boto.rds.connect_to_region("us-west-2") + conn.create_dbsecurity_group('db_sg', 'DB Security Group') + + list(conn.get_all_dbsecurity_groups()).should.have.length_of(1) + + conn.delete_dbsecurity_group("db_sg") + list(conn.get_all_dbsecurity_groups()).should.have.length_of(0) + + +@mock_rds +def test_delete_non_existant_security_group(): + conn = boto.rds.connect_to_region("us-west-2") + conn.delete_dbsecurity_group.when.called_with("not-a-db").should.throw(BotoServerError) + + +@mock_rds +def test_security_group_authorize(): + conn = boto.rds.connect_to_region("us-west-2") + security_group = conn.create_dbsecurity_group('db_sg', 'DB Security Group') + list(security_group.ip_ranges).should.equal([]) + + security_group.authorize(cidr_ip='10.3.2.45/32') + security_group = conn.get_all_dbsecurity_groups()[0] + list(security_group.ip_ranges).should.have.length_of(1) + security_group.ip_ranges[0].cidr_ip.should.equal('10.3.2.45/32') + + +@mock_rds +def test_add_security_group_to_database(): + conn = boto.rds.connect_to_region("us-west-2") + + database = conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2') + security_group = conn.create_dbsecurity_group('db_sg', 'DB Security Group') + database.modify(security_groups=[security_group]) + + database = conn.get_all_dbinstances()[0] + list(database.security_groups).should.have.length_of(1) + + database.security_groups[0].name.should.equal("db_sg")