From e2f6544228b9ee81324840e638ab062caa68b6e3 Mon Sep 17 00:00:00 2001 From: Alex Bainbridge Date: Fri, 26 Jun 2020 10:47:28 -0400 Subject: [PATCH] ssm document code done, testing now --- moto/ssm/exceptions.py | 50 ++++ moto/ssm/models.py | 470 ++++++++++++++++++++++++++++---- moto/ssm/responses.py | 92 +++++++ tests/test_ssm/test_ssm_docs.py | 0 4 files changed, 563 insertions(+), 49 deletions(-) create mode 100644 tests/test_ssm/test_ssm_docs.py diff --git a/moto/ssm/exceptions.py b/moto/ssm/exceptions.py index 83ae26b6c..a1e129002 100644 --- a/moto/ssm/exceptions.py +++ b/moto/ssm/exceptions.py @@ -53,3 +53,53 @@ class ValidationException(JsonRESTError): def __init__(self, message): super(ValidationException, self).__init__("ValidationException", message) + + +class DocumentAlreadyExists(JsonRESTError): + code = 400 + + def __init__(self, message): + super(DocumentAlreadyExists, self).__init__("DocumentAlreadyExists", message) + + +class InvalidDocument(JsonRESTError): + code = 400 + + def __init__(self, message): + super(InvalidDocument, self).__init__("InvalidDocument", message) + + +class InvalidDocumentOperation(JsonRESTError): + code = 400 + + def __init__(self, message): + super(InvalidDocumentOperation, self).__init__("InvalidDocumentOperation", message) + + +class InvalidDocumentContent(JsonRESTError): + code = 400 + + def __init__(self, message): + super(InvalidDocumentContent, self).__init__("InvalidDocumentContent", message) + + +class InvalidDocumentVersion(JsonRESTError): + code = 400 + + def __init__(self, message): + super(InvalidDocumentVersion, self).__init__("InvalidDocumentVersion", message) + + +class DuplicateDocumentVersionName(JsonRESTError): + code = 400 + + def __init__(self, message): + super(DuplicateDocumentVersionName, self).__init__("DuplicateDocumentVersionName", message) + + +class DuplicateDocumentContent(JsonRESTError): + code = 400 + + def __init__(self, message): + super(DuplicateDocumentContent, self).__init__("DuplicateDocumentContent", message) + diff --git a/moto/ssm/models.py b/moto/ssm/models.py index 67216972e..713cbd628 100644 --- a/moto/ssm/models.py +++ b/moto/ssm/models.py @@ -3,7 +3,7 @@ from __future__ import unicode_literals import re from collections import defaultdict -from moto.core import BaseBackend, BaseModel +from moto.core import ACCOUNT_ID, BaseBackend, BaseModel from moto.core.exceptions import RESTError from moto.ec2 import ec2_backends from moto.cloudformation import cloudformation_backends @@ -12,6 +12,8 @@ import datetime import time import uuid import itertools +import json +import yaml from .utils import parameter_arn from .exceptions import ( @@ -22,20 +24,27 @@ from .exceptions import ( ParameterVersionLabelLimitExceeded, ParameterVersionNotFound, ParameterNotFound, + DocumentAlreadyExists, + InvalidDocumentOperation, + InvalidDocument, + InvalidDocumentContent, + InvalidDocumentVersion, + DuplicateDocumentVersionName, + DuplicateDocumentContent ) class Parameter(BaseModel): def __init__( - self, - name, - value, - type, - description, - allowed_pattern, - keyid, - last_modified_date, - version, + self, + name, + value, + type, + description, + allowed_pattern, + keyid, + last_modified_date, + version, ): self.name = name self.type = type @@ -63,7 +72,7 @@ class Parameter(BaseModel): prefix = "kms:{}:".format(self.keyid or "default") if value.startswith(prefix): - return value[len(prefix) :] + return value[len(prefix):] def response_object(self, decrypt=False, region=None): r = { @@ -102,23 +111,86 @@ class Parameter(BaseModel): MAX_TIMEOUT_SECONDS = 3600 +def generate_ssm_doc_param_list(parameters): + if not parameters: + return None + param_list = [] + for param_name, param_info in parameters.items(): + param_info["Name"] = param_name + param_list.append(param_info) + return param_list + + +class Document(BaseModel): + def __init__(self, name, version_name, content, document_type, document_format, requires, attachments, + target_type, tags, document_version="1"): + self.name = name + self.version_name = version_name + self.content = content + self.document_type = document_type + self.document_format = document_format + self.requires = requires + self.attachments = attachments + self.target_type = target_type + self.tags = tags + + self.status = "Active" + self.document_version = document_version + self.owner = ACCOUNT_ID + self.created_date = datetime.datetime.now() + + if document_format == "JSON": + try: + content_json = json.loads(content) + except json.decoder.JSONDecodeError: + raise InvalidDocumentContent("The content for the document is not valid.") + elif document_format == "YAML": + try: + content_json = yaml.safe_load(content) + except yaml.YAMLError: + raise InvalidDocumentContent("The content for the document is not valid.") + else: + raise ValidationException(f'Invalid document format {document_format}') + + self.content_json = content_json + + try: + self.schema_version = content_json["schemaVersion"] + self.description = content_json.get("description") + self.outputs = content_json.get("outputs") + self.files = content_json.get("files") + # TODO add platformType + self.platform_types = "Not Implemented (moto)" + self.parameter_list = generate_ssm_doc_param_list(content_json.get("parameters")) + + if self.schema_version == "0.3" or self.schema_version == "2.0" or self.schema_version == "2.2": + self.mainSteps = content_json["mainSteps"] + elif self.schema_version == "1.2": + self.runtimeConfig = content_json.get("runtimeConfig") + + except KeyError: + raise InvalidDocumentContent("The content for the document is not valid.") + + + + class Command(BaseModel): def __init__( - self, - comment="", - document_name="", - timeout_seconds=MAX_TIMEOUT_SECONDS, - instance_ids=None, - max_concurrency="", - max_errors="", - notification_config=None, - output_s3_bucket_name="", - output_s3_key_prefix="", - output_s3_region="", - parameters=None, - service_role_arn="", - targets=None, - backend_region="us-east-1", + self, + comment="", + document_name="", + timeout_seconds=MAX_TIMEOUT_SECONDS, + instance_ids=None, + max_concurrency="", + max_errors="", + notification_config=None, + output_s3_bucket_name="", + output_s3_key_prefix="", + output_s3_region="", + parameters=None, + service_role_arn="", + targets=None, + backend_region="us-east-1", ): if instance_ids is None: @@ -269,6 +341,75 @@ class Command(BaseModel): return invocation +def _validate_document_format(document_format): + aws_doc_formats = ["JSON", "YAML"] + if document_format not in aws_doc_formats: + raise ValidationException(f'Invalid document format {document_format}') + + +def _validate_document_info(content, name, document_type, document_format, strict=True): + aws_ssm_name_regex = r'^[a-zA-Z0-9_\-.]{3,128}$' + aws_name_reject_list = ["aws-", "amazon", "amzn"] + aws_doc_types = ["Command", "Policy", "Automation", "Session", "Package", "ApplicationConfiguration", + "ApplicationConfigurationSchema", "DeploymentStrategy", "ChangeCalendar"] + + _validate_document_format(document_format) + + if not content: + raise ValidationException("Content is required") + + if list(filter(name.startswith, aws_name_reject_list)): + raise ValidationException(f'Invalid document name {name}') + ssm_name_pattern = re.compile(aws_ssm_name_regex) + if not ssm_name_pattern.match(name): + raise ValidationException(f'Invalid document name {name}') + + if strict and document_type not in aws_doc_types: + # Update document doesn't use document type + raise ValidationException(f'Invalid document type {document_type}') + + +def _document_filter_equal_comparator(keyed_value, filter): + for v in filter["Values"]: + if keyed_value == v: + return True + return False + + +def _document_filter_list_includes_comparator(keyed_value_list, filter): + for v in filter["Values"]: + if v in keyed_value_list: + return True + return False + + +def _document_filter_match(filters, ssm_doc): + for filter in filters: + if filter["Key"] == "Name" and not _document_filter_equal_comparator(ssm_doc.name, filter): + return False + + elif filter["Key"] == "Owner": + if len(filter["Values"]) != 1: + raise ValidationException("Owner filter can only have one value.") + if filter["Values"][0] == "Self": + # Update to running account ID + filter["Values"][0] = ACCOUNT_ID + if not _document_filter_equal_comparator(ssm_doc.owner, filter): + return False + + elif filter["Key"] == "PlatformTypes" and not \ + _document_filter_list_includes_comparator(ssm_doc.platform_types, filter): + return False + + elif filter["Key"] == "DocumentType" and not _document_filter_equal_comparator(ssm_doc.document_type, filter): + return False + + elif filter["Key"] == "TargetType" and not _document_filter_equal_comparator(ssm_doc.target_type, filter): + return False + + return True + + class SimpleSystemManagerBackend(BaseBackend): def __init__(self): # each value is a list of all of the versions for a parameter @@ -278,12 +419,243 @@ class SimpleSystemManagerBackend(BaseBackend): self._resource_tags = defaultdict(lambda: defaultdict(dict)) self._commands = [] self._errors = [] + self._documents = defaultdict(dict) # figure out what region we're in for region, backend in ssm_backends.items(): if backend == self: self._region = region + def _generate_document_description(self, document): + + latest = self._documents[document.name]['latest_version'] + default_version = self._documents[document.name]["default_version"] + + return { + "Hash": hash, + "HashType": "Sha256", + "Name": document.name, + "Owner": document.owner, + "CreatedDate": document.created_date, + "Status": document.status, + "DocumentVersion": document.document_version, + "Description": document.description, + "Parameters": document.parameter_list, + "PlatformTypes": document.platform_types, + "SchemaVersion": document.schema_version, + "LatestVersion": latest, + "DefaultVersion": default_version, + "DocumentFormat": document.document_format + } + + def _generate_document_information(self, ssm_document, document_format): + base = { + "Name": ssm_document.name, + "DocumentVersion": ssm_document.document_version, + "Status": ssm_document.status, + "Content": ssm_document.content, + "DocumentType": ssm_document.document_type, + "DocumentFormat": ssm_document.document_format + } + + if document_format == "JSON": + base["Content"] = json.dumps(ssm_document.content_json) + elif document_format == "YAML": + base["Content"] = yaml.dump(ssm_document.content_json) + else: + raise ValidationException(f'Invalid document format {document_format}') + + if ssm_document.version_name: + base["VersionName"] = ssm_document.version_name + if ssm_document.requires: + base["Requires"] = ssm_document.requires + if ssm_document.attachments: + base["AttachmentsContent"] = ssm_document.attachments + + return base + + def _generate_document_list_information(self, ssm_document): + base = { + "Name": ssm_document.name, + "Owner": ssm_document.owner, + "DocumentVersion": ssm_document.document_version, + "DocumentType": ssm_document.document_type, + "SchemaVersion": ssm_document.schema_version, + "DocumentFormat": ssm_document.document_format + } + if ssm_document.version_name: + base["VersionName"] = ssm_document.version_name + if ssm_document.platform_types: + base["PlatformTypes"] = ssm_document.platform_types + if ssm_document.target_type: + base["TargetType"] = ssm_document.target_type + if ssm_document.tags: + base["Tags"] = ssm_document.tags + if ssm_document.requires: + base["Requires"] = ssm_document.requires + + return base + + def create_document(self, content, requires, attachments, name, version_name, document_type, document_format, + target_type, tags): + ssm_document = Document(name=name, version_name=version_name, content=content, document_type=document_type, + document_format=document_format, requires=requires, attachments=attachments, + target_type=target_type, tags=tags) + + _validate_document_info(content=content, name=name, document_type=document_type) + + if self._documents.get(ssm_document.Name): + raise DocumentAlreadyExists(f"Document with same name {name} already exists") + + self._documents[ssm_document.Name] = { + "documents": { + ssm_document.document_version: ssm_document + }, + "default_version": ssm_document.document_version, + "latest_version": ssm_document.document_version + } + + return self._generate_document_description(ssm_document) + + def delete_document(self, name, document_version, version_name, force): + documents = self._documents.get(name, {}).get("documents", {}) + keys_to_delete = set() + + if documents: + if documents[0].document_type == "ApplicationConfigurationSchema" and not force: + raise InvalidDocumentOperation("You attempted to delete a document while it is still shared. " + "You must stop sharing the document before you can delete it.") + if document_version and document_version == self._documents[name]["default_version"]: + raise InvalidDocumentOperation("Default version of the document can't be deleted.") + + if document_version or version_name: + for doc_version, document in documents.items(): + if document_version and doc_version == document_version: + keys_to_delete.add(document_version) + continue + if version_name and document.version_name == version_name: + keys_to_delete.add(document_version) + continue + else: + keys_to_delete = set(documents.keys()) + + for key in keys_to_delete: + self._documents[name]["documents"][key] = None + + if len(self._documents[name]["documents"].keys()) == 0: + self._documents[name] = None + else: + raise InvalidDocument("The specified document does not exist.") + + def _find_document(self, name, document_version=None, version_name=None, strict=True): + if not self._documents.get(name): + raise InvalidDocument(f"Document with name {name} does not exist.") + + documents = self._documents[name]["documents"] + ssm_document = None + + if not version_name and not document_version: + # Retrieve default version + default_version = self._documents[name]['default_version'] + ssm_document = documents.get(default_version) + + elif version_name and document_version: + for doc_version, document in documents.items(): + if doc_version == document_version and document.version_name == version_name: + ssm_document = document + break + + else: + for doc_version, document in documents.items(): + if document_version and doc_version == document_version : + ssm_document = document + break + if version_name and document.version_name == version_name: + ssm_document = document + break + + if strict and not ssm_document: + raise InvalidDocument(f"Document with name {name} does not exist.") + + return ssm_document + + def get_document(self, name, document_version, version_name, document_format): + _validate_document_format(document_format=document_format) + + ssm_document = self._find_document(name, document_version, version_name) + + return self._generate_document_information(ssm_document, document_format) + + def update_document_default_version(self, name, document_version): + ssm_document = self._find_document(name, document_version=document_version) + self._documents[name]["default_version"] = document_version + base = { + 'Name': ssm_document.name, + 'DefaultVersion': document_version, + } + + if ssm_document.version_name: + base['DefaultVersionName'] = ssm_document.version_name + + return base + + def update_document(self, content, attachments, name, version_name, document_version, document_format, target_type): + _validate_document_info(content=content, name=name, document_type=None, strict=False) + if not self._documents.get(name): + raise InvalidDocument("The specified document does not exist.") + if self._documents.get[name]['latest_version'] != document_version or document_version != "$LATEST": + raise InvalidDocumentVersion("The document version is not valid or does not exist.") + if self._find_document(name, version_name=version_name, strict=False): + raise DuplicateDocumentVersionName(f"The specified version name is a duplicate.") + + old_ssm_document = self._find_document(name) + + new_ssm_document = Document(name=name, version_name=version_name, content=content, + document_type=old_ssm_document.document_type, document_format=document_format, + requires=old_ssm_document.requires, attachments=attachments, + target_type=target_type, tags=old_ssm_document.tags, + document_version=self._documents.get[name]['latest_version']) + + for doc_version, document in self._documents[name].items(): + if document.content == new_ssm_document.content: + raise DuplicateDocumentContent("The content of the association document matches another document. " + "Change the content of the document and try again.") + + self._documents[name]["documents"][new_ssm_document.document_version] = new_ssm_document + + return self._generate_document_description(new_ssm_document) + + def describe_document(self, name, document_version, version_name): + ssm_document = self._find_document(name, document_version, version_name) + return self._generate_document_description(ssm_document) + + def list_documents(self, document_filter_list, filters, max_results=10, next_token=0): + if document_filter_list: + raise ValidationException( + "DocumentFilterList is deprecated. Instead use Filters." + ) + + results = [] + dummy_token_tracker = 0 + # Sort to maintain next token adjacency + for document_name, document_bundle in sorted(self._documents.items()): + if dummy_token_tracker < next_token: + dummy_token_tracker = dummy_token_tracker + 1 + continue + + default_version = document_bundle['default_version'] + ssm_doc = self._documents[document_name]['documents'][default_version] + if filters and not _document_filter_match(filters, ssm_doc): + # If we have filters enabled, and we don't match them, + continue + else: + results.append(self._generate_document_list_information(ssm_doc)) + + if len(results) == max_results: + return results, next_token + max_results + + return results + def delete_parameter(self, name): return self._parameters.pop(name, None) @@ -449,9 +821,9 @@ class SimpleSystemManagerBackend(BaseBackend): "When using global parameters, please specify within a global namespace." ) if ( - "//" in value - or not value.startswith("/") - or not re.match("^[a-zA-Z0-9_.-/]*$", value) + "//" in value + or not value.startswith("/") + or not re.match("^[a-zA-Z0-9_.-/]*$", value) ): raise ValidationException( 'The parameter doesn\'t meet the parameter name requirements. The parameter name must begin with a forward slash "/". ' @@ -530,13 +902,13 @@ class SimpleSystemManagerBackend(BaseBackend): return result def get_parameters_by_path( - self, - path, - with_decryption, - recursive, - filters=None, - next_token=None, - max_results=10, + self, + path, + with_decryption, + recursive, + filters=None, + next_token=None, + max_results=10, ): """Implement the get-parameters-by-path-API in the backend.""" result = [] @@ -546,10 +918,10 @@ class SimpleSystemManagerBackend(BaseBackend): for param_name in self._parameters: if path != "/" and not param_name.startswith(path): continue - if "/" in param_name[len(path) + 1 :] and not recursive: + if "/" in param_name[len(path) + 1:] and not recursive: continue if not self._match_filters( - self.get_parameter(param_name, with_decryption), filters + self.get_parameter(param_name, with_decryption), filters ): continue result.append(self.get_parameter(param_name, with_decryption)) @@ -561,7 +933,7 @@ class SimpleSystemManagerBackend(BaseBackend): next_token = 0 next_token = int(next_token) max_results = int(max_results) - values = values_list[next_token : next_token + max_results] + values = values_list[next_token: next_token + max_results] if len(values) == max_results: next_token = str(next_token + max_results) else: @@ -599,7 +971,7 @@ class SimpleSystemManagerBackend(BaseBackend): if what is None: return False elif option == "BeginsWith" and not any( - what.startswith(value) for value in values + what.startswith(value) for value in values ): return False elif option == "Equals" and not any(what == value for value in values): @@ -608,10 +980,10 @@ class SimpleSystemManagerBackend(BaseBackend): if any(value == "/" and len(what.split("/")) == 2 for value in values): continue elif any( - value != "/" - and what.startswith(value + "/") - and len(what.split("/")) - 1 == len(value.split("/")) - for value in values + value != "/" + and what.startswith(value + "/") + and len(what.split("/")) - 1 == len(value.split("/")) + for value in values ): continue else: @@ -658,10 +1030,10 @@ class SimpleSystemManagerBackend(BaseBackend): invalid_labels = [] for label in labels: if ( - label.startswith("aws") - or label.startswith("ssm") - or label[:1].isdigit() - or not re.match(r"^[a-zA-z0-9_\.\-]*$", label) + label.startswith("aws") + or label.startswith("ssm") + or label[:1].isdigit() + or not re.match(r"^[a-zA-z0-9_\.\-]*$", label) ): invalid_labels.append(label) continue @@ -691,7 +1063,7 @@ class SimpleSystemManagerBackend(BaseBackend): return [invalid_labels, version] def put_parameter( - self, name, description, value, type, allowed_pattern, keyid, overwrite + self, name, description, value, type, allowed_pattern, keyid, overwrite ): previous_parameter_versions = self._parameters[name] if len(previous_parameter_versions) == 0: diff --git a/moto/ssm/responses.py b/moto/ssm/responses.py index 45d2dec0a..c0e35b914 100644 --- a/moto/ssm/responses.py +++ b/moto/ssm/responses.py @@ -17,6 +17,98 @@ class SimpleSystemManagerResponse(BaseResponse): except ValueError: return {} + def create_document(self): + content = self._get_param("Content") + requires = self._get_param("Requires") + attachments = self._get_param("Attachments") + name = self._get_param("Name") + version_name = self._get_param("VersionName") + document_type = self._get_param("DocumentType") + document_format = self._get_param("DocumentFormat") + target_type = self._get_param("TargetType") + tags = self._get_param("Tags") + + result = self.ssm_backend.create_document(content=content, requires=requires, attachments=attachments, + name=name, version_name=version_name, document_type=document_type, + document_format=document_format, target_type=target_type, tags=tags) + + return { + 'DocumentDescription': result + } + + def delete_document(self): + name = self._get_param("Name") + document_version = self._get_param("DocumentVersion") + version_name = self._get_param("VersionName") + force = self._get_param("Force", False) + self.ssm_backend.delete_document(name=name, document_version=document_version, + version_name=version_name, force=force) + + return {} + + def get_document(self): + name = self._get_param("Name") + version_name = self._get_param("VersionName") + document_version = self._get_param("DocumentVersion") + document_format = self._get_param("DocumentFormat") + + document = self.ssm_backend.get_document(name=name, document_version=document_version, + document_format=document_format, version_name=version_name) + + return document + + def describe_document(self): + name = self._get_param("Name") + document_version = self._get_param("DocumentVersion") + version_name = self._get_param("VersionName") + + result = self.ssm_backend.describe_document(name=name, document_version=document_version, + version_name=version_name) + + return { + 'Document': result + } + + def update_document(self): + content = self._get_param("Content") + attachments = self._get_param("Attachments") + name = self._get_param("Name") + version_name = self._get_param("VersionName") + document_version = self._get_param("DocumentVersion") + document_format = self._get_param("DocumentFormat") + target_type = self._get_param("TargetType") + + result = self.ssm_backend.update_document(content=content, attachments=attachments, name=name, + version_name=version_name, document_version=document_version, + document_format=document_format, target_type=target_type) + + return { + 'DocumentDescription': result + } + + def update_document_default_version(self): + name = self._get_param("Name") + document_version = self._get_param("DocumentVersion") + + result = self.ssm_backend.update_document_default_version(name=name, document_version=document_version) + return { + 'Description': result + } + + def list_documents(self): + document_filter_list = self._get_param("DocumentFilterList") + filters = self._get_param("Filters") + max_results = self._get_param("MaxResults", 10) + next_token = self._get_param("NextToken") + + documents, token = self.ssm_backend.list_documents(document_filter_list=document_filter_list, filters=filters, + max_results=max_results, next_token=next_token) + + return { + "DocumentIdentifiers": documents, + "NextToken": token + } + def _get_param(self, param, default=None): return self.request_params.get(param, default) diff --git a/tests/test_ssm/test_ssm_docs.py b/tests/test_ssm/test_ssm_docs.py new file mode 100644 index 000000000..e69de29bb