Consistent _get_multi_param() function in responses

This abstracts _get_multi_param() into BaseResponse and makes it
always ensure that the string it has been given ends with a '.'.  It
had been implemented in three different places, and in use it rarely
postpended a trailing period, which could make it match parameters it
shouldn't have.
This commit is contained in:
Chris St. Pierre 2014-05-08 10:41:28 -04:00
parent 745368242e
commit fab37942c4
4 changed files with 10 additions and 16 deletions

View File

@ -6,18 +6,11 @@ from .models import autoscaling_backend
class AutoScalingResponse(BaseResponse): class AutoScalingResponse(BaseResponse):
def _get_param(self, param_name):
return self.querystring.get(param_name, [None])[0]
def _get_int_param(self, param_name): def _get_int_param(self, param_name):
value = self._get_param(param_name) value = self._get_param(param_name)
if value is not None: if value is not None:
return int(value) return int(value)
def _get_multi_param(self, param_prefix):
return [value[0] for key, value in self.querystring.items() if key.startswith(param_prefix)]
def _get_list_prefix(self, param_prefix): def _get_list_prefix(self, param_prefix):
results = [] results = []
param_index = 1 param_index = 1
@ -43,7 +36,7 @@ class AutoScalingResponse(BaseResponse):
name=self._get_param('LaunchConfigurationName'), name=self._get_param('LaunchConfigurationName'),
image_id=self._get_param('ImageId'), image_id=self._get_param('ImageId'),
key_name=self._get_param('KeyName'), key_name=self._get_param('KeyName'),
security_groups=self._get_multi_param('SecurityGroups.member.'), security_groups=self._get_multi_param('SecurityGroups.member'),
user_data=self._get_param('UserData'), user_data=self._get_param('UserData'),
instance_type=self._get_param('InstanceType'), instance_type=self._get_param('InstanceType'),
instance_monitoring=instance_monitoring, instance_monitoring=instance_monitoring,

View File

@ -66,6 +66,14 @@ class BaseResponse(object):
def _get_param(self, param_name): def _get_param(self, param_name):
return self.querystring.get(param_name, [None])[0] return self.querystring.get(param_name, [None])[0]
def _get_multi_param(self, param_prefix):
if param_prefix.endswith("."):
prefix = param_prefix
else:
prefix = param_prefix + "."
return [value[0] for key, value in self.querystring.items()
if key.startswith(prefix)]
def metadata_response(request, full_url, headers): def metadata_response(request, full_url, headers):
""" """

View File

@ -8,10 +8,6 @@ from moto.ec2.exceptions import InvalidIdError
class InstanceResponse(BaseResponse): class InstanceResponse(BaseResponse):
def _get_multi_param(self, param_prefix):
return [value[0] for key, value in self.querystring.items()
if key.startswith(param_prefix + ".")]
def describe_instances(self): def describe_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
if instance_ids: if instance_ids:

View File

@ -13,9 +13,6 @@ class SpotInstances(BaseResponse):
if value is not None: if value is not None:
return int(value) return int(value)
def _get_multi_param(self, param_prefix):
return [value[0] for key, value in self.querystring.items() if key.startswith(param_prefix)]
def cancel_spot_instance_requests(self): def cancel_spot_instance_requests(self):
request_ids = self._get_multi_param('SpotInstanceRequestId') request_ids = self._get_multi_param('SpotInstanceRequestId')
requests = ec2_backend.cancel_spot_instance_requests(request_ids) requests = ec2_backend.cancel_spot_instance_requests(request_ids)
@ -49,7 +46,7 @@ class SpotInstances(BaseResponse):
launch_group = self._get_param('LaunchGroup') launch_group = self._get_param('LaunchGroup')
availability_zone_group = self._get_param('AvailabilityZoneGroup') availability_zone_group = self._get_param('AvailabilityZoneGroup')
key_name = self._get_param('LaunchSpecification.KeyName') key_name = self._get_param('LaunchSpecification.KeyName')
security_groups = self._get_multi_param('LaunchSpecification.SecurityGroup.') security_groups = self._get_multi_param('LaunchSpecification.SecurityGroup')
user_data = self._get_param('LaunchSpecification.UserData') user_data = self._get_param('LaunchSpecification.UserData')
instance_type = self._get_param('LaunchSpecification.InstanceType') instance_type = self._get_param('LaunchSpecification.InstanceType')
placement = self._get_param('LaunchSpecification.Placement.AvailabilityZone') placement = self._get_param('LaunchSpecification.Placement.AvailabilityZone')