refactor(ssm): Refactor document permisisons and conditions (#4243)

This commit is contained in:
Gonzalo Saad 2021-09-08 02:56:20 -03:00 committed by GitHub
parent b3795d312a
commit e6c6ce5942
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 148 additions and 87 deletions

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re import re
from dataclasses import dataclass
from typing import Dict from typing import Dict
from boto3 import Session from boto3 import Session
@ -131,33 +132,43 @@ def generate_ssm_doc_param_list(parameters):
return None return None
param_list = [] param_list = []
for param_name, param_info in parameters.items(): for param_name, param_info in parameters.items():
final_dict = {} final_dict = {
"Name": param_name,
}
final_dict["Name"] = param_name description = param_info.get("description")
final_dict["Type"] = param_info["type"] if description:
final_dict["Description"] = param_info["description"] final_dict["Description"] = description
if ( param_type = param_info["type"]
param_info["type"] == "StringList" final_dict["Type"] = param_type
or param_info["type"] == "StringMap"
or param_info["type"] == "MapList" default_value = param_info.get("default")
): if default_value is not None:
final_dict["DefaultValue"] = json.dumps(param_info["default"]) if param_type in {"StringList", "StringMap", "MapList"}:
final_dict["DefaultValue"] = json.dumps(default_value)
else: else:
final_dict["DefaultValue"] = str(param_info["default"]) final_dict["DefaultValue"] = str(default_value)
param_list.append(final_dict) param_list.append(final_dict)
return param_list return param_list
@dataclass(frozen=True)
class AccountPermission:
account_id: str
version: str
created_at: datetime
class Documents(BaseModel): class Documents(BaseModel):
def __init__(self, ssm_document): def __init__(self, ssm_document):
version = ssm_document.document_version version = ssm_document.document_version
self.versions = {version: ssm_document} self.versions = {version: ssm_document}
self.default_version = version self.default_version = version
self.latest_version = version self.latest_version = version
self.permissions = {} # {AccountID: version } self.permissions = {} # {AccountID: AccountPermission }
def get_default_version(self): def get_default_version(self):
return self.versions.get(self.default_version) return self.versions.get(self.default_version)
@ -233,7 +244,7 @@ class Documents(BaseModel):
new_latest_version = ordered_versions[-1] new_latest_version = ordered_versions[-1]
self.latest_version = new_latest_version self.latest_version = new_latest_version
def describe(self, document_version=None, version_name=None): def describe(self, document_version=None, version_name=None, tags=None):
document = self.find(document_version, version_name) document = self.find(document_version, version_name)
base = { base = {
"Hash": document.hash, "Hash": document.hash,
@ -256,20 +267,28 @@ class Documents(BaseModel):
base["VersionName"] = document.version_name base["VersionName"] = document.version_name
if document.target_type: if document.target_type:
base["TargetType"] = document.target_type base["TargetType"] = document.target_type
if document.tags: if tags:
base["Tags"] = document.tags base["Tags"] = tags
return base return base
def modify_permissions(self, accounts_to_add, accounts_to_remove, version): def modify_permissions(self, accounts_to_add, accounts_to_remove, version):
version = version or "$DEFAULT"
if accounts_to_add:
if "all" in accounts_to_add: if "all" in accounts_to_add:
self.permissions.clear() self.permissions.clear()
else: else:
self.permissions.pop("all", None) self.permissions.pop("all", None)
new_permissions = {account_id: version for account_id in accounts_to_add} new_permissions = {
account_id: AccountPermission(
account_id, version, datetime.datetime.now()
)
for account_id in accounts_to_add
}
self.permissions.update(**new_permissions) self.permissions.update(**new_permissions)
if accounts_to_remove:
if "all" in accounts_to_remove: if "all" in accounts_to_remove:
self.permissions.clear() self.permissions.clear()
else: else:
@ -277,11 +296,16 @@ class Documents(BaseModel):
self.permissions.pop(account_id, None) self.permissions.pop(account_id, None)
def describe_permissions(self): def describe_permissions(self):
permissions_ordered_by_date = sorted(
self.permissions.values(), key=lambda p: p.created_at
)
return { return {
"AccountIds": list(self.permissions.keys()), "AccountIds": [p.account_id for p in permissions_ordered_by_date],
"AccountSharingInfoList": [ "AccountSharingInfoList": [
{"AccountId": account_id, "SharedDocumentVersion": document_version} {"AccountId": p.account_id, "SharedDocumentVersion": p.version}
for account_id, document_version in self.permissions.items() for p in permissions_ordered_by_date
], ],
} }
@ -300,7 +324,6 @@ class Document(BaseModel):
requires, requires,
attachments, attachments,
target_type, target_type,
tags,
document_version="1", document_version="1",
): ):
self.name = name self.name = name
@ -311,7 +334,6 @@ class Document(BaseModel):
self.requires = requires self.requires = requires
self.attachments = attachments self.attachments = attachments
self.target_type = target_type self.target_type = target_type
self.tags = tags
self.status = "Active" self.status = "Active"
self.document_version = document_version self.document_version = document_version
@ -353,12 +375,8 @@ class Document(BaseModel):
content_json.get("parameters") content_json.get("parameters")
) )
if ( if self.schema_version in {"0.3", "2.0", "2.2"}:
self.schema_version == "0.3" self.mainSteps = content_json.get("mainSteps")
or self.schema_version == "2.0"
or self.schema_version == "2.2"
):
self.mainSteps = content_json["mainSteps"]
elif self.schema_version == "1.2": elif self.schema_version == "1.2":
self.runtimeConfig = content_json.get("runtimeConfig") self.runtimeConfig = content_json.get("runtimeConfig")
@ -369,6 +387,28 @@ class Document(BaseModel):
def hash(self): def hash(self):
return hashlib.sha256(self.content.encode("utf-8")).hexdigest() return hashlib.sha256(self.content.encode("utf-8")).hexdigest()
def list_describe(self, tags=None):
base = {
"Name": self.name,
"Owner": self.owner,
"DocumentVersion": self.document_version,
"DocumentType": self.document_type,
"SchemaVersion": self.schema_version,
"DocumentFormat": self.document_format,
}
if self.version_name:
base["VersionName"] = self.version_name
if self.platform_types:
base["PlatformTypes"] = self.platform_types
if self.target_type:
base["TargetType"] = self.target_type
if self.requires:
base["Requires"] = self.requires
if tags:
base["Tags"] = tags
return base
class Command(BaseModel): class Command(BaseModel):
def __init__( def __init__(
@ -680,35 +720,21 @@ class SimpleSystemManagerBackend(BaseBackend):
raise ValidationException("Invalid document format " + str(document_format)) raise ValidationException("Invalid document format " + str(document_format))
return content return content
@staticmethod
def _generate_document_list_information(ssm_document):
base = {
"Name": ssm_document.name,
"Owner": ssm_document.owner,
"DocumentVersion": ssm_document.document_version,
"DocumentType": ssm_document.document_type,
"SchemaVersion": ssm_document.schema_version,
"DocumentFormat": ssm_document.document_format,
}
if ssm_document.version_name:
base["VersionName"] = ssm_document.version_name
if ssm_document.platform_types:
base["PlatformTypes"] = ssm_document.platform_types
if ssm_document.target_type:
base["TargetType"] = ssm_document.target_type
if ssm_document.tags:
base["Tags"] = ssm_document.tags
if ssm_document.requires:
base["Requires"] = ssm_document.requires
return base
def _get_documents(self, name): def _get_documents(self, name):
documents = self._documents.get(name) documents = self._documents.get(name)
if not documents: if not documents:
raise InvalidDocument("The specified document does not exist.") raise InvalidDocument("The specified document does not exist.")
return documents return documents
def _get_documents_tags(self, name):
docs_tags = self._resource_tags.get("Document")
if docs_tags:
document_tags = docs_tags.get(name, {})
return [
{"Key": tag, "Value": value} for tag, value in document_tags.items()
]
return []
def create_document( def create_document(
self, self,
content, content,
@ -730,7 +756,6 @@ class SimpleSystemManagerBackend(BaseBackend):
requires=requires, requires=requires,
attachments=attachments, attachments=attachments,
target_type=target_type, target_type=target_type,
tags=tags,
) )
_validate_document_info( _validate_document_info(
@ -746,7 +771,11 @@ class SimpleSystemManagerBackend(BaseBackend):
documents = Documents(ssm_document) documents = Documents(ssm_document)
self._documents[ssm_document.name] = documents self._documents[ssm_document.name] = documents
return documents.describe() if tags:
document_tags = {t["Key"]: t["Value"] for t in tags}
self.add_tags_to_resource("Document", name, document_tags)
return documents.describe(tags=tags)
def delete_document(self, name, document_version, version_name, force): def delete_document(self, name, document_version, version_name, force):
documents = self._get_documents(name) documents = self._get_documents(name)
@ -873,23 +902,25 @@ class SimpleSystemManagerBackend(BaseBackend):
requires=old_ssm_document.requires, requires=old_ssm_document.requires,
attachments=attachments, attachments=attachments,
target_type=target_type, target_type=target_type,
tags=old_ssm_document.tags,
document_version=new_version, document_version=new_version,
) )
for doc_version, document in documents.versions.items(): for doc_version, document in documents.versions.items():
if document.content == new_ssm_document.content: if document.content == new_ssm_document.content:
if not target_type or target_type == document.target_type:
raise DuplicateDocumentContent( raise DuplicateDocumentContent(
"The content of the association document matches another document. " "The content of the association document matches another document. "
"Change the content of the document and try again." "Change the content of the document and try again."
) )
documents.add_new_version(new_ssm_document) documents.add_new_version(new_ssm_document)
return documents.describe(document_version=new_version) tags = self._get_documents_tags(name)
return documents.describe(document_version=new_version, tags=tags)
def describe_document(self, name, document_version, version_name): def describe_document(self, name, document_version, version_name):
documents = self._get_documents(name) documents = self._get_documents(name)
return documents.describe(document_version, version_name) tags = self._get_documents_tags(name)
return documents.describe(document_version, version_name, tags=tags)
def list_documents( def list_documents(
self, document_filter_list, filters, max_results=10, next_token="0" self, document_filter_list, filters, max_results=10, next_token="0"
@ -917,7 +948,9 @@ class SimpleSystemManagerBackend(BaseBackend):
# If we have filters enabled, and we don't match them, # If we have filters enabled, and we don't match them,
continue continue
else: else:
results.append(self._generate_document_list_information(ssm_doc)) tags = self._get_documents_tags(ssm_doc.name)
doc_describe = ssm_doc.list_describe(tags=tags)
results.append(doc_describe)
# If we've fallen out of the loop, theres no more documents. No next token. # If we've fallen out of the loop, theres no more documents. No next token.
return results, "" return results, ""
@ -965,14 +998,12 @@ class SimpleSystemManagerBackend(BaseBackend):
"Member must satisfy regular expression pattern: (?i)all|[0-9]{12}]." "Member must satisfy regular expression pattern: (?i)all|[0-9]{12}]."
) )
accounts_to_add = set(account_ids_to_add) if "all" in account_ids_to_add and len(account_ids_to_add) > 1:
if "all" in accounts_to_add and len(accounts_to_add) > 1:
raise DocumentPermissionLimit( raise DocumentPermissionLimit(
"Accounts can either be all or a group of AWS accounts" "Accounts can either be all or a group of AWS accounts"
) )
accounts_to_remove = set(account_ids_to_remove) if "all" in account_ids_to_remove and len(account_ids_to_remove) > 1:
if "all" in accounts_to_remove and len(accounts_to_remove) > 1:
raise DocumentPermissionLimit( raise DocumentPermissionLimit(
"Accounts can either be all or a group of AWS accounts" "Accounts can either be all or a group of AWS accounts"
) )
@ -985,7 +1016,7 @@ class SimpleSystemManagerBackend(BaseBackend):
document = self._get_documents(name) document = self._get_documents(name)
document.modify_permissions( document.modify_permissions(
accounts_to_add, accounts_to_remove, shared_document_version account_ids_to_add, account_ids_to_remove, shared_document_version
) )
def delete_parameter(self, name): def delete_parameter(self, name):
@ -1566,7 +1597,7 @@ class SimpleSystemManagerBackend(BaseBackend):
if tags: if tags:
tags = {t["Key"]: t["Value"] for t in tags} tags = {t["Key"]: t["Value"] for t in tags}
self.add_tags_to_resource(name, "Parameter", tags) self.add_tags_to_resource("Parameter", name, tags)
return version return version

View File

@ -261,7 +261,7 @@ class SimpleSystemManagerResponse(BaseResponse):
for parameter in result[token:]: for parameter in result[token:]:
response["Parameters"].append(parameter.describe_response_object(False)) response["Parameters"].append(parameter.describe_response_object(False))
token = token + 1 token += 1
if len(response["Parameters"]) == page_size: if len(response["Parameters"]) == page_size:
response["NextToken"] = str(end) response["NextToken"] = str(end)
break break
@ -346,20 +346,26 @@ class SimpleSystemManagerResponse(BaseResponse):
resource_id = self._get_param("ResourceId") resource_id = self._get_param("ResourceId")
resource_type = self._get_param("ResourceType") resource_type = self._get_param("ResourceType")
tags = {t["Key"]: t["Value"] for t in self._get_param("Tags")} tags = {t["Key"]: t["Value"] for t in self._get_param("Tags")}
self.ssm_backend.add_tags_to_resource(resource_id, resource_type, tags) self.ssm_backend.add_tags_to_resource(
resource_type=resource_type, resource_id=resource_id, tags=tags
)
return json.dumps({}) return json.dumps({})
def remove_tags_from_resource(self): def remove_tags_from_resource(self):
resource_id = self._get_param("ResourceId") resource_id = self._get_param("ResourceId")
resource_type = self._get_param("ResourceType") resource_type = self._get_param("ResourceType")
keys = self._get_param("TagKeys") keys = self._get_param("TagKeys")
self.ssm_backend.remove_tags_from_resource(resource_id, resource_type, keys) self.ssm_backend.remove_tags_from_resource(
resource_type=resource_type, resource_id=resource_id, keys=keys
)
return json.dumps({}) return json.dumps({})
def list_tags_for_resource(self): def list_tags_for_resource(self):
resource_id = self._get_param("ResourceId") resource_id = self._get_param("ResourceId")
resource_type = self._get_param("ResourceType") resource_type = self._get_param("ResourceType")
tags = self.ssm_backend.list_tags_for_resource(resource_id, resource_type) tags = self.ssm_backend.list_tags_for_resource(
resource_type=resource_type, resource_id=resource_id
)
tag_list = [{"Key": k, "Value": v} for (k, v) in tags.items()] tag_list = [{"Key": k, "Value": v} for (k, v) in tags.items()]
response = {"TagList": tag_list} response = {"TagList": tag_list}
return json.dumps(response) return json.dumps(response)

View File

@ -5,3 +5,4 @@ TestAccAWSEc2TransitGatewayVpcAttachment
TestAccAWSFms TestAccAWSFms
TestAccAWSIAMRolePolicy TestAccAWSIAMRolePolicy
TestAccAWSSecurityGroup_forceRevokeRules_ TestAccAWSSecurityGroup_forceRevokeRules_
TestAccAWSSSMDocument_package

View File

@ -72,6 +72,23 @@ TestAccAWSRolePolicyAttachment
TestAccAWSSNSSMSPreferences TestAccAWSSNSSMSPreferences
TestAccAWSSageMakerPrebuiltECRImage TestAccAWSSageMakerPrebuiltECRImage
TestAccAWSSQSQueuePolicy TestAccAWSSQSQueuePolicy
TestAccAWSSSMDocument_basic
TestAccAWSSSMDocument_Name
TestAccAWSSSMDocument_target_type
TestAccAWSSSMDocument_VersionName
TestAccAWSSSMDocument_update
TestAccAWSSSMDocument_permission_public
TestAccAWSSSMDocument_permission_private
TestAccAWSSSMDocument_permission_batching
TestAccAWSSSMDocument_permission_change
TestAccAWSSSMDocument_params
TestAccAWSSSMDocument_automation
TestAccAWSSSMDocument_SchemaVersion_1
TestAccAWSSSMDocument_session
TestAccAWSSSMDocument_DocumentFormat_YAML
TestAccAWSSSMDocument_Tags
TestAccAWSSSMDocument_disappears
TestValidateSSMDocumentPermissions
TestAccAWSSsmParameterDataSource TestAccAWSSsmParameterDataSource
TestAccAWSUserGroupMembership TestAccAWSUserGroupMembership
TestAccAWSUserPolicyAttachment TestAccAWSUserPolicyAttachment

View File

@ -63,8 +63,11 @@ def test_modify_document_permission_add_account_id(ids):
res.should.have.key("AccountIds") res.should.have.key("AccountIds")
set(res["AccountIds"]).should.equal(set(ids)) set(res["AccountIds"]).should.equal(set(ids))
res.should.have.key("AccountSharingInfoList").length_of(len(ids)) res.should.have.key("AccountSharingInfoList").length_of(len(ids))
for entry in [{"AccountId": _id} for _id in ids]:
res["AccountSharingInfoList"].should.contain(entry) expected_account_sharing = [
{"AccountId": _id, "SharedDocumentVersion": "$DEFAULT"} for _id in ids
]
res.should.have.key("AccountSharingInfoList").equal(expected_account_sharing)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -96,9 +99,12 @@ def test_modify_document_permission_remove_account_id(initial, to_remove):
res.should.have.key("AccountIds") res.should.have.key("AccountIds")
expected_new_list = set([x for x in initial if x not in to_remove]) expected_new_list = set([x for x in initial if x not in to_remove])
set(res["AccountIds"]).should.equal(expected_new_list) set(res["AccountIds"]).should.equal(expected_new_list)
res.should.have.key("AccountSharingInfoList").equal(
[{"AccountId": _id} for _id in expected_new_list] expected_account_sharing = [
) {"AccountId": _id, "SharedDocumentVersion": "$DEFAULT"}
for _id in expected_new_list
]
res.should.have.key("AccountSharingInfoList").equal(expected_account_sharing)
@mock_ssm @mock_ssm