feat(ssm): Add ssm documents permissions (#4217)

* refactor(ssm): Add Documents class to avoid dictionary handling
This also solves the datetime format issue in TestAccAWSCloudWatchEventTarget_ssmDocument
This commit is contained in:
Gonzalo Saad 2021-08-25 11:16:14 -03:00 committed by GitHub
parent f038859a37
commit 29b0122fac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 333 additions and 141 deletions

View File

@ -62,6 +62,22 @@ class DocumentAlreadyExists(JsonRESTError):
super(DocumentAlreadyExists, self).__init__("DocumentAlreadyExists", message)
class DocumentPermissionLimit(JsonRESTError):
code = 400
def __init__(self, message):
super(DocumentPermissionLimit, self).__init__(
"DocumentPermissionLimit", message
)
class InvalidPermissionType(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidPermissionType, self).__init__("InvalidPermissionType", message)
class InvalidDocument(JsonRESTError):
code = 400

View File

@ -1,6 +1,8 @@
from __future__ import unicode_literals
import re
from typing import Dict
from boto3 import Session
from collections import defaultdict
@ -33,6 +35,8 @@ from .exceptions import (
DuplicateDocumentVersionName,
DuplicateDocumentContent,
ParameterMaxVersionLimitExceeded,
DocumentPermissionLimit,
InvalidPermissionType,
)
@ -144,6 +148,144 @@ def generate_ssm_doc_param_list(parameters):
return param_list
class Documents(BaseModel):
def __init__(self, ssm_document):
version = ssm_document.document_version
self.versions = {version: ssm_document}
self.default_version = version
self.latest_version = version
self.permissions = {} # {AccountID: version }
def get_default_version(self):
return self.versions.get(self.default_version)
def get_latest_version(self):
return self.versions.get(self.latest_version)
def find_by_version_name(self, version_name):
return next(
(
document
for document in self.versions.values()
if document.version_name == version_name
),
None,
)
def find_by_version(self, version):
return self.versions.get(version)
def find_by_version_and_version_name(self, version, version_name):
return next(
(
document
for doc_version, document in self.versions.items()
if doc_version == version and document.version_name == version_name
),
None,
)
def find(self, document_version=None, version_name=None, strict=True):
if document_version == "$LATEST":
ssm_document = self.get_latest_version()
elif version_name and document_version:
ssm_document = self.find_by_version_and_version_name(
document_version, version_name
)
elif version_name:
ssm_document = self.find_by_version_name(version_name)
elif document_version:
ssm_document = self.find_by_version(document_version)
else:
ssm_document = self.get_default_version()
if strict and not ssm_document:
raise InvalidDocument("The specified document does not exist.")
return ssm_document
def exists(self, document_version=None, version_name=None):
return self.find(document_version, version_name, strict=False) is not None
def add_new_version(self, new_document_version):
version = new_document_version.document_version
self.latest_version = version
self.versions[version] = new_document_version
def update_default_version(self, version):
ssm_document = self.find_by_version(version)
if not ssm_document:
raise InvalidDocument("The specified document does not exist.")
self.default_version = version
return ssm_document
def delete(self, *versions):
for version in versions:
if version in self.versions:
del self.versions[version]
if self.versions and self.latest_version not in self.versions:
ordered_versions = sorted(self.versions.keys())
new_latest_version = ordered_versions[-1]
self.latest_version = new_latest_version
def describe(self, document_version=None, version_name=None):
document = self.find(document_version, version_name)
base = {
"Hash": document.hash,
"HashType": "Sha256",
"Name": document.name,
"Owner": document.owner,
"CreatedDate": document.created_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
"Status": document.status,
"DocumentVersion": document.document_version,
"Description": document.description,
"Parameters": document.parameter_list,
"PlatformTypes": document.platform_types,
"DocumentType": document.document_type,
"SchemaVersion": document.schema_version,
"LatestVersion": self.latest_version,
"DefaultVersion": self.default_version,
"DocumentFormat": document.document_format,
}
if document.version_name:
base["VersionName"] = document.version_name
if document.target_type:
base["TargetType"] = document.target_type
if document.tags:
base["Tags"] = document.tags
return base
def modify_permissions(self, accounts_to_add, accounts_to_remove, version):
if "all" in accounts_to_add:
self.permissions.clear()
else:
self.permissions.pop("all", None)
new_permissions = {account_id: version for account_id in accounts_to_add}
self.permissions.update(**new_permissions)
if "all" in accounts_to_remove:
self.permissions.clear()
else:
for account_id in accounts_to_remove:
self.permissions.pop(account_id, None)
def describe_permissions(self):
return {
"AccountIds": list(self.permissions.keys()),
"AccountSharingInfoList": [
{"AccountId": account_id, "SharedDocumentVersion": document_version}
for account_id, document_version in self.permissions.items()
],
}
def is_shared(self):
return len(self.permissions) > 0
class Document(BaseModel):
def __init__(
self,
@ -171,7 +313,7 @@ class Document(BaseModel):
self.status = "Active"
self.document_version = document_version
self.owner = ACCOUNT_ID
self.created_date = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
self.created_date = datetime.datetime.utcnow()
if document_format == "JSON":
try:
@ -220,6 +362,10 @@ class Document(BaseModel):
except KeyError:
raise InvalidDocumentContent("The content for the document is not valid.")
@property
def hash(self):
return hashlib.sha256(self.content.encode("utf-8")).hexdigest()
class Command(BaseModel):
def __init__(
@ -482,7 +628,7 @@ class SimpleSystemManagerBackend(BaseBackend):
self._resource_tags = defaultdict(lambda: defaultdict(dict))
self._commands = []
self._errors = []
self._documents = defaultdict(dict)
self._documents: Dict[str, Documents] = {}
self._region = region_name
@ -491,53 +637,17 @@ class SimpleSystemManagerBackend(BaseBackend):
self.__dict__ = {}
self.__init__(region_name)
def _generate_document_description(self, document):
latest = self._documents[document.name]["latest_version"]
default_version = self._documents[document.name]["default_version"]
base = {
"Hash": hashlib.sha256(document.content.encode("utf-8")).hexdigest(),
"HashType": "Sha256",
"Name": document.name,
"Owner": document.owner,
"CreatedDate": document.created_date,
"Status": document.status,
"DocumentVersion": document.document_version,
"Description": document.description,
"Parameters": document.parameter_list,
"PlatformTypes": document.platform_types,
"DocumentType": document.document_type,
"SchemaVersion": document.schema_version,
"LatestVersion": latest,
"DefaultVersion": default_version,
"DocumentFormat": document.document_format,
}
if document.version_name:
base["VersionName"] = document.version_name
if document.target_type:
base["TargetType"] = document.target_type
if document.tags:
base["Tags"] = document.tags
return base
def _generate_document_information(self, ssm_document, document_format):
content = self._get_document_content(document_format, ssm_document)
base = {
"Name": ssm_document.name,
"DocumentVersion": ssm_document.document_version,
"Status": ssm_document.status,
"Content": ssm_document.content,
"Content": content,
"DocumentType": ssm_document.document_type,
"DocumentFormat": document_format,
}
if document_format == "JSON":
base["Content"] = json.dumps(ssm_document.content_json)
elif document_format == "YAML":
base["Content"] = yaml.dump(ssm_document.content_json)
else:
raise ValidationException("Invalid document format " + str(document_format))
if ssm_document.version_name:
base["VersionName"] = ssm_document.version_name
if ssm_document.requires:
@ -547,7 +657,20 @@ class SimpleSystemManagerBackend(BaseBackend):
return base
def _generate_document_list_information(self, ssm_document):
@staticmethod
def _get_document_content(document_format, ssm_document):
if document_format == ssm_document.document_format:
content = ssm_document.content
elif document_format == "JSON":
content = json.dumps(ssm_document.content_json)
elif document_format == "YAML":
content = yaml.dump(ssm_document.content_json)
else:
raise ValidationException("Invalid document format " + str(document_format))
return content
@staticmethod
def _generate_document_list_information(ssm_document):
base = {
"Name": ssm_document.name,
"Owner": ssm_document.owner,
@ -569,6 +692,12 @@ class SimpleSystemManagerBackend(BaseBackend):
return base
def _get_documents(self, name):
documents = self._documents.get(name)
if not documents:
raise InvalidDocument("The specified document does not exist.")
return documents
def create_document(
self,
content,
@ -603,24 +732,23 @@ class SimpleSystemManagerBackend(BaseBackend):
if self._documents.get(ssm_document.name):
raise DocumentAlreadyExists("The specified document already exists.")
self._documents[ssm_document.name] = {
"documents": {ssm_document.document_version: ssm_document},
"default_version": ssm_document.document_version,
"latest_version": ssm_document.document_version,
}
documents = Documents(ssm_document)
self._documents[ssm_document.name] = documents
return self._generate_document_description(ssm_document)
return documents.describe()
def delete_document(self, name, document_version, version_name, force):
documents = self._documents.get(name, {}).get("documents", {})
documents = self._get_documents(name)
if documents.is_shared():
raise InvalidDocumentOperation("Must unshare document first before delete")
keys_to_delete = set()
if documents:
default_version = self._documents[name]["default_version"]
default_doc = documents.get_default_version()
if (
documents[default_version].document_type
== "ApplicationConfigurationSchema"
default_doc.document_type == "ApplicationConfigurationSchema"
and not force
):
raise InvalidDocumentOperation(
@ -628,20 +756,20 @@ class SimpleSystemManagerBackend(BaseBackend):
"You must stop sharing the document before you can delete it."
)
if document_version and document_version == default_version:
if document_version and document_version == default_doc.document_version:
raise InvalidDocumentOperation(
"Default version of the document can't be deleted."
)
if document_version or version_name:
# We delete only a specific version
delete_doc = self._find_document(name, document_version, version_name)
delete_doc = documents.find(document_version, version_name)
# we can't delete only the default version
if (
delete_doc
and delete_doc.document_version == default_version
and len(documents) != 1
and delete_doc.document_version == default_doc.document_version
and len(documents.versions) != 1
):
raise InvalidDocumentOperation(
"Default version of the document can't be deleted."
@ -653,64 +781,18 @@ class SimpleSystemManagerBackend(BaseBackend):
raise InvalidDocument("The specified document does not exist.")
else:
# We are deleting all versions
keys_to_delete = set(documents.keys())
keys_to_delete = set(documents.versions.keys())
for key in keys_to_delete:
del self._documents[name]["documents"][key]
documents.delete(*keys_to_delete)
if len(self._documents[name]["documents"].keys()) == 0:
if len(documents.versions) == 0:
del self._documents[name]
else:
old_latest = self._documents[name]["latest_version"]
if old_latest not in self._documents[name]["documents"].keys():
leftover_keys = self._documents[name]["documents"].keys()
int_keys = []
for key in leftover_keys:
int_keys.append(int(key))
self._documents[name]["latest_version"] = str(sorted(int_keys)[-1])
else:
raise InvalidDocument("The specified document does not exist.")
def _find_document(
self, name, document_version=None, version_name=None, strict=True
):
if not self._documents.get(name):
raise InvalidDocument("The specified document does not exist.")
documents = self._documents[name]["documents"]
ssm_document = None
if not version_name and not document_version:
# Retrieve default version
default_version = self._documents[name]["default_version"]
ssm_document = documents.get(default_version)
elif version_name and document_version:
for doc_version, document in documents.items():
if (
doc_version == document_version
and document.version_name == version_name
):
ssm_document = document
break
else:
for doc_version, document in documents.items():
if document_version and doc_version == document_version:
ssm_document = document
break
if version_name and document.version_name == version_name:
ssm_document = document
break
if strict and not ssm_document:
raise InvalidDocument("The specified document does not exist.")
return ssm_document
def get_document(self, name, document_version, version_name, document_format):
ssm_document = self._find_document(name, document_version, version_name)
documents = self._get_documents(name)
ssm_document = documents.find(document_version, version_name)
if not document_format:
document_format = ssm_document.document_format
else:
@ -719,18 +801,18 @@ class SimpleSystemManagerBackend(BaseBackend):
return self._generate_document_information(ssm_document, document_format)
def update_document_default_version(self, name, document_version):
documents = self._get_documents(name)
ssm_document = documents.update_default_version(document_version)
ssm_document = self._find_document(name, document_version=document_version)
self._documents[name]["default_version"] = document_version
base = {
result = {
"Name": ssm_document.name,
"DefaultVersion": document_version,
}
if ssm_document.version_name:
base["DefaultVersionName"] = ssm_document.version_name
result["DefaultVersionName"] = ssm_document.version_name
return base
return result
def update_document(
self,
@ -750,24 +832,27 @@ class SimpleSystemManagerBackend(BaseBackend):
strict=False,
)
if not self._documents.get(name):
documents = self._documents.get(name)
if not documents:
raise InvalidDocument("The specified document does not exist.")
if (
self._documents[name]["latest_version"] != document_version
documents.latest_version != document_version
and document_version != "$LATEST"
):
raise InvalidDocumentVersion(
"The document version is not valid or does not exist."
)
if version_name and self._find_document(
name, version_name=version_name, strict=False
):
raise DuplicateDocumentVersionName(
"The specified version name is a duplicate."
)
old_ssm_document = self._find_document(name)
if version_name:
if documents.exists(version_name=version_name):
raise DuplicateDocumentVersionName(
"The specified version name is a duplicate."
)
old_ssm_document = documents.get_default_version()
new_version = str(int(documents.latest_version) + 1)
new_ssm_document = Document(
name=name,
version_name=version_name,
@ -778,28 +863,22 @@ class SimpleSystemManagerBackend(BaseBackend):
attachments=attachments,
target_type=target_type,
tags=old_ssm_document.tags,
document_version=str(int(self._documents[name]["latest_version"]) + 1),
document_version=new_version,
)
for doc_version, document in self._documents[name]["documents"].items():
for doc_version, document in documents.versions.items():
if document.content == new_ssm_document.content:
raise DuplicateDocumentContent(
"The content of the association document matches another document. "
"Change the content of the document and try again."
)
self._documents[name]["latest_version"] = str(
int(self._documents[name]["latest_version"]) + 1
)
self._documents[name]["documents"][
new_ssm_document.document_version
] = new_ssm_document
return self._generate_document_description(new_ssm_document)
documents.add_new_version(new_ssm_document)
return documents.describe(document_version=new_version)
def describe_document(self, name, document_version, version_name):
ssm_document = self._find_document(name, document_version, version_name)
return self._generate_document_description(ssm_document)
documents = self._get_documents(name)
return documents.describe(document_version, version_name)
def list_documents(
self, document_filter_list, filters, max_results=10, next_token="0"
@ -813,17 +892,16 @@ class SimpleSystemManagerBackend(BaseBackend):
results = []
dummy_token_tracker = 0
# Sort to maintain next token adjacency
for document_name, document_bundle in sorted(self._documents.items()):
for document_name, documents in sorted(self._documents.items()):
if len(results) == max_results:
# There's still more to go so we need a next token
return results, str(next_token + len(results))
if dummy_token_tracker < next_token:
dummy_token_tracker = dummy_token_tracker + 1
dummy_token_tracker += 1
continue
default_version = document_bundle["default_version"]
ssm_doc = self._documents[document_name]["documents"][default_version]
ssm_doc = documents.get_default_version()
if filters and not _document_filter_match(filters, ssm_doc):
# If we have filters enabled, and we don't match them,
continue
@ -833,6 +911,72 @@ class SimpleSystemManagerBackend(BaseBackend):
# If we've fallen out of the loop, theres no more documents. No next token.
return results, ""
def describe_document_permission(
self, name, max_results=None, permission_type=None, next_token=None
):
# Parameters max_results, permission_type, and next_token not used because
# this current implementation doesn't support pagination.
document = self._get_documents(name)
return document.describe_permissions()
def modify_document_permission(
self,
name,
account_ids_to_add,
account_ids_to_remove,
shared_document_version,
permission_type,
):
account_id_regex = re.compile(r"(all|[0-9]{12})")
version_regex = re.compile(r"^([$]LATEST|[$]DEFAULT|[$]ALL)$")
account_ids_to_add = account_ids_to_add or []
account_ids_to_remove = account_ids_to_remove or []
if not version_regex.match(shared_document_version):
raise ValidationException(
f"Value '{shared_document_version}' at 'sharedDocumentVersion' failed to satisfy constraint: "
f"Member must satisfy regular expression pattern: ([$]LATEST|[$]DEFAULT|[$]ALL)."
)
for account_id in account_ids_to_add:
if not account_id_regex.match(account_id):
raise ValidationException(
f"Value '[{account_id}]' at 'accountIdsToAdd' failed to satisfy constraint: "
"Member must satisfy regular expression pattern: (all|[0-9]{12}])."
)
for account_id in account_ids_to_remove:
if not account_id_regex.match(account_id):
raise ValidationException(
f"Value '[{account_id}]' at 'accountIdsToRemove' failed to satisfy constraint: "
"Member must satisfy regular expression pattern: (?i)all|[0-9]{12}]."
)
accounts_to_add = set(account_ids_to_add)
if "all" in accounts_to_add and len(accounts_to_add) > 1:
raise DocumentPermissionLimit(
"Accounts can either be all or a group of AWS accounts"
)
accounts_to_remove = set(account_ids_to_remove)
if "all" in accounts_to_remove and len(accounts_to_remove) > 1:
raise DocumentPermissionLimit(
"Accounts can either be all or a group of AWS accounts"
)
if permission_type != "Share":
raise InvalidPermissionType(
f"Value '{permission_type}' at 'permissionType' failed to satisfy constraint: "
"Member must satisfy enum value set: [Share]."
)
document = self._get_documents(name)
document.modify_permissions(
accounts_to_add, accounts_to_remove, shared_document_version
)
def delete_parameter(self, name):
return self._parameters.pop(name, None)

View File

@ -127,6 +127,35 @@ class SimpleSystemManagerResponse(BaseResponse):
return json.dumps({"DocumentIdentifiers": documents, "NextToken": token})
def describe_document_permission(self):
name = self._get_param("Name")
max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken")
permission_type = self._get_param("PermissionType")
result = self.ssm_backend.describe_document_permission(
name=name,
max_results=max_results,
next_token=next_token,
permission_type=permission_type,
)
return json.dumps(result)
def modify_document_permission(self):
account_ids_to_add = self._get_param("AccountIdsToAdd")
account_ids_to_remove = self._get_param("AccountIdsToRemove")
name = self._get_param("Name")
permission_type = self._get_param("PermissionType")
shared_document_version = self._get_param("SharedDocumentVersion")
self.ssm_backend.modify_document_permission(
name=name,
account_ids_to_add=account_ids_to_add,
account_ids_to_remove=account_ids_to_remove,
shared_document_version=shared_document_version,
permission_type=permission_type,
)
def _get_param(self, param, default=None):
return self.request_params.get(param, default)

View File

@ -11,6 +11,7 @@ TestAccAWSCloudwatchEventBusPolicy
TestAccAWSCloudWatchEventConnection
TestAccAWSCloudWatchEventPermission
TestAccAWSCloudWatchEventRule
TestAccAWSCloudWatchEventTarget_ssmDocument
TestAccAWSCloudwatchLogGroupDataSource
TestAccAWSCloudWatchMetricAlarm
TestAccAWSDataSourceCloudwatch
@ -87,5 +88,6 @@ TestAccAWSRouteTable_IPv4_To_TransitGateway
TestAccAWSRouteTable_IPv4_To_VpcPeeringConnection
TestAccAWSRouteTable_disappears
TestAccAWSRouteTable_basic
TestAccAWSSsmDocumentDataSource
TestAccAwsEc2ManagedPrefixList
TestAccAWSEgressOnlyInternetGateway
TestAccAWSEgressOnlyInternetGateway

View File

@ -4,6 +4,7 @@ import boto3
import botocore.exceptions
import sure # noqa
import datetime
from datetime import timezone
import json
import yaml
import hashlib
@ -42,7 +43,7 @@ def _validate_document_description(
doc_description["Name"].should.equal(doc_name)
doc_description["Owner"].should.equal(ACCOUNT_ID)
difference = datetime.datetime.utcnow() - doc_description["CreatedDate"]
difference = datetime.datetime.now(tz=timezone.utc) - doc_description["CreatedDate"]
if difference.min > datetime.timedelta(minutes=1):
assert False
@ -224,7 +225,7 @@ def test_create_document():
doc_description["Name"].should.equal("EmptyParamDoc")
doc_description["Owner"].should.equal(ACCOUNT_ID)
difference = datetime.datetime.utcnow() - doc_description["CreatedDate"]
difference = datetime.datetime.now(tz=timezone.utc) - doc_description["CreatedDate"]
if difference.min > datetime.timedelta(minutes=1):
assert False