From 29b0122facdb4a27969c03d69039287cc716417d Mon Sep 17 00:00:00 2001 From: Gonzalo Saad Date: Wed, 25 Aug 2021 11:16:14 -0300 Subject: [PATCH] feat(ssm): Add ssm documents permissions (#4217) * refactor(ssm): Add Documents class to avoid dictionary handling This also solves the datetime format issue in TestAccAWSCloudWatchEventTarget_ssmDocument --- moto/ssm/exceptions.py | 16 ++ moto/ssm/models.py | 420 ++++++++++++++++++++---------- moto/ssm/responses.py | 29 +++ tests/terraform-tests.success.txt | 4 +- tests/test_ssm/test_ssm_docs.py | 5 +- 5 files changed, 333 insertions(+), 141 deletions(-) diff --git a/moto/ssm/exceptions.py b/moto/ssm/exceptions.py index 0d2fdee3b..be3071229 100644 --- a/moto/ssm/exceptions.py +++ b/moto/ssm/exceptions.py @@ -62,6 +62,22 @@ class DocumentAlreadyExists(JsonRESTError): super(DocumentAlreadyExists, self).__init__("DocumentAlreadyExists", message) +class DocumentPermissionLimit(JsonRESTError): + code = 400 + + def __init__(self, message): + super(DocumentPermissionLimit, self).__init__( + "DocumentPermissionLimit", message + ) + + +class InvalidPermissionType(JsonRESTError): + code = 400 + + def __init__(self, message): + super(InvalidPermissionType, self).__init__("InvalidPermissionType", message) + + class InvalidDocument(JsonRESTError): code = 400 diff --git a/moto/ssm/models.py b/moto/ssm/models.py index 775219ee1..59d4a1ec8 100644 --- a/moto/ssm/models.py +++ b/moto/ssm/models.py @@ -1,6 +1,8 @@ from __future__ import unicode_literals import re +from typing import Dict + from boto3 import Session from collections import defaultdict @@ -33,6 +35,8 @@ from .exceptions import ( DuplicateDocumentVersionName, DuplicateDocumentContent, ParameterMaxVersionLimitExceeded, + DocumentPermissionLimit, + InvalidPermissionType, ) @@ -144,6 +148,144 @@ def generate_ssm_doc_param_list(parameters): return param_list +class Documents(BaseModel): + def __init__(self, ssm_document): + version = ssm_document.document_version + self.versions = {version: ssm_document} + self.default_version = version + self.latest_version = version + self.permissions = {} # {AccountID: version } + + def get_default_version(self): + return self.versions.get(self.default_version) + + def get_latest_version(self): + return self.versions.get(self.latest_version) + + def find_by_version_name(self, version_name): + return next( + ( + document + for document in self.versions.values() + if document.version_name == version_name + ), + None, + ) + + def find_by_version(self, version): + return self.versions.get(version) + + def find_by_version_and_version_name(self, version, version_name): + return next( + ( + document + for doc_version, document in self.versions.items() + if doc_version == version and document.version_name == version_name + ), + None, + ) + + def find(self, document_version=None, version_name=None, strict=True): + + if document_version == "$LATEST": + ssm_document = self.get_latest_version() + elif version_name and document_version: + ssm_document = self.find_by_version_and_version_name( + document_version, version_name + ) + elif version_name: + ssm_document = self.find_by_version_name(version_name) + elif document_version: + ssm_document = self.find_by_version(document_version) + else: + ssm_document = self.get_default_version() + + if strict and not ssm_document: + raise InvalidDocument("The specified document does not exist.") + + return ssm_document + + def exists(self, document_version=None, version_name=None): + return self.find(document_version, version_name, strict=False) is not None + + def add_new_version(self, new_document_version): + version = new_document_version.document_version + self.latest_version = version + self.versions[version] = new_document_version + + def update_default_version(self, version): + ssm_document = self.find_by_version(version) + if not ssm_document: + raise InvalidDocument("The specified document does not exist.") + self.default_version = version + return ssm_document + + def delete(self, *versions): + for version in versions: + if version in self.versions: + del self.versions[version] + + if self.versions and self.latest_version not in self.versions: + ordered_versions = sorted(self.versions.keys()) + new_latest_version = ordered_versions[-1] + self.latest_version = new_latest_version + + def describe(self, document_version=None, version_name=None): + document = self.find(document_version, version_name) + base = { + "Hash": document.hash, + "HashType": "Sha256", + "Name": document.name, + "Owner": document.owner, + "CreatedDate": document.created_date.strftime("%Y-%m-%dT%H:%M:%SZ"), + "Status": document.status, + "DocumentVersion": document.document_version, + "Description": document.description, + "Parameters": document.parameter_list, + "PlatformTypes": document.platform_types, + "DocumentType": document.document_type, + "SchemaVersion": document.schema_version, + "LatestVersion": self.latest_version, + "DefaultVersion": self.default_version, + "DocumentFormat": document.document_format, + } + if document.version_name: + base["VersionName"] = document.version_name + if document.target_type: + base["TargetType"] = document.target_type + if document.tags: + base["Tags"] = document.tags + + return base + + def modify_permissions(self, accounts_to_add, accounts_to_remove, version): + if "all" in accounts_to_add: + self.permissions.clear() + else: + self.permissions.pop("all", None) + + new_permissions = {account_id: version for account_id in accounts_to_add} + self.permissions.update(**new_permissions) + + if "all" in accounts_to_remove: + self.permissions.clear() + else: + for account_id in accounts_to_remove: + self.permissions.pop(account_id, None) + + def describe_permissions(self): + return { + "AccountIds": list(self.permissions.keys()), + "AccountSharingInfoList": [ + {"AccountId": account_id, "SharedDocumentVersion": document_version} + for account_id, document_version in self.permissions.items() + ], + } + + def is_shared(self): + return len(self.permissions) > 0 + + class Document(BaseModel): def __init__( self, @@ -171,7 +313,7 @@ class Document(BaseModel): self.status = "Active" self.document_version = document_version self.owner = ACCOUNT_ID - self.created_date = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") + self.created_date = datetime.datetime.utcnow() if document_format == "JSON": try: @@ -220,6 +362,10 @@ class Document(BaseModel): except KeyError: raise InvalidDocumentContent("The content for the document is not valid.") + @property + def hash(self): + return hashlib.sha256(self.content.encode("utf-8")).hexdigest() + class Command(BaseModel): def __init__( @@ -482,7 +628,7 @@ class SimpleSystemManagerBackend(BaseBackend): self._resource_tags = defaultdict(lambda: defaultdict(dict)) self._commands = [] self._errors = [] - self._documents = defaultdict(dict) + self._documents: Dict[str, Documents] = {} self._region = region_name @@ -491,53 +637,17 @@ class SimpleSystemManagerBackend(BaseBackend): self.__dict__ = {} self.__init__(region_name) - def _generate_document_description(self, document): - - latest = self._documents[document.name]["latest_version"] - default_version = self._documents[document.name]["default_version"] - base = { - "Hash": hashlib.sha256(document.content.encode("utf-8")).hexdigest(), - "HashType": "Sha256", - "Name": document.name, - "Owner": document.owner, - "CreatedDate": document.created_date, - "Status": document.status, - "DocumentVersion": document.document_version, - "Description": document.description, - "Parameters": document.parameter_list, - "PlatformTypes": document.platform_types, - "DocumentType": document.document_type, - "SchemaVersion": document.schema_version, - "LatestVersion": latest, - "DefaultVersion": default_version, - "DocumentFormat": document.document_format, - } - if document.version_name: - base["VersionName"] = document.version_name - if document.target_type: - base["TargetType"] = document.target_type - if document.tags: - base["Tags"] = document.tags - - return base - def _generate_document_information(self, ssm_document, document_format): + content = self._get_document_content(document_format, ssm_document) base = { "Name": ssm_document.name, "DocumentVersion": ssm_document.document_version, "Status": ssm_document.status, - "Content": ssm_document.content, + "Content": content, "DocumentType": ssm_document.document_type, "DocumentFormat": document_format, } - if document_format == "JSON": - base["Content"] = json.dumps(ssm_document.content_json) - elif document_format == "YAML": - base["Content"] = yaml.dump(ssm_document.content_json) - else: - raise ValidationException("Invalid document format " + str(document_format)) - if ssm_document.version_name: base["VersionName"] = ssm_document.version_name if ssm_document.requires: @@ -547,7 +657,20 @@ class SimpleSystemManagerBackend(BaseBackend): return base - def _generate_document_list_information(self, ssm_document): + @staticmethod + def _get_document_content(document_format, ssm_document): + if document_format == ssm_document.document_format: + content = ssm_document.content + elif document_format == "JSON": + content = json.dumps(ssm_document.content_json) + elif document_format == "YAML": + content = yaml.dump(ssm_document.content_json) + else: + raise ValidationException("Invalid document format " + str(document_format)) + return content + + @staticmethod + def _generate_document_list_information(ssm_document): base = { "Name": ssm_document.name, "Owner": ssm_document.owner, @@ -569,6 +692,12 @@ class SimpleSystemManagerBackend(BaseBackend): return base + def _get_documents(self, name): + documents = self._documents.get(name) + if not documents: + raise InvalidDocument("The specified document does not exist.") + return documents + def create_document( self, content, @@ -603,24 +732,23 @@ class SimpleSystemManagerBackend(BaseBackend): if self._documents.get(ssm_document.name): raise DocumentAlreadyExists("The specified document already exists.") - self._documents[ssm_document.name] = { - "documents": {ssm_document.document_version: ssm_document}, - "default_version": ssm_document.document_version, - "latest_version": ssm_document.document_version, - } + documents = Documents(ssm_document) + self._documents[ssm_document.name] = documents - return self._generate_document_description(ssm_document) + return documents.describe() def delete_document(self, name, document_version, version_name, force): - documents = self._documents.get(name, {}).get("documents", {}) + documents = self._get_documents(name) + + if documents.is_shared(): + raise InvalidDocumentOperation("Must unshare document first before delete") + keys_to_delete = set() if documents: - default_version = self._documents[name]["default_version"] - + default_doc = documents.get_default_version() if ( - documents[default_version].document_type - == "ApplicationConfigurationSchema" + default_doc.document_type == "ApplicationConfigurationSchema" and not force ): raise InvalidDocumentOperation( @@ -628,20 +756,20 @@ class SimpleSystemManagerBackend(BaseBackend): "You must stop sharing the document before you can delete it." ) - if document_version and document_version == default_version: + if document_version and document_version == default_doc.document_version: raise InvalidDocumentOperation( "Default version of the document can't be deleted." ) if document_version or version_name: # We delete only a specific version - delete_doc = self._find_document(name, document_version, version_name) + delete_doc = documents.find(document_version, version_name) # we can't delete only the default version if ( delete_doc - and delete_doc.document_version == default_version - and len(documents) != 1 + and delete_doc.document_version == default_doc.document_version + and len(documents.versions) != 1 ): raise InvalidDocumentOperation( "Default version of the document can't be deleted." @@ -653,64 +781,18 @@ class SimpleSystemManagerBackend(BaseBackend): raise InvalidDocument("The specified document does not exist.") else: # We are deleting all versions - keys_to_delete = set(documents.keys()) + keys_to_delete = set(documents.versions.keys()) - for key in keys_to_delete: - del self._documents[name]["documents"][key] + documents.delete(*keys_to_delete) - if len(self._documents[name]["documents"].keys()) == 0: + if len(documents.versions) == 0: del self._documents[name] - else: - old_latest = self._documents[name]["latest_version"] - if old_latest not in self._documents[name]["documents"].keys(): - leftover_keys = self._documents[name]["documents"].keys() - int_keys = [] - for key in leftover_keys: - int_keys.append(int(key)) - self._documents[name]["latest_version"] = str(sorted(int_keys)[-1]) - else: - raise InvalidDocument("The specified document does not exist.") - - def _find_document( - self, name, document_version=None, version_name=None, strict=True - ): - if not self._documents.get(name): - raise InvalidDocument("The specified document does not exist.") - - documents = self._documents[name]["documents"] - ssm_document = None - - if not version_name and not document_version: - # Retrieve default version - default_version = self._documents[name]["default_version"] - ssm_document = documents.get(default_version) - - elif version_name and document_version: - for doc_version, document in documents.items(): - if ( - doc_version == document_version - and document.version_name == version_name - ): - ssm_document = document - break - - else: - for doc_version, document in documents.items(): - if document_version and doc_version == document_version: - ssm_document = document - break - if version_name and document.version_name == version_name: - ssm_document = document - break - - if strict and not ssm_document: - raise InvalidDocument("The specified document does not exist.") - - return ssm_document def get_document(self, name, document_version, version_name, document_format): - ssm_document = self._find_document(name, document_version, version_name) + documents = self._get_documents(name) + ssm_document = documents.find(document_version, version_name) + if not document_format: document_format = ssm_document.document_format else: @@ -719,18 +801,18 @@ class SimpleSystemManagerBackend(BaseBackend): return self._generate_document_information(ssm_document, document_format) def update_document_default_version(self, name, document_version): + documents = self._get_documents(name) + ssm_document = documents.update_default_version(document_version) - ssm_document = self._find_document(name, document_version=document_version) - self._documents[name]["default_version"] = document_version - base = { + result = { "Name": ssm_document.name, "DefaultVersion": document_version, } if ssm_document.version_name: - base["DefaultVersionName"] = ssm_document.version_name + result["DefaultVersionName"] = ssm_document.version_name - return base + return result def update_document( self, @@ -750,24 +832,27 @@ class SimpleSystemManagerBackend(BaseBackend): strict=False, ) - if not self._documents.get(name): + documents = self._documents.get(name) + if not documents: raise InvalidDocument("The specified document does not exist.") + if ( - self._documents[name]["latest_version"] != document_version + documents.latest_version != document_version and document_version != "$LATEST" ): raise InvalidDocumentVersion( "The document version is not valid or does not exist." ) - if version_name and self._find_document( - name, version_name=version_name, strict=False - ): - raise DuplicateDocumentVersionName( - "The specified version name is a duplicate." - ) - old_ssm_document = self._find_document(name) + if version_name: + if documents.exists(version_name=version_name): + raise DuplicateDocumentVersionName( + "The specified version name is a duplicate." + ) + old_ssm_document = documents.get_default_version() + + new_version = str(int(documents.latest_version) + 1) new_ssm_document = Document( name=name, version_name=version_name, @@ -778,28 +863,22 @@ class SimpleSystemManagerBackend(BaseBackend): attachments=attachments, target_type=target_type, tags=old_ssm_document.tags, - document_version=str(int(self._documents[name]["latest_version"]) + 1), + document_version=new_version, ) - for doc_version, document in self._documents[name]["documents"].items(): + for doc_version, document in documents.versions.items(): if document.content == new_ssm_document.content: raise DuplicateDocumentContent( "The content of the association document matches another document. " "Change the content of the document and try again." ) - self._documents[name]["latest_version"] = str( - int(self._documents[name]["latest_version"]) + 1 - ) - self._documents[name]["documents"][ - new_ssm_document.document_version - ] = new_ssm_document - - return self._generate_document_description(new_ssm_document) + documents.add_new_version(new_ssm_document) + return documents.describe(document_version=new_version) def describe_document(self, name, document_version, version_name): - ssm_document = self._find_document(name, document_version, version_name) - return self._generate_document_description(ssm_document) + documents = self._get_documents(name) + return documents.describe(document_version, version_name) def list_documents( self, document_filter_list, filters, max_results=10, next_token="0" @@ -813,17 +892,16 @@ class SimpleSystemManagerBackend(BaseBackend): results = [] dummy_token_tracker = 0 # Sort to maintain next token adjacency - for document_name, document_bundle in sorted(self._documents.items()): + for document_name, documents in sorted(self._documents.items()): if len(results) == max_results: # There's still more to go so we need a next token return results, str(next_token + len(results)) if dummy_token_tracker < next_token: - dummy_token_tracker = dummy_token_tracker + 1 + dummy_token_tracker += 1 continue - default_version = document_bundle["default_version"] - ssm_doc = self._documents[document_name]["documents"][default_version] + ssm_doc = documents.get_default_version() if filters and not _document_filter_match(filters, ssm_doc): # If we have filters enabled, and we don't match them, continue @@ -833,6 +911,72 @@ class SimpleSystemManagerBackend(BaseBackend): # If we've fallen out of the loop, theres no more documents. No next token. return results, "" + def describe_document_permission( + self, name, max_results=None, permission_type=None, next_token=None + ): + # Parameters max_results, permission_type, and next_token not used because + # this current implementation doesn't support pagination. + document = self._get_documents(name) + return document.describe_permissions() + + def modify_document_permission( + self, + name, + account_ids_to_add, + account_ids_to_remove, + shared_document_version, + permission_type, + ): + + account_id_regex = re.compile(r"(all|[0-9]{12})") + version_regex = re.compile(r"^([$]LATEST|[$]DEFAULT|[$]ALL)$") + + account_ids_to_add = account_ids_to_add or [] + account_ids_to_remove = account_ids_to_remove or [] + + if not version_regex.match(shared_document_version): + raise ValidationException( + f"Value '{shared_document_version}' at 'sharedDocumentVersion' failed to satisfy constraint: " + f"Member must satisfy regular expression pattern: ([$]LATEST|[$]DEFAULT|[$]ALL)." + ) + + for account_id in account_ids_to_add: + if not account_id_regex.match(account_id): + raise ValidationException( + f"Value '[{account_id}]' at 'accountIdsToAdd' failed to satisfy constraint: " + "Member must satisfy regular expression pattern: (all|[0-9]{12}])." + ) + + for account_id in account_ids_to_remove: + if not account_id_regex.match(account_id): + raise ValidationException( + f"Value '[{account_id}]' at 'accountIdsToRemove' failed to satisfy constraint: " + "Member must satisfy regular expression pattern: (?i)all|[0-9]{12}]." + ) + + accounts_to_add = set(account_ids_to_add) + if "all" in accounts_to_add and len(accounts_to_add) > 1: + raise DocumentPermissionLimit( + "Accounts can either be all or a group of AWS accounts" + ) + + accounts_to_remove = set(account_ids_to_remove) + if "all" in accounts_to_remove and len(accounts_to_remove) > 1: + raise DocumentPermissionLimit( + "Accounts can either be all or a group of AWS accounts" + ) + + if permission_type != "Share": + raise InvalidPermissionType( + f"Value '{permission_type}' at 'permissionType' failed to satisfy constraint: " + "Member must satisfy enum value set: [Share]." + ) + + document = self._get_documents(name) + document.modify_permissions( + accounts_to_add, accounts_to_remove, shared_document_version + ) + def delete_parameter(self, name): return self._parameters.pop(name, None) diff --git a/moto/ssm/responses.py b/moto/ssm/responses.py index d99140c3a..04b512cd3 100644 --- a/moto/ssm/responses.py +++ b/moto/ssm/responses.py @@ -127,6 +127,35 @@ class SimpleSystemManagerResponse(BaseResponse): return json.dumps({"DocumentIdentifiers": documents, "NextToken": token}) + def describe_document_permission(self): + name = self._get_param("Name") + max_results = self._get_param("MaxResults") + next_token = self._get_param("NextToken") + permission_type = self._get_param("PermissionType") + + result = self.ssm_backend.describe_document_permission( + name=name, + max_results=max_results, + next_token=next_token, + permission_type=permission_type, + ) + return json.dumps(result) + + def modify_document_permission(self): + account_ids_to_add = self._get_param("AccountIdsToAdd") + account_ids_to_remove = self._get_param("AccountIdsToRemove") + name = self._get_param("Name") + permission_type = self._get_param("PermissionType") + shared_document_version = self._get_param("SharedDocumentVersion") + + self.ssm_backend.modify_document_permission( + name=name, + account_ids_to_add=account_ids_to_add, + account_ids_to_remove=account_ids_to_remove, + shared_document_version=shared_document_version, + permission_type=permission_type, + ) + def _get_param(self, param, default=None): return self.request_params.get(param, default) diff --git a/tests/terraform-tests.success.txt b/tests/terraform-tests.success.txt index 94eb79247..682266fda 100644 --- a/tests/terraform-tests.success.txt +++ b/tests/terraform-tests.success.txt @@ -11,6 +11,7 @@ TestAccAWSCloudwatchEventBusPolicy TestAccAWSCloudWatchEventConnection TestAccAWSCloudWatchEventPermission TestAccAWSCloudWatchEventRule +TestAccAWSCloudWatchEventTarget_ssmDocument TestAccAWSCloudwatchLogGroupDataSource TestAccAWSCloudWatchMetricAlarm TestAccAWSDataSourceCloudwatch @@ -87,5 +88,6 @@ TestAccAWSRouteTable_IPv4_To_TransitGateway TestAccAWSRouteTable_IPv4_To_VpcPeeringConnection TestAccAWSRouteTable_disappears TestAccAWSRouteTable_basic +TestAccAWSSsmDocumentDataSource TestAccAwsEc2ManagedPrefixList -TestAccAWSEgressOnlyInternetGateway \ No newline at end of file +TestAccAWSEgressOnlyInternetGateway diff --git a/tests/test_ssm/test_ssm_docs.py b/tests/test_ssm/test_ssm_docs.py index aa0118820..fd2237db2 100644 --- a/tests/test_ssm/test_ssm_docs.py +++ b/tests/test_ssm/test_ssm_docs.py @@ -4,6 +4,7 @@ import boto3 import botocore.exceptions import sure # noqa import datetime +from datetime import timezone import json import yaml import hashlib @@ -42,7 +43,7 @@ def _validate_document_description( doc_description["Name"].should.equal(doc_name) doc_description["Owner"].should.equal(ACCOUNT_ID) - difference = datetime.datetime.utcnow() - doc_description["CreatedDate"] + difference = datetime.datetime.now(tz=timezone.utc) - doc_description["CreatedDate"] if difference.min > datetime.timedelta(minutes=1): assert False @@ -224,7 +225,7 @@ def test_create_document(): doc_description["Name"].should.equal("EmptyParamDoc") doc_description["Owner"].should.equal(ACCOUNT_ID) - difference = datetime.datetime.utcnow() - doc_description["CreatedDate"] + difference = datetime.datetime.now(tz=timezone.utc) - doc_description["CreatedDate"] if difference.min > datetime.timedelta(minutes=1): assert False