Add multi-region support to EC2 Instances

This commit is contained in:
Hugo Lopes Tavares 2014-08-26 16:42:08 -04:00
parent 348d1803ed
commit 82eef28937
8 changed files with 92 additions and 24 deletions

View File

@ -9,8 +9,8 @@ from .utils import convert_regex_to_flask_path
class MockAWS(object): class MockAWS(object):
nested_count = 0 nested_count = 0
def __init__(self, backend): def __init__(self, backends):
self.backend = backend self.backends = backends
if self.__class__.nested_count == 0: if self.__class__.nested_count == 0:
HTTPretty.reset() HTTPretty.reset()
@ -26,13 +26,15 @@ class MockAWS(object):
def start(self): def start(self):
self.__class__.nested_count += 1 self.__class__.nested_count += 1
self.backend.reset() for backend in self.backends.values():
backend.reset()
if not HTTPretty.is_enabled(): if not HTTPretty.is_enabled():
HTTPretty.enable() HTTPretty.enable()
for method in HTTPretty.METHODS: for method in HTTPretty.METHODS:
for key, value in self.backend.urls.iteritems(): backend = self.backends.values()[0]
for key, value in backend.urls.iteritems():
HTTPretty.register_uri( HTTPretty.register_uri(
method=method, method=method,
uri=re.compile(key), uri=re.compile(key),
@ -151,6 +153,6 @@ class BaseBackend(object):
def decorator(self, func=None): def decorator(self, func=None):
if func: if func:
return MockAWS(self)(func) return MockAWS({'global': self})(func)
else: else:
return MockAWS(self) return MockAWS({'global': self})

View File

@ -1,5 +1,6 @@
import datetime import datetime
import json import json
import re
from urlparse import parse_qs, urlparse from urlparse import parse_qs, urlparse
@ -9,6 +10,8 @@ from moto.core.utils import camelcase_to_underscores, method_names_from_class
class BaseResponse(object): class BaseResponse(object):
region = 'us-east-1'
def dispatch(self, request, full_url, headers): def dispatch(self, request, full_url, headers):
querystring = {} querystring = {}
@ -38,6 +41,9 @@ class BaseResponse(object):
self.path = urlparse(full_url).path self.path = urlparse(full_url).path
self.querystring = querystring self.querystring = querystring
self.method = request.method self.method = request.method
region = re.search(r'\.(.+?)\.amazonaws\.com', full_url)
if region:
self.region = region.group(1)
self.headers = dict(request.headers) self.headers = dict(request.headers)
self.response_headers = headers self.response_headers = headers

View File

@ -1,2 +1,8 @@
from .models import ec2_backend from .models import ec2_backends, ec2_backend
mock_ec2 = ec2_backend.decorator from ..core.models import MockAWS
def mock_ec2(func=None):
if func:
return MockAWS(ec2_backends)(func)
else:
return MockAWS(ec2_backends)

View File

@ -2,6 +2,7 @@ import copy
import itertools import itertools
from collections import defaultdict from collections import defaultdict
import boto
from boto.ec2.instance import Instance as BotoInstance, Reservation from boto.ec2.instance import Instance as BotoInstance, Reservation
from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType
from boto.ec2.spotinstancerequest import SpotInstanceRequest as BotoSpotRequest from boto.ec2.spotinstancerequest import SpotInstanceRequest as BotoSpotRequest
@ -1360,4 +1361,8 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, AmiBackend,
raise EC2ClientError(code, message) raise EC2ClientError(code, message)
ec2_backend = EC2Backend() ec2_backends = {}
for region_name in boto.regioninfo.load_regions()['ec2'].keys():
ec2_backends[region_name] = EC2Backend()
ec2_backend = ec2_backends['us-east-1']

View File

@ -60,4 +60,7 @@ class EC2Response(
VPNConnections, VPNConnections,
Windows, Windows,
): ):
pass @property
def ec2_backend(self):
from moto.ec2.models import ec2_backends
return ec2_backends[self.region]

View File

@ -2,7 +2,6 @@ from jinja2 import Template
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
from moto.ec2.models import ec2_backend
from moto.ec2.utils import instance_ids_from_querystring, filters_from_querystring, filter_reservations from moto.ec2.utils import instance_ids_from_querystring, filters_from_querystring, filter_reservations
@ -10,9 +9,9 @@ class InstanceResponse(BaseResponse):
def describe_instances(self): def describe_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
if instance_ids: if instance_ids:
reservations = ec2_backend.get_reservations_by_instance_ids(instance_ids) reservations = self.ec2_backend.get_reservations_by_instance_ids(instance_ids)
else: else:
reservations = ec2_backend.all_reservations(make_copy=True) reservations = self.ec2_backend.all_reservations(make_copy=True)
filter_dict = filters_from_querystring(self.querystring) filter_dict = filters_from_querystring(self.querystring)
reservations = filter_reservations(reservations, filter_dict) reservations = filter_reservations(reservations, filter_dict)
@ -29,7 +28,7 @@ class InstanceResponse(BaseResponse):
instance_type = self.querystring.get("InstanceType", ["m1.small"])[0] instance_type = self.querystring.get("InstanceType", ["m1.small"])[0]
subnet_id = self.querystring.get("SubnetId", [None])[0] subnet_id = self.querystring.get("SubnetId", [None])[0]
key_name = self.querystring.get("KeyName", [None])[0] key_name = self.querystring.get("KeyName", [None])[0]
new_reservation = ec2_backend.add_instances( new_reservation = self.ec2_backend.add_instances(
image_id, min_count, user_data, security_group_names, image_id, min_count, user_data, security_group_names,
instance_type=instance_type, subnet_id=subnet_id, instance_type=instance_type, subnet_id=subnet_id,
key_name=key_name, security_group_ids=security_group_ids) key_name=key_name, security_group_ids=security_group_ids)
@ -38,25 +37,25 @@ class InstanceResponse(BaseResponse):
def terminate_instances(self): def terminate_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instances = ec2_backend.terminate_instances(instance_ids) instances = self.ec2_backend.terminate_instances(instance_ids)
template = Template(EC2_TERMINATE_INSTANCES) template = Template(EC2_TERMINATE_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
def reboot_instances(self): def reboot_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instances = ec2_backend.reboot_instances(instance_ids) instances = self.ec2_backend.reboot_instances(instance_ids)
template = Template(EC2_REBOOT_INSTANCES) template = Template(EC2_REBOOT_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
def stop_instances(self): def stop_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instances = ec2_backend.stop_instances(instance_ids) instances = self.ec2_backend.stop_instances(instance_ids)
template = Template(EC2_STOP_INSTANCES) template = Template(EC2_STOP_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
def start_instances(self): def start_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instances = ec2_backend.start_instances(instance_ids) instances = self.ec2_backend.start_instances(instance_ids)
template = Template(EC2_START_INSTANCES) template = Template(EC2_START_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
@ -64,9 +63,9 @@ class InstanceResponse(BaseResponse):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
if instance_ids: if instance_ids:
instances = ec2_backend.get_multi_instances_by_id(instance_ids) instances = self.ec2_backend.get_multi_instances_by_id(instance_ids)
else: else:
instances = ec2_backend.all_instances() instances = self.ec2_backend.all_instances()
template = Template(EC2_INSTANCE_STATUS) template = Template(EC2_INSTANCE_STATUS)
return template.render(instances=instances) return template.render(instances=instances)
@ -78,7 +77,7 @@ class InstanceResponse(BaseResponse):
key = camelcase_to_underscores(attribute) key = camelcase_to_underscores(attribute)
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0] instance_id = instance_ids[0]
instance, value = ec2_backend.describe_instance_attribute(instance_id, key) instance, value = self.ec2_backend.describe_instance_attribute(instance_id, key)
template = Template(EC2_DESCRIBE_INSTANCE_ATTRIBUTE) template = Template(EC2_DESCRIBE_INSTANCE_ATTRIBUTE)
return template.render(instance=instance, attribute=attribute, value=value) return template.render(instance=instance, attribute=attribute, value=value)
@ -126,7 +125,7 @@ class InstanceResponse(BaseResponse):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0] instance_id = instance_ids[0]
instance = ec2_backend.get_instance(instance_id) instance = self.ec2_backend.get_instance(instance_id)
block_device_type = instance.block_device_mapping[device_name_value] block_device_type = instance.block_device_mapping[device_name_value]
block_device_type.delete_on_termination = del_on_term_value block_device_type.delete_on_termination = del_on_term_value
@ -151,7 +150,7 @@ class InstanceResponse(BaseResponse):
normalized_attribute = camelcase_to_underscores(attribute_key.split(".")[0]) normalized_attribute = camelcase_to_underscores(attribute_key.split(".")[0])
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0] instance_id = instance_ids[0]
ec2_backend.modify_instance_attribute(instance_id, normalized_attribute, value) self.ec2_backend.modify_instance_attribute(instance_id, normalized_attribute, value)
return EC2_MODIFY_INSTANCE_ATTRIBUTE return EC2_MODIFY_INSTANCE_ATTRIBUTE

View File

@ -0,0 +1,44 @@
import boto.ec2
import sure
from moto import mock_ec2
def add_servers_to_region(ami_id, count, region):
conn = boto.ec2.connect_to_region(region)
for index in range(count):
conn.run_instances(ami_id)
@mock_ec2
def test_add_servers_to_a_single_region():
region = 'ap-northeast-1'
add_servers_to_region('ami-1234abcd', 1, region)
add_servers_to_region('ami-5678efgh', 1, region)
conn = boto.ec2.connect_to_region(region)
instances = conn.get_only_instances()
len(instances).should.equal(2)
instances.sort(key=lambda x: x.image_id)
instances[0].image_id.should.equal('ami-1234abcd')
instances[1].image_id.should.equal('ami-5678efgh')
@mock_ec2
def test_add_servers_to_multiple_regions():
region1 = 'us-east-1'
region2 = 'ap-northeast-1'
add_servers_to_region('ami-1234abcd', 1, region1)
add_servers_to_region('ami-5678efgh', 1, region2)
us_conn = boto.ec2.connect_to_region(region1)
ap_conn = boto.ec2.connect_to_region(region2)
us_instances = us_conn.get_only_instances()
ap_instances = ap_conn.get_only_instances()
len(us_instances).should.equal(1)
len(ap_instances).should.equal(1)
us_instances[0].image_id.should.equal('ami-1234abcd')
ap_instances[0].image_id.should.equal('ami-5678efgh')

View File

@ -12,7 +12,10 @@ def test_ec2_server_get():
backend = server.create_backend_app("ec2") backend = server.create_backend_app("ec2")
test_client = backend.test_client() test_client = backend.test_client()
res = test_client.get('/?Action=RunInstances&ImageId=ami-60a54009') res = test_client.get(
'/?Action=RunInstances&ImageId=ami-60a54009',
headers={"Host": "ec2.us-east-1.amazonaws.com"}
)
groups = re.search("<instanceId>(.*)</instanceId>", res.data) groups = re.search("<instanceId>(.*)</instanceId>", res.data)
instance_id = groups.groups()[0] instance_id = groups.groups()[0]