refactor(ssm): Refactor document permisisons and conditions (#4243)

This commit is contained in:
Gonzalo Saad 2021-09-08 02:56:20 -03:00 committed by GitHub
parent b3795d312a
commit e6c6ce5942
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 148 additions and 87 deletions

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
import re
from dataclasses import dataclass
from typing import Dict
from boto3 import Session
@ -131,33 +132,43 @@ def generate_ssm_doc_param_list(parameters):
return None
param_list = []
for param_name, param_info in parameters.items():
final_dict = {}
final_dict = {
"Name": param_name,
}
final_dict["Name"] = param_name
final_dict["Type"] = param_info["type"]
final_dict["Description"] = param_info["description"]
description = param_info.get("description")
if description:
final_dict["Description"] = description
if (
param_info["type"] == "StringList"
or param_info["type"] == "StringMap"
or param_info["type"] == "MapList"
):
final_dict["DefaultValue"] = json.dumps(param_info["default"])
else:
final_dict["DefaultValue"] = str(param_info["default"])
param_type = param_info["type"]
final_dict["Type"] = param_type
default_value = param_info.get("default")
if default_value is not None:
if param_type in {"StringList", "StringMap", "MapList"}:
final_dict["DefaultValue"] = json.dumps(default_value)
else:
final_dict["DefaultValue"] = str(default_value)
param_list.append(final_dict)
return param_list
@dataclass(frozen=True)
class AccountPermission:
account_id: str
version: str
created_at: datetime
class Documents(BaseModel):
def __init__(self, ssm_document):
version = ssm_document.document_version
self.versions = {version: ssm_document}
self.default_version = version
self.latest_version = version
self.permissions = {} # {AccountID: version }
self.permissions = {} # {AccountID: AccountPermission }
def get_default_version(self):
return self.versions.get(self.default_version)
@ -233,7 +244,7 @@ class Documents(BaseModel):
new_latest_version = ordered_versions[-1]
self.latest_version = new_latest_version
def describe(self, document_version=None, version_name=None):
def describe(self, document_version=None, version_name=None, tags=None):
document = self.find(document_version, version_name)
base = {
"Hash": document.hash,
@ -256,32 +267,45 @@ class Documents(BaseModel):
base["VersionName"] = document.version_name
if document.target_type:
base["TargetType"] = document.target_type
if document.tags:
base["Tags"] = document.tags
if tags:
base["Tags"] = tags
return base
def modify_permissions(self, accounts_to_add, accounts_to_remove, version):
if "all" in accounts_to_add:
self.permissions.clear()
else:
self.permissions.pop("all", None)
version = version or "$DEFAULT"
if accounts_to_add:
if "all" in accounts_to_add:
self.permissions.clear()
else:
self.permissions.pop("all", None)
new_permissions = {account_id: version for account_id in accounts_to_add}
self.permissions.update(**new_permissions)
new_permissions = {
account_id: AccountPermission(
account_id, version, datetime.datetime.now()
)
for account_id in accounts_to_add
}
self.permissions.update(**new_permissions)
if "all" in accounts_to_remove:
self.permissions.clear()
else:
for account_id in accounts_to_remove:
self.permissions.pop(account_id, None)
if accounts_to_remove:
if "all" in accounts_to_remove:
self.permissions.clear()
else:
for account_id in accounts_to_remove:
self.permissions.pop(account_id, None)
def describe_permissions(self):
permissions_ordered_by_date = sorted(
self.permissions.values(), key=lambda p: p.created_at
)
return {
"AccountIds": list(self.permissions.keys()),
"AccountIds": [p.account_id for p in permissions_ordered_by_date],
"AccountSharingInfoList": [
{"AccountId": account_id, "SharedDocumentVersion": document_version}
for account_id, document_version in self.permissions.items()
{"AccountId": p.account_id, "SharedDocumentVersion": p.version}
for p in permissions_ordered_by_date
],
}
@ -300,7 +324,6 @@ class Document(BaseModel):
requires,
attachments,
target_type,
tags,
document_version="1",
):
self.name = name
@ -311,7 +334,6 @@ class Document(BaseModel):
self.requires = requires
self.attachments = attachments
self.target_type = target_type
self.tags = tags
self.status = "Active"
self.document_version = document_version
@ -353,12 +375,8 @@ class Document(BaseModel):
content_json.get("parameters")
)
if (
self.schema_version == "0.3"
or self.schema_version == "2.0"
or self.schema_version == "2.2"
):
self.mainSteps = content_json["mainSteps"]
if self.schema_version in {"0.3", "2.0", "2.2"}:
self.mainSteps = content_json.get("mainSteps")
elif self.schema_version == "1.2":
self.runtimeConfig = content_json.get("runtimeConfig")
@ -369,6 +387,28 @@ class Document(BaseModel):
def hash(self):
return hashlib.sha256(self.content.encode("utf-8")).hexdigest()
def list_describe(self, tags=None):
base = {
"Name": self.name,
"Owner": self.owner,
"DocumentVersion": self.document_version,
"DocumentType": self.document_type,
"SchemaVersion": self.schema_version,
"DocumentFormat": self.document_format,
}
if self.version_name:
base["VersionName"] = self.version_name
if self.platform_types:
base["PlatformTypes"] = self.platform_types
if self.target_type:
base["TargetType"] = self.target_type
if self.requires:
base["Requires"] = self.requires
if tags:
base["Tags"] = tags
return base
class Command(BaseModel):
def __init__(
@ -680,35 +720,21 @@ class SimpleSystemManagerBackend(BaseBackend):
raise ValidationException("Invalid document format " + str(document_format))
return content
@staticmethod
def _generate_document_list_information(ssm_document):
base = {
"Name": ssm_document.name,
"Owner": ssm_document.owner,
"DocumentVersion": ssm_document.document_version,
"DocumentType": ssm_document.document_type,
"SchemaVersion": ssm_document.schema_version,
"DocumentFormat": ssm_document.document_format,
}
if ssm_document.version_name:
base["VersionName"] = ssm_document.version_name
if ssm_document.platform_types:
base["PlatformTypes"] = ssm_document.platform_types
if ssm_document.target_type:
base["TargetType"] = ssm_document.target_type
if ssm_document.tags:
base["Tags"] = ssm_document.tags
if ssm_document.requires:
base["Requires"] = ssm_document.requires
return base
def _get_documents(self, name):
documents = self._documents.get(name)
if not documents:
raise InvalidDocument("The specified document does not exist.")
return documents
def _get_documents_tags(self, name):
docs_tags = self._resource_tags.get("Document")
if docs_tags:
document_tags = docs_tags.get(name, {})
return [
{"Key": tag, "Value": value} for tag, value in document_tags.items()
]
return []
def create_document(
self,
content,
@ -730,7 +756,6 @@ class SimpleSystemManagerBackend(BaseBackend):
requires=requires,
attachments=attachments,
target_type=target_type,
tags=tags,
)
_validate_document_info(
@ -746,7 +771,11 @@ class SimpleSystemManagerBackend(BaseBackend):
documents = Documents(ssm_document)
self._documents[ssm_document.name] = documents
return documents.describe()
if tags:
document_tags = {t["Key"]: t["Value"] for t in tags}
self.add_tags_to_resource("Document", name, document_tags)
return documents.describe(tags=tags)
def delete_document(self, name, document_version, version_name, force):
documents = self._get_documents(name)
@ -873,23 +902,25 @@ class SimpleSystemManagerBackend(BaseBackend):
requires=old_ssm_document.requires,
attachments=attachments,
target_type=target_type,
tags=old_ssm_document.tags,
document_version=new_version,
)
for doc_version, document in documents.versions.items():
if document.content == new_ssm_document.content:
raise DuplicateDocumentContent(
"The content of the association document matches another document. "
"Change the content of the document and try again."
)
if not target_type or target_type == document.target_type:
raise DuplicateDocumentContent(
"The content of the association document matches another document. "
"Change the content of the document and try again."
)
documents.add_new_version(new_ssm_document)
return documents.describe(document_version=new_version)
tags = self._get_documents_tags(name)
return documents.describe(document_version=new_version, tags=tags)
def describe_document(self, name, document_version, version_name):
documents = self._get_documents(name)
return documents.describe(document_version, version_name)
tags = self._get_documents_tags(name)
return documents.describe(document_version, version_name, tags=tags)
def list_documents(
self, document_filter_list, filters, max_results=10, next_token="0"
@ -917,7 +948,9 @@ class SimpleSystemManagerBackend(BaseBackend):
# If we have filters enabled, and we don't match them,
continue
else:
results.append(self._generate_document_list_information(ssm_doc))
tags = self._get_documents_tags(ssm_doc.name)
doc_describe = ssm_doc.list_describe(tags=tags)
results.append(doc_describe)
# If we've fallen out of the loop, theres no more documents. No next token.
return results, ""
@ -965,14 +998,12 @@ class SimpleSystemManagerBackend(BaseBackend):
"Member must satisfy regular expression pattern: (?i)all|[0-9]{12}]."
)
accounts_to_add = set(account_ids_to_add)
if "all" in accounts_to_add and len(accounts_to_add) > 1:
if "all" in account_ids_to_add and len(account_ids_to_add) > 1:
raise DocumentPermissionLimit(
"Accounts can either be all or a group of AWS accounts"
)
accounts_to_remove = set(account_ids_to_remove)
if "all" in accounts_to_remove and len(accounts_to_remove) > 1:
if "all" in account_ids_to_remove and len(account_ids_to_remove) > 1:
raise DocumentPermissionLimit(
"Accounts can either be all or a group of AWS accounts"
)
@ -985,7 +1016,7 @@ class SimpleSystemManagerBackend(BaseBackend):
document = self._get_documents(name)
document.modify_permissions(
accounts_to_add, accounts_to_remove, shared_document_version
account_ids_to_add, account_ids_to_remove, shared_document_version
)
def delete_parameter(self, name):
@ -1566,7 +1597,7 @@ class SimpleSystemManagerBackend(BaseBackend):
if tags:
tags = {t["Key"]: t["Value"] for t in tags}
self.add_tags_to_resource(name, "Parameter", tags)
self.add_tags_to_resource("Parameter", name, tags)
return version

View File

@ -261,7 +261,7 @@ class SimpleSystemManagerResponse(BaseResponse):
for parameter in result[token:]:
response["Parameters"].append(parameter.describe_response_object(False))
token = token + 1
token += 1
if len(response["Parameters"]) == page_size:
response["NextToken"] = str(end)
break
@ -346,20 +346,26 @@ class SimpleSystemManagerResponse(BaseResponse):
resource_id = self._get_param("ResourceId")
resource_type = self._get_param("ResourceType")
tags = {t["Key"]: t["Value"] for t in self._get_param("Tags")}
self.ssm_backend.add_tags_to_resource(resource_id, resource_type, tags)
self.ssm_backend.add_tags_to_resource(
resource_type=resource_type, resource_id=resource_id, tags=tags
)
return json.dumps({})
def remove_tags_from_resource(self):
resource_id = self._get_param("ResourceId")
resource_type = self._get_param("ResourceType")
keys = self._get_param("TagKeys")
self.ssm_backend.remove_tags_from_resource(resource_id, resource_type, keys)
self.ssm_backend.remove_tags_from_resource(
resource_type=resource_type, resource_id=resource_id, keys=keys
)
return json.dumps({})
def list_tags_for_resource(self):
resource_id = self._get_param("ResourceId")
resource_type = self._get_param("ResourceType")
tags = self.ssm_backend.list_tags_for_resource(resource_id, resource_type)
tags = self.ssm_backend.list_tags_for_resource(
resource_type=resource_type, resource_id=resource_id
)
tag_list = [{"Key": k, "Value": v} for (k, v) in tags.items()]
response = {"TagList": tag_list}
return json.dumps(response)

View File

@ -4,4 +4,5 @@ TestAccAWSEc2TransitGatewayRouteTableAssociation
TestAccAWSEc2TransitGatewayVpcAttachment
TestAccAWSFms
TestAccAWSIAMRolePolicy
TestAccAWSSecurityGroup_forceRevokeRules_
TestAccAWSSecurityGroup_forceRevokeRules_
TestAccAWSSSMDocument_package

View File

@ -72,6 +72,23 @@ TestAccAWSRolePolicyAttachment
TestAccAWSSNSSMSPreferences
TestAccAWSSageMakerPrebuiltECRImage
TestAccAWSSQSQueuePolicy
TestAccAWSSSMDocument_basic
TestAccAWSSSMDocument_Name
TestAccAWSSSMDocument_target_type
TestAccAWSSSMDocument_VersionName
TestAccAWSSSMDocument_update
TestAccAWSSSMDocument_permission_public
TestAccAWSSSMDocument_permission_private
TestAccAWSSSMDocument_permission_batching
TestAccAWSSSMDocument_permission_change
TestAccAWSSSMDocument_params
TestAccAWSSSMDocument_automation
TestAccAWSSSMDocument_SchemaVersion_1
TestAccAWSSSMDocument_session
TestAccAWSSSMDocument_DocumentFormat_YAML
TestAccAWSSSMDocument_Tags
TestAccAWSSSMDocument_disappears
TestValidateSSMDocumentPermissions
TestAccAWSSsmParameterDataSource
TestAccAWSUserGroupMembership
TestAccAWSUserPolicyAttachment

View File

@ -63,8 +63,11 @@ def test_modify_document_permission_add_account_id(ids):
res.should.have.key("AccountIds")
set(res["AccountIds"]).should.equal(set(ids))
res.should.have.key("AccountSharingInfoList").length_of(len(ids))
for entry in [{"AccountId": _id} for _id in ids]:
res["AccountSharingInfoList"].should.contain(entry)
expected_account_sharing = [
{"AccountId": _id, "SharedDocumentVersion": "$DEFAULT"} for _id in ids
]
res.should.have.key("AccountSharingInfoList").equal(expected_account_sharing)
@pytest.mark.parametrize(
@ -96,9 +99,12 @@ def test_modify_document_permission_remove_account_id(initial, to_remove):
res.should.have.key("AccountIds")
expected_new_list = set([x for x in initial if x not in to_remove])
set(res["AccountIds"]).should.equal(expected_new_list)
res.should.have.key("AccountSharingInfoList").equal(
[{"AccountId": _id} for _id in expected_new_list]
)
expected_account_sharing = [
{"AccountId": _id, "SharedDocumentVersion": "$DEFAULT"}
for _id in expected_new_list
]
res.should.have.key("AccountSharingInfoList").equal(expected_account_sharing)
@mock_ssm