Merge branch 'master' into create-access-key-fix

This commit is contained in:
Bendegúz Ács 2019-07-05 17:11:55 +02:00 committed by GitHub
commit 5594195e28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 605 additions and 48 deletions

224
moto/core/authentication.py Normal file
View File

@ -0,0 +1,224 @@
import json
import re
from abc import ABC, abstractmethod
from enum import Enum
from botocore.auth import SigV4Auth, S3SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
from moto.iam.models import ACCOUNT_ID, Policy
from moto.iam import iam_backend
from moto.core.exceptions import SignatureDoesNotMatchError, AccessDeniedError, InvalidClientTokenIdError
from moto.s3.exceptions import BucketAccessDeniedError, S3AccessDeniedError
ACCESS_KEY_STORE = {
"AKIAJDULPKHCC4KGTYVA": {
"owner": "avatao-user",
"secret_access_key": "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"
}
}
class IAMRequestBase(ABC):
def __init__(self, method, path, data, headers):
print(f"Creating {self.__class__.__name__} with method={method}, path={path}, data={data}, headers={headers}")
self._method = method
self._path = path
self._data = data
self._headers = headers
credential_scope = self._get_string_between('Credential=', ',', self._headers['Authorization'])
credential_data = credential_scope.split('/')
self._access_key = credential_data[0]
self._region = credential_data[2]
self._service = credential_data[3]
self._action = self._service + ":" + self._data["Action"][0]
def check_signature(self):
original_signature = self._get_string_between('Signature=', ',', self._headers['Authorization'])
calculated_signature = self._calculate_signature()
if original_signature != calculated_signature:
raise SignatureDoesNotMatchError()
def check_action_permitted(self):
iam_user_name = ACCESS_KEY_STORE[self._access_key]["owner"]
user_policies = self._collect_policies_for_iam_user(iam_user_name)
permitted = False
for policy in user_policies:
iam_policy = IAMPolicy(policy)
permission_result = iam_policy.is_action_permitted(self._action)
if permission_result == PermissionResult.DENIED:
self._raise_access_denied(iam_user_name)
elif permission_result == PermissionResult.PERMITTED:
permitted = True
if not permitted:
self._raise_access_denied(iam_user_name)
@abstractmethod
def _raise_access_denied(self, iam_user_name):
raise NotImplementedError()
@staticmethod
def _collect_policies_for_iam_user(iam_user_name):
user_policies = []
inline_policy_names = iam_backend.list_user_policies(iam_user_name)
for inline_policy_name in inline_policy_names:
inline_policy = iam_backend.get_user_policy(iam_user_name, inline_policy_name)
user_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_user_policies(iam_user_name)
user_policies += attached_policies
user_groups = iam_backend.get_groups_for_user(iam_user_name)
for user_group in user_groups:
inline_group_policy_names = iam_backend.list_group_policies(user_group)
for inline_group_policy_name in inline_group_policy_names:
inline_user_group_policy = iam_backend.get_group_policy(user_group.name, inline_group_policy_name)
user_policies.append(inline_user_group_policy)
attached_group_policies = iam_backend.list_attached_group_policies(user_group.name)
user_policies += attached_group_policies
return user_policies
@abstractmethod
def _create_auth(self, credentials):
raise NotImplementedError()
@staticmethod
def _create_headers_for_aws_request(signed_headers, original_headers):
headers = {}
for key, value in original_headers.items():
if key.lower() in signed_headers:
headers[key] = value
return headers
def _create_aws_request(self):
signed_headers = self._get_string_between('SignedHeaders=', ',', self._headers['Authorization']).split(';')
headers = self._create_headers_for_aws_request(signed_headers, self._headers)
request = AWSRequest(method=self._method, url=self._path, data=self._data, headers=headers)
request.context['timestamp'] = headers['X-Amz-Date']
return request
def _calculate_signature(self):
if self._access_key not in ACCESS_KEY_STORE:
raise InvalidClientTokenIdError()
secret_key = ACCESS_KEY_STORE[self._access_key]["secret_access_key"]
credentials = Credentials(self._access_key, secret_key)
auth = self._create_auth(credentials)
request = self._create_aws_request()
canonical_request = auth.canonical_request(request)
string_to_sign = auth.string_to_sign(request, canonical_request)
return auth.signature(string_to_sign, request)
@staticmethod
def _get_string_between(first_separator, second_separator, string):
return string.partition(first_separator)[2].partition(second_separator)[0]
class IAMRequest(IAMRequestBase):
def _create_auth(self, credentials):
return SigV4Auth(credentials, self._service, self._region)
def _raise_access_denied(self, iam_user_name):
raise AccessDeniedError(
account_id=ACCOUNT_ID,
iam_user_name=iam_user_name,
action=self._action
)
class S3IAMRequest(IAMRequestBase):
def _create_auth(self, credentials):
return S3SigV4Auth(credentials, self._service, self._region)
def _raise_access_denied(self, _):
if "BucketName" in self._data:
raise BucketAccessDeniedError(bucket=self._data["BucketName"])
else:
raise S3AccessDeniedError()
class IAMPolicy:
def __init__(self, policy):
self._policy = policy
def is_action_permitted(self, action):
if isinstance(self._policy, Policy):
default_version = next(policy_version for policy_version in self._policy.versions if policy_version.is_default)
policy_document = default_version.document
else:
policy_document = self._policy["policy_document"]
policy_json = json.loads(policy_document)
permitted = False
for policy_statement in policy_json["Statement"]:
iam_policy_statement = IAMPolicyStatement(policy_statement)
permission_result = iam_policy_statement.is_action_permitted(action)
if permission_result == PermissionResult.DENIED:
return permission_result
elif permission_result == PermissionResult.PERMITTED:
permitted = True
if permitted:
return PermissionResult.PERMITTED
else:
return PermissionResult.NEUTRAL
class IAMPolicyStatement:
def __init__(self, statement):
self._statement = statement
def is_action_permitted(self, action):
is_action_concerned = False
if "NotAction" in self._statement:
if not self._check_element_matches("NotAction", action):
is_action_concerned = True
else: # Action is present
if self._check_element_matches("Action", action):
is_action_concerned = True
# TODO: check Resource/NotResource and Condition
if is_action_concerned:
if self._statement["Effect"] == "Allow":
return PermissionResult.PERMITTED
else: # Deny
return PermissionResult.DENIED
else:
return PermissionResult.NEUTRAL
def _check_element_matches(self, statement_element, value):
if isinstance(self._statement[statement_element], list):
for statement_element_value in self._statement[statement_element]:
if self._match(statement_element_value, value):
return True
return False
else: # string
return self._match(self._statement[statement_element], value)
@staticmethod
def _match(pattern, string):
pattern = pattern.replace("*", ".*")
pattern = f"^{pattern}$"
return re.match(pattern, string)
class PermissionResult(Enum):
PERMITTED = 1
DENIED = 2
NEUTRAL = 3

View File

@ -65,3 +65,34 @@ class JsonRESTError(RESTError):
def get_body(self, *args, **kwargs): def get_body(self, *args, **kwargs):
return self.description return self.description
class SignatureDoesNotMatchError(RESTError):
code = 400
def __init__(self):
super(SignatureDoesNotMatchError, self).__init__(
'SignatureDoesNotMatch',
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.")
class InvalidClientTokenIdError(RESTError):
code = 400
def __init__(self):
super(InvalidClientTokenIdError, self).__init__(
'InvalidClientTokenId',
"The security token included in the request is invalid.")
class AccessDeniedError(RESTError):
code = 403
def __init__(self, account_id, iam_user_name, action):
super(AccessDeniedError, self).__init__(
'AccessDenied',
"User: arn:aws:iam::{account_id}:user/{iam_user_name} is not authorized to perform: {operation}".format(
account_id=account_id,
iam_user_name=iam_user_name,
operation=action
))

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import os
from collections import defaultdict from collections import defaultdict
import datetime import datetime
import json import json
@ -8,6 +9,8 @@ import re
import io import io
import pytz import pytz
from moto.core.authentication import IAMRequest, S3IAMRequest
from moto.core.exceptions import DryRunClientError from moto.core.exceptions import DryRunClientError
from jinja2 import Environment, DictLoader, TemplateNotFound from jinja2 import Environment, DictLoader, TemplateNotFound
@ -103,7 +106,29 @@ class _TemplateEnvironmentMixin(object):
return self.environment.get_template(template_id) return self.environment.get_template(template_id)
class BaseResponse(_TemplateEnvironmentMixin): class ActionAuthenticatorMixin(object):
INITIAL_NO_AUTH_ACTION_COUNT = int(os.environ.get("INITIAL_NO_AUTH_ACTION_COUNT", 999999999))
request_count = 0
def _authenticate_action(self, iam_request):
iam_request.check_signature()
if ActionAuthenticatorMixin.request_count >= ActionAuthenticatorMixin.INITIAL_NO_AUTH_ACTION_COUNT:
iam_request.check_action_permitted()
else:
ActionAuthenticatorMixin.request_count += 1
def _authenticate_normal_action(self):
iam_request = IAMRequest(method=self.method, path=self.path, data=self.data, headers=self.headers)
self._authenticate_action(iam_request)
def _authenticate_s3_action(self):
iam_request = S3IAMRequest(method=self.method, path=self.path, data=self.data, headers=self.headers)
self._authenticate_action(iam_request)
class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
default_region = 'us-east-1' default_region = 'us-east-1'
# to extract region, use [^.] # to extract region, use [^.]
@ -167,6 +192,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
self.uri = full_url self.uri = full_url
self.path = urlparse(full_url).path self.path = urlparse(full_url).path
self.querystring = querystring self.querystring = querystring
self.data = querystring
self.method = request.method self.method = request.method
self.region = self.get_region_from_url(request, full_url) self.region = self.get_region_from_url(request, full_url)
self.uri_match = None self.uri_match = None
@ -273,6 +299,13 @@ class BaseResponse(_TemplateEnvironmentMixin):
def call_action(self): def call_action(self):
headers = self.response_headers headers = self.response_headers
try:
self._authenticate_normal_action()
except HTTPException as http_error:
response = http_error.description, dict(status=http_error.code)
return self._send_response(headers, response)
action = camelcase_to_underscores(self._get_action()) action = camelcase_to_underscores(self._get_action())
method_names = method_names_from_class(self.__class__) method_names = method_names_from_class(self.__class__)
if action in method_names: if action in method_names:
@ -285,16 +318,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
if isinstance(response, six.string_types): if isinstance(response, six.string_types):
return 200, headers, response return 200, headers, response
else: else:
if len(response) == 2: return self._send_response(headers, response)
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
return status, headers, body
if not action: if not action:
return 404, headers, '' return 404, headers, ''
@ -302,6 +326,19 @@ class BaseResponse(_TemplateEnvironmentMixin):
raise NotImplementedError( raise NotImplementedError(
"The {0} action has not been implemented".format(action)) "The {0} action has not been implemented".format(action))
@staticmethod
def _send_response(headers, response):
if len(response) == 2:
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
return status, headers, body
def _get_param(self, param_name, if_none=None): def _get_param(self, param_name, if_none=None):
val = self.querystring.get(param_name) val = self.querystring.get(param_name)
if val is not None: if val is not None:

View File

@ -26,6 +26,14 @@ class IAMReportNotPresentException(RESTError):
"ReportNotPresent", message) "ReportNotPresent", message)
class IAMLimitExceededException(RESTError):
code = 400
def __init__(self, message):
super(IAMLimitExceededException, self).__init__(
"LimitExceeded", message)
class MalformedCertificate(RESTError): class MalformedCertificate(RESTError):
code = 400 code = 400

View File

@ -14,8 +14,8 @@ from moto.core.utils import iso_8601_datetime_without_milliseconds, iso_8601_dat
from moto.iam.policy_validation import IAMPolicyDocumentValidator from moto.iam.policy_validation import IAMPolicyDocumentValidator
from .aws_managed_policies import aws_managed_policies_data from .aws_managed_policies import aws_managed_policies_data
from .exceptions import IAMNotFoundException, IAMConflictException, IAMReportNotPresentException, MalformedCertificate, \ from .exceptions import IAMNotFoundException, IAMConflictException, IAMReportNotPresentException, IAMLimitExceededException, \
DuplicateTags, TagKeyTooBig, InvalidTagCharacters, TooManyTags, TagValueTooBig MalformedCertificate, DuplicateTags, TagKeyTooBig, InvalidTagCharacters, TooManyTags, TagValueTooBig
from .utils import random_access_key, random_alphanumeric, random_resource_id, random_policy_id from .utils import random_access_key, random_alphanumeric, random_resource_id, random_policy_id
ACCOUNT_ID = 123456789012 ACCOUNT_ID = 123456789012
@ -67,6 +67,13 @@ class Policy(BaseModel):
self.create_date = create_date if create_date is not None else datetime.utcnow() self.create_date = create_date if create_date is not None else datetime.utcnow()
self.update_date = update_date if update_date is not None else datetime.utcnow() self.update_date = update_date if update_date is not None else datetime.utcnow()
def update_default_version(self, new_default_version_id):
for version in self.versions:
if version.version_id == self.default_version_id:
version.is_default = False
break
self.default_version_id = new_default_version_id
@property @property
def created_iso_8601(self): def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date) return iso_8601_datetime_with_milliseconds(self.create_date)
@ -770,13 +777,16 @@ class IAMBackend(BaseBackend):
policy = self.get_policy(policy_arn) policy = self.get_policy(policy_arn)
if not policy: if not policy:
raise IAMNotFoundException("Policy not found") raise IAMNotFoundException("Policy not found")
if len(policy.versions) >= 5:
raise IAMLimitExceededException("A managed policy can have up to 5 versions. Before you create a new version, you must delete an existing version.")
set_as_default = (set_as_default == "true") # convert it to python bool
version = PolicyVersion(policy_arn, policy_document, set_as_default) version = PolicyVersion(policy_arn, policy_document, set_as_default)
policy.versions.append(version) policy.versions.append(version)
version.version_id = 'v{0}'.format(policy.next_version_num) version.version_id = 'v{0}'.format(policy.next_version_num)
policy.next_version_num += 1 policy.next_version_num += 1
if set_as_default: if set_as_default:
policy.default_version_id = version.version_id policy.update_default_version(version.version_id)
return version return version
def get_policy_version(self, policy_arn, version_id): def get_policy_version(self, policy_arn, version_id):
@ -799,8 +809,8 @@ class IAMBackend(BaseBackend):
if not policy: if not policy:
raise IAMNotFoundException("Policy not found") raise IAMNotFoundException("Policy not found")
if version_id == policy.default_version_id: if version_id == policy.default_version_id:
raise IAMConflictException( raise IAMConflictException(code="DeleteConflict",
"Cannot delete the default version of a policy") message="Cannot delete the default version of a policy.")
for i, v in enumerate(policy.versions): for i, v in enumerate(policy.versions):
if v.version_id == version_id: if v.version_id == version_id:
del policy.versions[i] del policy.versions[i]

View File

@ -1,7 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import iam_backend, User from .models import iam_backend
AVATAO_USER_NAME = "avatao-user"
class IamResponse(BaseResponse): class IamResponse(BaseResponse):
@ -425,11 +427,10 @@ class IamResponse(BaseResponse):
def get_user(self): def get_user(self):
user_name = self._get_param('UserName') user_name = self._get_param('UserName')
if user_name: if not user_name:
user = iam_backend.get_user(user_name) user_name = AVATAO_USER_NAME
else: # If no user is specified, IAM returns the current user
user = User(name='default_user') user = iam_backend.get_user(user_name)
# If no user is specific, IAM returns the current user
template = self.response_template(USER_TEMPLATE) template = self.response_template(USER_TEMPLATE)
return template.render(action='Get', user=user) return template.render(action='Get', user=user)
@ -457,7 +458,6 @@ class IamResponse(BaseResponse):
def create_login_profile(self): def create_login_profile(self):
user_name = self._get_param('UserName') user_name = self._get_param('UserName')
password = self._get_param('Password') password = self._get_param('Password')
password = self._get_param('Password')
user = iam_backend.create_login_profile(user_name, password) user = iam_backend.create_login_profile(user_name, password)
template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE) template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE)
@ -1144,7 +1144,7 @@ CREATE_POLICY_VERSION_TEMPLATE = """<CreatePolicyVersionResponse xmlns="https://
<PolicyVersion> <PolicyVersion>
<Document>{{ policy_version.document }}</Document> <Document>{{ policy_version.document }}</Document>
<VersionId>{{ policy_version.version_id }}</VersionId> <VersionId>{{ policy_version.version_id }}</VersionId>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion> <IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate> <CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</PolicyVersion> </PolicyVersion>
</CreatePolicyVersionResult> </CreatePolicyVersionResult>
@ -1158,7 +1158,7 @@ GET_POLICY_VERSION_TEMPLATE = """<GetPolicyVersionResponse xmlns="https://iam.am
<PolicyVersion> <PolicyVersion>
<Document>{{ policy_version.document }}</Document> <Document>{{ policy_version.document }}</Document>
<VersionId>{{ policy_version.version_id }}</VersionId> <VersionId>{{ policy_version.version_id }}</VersionId>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion> <IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate> <CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</PolicyVersion> </PolicyVersion>
</GetPolicyVersionResult> </GetPolicyVersionResult>
@ -1175,7 +1175,7 @@ LIST_POLICY_VERSIONS_TEMPLATE = """<ListPolicyVersionsResponse xmlns="https://ia
<member> <member>
<Document>{{ policy_version.document }}</Document> <Document>{{ policy_version.document }}</Document>
<VersionId>{{ policy_version.version_id }}</VersionId> <VersionId>{{ policy_version.version_id }}</VersionId>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion> <IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate> <CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</member> </member>
{% endfor %} {% endfor %}
@ -1787,7 +1787,7 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
{% for policy_version in policy.versions %} {% for policy_version in policy.versions %}
<member> <member>
<Document>{{ policy_version.document }}</Document> <Document>{{ policy_version.document }}</Document>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion> <IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<VersionId>{{ policy_version.version_id }}</VersionId> <VersionId>{{ policy_version.version_id }}</VersionId>
<CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate> <CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</member> </member>

View File

@ -199,3 +199,17 @@ class DuplicateTagKeys(S3ClientError):
"InvalidTag", "InvalidTag",
"Cannot provide multiple Tags with the same key", "Cannot provide multiple Tags with the same key",
*args, **kwargs) *args, **kwargs)
class S3AccessDeniedError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(S3AccessDeniedError, self).__init__('AccessDenied', 'Access Denied', *args, **kwargs)
class BucketAccessDeniedError(BucketError):
code = 403
def __init__(self, *args, **kwargs):
super(BucketAccessDeniedError, self).__init__('AccessDenied', 'Access Denied', *args, **kwargs)

View File

@ -3,13 +3,15 @@ from __future__ import unicode_literals
import re import re
import six import six
from werkzeug.exceptions import HTTPException
from moto.core.utils import str_to_rfc_1123_datetime from moto.core.utils import str_to_rfc_1123_datetime
from six.moves.urllib.parse import parse_qs, urlparse, unquote from six.moves.urllib.parse import parse_qs, urlparse, unquote
import xmltodict import xmltodict
from moto.packages.httpretty.core import HTTPrettyRequest from moto.packages.httpretty.core import HTTPrettyRequest
from moto.core.responses import _TemplateEnvironmentMixin from moto.core.responses import _TemplateEnvironmentMixin, ActionAuthenticatorMixin
from moto.core.utils import path_url from moto.core.utils import path_url
from moto.s3bucket_path.utils import bucket_name_from_url as bucketpath_bucket_name_from_url, \ from moto.s3bucket_path.utils import bucket_name_from_url as bucketpath_bucket_name_from_url, \
@ -25,6 +27,72 @@ from xml.dom import minidom
DEFAULT_REGION_NAME = 'us-east-1' DEFAULT_REGION_NAME = 'us-east-1'
ACTION_MAP = {
"BUCKET": {
"GET": {
"uploads": "ListBucketMultipartUploads",
"location": "GetBucketLocation",
"lifecycle": "GetLifecycleConfiguration",
"versioning": "GetBucketVersioning",
"policy": "GetBucketPolicy",
"website": "GetBucketWebsite",
"acl": "GetBucketAcl",
"tagging": "GetBucketTagging",
"logging": "GetBucketLogging",
"cors": "GetBucketCORS",
"notification": "GetBucketNotification",
"accelerate": "GetAccelerateConfiguration",
"versions": "ListBucketVersions",
"DEFAULT": "ListBucket"
},
"PUT": {
"lifecycle": "PutLifecycleConfiguration",
"versioning": "PutBucketVersioning",
"policy": "PutBucketPolicy",
"website": "PutBucketWebsite",
"acl": "PutBucketAcl",
"tagging": "PutBucketTagging",
"logging": "PutBucketLogging",
"cors": "PutBucketCORS",
"notification": "PutBucketNotification",
"accelerate": "PutAccelerateConfiguration",
"DEFAULT": "CreateBucket"
},
"DELETE": {
"lifecycle": "PutLifecycleConfiguration",
"policy": "DeleteBucketPolicy",
"tagging": "PutBucketTagging",
"cors": "PutBucketCORS",
"DEFAULT": "DeleteBucket"
}
},
"KEY": {
"GET": {
"uploadId": "ListMultipartUploadParts",
"acl": "GetObjectAcl",
"tagging": "GetObjectTagging",
"versionId": "GetObjectVersion",
"DEFAULT": "GetObject"
},
"PUT": {
"acl": "PutObjectAcl",
"tagging": "PutObjectTagging",
"DEFAULT": "PutObject"
},
"DELETE": {
"uploadId": "AbortMultipartUpload",
"versionId": "DeleteObjectVersion",
"DEFAULT": " DeleteObject"
},
"POST": {
"uploads": "PutObject",
"restore": "RestoreObject",
"uploadId": "PutObject"
}
}
}
def parse_key_name(pth): def parse_key_name(pth):
return pth.lstrip("/") return pth.lstrip("/")
@ -37,17 +105,27 @@ def is_delete_keys(request, path, bucket_name):
) )
class ResponseObject(_TemplateEnvironmentMixin): class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def __init__(self, backend): def __init__(self, backend):
super(ResponseObject, self).__init__() super(ResponseObject, self).__init__()
self.backend = backend self.backend = backend
self.method = ""
self.path = ""
self.data = {}
self.headers = {}
@property @property
def should_autoescape(self): def should_autoescape(self):
return True return True
def all_buckets(self): def all_buckets(self, headers):
try:
self.data["Action"] = "ListAllMyBuckets"
self._authenticate_s3_action()
except HTTPException as http_error:
response = http_error.code, headers, http_error.description
return self._send_response(response)
# No bucket specified. Listing all buckets # No bucket specified. Listing all buckets
all_buckets = self.backend.get_all_buckets() all_buckets = self.backend.get_all_buckets()
template = self.response_template(S3_ALL_BUCKETS) template = self.response_template(S3_ALL_BUCKETS)
@ -112,11 +190,18 @@ class ResponseObject(_TemplateEnvironmentMixin):
return self.bucket_response(request, full_url, headers) return self.bucket_response(request, full_url, headers)
def bucket_response(self, request, full_url, headers): def bucket_response(self, request, full_url, headers):
self.method = request.method
self.path = self._get_path(request)
self.headers = request.headers
try: try:
response = self._bucket_response(request, full_url, headers) response = self._bucket_response(request, full_url, headers)
except S3ClientError as s3error: except S3ClientError as s3error:
response = s3error.code, {}, s3error.description response = s3error.code, {}, s3error.description
return self._send_response(response)
@staticmethod
def _send_response(response):
if isinstance(response, six.string_types): if isinstance(response, six.string_types):
return 200, {}, response.encode("utf-8") return 200, {}, response.encode("utf-8")
else: else:
@ -127,15 +212,16 @@ class ResponseObject(_TemplateEnvironmentMixin):
return status_code, headers, response_content return status_code, headers, response_content
def _bucket_response(self, request, full_url, headers): def _bucket_response(self, request, full_url, headers):
parsed_url = urlparse(full_url) querystring = self._get_querystring(full_url)
querystring = parse_qs(parsed_url.query, keep_blank_values=True)
method = request.method method = request.method
region_name = parse_region_from_url(full_url) region_name = parse_region_from_url(full_url)
bucket_name = self.parse_bucket_name_from_url(request, full_url) bucket_name = self.parse_bucket_name_from_url(request, full_url)
if not bucket_name: if not bucket_name:
# If no bucket specified, list all buckets # If no bucket specified, list all buckets
return self.all_buckets() return self.all_buckets(headers)
self.data["BucketName"] = bucket_name
if hasattr(request, 'body'): if hasattr(request, 'body'):
# Boto # Boto
@ -163,6 +249,12 @@ class ResponseObject(_TemplateEnvironmentMixin):
raise NotImplementedError( raise NotImplementedError(
"Method {0} has not been impelemented in the S3 backend yet".format(method)) "Method {0} has not been impelemented in the S3 backend yet".format(method))
@staticmethod
def _get_querystring(full_url):
parsed_url = urlparse(full_url)
querystring = parse_qs(parsed_url.query, keep_blank_values=True)
return querystring
def _bucket_response_head(self, bucket_name, headers): def _bucket_response_head(self, bucket_name, headers):
try: try:
self.backend.get_bucket(bucket_name) self.backend.get_bucket(bucket_name)
@ -175,6 +267,14 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 200, {}, "" return 200, {}, ""
def _bucket_response_get(self, bucket_name, querystring, headers): def _bucket_response_get(self, bucket_name, querystring, headers):
self._set_action("BUCKET", "GET", querystring)
try:
self._authenticate_s3_action()
except HTTPException as http_error:
response = http_error.code, headers, http_error.description
return self._send_response(response)
if 'uploads' in querystring: if 'uploads' in querystring:
for unsup in ('delimiter', 'max-uploads'): for unsup in ('delimiter', 'max-uploads'):
if unsup in querystring: if unsup in querystring:
@ -333,6 +433,15 @@ class ResponseObject(_TemplateEnvironmentMixin):
max_keys=max_keys max_keys=max_keys
) )
def _set_action(self, action_resource_type, method, querystring):
action_set = False
for action_in_querystring, action in ACTION_MAP[action_resource_type][method].items():
if action_in_querystring in querystring:
self.data["Action"] = action
action_set = True
if not action_set:
self.data["Action"] = ACTION_MAP[action_resource_type][method]["DEFAULT"]
def _handle_list_objects_v2(self, bucket_name, querystring): def _handle_list_objects_v2(self, bucket_name, querystring):
template = self.response_template(S3_BUCKET_GET_RESPONSE_V2) template = self.response_template(S3_BUCKET_GET_RESPONSE_V2)
bucket = self.backend.get_bucket(bucket_name) bucket = self.backend.get_bucket(bucket_name)
@ -396,6 +505,15 @@ class ResponseObject(_TemplateEnvironmentMixin):
def _bucket_response_put(self, request, body, region_name, bucket_name, querystring, headers): def _bucket_response_put(self, request, body, region_name, bucket_name, querystring, headers):
if not request.headers.get('Content-Length'): if not request.headers.get('Content-Length'):
return 411, {}, "Content-Length required" return 411, {}, "Content-Length required"
self._set_action("BUCKET", "PUT", querystring)
try:
self._authenticate_s3_action()
except HTTPException as http_error:
response = http_error.code, headers, http_error.description
return self._send_response(response)
if 'versioning' in querystring: if 'versioning' in querystring:
ver = re.search('<Status>([A-Za-z]+)</Status>', body.decode()) ver = re.search('<Status>([A-Za-z]+)</Status>', body.decode())
if ver: if ver:
@ -495,6 +613,14 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 200, {}, template.render(bucket=new_bucket) return 200, {}, template.render(bucket=new_bucket)
def _bucket_response_delete(self, body, bucket_name, querystring, headers): def _bucket_response_delete(self, body, bucket_name, querystring, headers):
self._set_action("BUCKET", "DELETE", querystring)
try:
self._authenticate_s3_action()
except HTTPException as http_error:
response = http_error.code, headers, http_error.description
return self._send_response(response)
if 'policy' in querystring: if 'policy' in querystring:
self.backend.delete_bucket_policy(bucket_name, body) self.backend.delete_bucket_policy(bucket_name, body)
return 204, {}, "" return 204, {}, ""
@ -525,14 +651,27 @@ class ResponseObject(_TemplateEnvironmentMixin):
if not request.headers.get('Content-Length'): if not request.headers.get('Content-Length'):
return 411, {}, "Content-Length required" return 411, {}, "Content-Length required"
if isinstance(request, HTTPrettyRequest): path = self._get_path(request)
path = request.path
else:
path = request.full_path if hasattr(request, 'full_path') else path_url(request.url)
if self.is_delete_keys(request, path, bucket_name): if self.is_delete_keys(request, path, bucket_name):
self.data["Action"] = "DeleteObject"
try:
self._authenticate_s3_action()
except HTTPException as http_error:
response = http_error.code, headers, http_error.description
return self._send_response(response)
return self._bucket_response_delete_keys(request, body, bucket_name, headers) return self._bucket_response_delete_keys(request, body, bucket_name, headers)
self.data["Action"] = "PutObject"
try:
self._authenticate_s3_action()
except HTTPException as http_error:
response = http_error.code, headers, http_error.description
return self._send_response(response)
# POST to bucket-url should create file from form # POST to bucket-url should create file from form
if hasattr(request, 'form'): if hasattr(request, 'form'):
# Not HTTPretty # Not HTTPretty
@ -560,6 +699,14 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 200, {}, "" return 200, {}, ""
@staticmethod
def _get_path(request):
if isinstance(request, HTTPrettyRequest):
path = request.path
else:
path = request.full_path if hasattr(request, 'full_path') else path_url(request.url)
return path
def _bucket_response_delete_keys(self, request, body, bucket_name, headers): def _bucket_response_delete_keys(self, request, body, bucket_name, headers):
template = self.response_template(S3_DELETE_KEYS_RESPONSE) template = self.response_template(S3_DELETE_KEYS_RESPONSE)
@ -604,6 +751,9 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 206, response_headers, response_content[begin:end + 1] return 206, response_headers, response_content[begin:end + 1]
def key_response(self, request, full_url, headers): def key_response(self, request, full_url, headers):
self.method = request.method
self.path = self._get_path(request)
self.headers = request.headers
response_headers = {} response_headers = {}
try: try:
response = self._key_response(request, full_url, headers) response = self._key_response(request, full_url, headers)
@ -671,6 +821,14 @@ class ResponseObject(_TemplateEnvironmentMixin):
"Method {0} has not been implemented in the S3 backend yet".format(method)) "Method {0} has not been implemented in the S3 backend yet".format(method))
def _key_response_get(self, bucket_name, query, key_name, headers): def _key_response_get(self, bucket_name, query, key_name, headers):
self._set_action("KEY", "GET", query)
try:
self._authenticate_s3_action()
except HTTPException as http_error:
response = http_error.code, headers, http_error.description
return self._send_response(response)
response_headers = {} response_headers = {}
if query.get('uploadId'): if query.get('uploadId'):
upload_id = query['uploadId'][0] upload_id = query['uploadId'][0]
@ -700,6 +858,14 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 200, response_headers, key.value return 200, response_headers, key.value
def _key_response_put(self, request, body, bucket_name, query, key_name, headers): def _key_response_put(self, request, body, bucket_name, query, key_name, headers):
self._set_action("KEY", "PUT", query)
try:
self._authenticate_s3_action()
except HTTPException as http_error:
response = http_error.code, headers, http_error.description
return self._send_response(response)
response_headers = {} response_headers = {}
if query.get('uploadId') and query.get('partNumber'): if query.get('uploadId') and query.get('partNumber'):
upload_id = query['uploadId'][0] upload_id = query['uploadId'][0]
@ -1067,6 +1233,14 @@ class ResponseObject(_TemplateEnvironmentMixin):
return config['Status'] return config['Status']
def _key_response_delete(self, bucket_name, query, key_name, headers): def _key_response_delete(self, bucket_name, query, key_name, headers):
self._set_action("KEY", "DELETE", query)
try:
self._authenticate_s3_action()
except HTTPException as http_error:
response = http_error.code, headers, http_error.description
return self._send_response(response)
if query.get('uploadId'): if query.get('uploadId'):
upload_id = query['uploadId'][0] upload_id = query['uploadId'][0]
self.backend.cancel_multipart(bucket_name, upload_id) self.backend.cancel_multipart(bucket_name, upload_id)
@ -1087,6 +1261,14 @@ class ResponseObject(_TemplateEnvironmentMixin):
yield (pn, p.getElementsByTagName('ETag')[0].firstChild.wholeText) yield (pn, p.getElementsByTagName('ETag')[0].firstChild.wholeText)
def _key_response_post(self, request, body, bucket_name, query, key_name, headers): def _key_response_post(self, request, body, bucket_name, query, key_name, headers):
self._set_action("KEY", "POST", query)
try:
self._authenticate_s3_action()
except HTTPException as http_error:
response = http_error.code, headers, http_error.description
return self._send_response(response)
if body == b'' and 'uploads' in query: if body == b'' and 'uploads' in query:
metadata = metadata_from_headers(request.headers) metadata = metadata_from_headers(request.headers)
multipart = self.backend.initiate_multipart( multipart = self.backend.initiate_multipart(

View File

@ -7,15 +7,6 @@ url_bases = [
r"https?://(?P<bucket_name>[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com" r"https?://(?P<bucket_name>[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com"
] ]
def ambiguous_response1(*args, **kwargs):
return S3ResponseInstance.ambiguous_response(*args, **kwargs)
def ambiguous_response2(*args, **kwargs):
return S3ResponseInstance.ambiguous_response(*args, **kwargs)
url_paths = { url_paths = {
# subdomain bucket # subdomain bucket
'{0}/$': S3ResponseInstance.bucket_response, '{0}/$': S3ResponseInstance.bucket_response,

View File

@ -345,6 +345,7 @@ def test_create_policy_versions():
SetAsDefault=True) SetAsDefault=True)
version.get('PolicyVersion').get('Document').should.equal(json.loads(MOCK_POLICY)) version.get('PolicyVersion').get('Document').should.equal(json.loads(MOCK_POLICY))
version.get('PolicyVersion').get('VersionId').should.equal("v2") version.get('PolicyVersion').get('VersionId').should.equal("v2")
version.get('PolicyVersion').get('IsDefaultVersion').should.be.ok
conn.delete_policy_version( conn.delete_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion", PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion",
VersionId="v1") VersionId="v1")
@ -352,6 +353,47 @@ def test_create_policy_versions():
PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion", PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion",
PolicyDocument=MOCK_POLICY) PolicyDocument=MOCK_POLICY)
version.get('PolicyVersion').get('VersionId').should.equal("v3") version.get('PolicyVersion').get('VersionId').should.equal("v3")
version.get('PolicyVersion').get('IsDefaultVersion').shouldnt.be.ok
@mock_iam
def test_create_many_policy_versions():
conn = boto3.client('iam', region_name='us-east-1')
conn.create_policy(
PolicyName="TestCreateManyPolicyVersions",
PolicyDocument='{"some":"policy"}')
for _ in range(0, 4):
conn.create_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestCreateManyPolicyVersions",
PolicyDocument='{"some":"policy"}')
with assert_raises(ClientError):
conn.create_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestCreateManyPolicyVersions",
PolicyDocument='{"some":"policy"}')
@mock_iam
def test_set_default_policy_version():
conn = boto3.client('iam', region_name='us-east-1')
conn.create_policy(
PolicyName="TestSetDefaultPolicyVersion",
PolicyDocument='{"first":"policy"}')
conn.create_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestSetDefaultPolicyVersion",
PolicyDocument='{"second":"policy"}',
SetAsDefault=True)
conn.create_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestSetDefaultPolicyVersion",
PolicyDocument='{"third":"policy"}',
SetAsDefault=True)
versions = conn.list_policy_versions(
PolicyArn="arn:aws:iam::123456789012:policy/TestSetDefaultPolicyVersion")
versions.get('Versions')[0].get('Document').should.equal({'first': 'policy'})
versions.get('Versions')[0].get('IsDefaultVersion').shouldnt.be.ok
versions.get('Versions')[1].get('Document').should.equal({'second': 'policy'})
versions.get('Versions')[1].get('IsDefaultVersion').shouldnt.be.ok
versions.get('Versions')[2].get('Document').should.equal({'third': 'policy'})
versions.get('Versions')[2].get('IsDefaultVersion').should.be.ok
@mock_iam @mock_iam
@ -393,6 +435,7 @@ def test_get_policy_version():
PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicyVersion", PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicyVersion",
VersionId=version.get('PolicyVersion').get('VersionId')) VersionId=version.get('PolicyVersion').get('VersionId'))
retrieved.get('PolicyVersion').get('Document').should.equal(json.loads(MOCK_POLICY)) retrieved.get('PolicyVersion').get('Document').should.equal(json.loads(MOCK_POLICY))
retrieved.get('PolicyVersion').get('IsDefaultVersion').shouldnt.be.ok
@mock_iam @mock_iam
@ -439,6 +482,7 @@ def test_list_policy_versions():
versions = conn.list_policy_versions( versions = conn.list_policy_versions(
PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions") PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions")
versions.get('Versions')[0].get('VersionId').should.equal('v1') versions.get('Versions')[0].get('VersionId').should.equal('v1')
versions.get('Versions')[0].get('IsDefaultVersion').should.be.ok
conn.create_policy_version( conn.create_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions", PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions",
@ -448,9 +492,10 @@ def test_list_policy_versions():
PolicyDocument=MOCK_POLICY_3) PolicyDocument=MOCK_POLICY_3)
versions = conn.list_policy_versions( versions = conn.list_policy_versions(
PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions") PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions")
print(versions.get('Versions'))
versions.get('Versions')[1].get('Document').should.equal(json.loads(MOCK_POLICY_2)) versions.get('Versions')[1].get('Document').should.equal(json.loads(MOCK_POLICY_2))
versions.get('Versions')[1].get('IsDefaultVersion').shouldnt.be.ok
versions.get('Versions')[2].get('Document').should.equal(json.loads(MOCK_POLICY_3)) versions.get('Versions')[2].get('Document').should.equal(json.loads(MOCK_POLICY_3))
versions.get('Versions')[2].get('IsDefaultVersion').shouldnt.be.ok
@mock_iam @mock_iam
@ -474,6 +519,21 @@ def test_delete_policy_version():
len(versions.get('Versions')).should.equal(1) len(versions.get('Versions')).should.equal(1)
@mock_iam
def test_delete_default_policy_version():
conn = boto3.client('iam', region_name='us-east-1')
conn.create_policy(
PolicyName="TestDeletePolicyVersion",
PolicyDocument='{"first":"policy"}')
conn.create_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion",
PolicyDocument='{"second":"policy"}')
with assert_raises(ClientError):
conn.delete_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion",
VersionId='v1')
@mock_iam_deprecated() @mock_iam_deprecated()
def test_create_user(): def test_create_user():
conn = boto.connect_iam() conn = boto.connect_iam()