feat(ssm): Add ssm documents permissions (#4217)
* refactor(ssm): Add Documents class to avoid dictionary handling This also solves the datetime format issue in TestAccAWSCloudWatchEventTarget_ssmDocument
This commit is contained in:
		
							parent
							
								
									f038859a37
								
							
						
					
					
						commit
						29b0122fac
					
				| @ -62,6 +62,22 @@ class DocumentAlreadyExists(JsonRESTError): | |||||||
|         super(DocumentAlreadyExists, self).__init__("DocumentAlreadyExists", message) |         super(DocumentAlreadyExists, self).__init__("DocumentAlreadyExists", message) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class DocumentPermissionLimit(JsonRESTError): | ||||||
|  |     code = 400 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, message): | ||||||
|  |         super(DocumentPermissionLimit, self).__init__( | ||||||
|  |             "DocumentPermissionLimit", message | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class InvalidPermissionType(JsonRESTError): | ||||||
|  |     code = 400 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, message): | ||||||
|  |         super(InvalidPermissionType, self).__init__("InvalidPermissionType", message) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class InvalidDocument(JsonRESTError): | class InvalidDocument(JsonRESTError): | ||||||
|     code = 400 |     code = 400 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,6 +1,8 @@ | |||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals | ||||||
| 
 | 
 | ||||||
| import re | import re | ||||||
|  | from typing import Dict | ||||||
|  | 
 | ||||||
| from boto3 import Session | from boto3 import Session | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| 
 | 
 | ||||||
| @ -33,6 +35,8 @@ from .exceptions import ( | |||||||
|     DuplicateDocumentVersionName, |     DuplicateDocumentVersionName, | ||||||
|     DuplicateDocumentContent, |     DuplicateDocumentContent, | ||||||
|     ParameterMaxVersionLimitExceeded, |     ParameterMaxVersionLimitExceeded, | ||||||
|  |     DocumentPermissionLimit, | ||||||
|  |     InvalidPermissionType, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -144,6 +148,144 @@ def generate_ssm_doc_param_list(parameters): | |||||||
|     return param_list |     return param_list | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class Documents(BaseModel): | ||||||
|  |     def __init__(self, ssm_document): | ||||||
|  |         version = ssm_document.document_version | ||||||
|  |         self.versions = {version: ssm_document} | ||||||
|  |         self.default_version = version | ||||||
|  |         self.latest_version = version | ||||||
|  |         self.permissions = {}  # {AccountID: version } | ||||||
|  | 
 | ||||||
|  |     def get_default_version(self): | ||||||
|  |         return self.versions.get(self.default_version) | ||||||
|  | 
 | ||||||
|  |     def get_latest_version(self): | ||||||
|  |         return self.versions.get(self.latest_version) | ||||||
|  | 
 | ||||||
|  |     def find_by_version_name(self, version_name): | ||||||
|  |         return next( | ||||||
|  |             ( | ||||||
|  |                 document | ||||||
|  |                 for document in self.versions.values() | ||||||
|  |                 if document.version_name == version_name | ||||||
|  |             ), | ||||||
|  |             None, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def find_by_version(self, version): | ||||||
|  |         return self.versions.get(version) | ||||||
|  | 
 | ||||||
|  |     def find_by_version_and_version_name(self, version, version_name): | ||||||
|  |         return next( | ||||||
|  |             ( | ||||||
|  |                 document | ||||||
|  |                 for doc_version, document in self.versions.items() | ||||||
|  |                 if doc_version == version and document.version_name == version_name | ||||||
|  |             ), | ||||||
|  |             None, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def find(self, document_version=None, version_name=None, strict=True): | ||||||
|  | 
 | ||||||
|  |         if document_version == "$LATEST": | ||||||
|  |             ssm_document = self.get_latest_version() | ||||||
|  |         elif version_name and document_version: | ||||||
|  |             ssm_document = self.find_by_version_and_version_name( | ||||||
|  |                 document_version, version_name | ||||||
|  |             ) | ||||||
|  |         elif version_name: | ||||||
|  |             ssm_document = self.find_by_version_name(version_name) | ||||||
|  |         elif document_version: | ||||||
|  |             ssm_document = self.find_by_version(document_version) | ||||||
|  |         else: | ||||||
|  |             ssm_document = self.get_default_version() | ||||||
|  | 
 | ||||||
|  |         if strict and not ssm_document: | ||||||
|  |             raise InvalidDocument("The specified document does not exist.") | ||||||
|  | 
 | ||||||
|  |         return ssm_document | ||||||
|  | 
 | ||||||
|  |     def exists(self, document_version=None, version_name=None): | ||||||
|  |         return self.find(document_version, version_name, strict=False) is not None | ||||||
|  | 
 | ||||||
|  |     def add_new_version(self, new_document_version): | ||||||
|  |         version = new_document_version.document_version | ||||||
|  |         self.latest_version = version | ||||||
|  |         self.versions[version] = new_document_version | ||||||
|  | 
 | ||||||
|  |     def update_default_version(self, version): | ||||||
|  |         ssm_document = self.find_by_version(version) | ||||||
|  |         if not ssm_document: | ||||||
|  |             raise InvalidDocument("The specified document does not exist.") | ||||||
|  |         self.default_version = version | ||||||
|  |         return ssm_document | ||||||
|  | 
 | ||||||
|  |     def delete(self, *versions): | ||||||
|  |         for version in versions: | ||||||
|  |             if version in self.versions: | ||||||
|  |                 del self.versions[version] | ||||||
|  | 
 | ||||||
|  |         if self.versions and self.latest_version not in self.versions: | ||||||
|  |             ordered_versions = sorted(self.versions.keys()) | ||||||
|  |             new_latest_version = ordered_versions[-1] | ||||||
|  |             self.latest_version = new_latest_version | ||||||
|  | 
 | ||||||
|  |     def describe(self, document_version=None, version_name=None): | ||||||
|  |         document = self.find(document_version, version_name) | ||||||
|  |         base = { | ||||||
|  |             "Hash": document.hash, | ||||||
|  |             "HashType": "Sha256", | ||||||
|  |             "Name": document.name, | ||||||
|  |             "Owner": document.owner, | ||||||
|  |             "CreatedDate": document.created_date.strftime("%Y-%m-%dT%H:%M:%SZ"), | ||||||
|  |             "Status": document.status, | ||||||
|  |             "DocumentVersion": document.document_version, | ||||||
|  |             "Description": document.description, | ||||||
|  |             "Parameters": document.parameter_list, | ||||||
|  |             "PlatformTypes": document.platform_types, | ||||||
|  |             "DocumentType": document.document_type, | ||||||
|  |             "SchemaVersion": document.schema_version, | ||||||
|  |             "LatestVersion": self.latest_version, | ||||||
|  |             "DefaultVersion": self.default_version, | ||||||
|  |             "DocumentFormat": document.document_format, | ||||||
|  |         } | ||||||
|  |         if document.version_name: | ||||||
|  |             base["VersionName"] = document.version_name | ||||||
|  |         if document.target_type: | ||||||
|  |             base["TargetType"] = document.target_type | ||||||
|  |         if document.tags: | ||||||
|  |             base["Tags"] = document.tags | ||||||
|  | 
 | ||||||
|  |         return base | ||||||
|  | 
 | ||||||
|  |     def modify_permissions(self, accounts_to_add, accounts_to_remove, version): | ||||||
|  |         if "all" in accounts_to_add: | ||||||
|  |             self.permissions.clear() | ||||||
|  |         else: | ||||||
|  |             self.permissions.pop("all", None) | ||||||
|  | 
 | ||||||
|  |         new_permissions = {account_id: version for account_id in accounts_to_add} | ||||||
|  |         self.permissions.update(**new_permissions) | ||||||
|  | 
 | ||||||
|  |         if "all" in accounts_to_remove: | ||||||
|  |             self.permissions.clear() | ||||||
|  |         else: | ||||||
|  |             for account_id in accounts_to_remove: | ||||||
|  |                 self.permissions.pop(account_id, None) | ||||||
|  | 
 | ||||||
|  |     def describe_permissions(self): | ||||||
|  |         return { | ||||||
|  |             "AccountIds": list(self.permissions.keys()), | ||||||
|  |             "AccountSharingInfoList": [ | ||||||
|  |                 {"AccountId": account_id, "SharedDocumentVersion": document_version} | ||||||
|  |                 for account_id, document_version in self.permissions.items() | ||||||
|  |             ], | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     def is_shared(self): | ||||||
|  |         return len(self.permissions) > 0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class Document(BaseModel): | class Document(BaseModel): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
| @ -171,7 +313,7 @@ class Document(BaseModel): | |||||||
|         self.status = "Active" |         self.status = "Active" | ||||||
|         self.document_version = document_version |         self.document_version = document_version | ||||||
|         self.owner = ACCOUNT_ID |         self.owner = ACCOUNT_ID | ||||||
|         self.created_date = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") |         self.created_date = datetime.datetime.utcnow() | ||||||
| 
 | 
 | ||||||
|         if document_format == "JSON": |         if document_format == "JSON": | ||||||
|             try: |             try: | ||||||
| @ -220,6 +362,10 @@ class Document(BaseModel): | |||||||
|         except KeyError: |         except KeyError: | ||||||
|             raise InvalidDocumentContent("The content for the document is not valid.") |             raise InvalidDocumentContent("The content for the document is not valid.") | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     def hash(self): | ||||||
|  |         return hashlib.sha256(self.content.encode("utf-8")).hexdigest() | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class Command(BaseModel): | class Command(BaseModel): | ||||||
|     def __init__( |     def __init__( | ||||||
| @ -482,7 +628,7 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
|         self._resource_tags = defaultdict(lambda: defaultdict(dict)) |         self._resource_tags = defaultdict(lambda: defaultdict(dict)) | ||||||
|         self._commands = [] |         self._commands = [] | ||||||
|         self._errors = [] |         self._errors = [] | ||||||
|         self._documents = defaultdict(dict) |         self._documents: Dict[str, Documents] = {} | ||||||
| 
 | 
 | ||||||
|         self._region = region_name |         self._region = region_name | ||||||
| 
 | 
 | ||||||
| @ -491,53 +637,17 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
|         self.__dict__ = {} |         self.__dict__ = {} | ||||||
|         self.__init__(region_name) |         self.__init__(region_name) | ||||||
| 
 | 
 | ||||||
|     def _generate_document_description(self, document): |  | ||||||
| 
 |  | ||||||
|         latest = self._documents[document.name]["latest_version"] |  | ||||||
|         default_version = self._documents[document.name]["default_version"] |  | ||||||
|         base = { |  | ||||||
|             "Hash": hashlib.sha256(document.content.encode("utf-8")).hexdigest(), |  | ||||||
|             "HashType": "Sha256", |  | ||||||
|             "Name": document.name, |  | ||||||
|             "Owner": document.owner, |  | ||||||
|             "CreatedDate": document.created_date, |  | ||||||
|             "Status": document.status, |  | ||||||
|             "DocumentVersion": document.document_version, |  | ||||||
|             "Description": document.description, |  | ||||||
|             "Parameters": document.parameter_list, |  | ||||||
|             "PlatformTypes": document.platform_types, |  | ||||||
|             "DocumentType": document.document_type, |  | ||||||
|             "SchemaVersion": document.schema_version, |  | ||||||
|             "LatestVersion": latest, |  | ||||||
|             "DefaultVersion": default_version, |  | ||||||
|             "DocumentFormat": document.document_format, |  | ||||||
|         } |  | ||||||
|         if document.version_name: |  | ||||||
|             base["VersionName"] = document.version_name |  | ||||||
|         if document.target_type: |  | ||||||
|             base["TargetType"] = document.target_type |  | ||||||
|         if document.tags: |  | ||||||
|             base["Tags"] = document.tags |  | ||||||
| 
 |  | ||||||
|         return base |  | ||||||
| 
 |  | ||||||
|     def _generate_document_information(self, ssm_document, document_format): |     def _generate_document_information(self, ssm_document, document_format): | ||||||
|  |         content = self._get_document_content(document_format, ssm_document) | ||||||
|         base = { |         base = { | ||||||
|             "Name": ssm_document.name, |             "Name": ssm_document.name, | ||||||
|             "DocumentVersion": ssm_document.document_version, |             "DocumentVersion": ssm_document.document_version, | ||||||
|             "Status": ssm_document.status, |             "Status": ssm_document.status, | ||||||
|             "Content": ssm_document.content, |             "Content": content, | ||||||
|             "DocumentType": ssm_document.document_type, |             "DocumentType": ssm_document.document_type, | ||||||
|             "DocumentFormat": document_format, |             "DocumentFormat": document_format, | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         if document_format == "JSON": |  | ||||||
|             base["Content"] = json.dumps(ssm_document.content_json) |  | ||||||
|         elif document_format == "YAML": |  | ||||||
|             base["Content"] = yaml.dump(ssm_document.content_json) |  | ||||||
|         else: |  | ||||||
|             raise ValidationException("Invalid document format " + str(document_format)) |  | ||||||
| 
 |  | ||||||
|         if ssm_document.version_name: |         if ssm_document.version_name: | ||||||
|             base["VersionName"] = ssm_document.version_name |             base["VersionName"] = ssm_document.version_name | ||||||
|         if ssm_document.requires: |         if ssm_document.requires: | ||||||
| @ -547,7 +657,20 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
| 
 | 
 | ||||||
|         return base |         return base | ||||||
| 
 | 
 | ||||||
|     def _generate_document_list_information(self, ssm_document): |     @staticmethod | ||||||
|  |     def _get_document_content(document_format, ssm_document): | ||||||
|  |         if document_format == ssm_document.document_format: | ||||||
|  |             content = ssm_document.content | ||||||
|  |         elif document_format == "JSON": | ||||||
|  |             content = json.dumps(ssm_document.content_json) | ||||||
|  |         elif document_format == "YAML": | ||||||
|  |             content = yaml.dump(ssm_document.content_json) | ||||||
|  |         else: | ||||||
|  |             raise ValidationException("Invalid document format " + str(document_format)) | ||||||
|  |         return content | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _generate_document_list_information(ssm_document): | ||||||
|         base = { |         base = { | ||||||
|             "Name": ssm_document.name, |             "Name": ssm_document.name, | ||||||
|             "Owner": ssm_document.owner, |             "Owner": ssm_document.owner, | ||||||
| @ -569,6 +692,12 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
| 
 | 
 | ||||||
|         return base |         return base | ||||||
| 
 | 
 | ||||||
|  |     def _get_documents(self, name): | ||||||
|  |         documents = self._documents.get(name) | ||||||
|  |         if not documents: | ||||||
|  |             raise InvalidDocument("The specified document does not exist.") | ||||||
|  |         return documents | ||||||
|  | 
 | ||||||
|     def create_document( |     def create_document( | ||||||
|         self, |         self, | ||||||
|         content, |         content, | ||||||
| @ -603,24 +732,23 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
|         if self._documents.get(ssm_document.name): |         if self._documents.get(ssm_document.name): | ||||||
|             raise DocumentAlreadyExists("The specified document already exists.") |             raise DocumentAlreadyExists("The specified document already exists.") | ||||||
| 
 | 
 | ||||||
|         self._documents[ssm_document.name] = { |         documents = Documents(ssm_document) | ||||||
|             "documents": {ssm_document.document_version: ssm_document}, |         self._documents[ssm_document.name] = documents | ||||||
|             "default_version": ssm_document.document_version, |  | ||||||
|             "latest_version": ssm_document.document_version, |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         return self._generate_document_description(ssm_document) |         return documents.describe() | ||||||
| 
 | 
 | ||||||
|     def delete_document(self, name, document_version, version_name, force): |     def delete_document(self, name, document_version, version_name, force): | ||||||
|         documents = self._documents.get(name, {}).get("documents", {}) |         documents = self._get_documents(name) | ||||||
|  | 
 | ||||||
|  |         if documents.is_shared(): | ||||||
|  |             raise InvalidDocumentOperation("Must unshare document first before delete") | ||||||
|  | 
 | ||||||
|         keys_to_delete = set() |         keys_to_delete = set() | ||||||
| 
 | 
 | ||||||
|         if documents: |         if documents: | ||||||
|             default_version = self._documents[name]["default_version"] |             default_doc = documents.get_default_version() | ||||||
| 
 |  | ||||||
|             if ( |             if ( | ||||||
|                 documents[default_version].document_type |                 default_doc.document_type == "ApplicationConfigurationSchema" | ||||||
|                 == "ApplicationConfigurationSchema" |  | ||||||
|                 and not force |                 and not force | ||||||
|             ): |             ): | ||||||
|                 raise InvalidDocumentOperation( |                 raise InvalidDocumentOperation( | ||||||
| @ -628,20 +756,20 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
|                     "You must stop sharing the document before you can delete it." |                     "You must stop sharing the document before you can delete it." | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|             if document_version and document_version == default_version: |             if document_version and document_version == default_doc.document_version: | ||||||
|                 raise InvalidDocumentOperation( |                 raise InvalidDocumentOperation( | ||||||
|                     "Default version of the document can't be deleted." |                     "Default version of the document can't be deleted." | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|             if document_version or version_name: |             if document_version or version_name: | ||||||
|                 # We delete only a specific version |                 # We delete only a specific version | ||||||
|                 delete_doc = self._find_document(name, document_version, version_name) |                 delete_doc = documents.find(document_version, version_name) | ||||||
| 
 | 
 | ||||||
|                 # we can't delete only the default version |                 # we can't delete only the default version | ||||||
|                 if ( |                 if ( | ||||||
|                     delete_doc |                     delete_doc | ||||||
|                     and delete_doc.document_version == default_version |                     and delete_doc.document_version == default_doc.document_version | ||||||
|                     and len(documents) != 1 |                     and len(documents.versions) != 1 | ||||||
|                 ): |                 ): | ||||||
|                     raise InvalidDocumentOperation( |                     raise InvalidDocumentOperation( | ||||||
|                         "Default version of the document can't be deleted." |                         "Default version of the document can't be deleted." | ||||||
| @ -653,64 +781,18 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
|                     raise InvalidDocument("The specified document does not exist.") |                     raise InvalidDocument("The specified document does not exist.") | ||||||
|             else: |             else: | ||||||
|                 # We are deleting all versions |                 # We are deleting all versions | ||||||
|                 keys_to_delete = set(documents.keys()) |                 keys_to_delete = set(documents.versions.keys()) | ||||||
| 
 | 
 | ||||||
|             for key in keys_to_delete: |             documents.delete(*keys_to_delete) | ||||||
|                 del self._documents[name]["documents"][key] |  | ||||||
| 
 | 
 | ||||||
|             if len(self._documents[name]["documents"].keys()) == 0: |             if len(documents.versions) == 0: | ||||||
|                 del self._documents[name] |                 del self._documents[name] | ||||||
|             else: |  | ||||||
|                 old_latest = self._documents[name]["latest_version"] |  | ||||||
|                 if old_latest not in self._documents[name]["documents"].keys(): |  | ||||||
|                     leftover_keys = self._documents[name]["documents"].keys() |  | ||||||
|                     int_keys = [] |  | ||||||
|                     for key in leftover_keys: |  | ||||||
|                         int_keys.append(int(key)) |  | ||||||
|                     self._documents[name]["latest_version"] = str(sorted(int_keys)[-1]) |  | ||||||
|         else: |  | ||||||
|             raise InvalidDocument("The specified document does not exist.") |  | ||||||
| 
 |  | ||||||
|     def _find_document( |  | ||||||
|         self, name, document_version=None, version_name=None, strict=True |  | ||||||
|     ): |  | ||||||
|         if not self._documents.get(name): |  | ||||||
|             raise InvalidDocument("The specified document does not exist.") |  | ||||||
| 
 |  | ||||||
|         documents = self._documents[name]["documents"] |  | ||||||
|         ssm_document = None |  | ||||||
| 
 |  | ||||||
|         if not version_name and not document_version: |  | ||||||
|             # Retrieve default version |  | ||||||
|             default_version = self._documents[name]["default_version"] |  | ||||||
|             ssm_document = documents.get(default_version) |  | ||||||
| 
 |  | ||||||
|         elif version_name and document_version: |  | ||||||
|             for doc_version, document in documents.items(): |  | ||||||
|                 if ( |  | ||||||
|                     doc_version == document_version |  | ||||||
|                     and document.version_name == version_name |  | ||||||
|                 ): |  | ||||||
|                     ssm_document = document |  | ||||||
|                     break |  | ||||||
| 
 |  | ||||||
|         else: |  | ||||||
|             for doc_version, document in documents.items(): |  | ||||||
|                 if document_version and doc_version == document_version: |  | ||||||
|                     ssm_document = document |  | ||||||
|                     break |  | ||||||
|                 if version_name and document.version_name == version_name: |  | ||||||
|                     ssm_document = document |  | ||||||
|                     break |  | ||||||
| 
 |  | ||||||
|         if strict and not ssm_document: |  | ||||||
|             raise InvalidDocument("The specified document does not exist.") |  | ||||||
| 
 |  | ||||||
|         return ssm_document |  | ||||||
| 
 | 
 | ||||||
|     def get_document(self, name, document_version, version_name, document_format): |     def get_document(self, name, document_version, version_name, document_format): | ||||||
| 
 | 
 | ||||||
|         ssm_document = self._find_document(name, document_version, version_name) |         documents = self._get_documents(name) | ||||||
|  |         ssm_document = documents.find(document_version, version_name) | ||||||
|  | 
 | ||||||
|         if not document_format: |         if not document_format: | ||||||
|             document_format = ssm_document.document_format |             document_format = ssm_document.document_format | ||||||
|         else: |         else: | ||||||
| @ -719,18 +801,18 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
|         return self._generate_document_information(ssm_document, document_format) |         return self._generate_document_information(ssm_document, document_format) | ||||||
| 
 | 
 | ||||||
|     def update_document_default_version(self, name, document_version): |     def update_document_default_version(self, name, document_version): | ||||||
|  |         documents = self._get_documents(name) | ||||||
|  |         ssm_document = documents.update_default_version(document_version) | ||||||
| 
 | 
 | ||||||
|         ssm_document = self._find_document(name, document_version=document_version) |         result = { | ||||||
|         self._documents[name]["default_version"] = document_version |  | ||||||
|         base = { |  | ||||||
|             "Name": ssm_document.name, |             "Name": ssm_document.name, | ||||||
|             "DefaultVersion": document_version, |             "DefaultVersion": document_version, | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         if ssm_document.version_name: |         if ssm_document.version_name: | ||||||
|             base["DefaultVersionName"] = ssm_document.version_name |             result["DefaultVersionName"] = ssm_document.version_name | ||||||
| 
 | 
 | ||||||
|         return base |         return result | ||||||
| 
 | 
 | ||||||
|     def update_document( |     def update_document( | ||||||
|         self, |         self, | ||||||
| @ -750,24 +832,27 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
|             strict=False, |             strict=False, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         if not self._documents.get(name): |         documents = self._documents.get(name) | ||||||
|  |         if not documents: | ||||||
|             raise InvalidDocument("The specified document does not exist.") |             raise InvalidDocument("The specified document does not exist.") | ||||||
|  | 
 | ||||||
|         if ( |         if ( | ||||||
|             self._documents[name]["latest_version"] != document_version |             documents.latest_version != document_version | ||||||
|             and document_version != "$LATEST" |             and document_version != "$LATEST" | ||||||
|         ): |         ): | ||||||
|             raise InvalidDocumentVersion( |             raise InvalidDocumentVersion( | ||||||
|                 "The document version is not valid or does not exist." |                 "The document version is not valid or does not exist." | ||||||
|             ) |             ) | ||||||
|         if version_name and self._find_document( |  | ||||||
|             name, version_name=version_name, strict=False |  | ||||||
|         ): |  | ||||||
|             raise DuplicateDocumentVersionName( |  | ||||||
|                 "The specified version name is a duplicate." |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|         old_ssm_document = self._find_document(name) |         if version_name: | ||||||
|  |             if documents.exists(version_name=version_name): | ||||||
|  |                 raise DuplicateDocumentVersionName( | ||||||
|  |                     "The specified version name is a duplicate." | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|  |         old_ssm_document = documents.get_default_version() | ||||||
|  | 
 | ||||||
|  |         new_version = str(int(documents.latest_version) + 1) | ||||||
|         new_ssm_document = Document( |         new_ssm_document = Document( | ||||||
|             name=name, |             name=name, | ||||||
|             version_name=version_name, |             version_name=version_name, | ||||||
| @ -778,28 +863,22 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
|             attachments=attachments, |             attachments=attachments, | ||||||
|             target_type=target_type, |             target_type=target_type, | ||||||
|             tags=old_ssm_document.tags, |             tags=old_ssm_document.tags, | ||||||
|             document_version=str(int(self._documents[name]["latest_version"]) + 1), |             document_version=new_version, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         for doc_version, document in self._documents[name]["documents"].items(): |         for doc_version, document in documents.versions.items(): | ||||||
|             if document.content == new_ssm_document.content: |             if document.content == new_ssm_document.content: | ||||||
|                 raise DuplicateDocumentContent( |                 raise DuplicateDocumentContent( | ||||||
|                     "The content of the association document matches another document. " |                     "The content of the association document matches another document. " | ||||||
|                     "Change the content of the document and try again." |                     "Change the content of the document and try again." | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|         self._documents[name]["latest_version"] = str( |         documents.add_new_version(new_ssm_document) | ||||||
|             int(self._documents[name]["latest_version"]) + 1 |         return documents.describe(document_version=new_version) | ||||||
|         ) |  | ||||||
|         self._documents[name]["documents"][ |  | ||||||
|             new_ssm_document.document_version |  | ||||||
|         ] = new_ssm_document |  | ||||||
| 
 |  | ||||||
|         return self._generate_document_description(new_ssm_document) |  | ||||||
| 
 | 
 | ||||||
|     def describe_document(self, name, document_version, version_name): |     def describe_document(self, name, document_version, version_name): | ||||||
|         ssm_document = self._find_document(name, document_version, version_name) |         documents = self._get_documents(name) | ||||||
|         return self._generate_document_description(ssm_document) |         return documents.describe(document_version, version_name) | ||||||
| 
 | 
 | ||||||
|     def list_documents( |     def list_documents( | ||||||
|         self, document_filter_list, filters, max_results=10, next_token="0" |         self, document_filter_list, filters, max_results=10, next_token="0" | ||||||
| @ -813,17 +892,16 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
|         results = [] |         results = [] | ||||||
|         dummy_token_tracker = 0 |         dummy_token_tracker = 0 | ||||||
|         # Sort to maintain next token adjacency |         # Sort to maintain next token adjacency | ||||||
|         for document_name, document_bundle in sorted(self._documents.items()): |         for document_name, documents in sorted(self._documents.items()): | ||||||
|             if len(results) == max_results: |             if len(results) == max_results: | ||||||
|                 # There's still more to go so we need a next token |                 # There's still more to go so we need a next token | ||||||
|                 return results, str(next_token + len(results)) |                 return results, str(next_token + len(results)) | ||||||
| 
 | 
 | ||||||
|             if dummy_token_tracker < next_token: |             if dummy_token_tracker < next_token: | ||||||
|                 dummy_token_tracker = dummy_token_tracker + 1 |                 dummy_token_tracker += 1 | ||||||
|                 continue |                 continue | ||||||
| 
 | 
 | ||||||
|             default_version = document_bundle["default_version"] |             ssm_doc = documents.get_default_version() | ||||||
|             ssm_doc = self._documents[document_name]["documents"][default_version] |  | ||||||
|             if filters and not _document_filter_match(filters, ssm_doc): |             if filters and not _document_filter_match(filters, ssm_doc): | ||||||
|                 # If we have filters enabled, and we don't match them, |                 # If we have filters enabled, and we don't match them, | ||||||
|                 continue |                 continue | ||||||
| @ -833,6 +911,72 @@ class SimpleSystemManagerBackend(BaseBackend): | |||||||
|         # If we've fallen out of the loop, theres no more documents. No next token. |         # If we've fallen out of the loop, theres no more documents. No next token. | ||||||
|         return results, "" |         return results, "" | ||||||
| 
 | 
 | ||||||
|  |     def describe_document_permission( | ||||||
|  |         self, name, max_results=None, permission_type=None, next_token=None | ||||||
|  |     ): | ||||||
|  |         # Parameters max_results, permission_type, and next_token not used because | ||||||
|  |         # this current implementation doesn't support pagination. | ||||||
|  |         document = self._get_documents(name) | ||||||
|  |         return document.describe_permissions() | ||||||
|  | 
 | ||||||
|  |     def modify_document_permission( | ||||||
|  |         self, | ||||||
|  |         name, | ||||||
|  |         account_ids_to_add, | ||||||
|  |         account_ids_to_remove, | ||||||
|  |         shared_document_version, | ||||||
|  |         permission_type, | ||||||
|  |     ): | ||||||
|  | 
 | ||||||
|  |         account_id_regex = re.compile(r"(all|[0-9]{12})") | ||||||
|  |         version_regex = re.compile(r"^([$]LATEST|[$]DEFAULT|[$]ALL)$") | ||||||
|  | 
 | ||||||
|  |         account_ids_to_add = account_ids_to_add or [] | ||||||
|  |         account_ids_to_remove = account_ids_to_remove or [] | ||||||
|  | 
 | ||||||
|  |         if not version_regex.match(shared_document_version): | ||||||
|  |             raise ValidationException( | ||||||
|  |                 f"Value '{shared_document_version}' at 'sharedDocumentVersion' failed to satisfy constraint: " | ||||||
|  |                 f"Member must satisfy regular expression pattern: ([$]LATEST|[$]DEFAULT|[$]ALL)." | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         for account_id in account_ids_to_add: | ||||||
|  |             if not account_id_regex.match(account_id): | ||||||
|  |                 raise ValidationException( | ||||||
|  |                     f"Value '[{account_id}]' at 'accountIdsToAdd' failed to satisfy constraint: " | ||||||
|  |                     "Member must satisfy regular expression pattern: (all|[0-9]{12}])." | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |         for account_id in account_ids_to_remove: | ||||||
|  |             if not account_id_regex.match(account_id): | ||||||
|  |                 raise ValidationException( | ||||||
|  |                     f"Value '[{account_id}]' at 'accountIdsToRemove' failed to satisfy constraint: " | ||||||
|  |                     "Member must satisfy regular expression pattern: (?i)all|[0-9]{12}]." | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |         accounts_to_add = set(account_ids_to_add) | ||||||
|  |         if "all" in accounts_to_add and len(accounts_to_add) > 1: | ||||||
|  |             raise DocumentPermissionLimit( | ||||||
|  |                 "Accounts can either be all or a group of AWS accounts" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         accounts_to_remove = set(account_ids_to_remove) | ||||||
|  |         if "all" in accounts_to_remove and len(accounts_to_remove) > 1: | ||||||
|  |             raise DocumentPermissionLimit( | ||||||
|  |                 "Accounts can either be all or a group of AWS accounts" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         if permission_type != "Share": | ||||||
|  |             raise InvalidPermissionType( | ||||||
|  |                 f"Value '{permission_type}' at 'permissionType' failed to satisfy constraint: " | ||||||
|  |                 "Member must satisfy enum value set: [Share]." | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         document = self._get_documents(name) | ||||||
|  |         document.modify_permissions( | ||||||
|  |             accounts_to_add, accounts_to_remove, shared_document_version | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|     def delete_parameter(self, name): |     def delete_parameter(self, name): | ||||||
|         return self._parameters.pop(name, None) |         return self._parameters.pop(name, None) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -127,6 +127,35 @@ class SimpleSystemManagerResponse(BaseResponse): | |||||||
| 
 | 
 | ||||||
|         return json.dumps({"DocumentIdentifiers": documents, "NextToken": token}) |         return json.dumps({"DocumentIdentifiers": documents, "NextToken": token}) | ||||||
| 
 | 
 | ||||||
|  |     def describe_document_permission(self): | ||||||
|  |         name = self._get_param("Name") | ||||||
|  |         max_results = self._get_param("MaxResults") | ||||||
|  |         next_token = self._get_param("NextToken") | ||||||
|  |         permission_type = self._get_param("PermissionType") | ||||||
|  | 
 | ||||||
|  |         result = self.ssm_backend.describe_document_permission( | ||||||
|  |             name=name, | ||||||
|  |             max_results=max_results, | ||||||
|  |             next_token=next_token, | ||||||
|  |             permission_type=permission_type, | ||||||
|  |         ) | ||||||
|  |         return json.dumps(result) | ||||||
|  | 
 | ||||||
|  |     def modify_document_permission(self): | ||||||
|  |         account_ids_to_add = self._get_param("AccountIdsToAdd") | ||||||
|  |         account_ids_to_remove = self._get_param("AccountIdsToRemove") | ||||||
|  |         name = self._get_param("Name") | ||||||
|  |         permission_type = self._get_param("PermissionType") | ||||||
|  |         shared_document_version = self._get_param("SharedDocumentVersion") | ||||||
|  | 
 | ||||||
|  |         self.ssm_backend.modify_document_permission( | ||||||
|  |             name=name, | ||||||
|  |             account_ids_to_add=account_ids_to_add, | ||||||
|  |             account_ids_to_remove=account_ids_to_remove, | ||||||
|  |             shared_document_version=shared_document_version, | ||||||
|  |             permission_type=permission_type, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|     def _get_param(self, param, default=None): |     def _get_param(self, param, default=None): | ||||||
|         return self.request_params.get(param, default) |         return self.request_params.get(param, default) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -11,6 +11,7 @@ TestAccAWSCloudwatchEventBusPolicy | |||||||
| TestAccAWSCloudWatchEventConnection | TestAccAWSCloudWatchEventConnection | ||||||
| TestAccAWSCloudWatchEventPermission | TestAccAWSCloudWatchEventPermission | ||||||
| TestAccAWSCloudWatchEventRule | TestAccAWSCloudWatchEventRule | ||||||
|  | TestAccAWSCloudWatchEventTarget_ssmDocument | ||||||
| TestAccAWSCloudwatchLogGroupDataSource | TestAccAWSCloudwatchLogGroupDataSource | ||||||
| TestAccAWSCloudWatchMetricAlarm | TestAccAWSCloudWatchMetricAlarm | ||||||
| TestAccAWSDataSourceCloudwatch | TestAccAWSDataSourceCloudwatch | ||||||
| @ -87,5 +88,6 @@ TestAccAWSRouteTable_IPv4_To_TransitGateway | |||||||
| TestAccAWSRouteTable_IPv4_To_VpcPeeringConnection | TestAccAWSRouteTable_IPv4_To_VpcPeeringConnection | ||||||
| TestAccAWSRouteTable_disappears | TestAccAWSRouteTable_disappears | ||||||
| TestAccAWSRouteTable_basic | TestAccAWSRouteTable_basic | ||||||
|  | TestAccAWSSsmDocumentDataSource | ||||||
| TestAccAwsEc2ManagedPrefixList | TestAccAwsEc2ManagedPrefixList | ||||||
| TestAccAWSEgressOnlyInternetGateway | TestAccAWSEgressOnlyInternetGateway | ||||||
| @ -4,6 +4,7 @@ import boto3 | |||||||
| import botocore.exceptions | import botocore.exceptions | ||||||
| import sure  # noqa | import sure  # noqa | ||||||
| import datetime | import datetime | ||||||
|  | from datetime import timezone | ||||||
| import json | import json | ||||||
| import yaml | import yaml | ||||||
| import hashlib | import hashlib | ||||||
| @ -42,7 +43,7 @@ def _validate_document_description( | |||||||
|     doc_description["Name"].should.equal(doc_name) |     doc_description["Name"].should.equal(doc_name) | ||||||
|     doc_description["Owner"].should.equal(ACCOUNT_ID) |     doc_description["Owner"].should.equal(ACCOUNT_ID) | ||||||
| 
 | 
 | ||||||
|     difference = datetime.datetime.utcnow() - doc_description["CreatedDate"] |     difference = datetime.datetime.now(tz=timezone.utc) - doc_description["CreatedDate"] | ||||||
|     if difference.min > datetime.timedelta(minutes=1): |     if difference.min > datetime.timedelta(minutes=1): | ||||||
|         assert False |         assert False | ||||||
| 
 | 
 | ||||||
| @ -224,7 +225,7 @@ def test_create_document(): | |||||||
|     doc_description["Name"].should.equal("EmptyParamDoc") |     doc_description["Name"].should.equal("EmptyParamDoc") | ||||||
|     doc_description["Owner"].should.equal(ACCOUNT_ID) |     doc_description["Owner"].should.equal(ACCOUNT_ID) | ||||||
| 
 | 
 | ||||||
|     difference = datetime.datetime.utcnow() - doc_description["CreatedDate"] |     difference = datetime.datetime.now(tz=timezone.utc) - doc_description["CreatedDate"] | ||||||
|     if difference.min > datetime.timedelta(minutes=1): |     if difference.min > datetime.timedelta(minutes=1): | ||||||
|         assert False |         assert False | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user