From 3b9635b3c72acc1770ef9666dadec6c3c650e712 Mon Sep 17 00:00:00 2001 From: Brian Pandola Date: Sun, 8 Nov 2020 00:06:35 -0800 Subject: [PATCH] Add ssm:SendCommand support for instance tag Targets Replace the special-case code to handle Cloud Formation tags with a more generic implementation that covers all instance tags. Supersedes #2863 Closes #2862 --- moto/ssm/models.py | 48 +++++++++++++++----------------- tests/test_ssm/test_ssm_boto3.py | 35 ++++++++++++++++++++++- 2 files changed, 56 insertions(+), 27 deletions(-) diff --git a/moto/ssm/models.py b/moto/ssm/models.py index 07812c316..538e700f8 100644 --- a/moto/ssm/models.py +++ b/moto/ssm/models.py @@ -6,12 +6,11 @@ from collections import defaultdict from moto.core import ACCOUNT_ID, BaseBackend, BaseModel from moto.core.exceptions import RESTError -from moto.cloudformation import cloudformation_backends +from moto.ec2 import ec2_backends import datetime import time import uuid -import itertools import json import yaml import hashlib @@ -246,9 +245,6 @@ class Command(BaseModel): if targets is None: targets = [] - self.error_count = 0 - self.completed_count = len(instance_ids) - self.target_count = len(instance_ids) self.command_id = str(uuid.uuid4()) self.status = "Success" self.status_details = "Details placeholder" @@ -262,7 +258,6 @@ class Command(BaseModel): self.comment = comment self.document_name = document_name - self.instance_ids = instance_ids self.max_concurrency = max_concurrency self.max_errors = max_errors self.notification_config = notification_config @@ -274,14 +269,19 @@ class Command(BaseModel): self.targets = targets self.backend_region = backend_region - # Get instance ids from a cloud formation stack target. - stack_instance_ids = [ - self.get_instance_ids_by_stack_ids(target["Values"]) - for target in self.targets - if target["Key"] == "tag:aws:cloudformation:stack-name" - ] + self.instance_ids = instance_ids + self.instance_ids += self._get_instance_ids_from_targets() + # Ensure no duplicate instance_ids + self.instance_ids = list(set(self.instance_ids)) - self.instance_ids += list(itertools.chain.from_iterable(stack_instance_ids)) + # NOTE: All of these counts are 0 in the ssm:SendCommand response + # received from a real AWS backend. The counts are correct when + # making subsequent calls to ssm:DescribeCommand or ssm:ListCommands. + # Not likely to cause any problems, but perhaps an area for future + # improvement. + self.error_count = 0 + self.completed_count = len(instance_ids) + self.target_count = len(instance_ids) # Create invocations with a single run command plugin. self.invocations = [] @@ -290,19 +290,15 @@ class Command(BaseModel): self.invocation_response(instance_id, "aws:runShellScript") ) - def get_instance_ids_by_stack_ids(self, stack_ids): - instance_ids = [] - cloudformation_backend = cloudformation_backends[self.backend_region] - for stack_id in stack_ids: - stack_resources = cloudformation_backend.list_stack_resources(stack_id) - instance_resources = [ - instance.id - for instance in stack_resources - if instance.type == "AWS::EC2::Instance" - ] - instance_ids.extend(instance_resources) - - return instance_ids + def _get_instance_ids_from_targets(self): + target_instance_ids = [] + ec2_backend = ec2_backends[self.backend_region] + ec2_filters = {target["Key"]: target["Values"] for target in self.targets} + reservations = ec2_backend.all_reservations(filters=ec2_filters) + for reservation in reservations: + for instance in reservation.instances: + target_instance_ids.append(instance.id) + return target_instance_ids def response_object(self): r = { diff --git a/tests/test_ssm/test_ssm_boto3.py b/tests/test_ssm/test_ssm_boto3.py index 2f74759e9..c590e75b7 100644 --- a/tests/test_ssm/test_ssm_boto3.py +++ b/tests/test_ssm/test_ssm_boto3.py @@ -11,7 +11,7 @@ import uuid from botocore.exceptions import ClientError, ParamValidationError from nose.tools import assert_raises -from moto import mock_ssm +from moto import mock_ec2, mock_ssm @mock_ssm @@ -1713,3 +1713,36 @@ def test_get_command_invocation(): invocation_response = client.get_command_invocation( CommandId=cmd_id, InstanceId=instance_id, PluginName="FAKE" ) + + +@mock_ec2 +@mock_ssm +def test_get_command_invocations_by_instance_tag(): + ec2 = boto3.client("ec2", region_name="us-east-1") + ssm = boto3.client("ssm", region_name="us-east-1") + tag_specifications = [ + {"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": "test-tag"}]} + ] + num_instances = 3 + resp = ec2.run_instances( + ImageId="ami-1234abcd", + MaxCount=num_instances, + MinCount=num_instances, + TagSpecifications=tag_specifications, + ) + instance_ids = [] + for instance in resp["Instances"]: + instance_ids.append(instance["InstanceId"]) + instance_ids.should.have.length_of(num_instances) + + command_id = ssm.send_command( + DocumentName="AWS-RunShellScript", + Targets=[{"Key": "tag:Name", "Values": ["test-tag"]}], + )["Command"]["CommandId"] + + resp = ssm.list_commands(CommandId=command_id) + resp["Commands"][0]["TargetCount"].should.equal(num_instances) + + for instance_id in instance_ids: + resp = ssm.get_command_invocation(CommandId=command_id, InstanceId=instance_id) + resp["Status"].should.equal("Success")