diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index db77f06e4..8744f4759 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -3031,17 +3031,17 @@ ## ecs
-62% implemented +73% implemented - [ ] create_capacity_provider - [X] create_cluster - [X] create_service -- [ ] create_task_set +- [X] create_task_set - [ ] delete_account_setting - [X] delete_attributes - [X] delete_cluster - [X] delete_service -- [ ] delete_task_set +- [X] delete_task_set - [X] deregister_container_instance - [X] deregister_task_definition - [ ] describe_capacity_providers @@ -3049,7 +3049,7 @@ - [X] describe_container_instances - [X] describe_services - [X] describe_task_definition -- [ ] describe_task_sets +- [X] describe_task_sets - [X] describe_tasks - [ ] discover_poll_endpoint - [ ] list_account_settings @@ -3079,8 +3079,8 @@ - [ ] update_container_agent - [X] update_container_instances_state - [X] update_service -- [ ] update_service_primary_task_set -- [ ] update_task_set +- [X] update_service_primary_task_set +- [X] update_task_set
## efs diff --git a/moto/ecs/exceptions.py b/moto/ecs/exceptions.py index d08066192..72129224e 100644 --- a/moto/ecs/exceptions.py +++ b/moto/ecs/exceptions.py @@ -21,3 +21,22 @@ class TaskDefinitionNotFoundException(JsonRESTError): error_type="ClientException", message="The specified task definition does not exist.", ) + + +class TaskSetNotFoundException(JsonRESTError): + code = 400 + + def __init__(self): + super(TaskSetNotFoundException, self).__init__( + error_type="ClientException", + message="The specified task set does not exist.", + ) + + +class ClusterNotFoundException(JsonRESTError): + code = 400 + + def __init__(self): + super(ClusterNotFoundException, self).__init__( + error_type="ClientException", message="Cluster not found", + ) diff --git a/moto/ecs/models.py b/moto/ecs/models.py index 1a385226b..36c7cd44a 100644 --- a/moto/ecs/models.py +++ b/moto/ecs/models.py @@ -13,7 +13,12 @@ from moto.core.utils import unix_time from moto.ec2 import ec2_backends from copy import copy -from .exceptions import ServiceNotFoundException, TaskDefinitionNotFoundException +from .exceptions import ( + ServiceNotFoundException, + TaskDefinitionNotFoundException, + TaskSetNotFoundException, + ClusterNotFoundException, +) class BaseObject(BaseModel): @@ -176,7 +181,6 @@ class TaskDefinition(BaseObject): cls, original_resource, new_resource_name, cloudformation_json, region_name ): properties = cloudformation_json["Properties"] - family = properties.get( "Family", "task-definition-{0}".format(int(random() * 10 ** 6)) ) @@ -236,11 +240,12 @@ class Service(BaseObject): self, cluster, service_name, - task_definition, desired_count, + task_definition=None, load_balancers=None, scheduling_strategy=None, tags=None, + deployment_controller=None, ): self.cluster_arn = cluster.arn self.arn = "arn:aws:ecs:{0}:012345678910:service/{1}".format( @@ -249,21 +254,29 @@ class Service(BaseObject): self.name = service_name self.status = "ACTIVE" self.running_count = 0 - self.task_definition = task_definition.arn + if task_definition: + self.task_definition = task_definition.arn + else: + self.task_definition = None self.desired_count = desired_count + self.task_sets = [] + self.deployment_controller = deployment_controller or {"type": "ECS"} self.events = [] - self.deployments = [ - { - "createdAt": datetime.now(pytz.utc), - "desiredCount": self.desired_count, - "id": "ecs-svc/{}".format(randint(0, 32 ** 12)), - "pendingCount": self.desired_count, - "runningCount": 0, - "status": "PRIMARY", - "taskDefinition": task_definition.arn, - "updatedAt": datetime.now(pytz.utc), - } - ] + if self.deployment_controller["type"] == "ECS": + self.deployments = [ + { + "createdAt": datetime.now(pytz.utc), + "desiredCount": self.desired_count, + "id": "ecs-svc/{}".format(randint(0, 32 ** 12)), + "pendingCount": self.desired_count, + "runningCount": 0, + "status": "PRIMARY", + "taskDefinition": self.task_definition, + "updatedAt": datetime.now(pytz.utc), + } + ] + else: + self.deployments = [] self.load_balancers = load_balancers if load_balancers is not None else [] self.scheduling_strategy = ( scheduling_strategy if scheduling_strategy is not None else "REPLICA" @@ -282,6 +295,13 @@ class Service(BaseObject): response_object["serviceName"] = self.name response_object["serviceArn"] = self.arn response_object["schedulingStrategy"] = self.scheduling_strategy + if response_object["deploymentController"]["type"] == "ECS": + del response_object["deploymentController"] + del response_object["taskSets"] + else: + response_object["taskSets"] = [ + t.response_object for t in response_object["taskSets"] + ] for deployment in response_object["deployments"]: if isinstance(deployment["createdAt"], datetime): @@ -315,7 +335,7 @@ class Service(BaseObject): ecs_backend = ecs_backends[region_name] return ecs_backend.create_service( - cluster, service_name, task_definition, desired_count + cluster, service_name, desired_count, task_definition_str=task_definition ) @classmethod @@ -343,7 +363,10 @@ class Service(BaseObject): cluster_name, int(random() * 10 ** 6) ) return ecs_backend.create_service( - cluster_name, new_service_name, task_definition, desired_count + cluster_name, + new_service_name, + desired_count, + task_definition_str=task_definition, ) else: return ecs_backend.update_service( @@ -494,6 +517,73 @@ class ContainerInstanceFailure(BaseObject): return response_object +class TaskSet(BaseObject): + def __init__( + self, + service, + cluster, + task_definition, + region_name, + external_id=None, + network_configuration=None, + load_balancers=None, + service_registries=None, + launch_type=None, + capacity_provider_strategy=None, + platform_version=None, + scale=None, + client_token=None, + tags=None, + ): + self.service = service + self.cluster = cluster + self.status = "ACTIVE" + self.task_definition = task_definition or "" + self.region_name = region_name + self.external_id = external_id or "" + self.network_configuration = network_configuration or {} + self.load_balancers = load_balancers or [] + self.service_registries = service_registries or [] + self.launch_type = launch_type + self.capacity_provider_strategy = capacity_provider_strategy or [] + self.platform_version = platform_version or "" + self.scale = scale or {"value": 100.0, "unit": "PERCENT"} + self.client_token = client_token or "" + self.tags = tags or [] + self.stabilityStatus = "STEADY_STATE" + self.createdAt = datetime.now(pytz.utc) + self.updatedAt = datetime.now(pytz.utc) + self.stabilityStatusAt = datetime.now(pytz.utc) + self.id = "ecs-svc/{}".format(randint(0, 32 ** 12)) + self.service_arn = "" + self.cluster_arn = "" + + cluster_name = self.cluster.split("/")[-1] + service_name = self.service.split("/")[-1] + self.task_set_arn = "arn:aws:ecs:{0}:012345678910:task-set/{1}/{2}/{3}".format( + region_name, cluster_name, service_name, self.id + ) + + @property + def response_object(self): + response_object = self.gen_response_object() + if isinstance(response_object["createdAt"], datetime): + response_object["createdAt"] = unix_time( + self.createdAt.replace(tzinfo=None) + ) + if isinstance(response_object["updatedAt"], datetime): + response_object["updatedAt"] = unix_time( + self.updatedAt.replace(tzinfo=None) + ) + if isinstance(response_object["stabilityStatusAt"], datetime): + response_object["stabilityStatusAt"] = unix_time( + self.stabilityStatusAt.replace(tzinfo=None) + ) + del response_object["service"] + del response_object["cluster"] + return response_object + + class EC2ContainerServiceBackend(BaseBackend): def __init__(self, region_name): super(EC2ContainerServiceBackend, self).__init__() @@ -502,6 +592,7 @@ class EC2ContainerServiceBackend(BaseBackend): self.tasks = {} self.services = {} self.container_instances = {} + self.task_sets = {} self.region_name = region_name def reset(self): @@ -871,28 +962,33 @@ class EC2ContainerServiceBackend(BaseBackend): self, cluster_str, service_name, - task_definition_str, desired_count, + task_definition_str=None, load_balancers=None, scheduling_strategy=None, tags=None, + deployment_controller=None, ): cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: cluster = self.clusters[cluster_name] else: raise Exception("{0} is not a cluster".format(cluster_name)) - task_definition = self.describe_task_definition(task_definition_str) + if task_definition_str is not None: + task_definition = self.describe_task_definition(task_definition_str) + else: + task_definition = None desired_count = desired_count if desired_count is not None else 0 service = Service( cluster, service_name, - task_definition, desired_count, + task_definition, load_balancers, scheduling_strategy, tags, + deployment_controller, ) cluster_service_pair = "{0}:{1}".format(cluster_name, service_name) self.services[cluster_service_pair] = service @@ -928,6 +1024,7 @@ class EC2ContainerServiceBackend(BaseBackend): or existing_service_obj.arn == requested_name_or_arn ): result.append(existing_service_obj) + return result def update_service( @@ -1101,9 +1198,7 @@ class EC2ContainerServiceBackend(BaseBackend): def put_attributes(self, cluster_name, attributes=None): if cluster_name is None or cluster_name not in self.clusters: - raise JsonRESTError( - "ClusterNotFoundException", "Cluster not found", status=400 - ) + raise ClusterNotFoundException if attributes is None: raise JsonRESTError( @@ -1192,9 +1287,7 @@ class EC2ContainerServiceBackend(BaseBackend): def delete_attributes(self, cluster_name, attributes=None): if cluster_name is None or cluster_name not in self.clusters: - raise JsonRESTError( - "ClusterNotFoundException", "Cluster not found", status=400 - ) + raise ClusterNotFoundException if attributes is None: raise JsonRESTError( @@ -1327,6 +1420,134 @@ class EC2ContainerServiceBackend(BaseBackend): raise ServiceNotFoundException(service_name=parsed_arn["id"]) raise NotImplementedError() + def create_task_set( + self, + service, + cluster, + task_definition, + external_id=None, + network_configuration=None, + load_balancers=None, + service_registries=None, + launch_type=None, + capacity_provider_strategy=None, + platform_version=None, + scale=None, + client_token=None, + tags=None, + ): + task_set = TaskSet( + service, + cluster, + task_definition, + self.region_name, + external_id=external_id, + network_configuration=network_configuration, + load_balancers=load_balancers, + service_registries=service_registries, + launch_type=launch_type, + capacity_provider_strategy=capacity_provider_strategy, + platform_version=platform_version, + scale=scale, + client_token=client_token, + tags=tags, + ) + + cluster_name = cluster.split("/")[-1] + service_name = service.split("/")[-1] + + service_obj = self.services.get("{0}:{1}".format(cluster_name, service_name)) + if not service_obj: + raise ServiceNotFoundException(service_name=service_name) + + cluster_obj = self.clusters.get(cluster_name) + if not cluster_obj: + raise ClusterNotFoundException + + task_set.task_definition = self.describe_task_definition(task_definition).arn + task_set.service_arn = service_obj.arn + task_set.cluster_arn = cluster_obj.arn + + service_obj.task_sets.append(task_set) + # TODO: validate load balancers + + return task_set + + def describe_task_sets(self, cluster, service, task_sets=None, include=None): + task_sets = task_sets or [] + include = include or [] + + cluster_name = cluster.split("/")[-1] + service_name = service.split("/")[-1] + service_key = "{0}:{1}".format(cluster_name, service_name) + + service_obj = self.services.get(service_key) + if not service_obj: + raise ServiceNotFoundException(service_name=service_name) + + cluster_obj = self.clusters.get(cluster_name) + if not cluster_obj: + raise ClusterNotFoundException + + task_set_results = [] + if task_sets: + for task_set in service_obj.task_sets: + if task_set.task_set_arn in task_sets: + task_set_results.append(task_set) + else: + task_set_results = service_obj.task_sets + + return task_set_results + + def delete_task_set(self, cluster, service, task_set, force=False): + cluster_name = cluster.split("/")[-1] + service_name = service.split("/")[-1] + + service_key = "{0}:{1}".format(cluster_name, service_name) + task_set_element = None + for i, ts in enumerate(self.services[service_key].task_sets): + if task_set == ts.task_set_arn: + task_set_element = i + + if task_set_element is not None: + deleted_task_set = self.services[service_key].task_sets.pop( + task_set_element + ) + else: + raise TaskSetNotFoundException + + # TODO: add logic for `force` to raise an exception if `PRIMARY` task has not been scaled to 0. + + return deleted_task_set + + def update_task_set(self, cluster, service, task_set, scale): + cluster_name = cluster.split("/")[-1] + service_name = service.split("/")[-1] + task_set_obj = self.describe_task_sets( + cluster_name, service_name, task_sets=[task_set] + )[0] + task_set_obj.scale = scale + return task_set_obj + + def update_service_primary_task_set(self, cluster, service, primary_task_set): + """ Updates task sets be PRIMARY or ACTIVE for given cluster:service task sets """ + cluster_name = cluster.split("/")[-1] + service_name = service.split("/")[-1] + task_set_obj = self.describe_task_sets( + cluster_name, service_name, task_sets=[primary_task_set] + )[0] + + service_obj = self.describe_services(cluster, [service])[0] + service_obj.load_balancers = task_set_obj.load_balancers + service_obj.task_definition = task_set_obj.task_definition + + for task_set in service_obj.task_sets: + if task_set.task_set_arn == primary_task_set: + task_set.status = "PRIMARY" + else: + task_set.status = "ACTIVE" + return task_set_obj + ecs_backends = {} for region in Session().get_available_regions("ecs"): diff --git a/moto/ecs/responses.py b/moto/ecs/responses.py index c8f1e06ce..e911bb943 100644 --- a/moto/ecs/responses.py +++ b/moto/ecs/responses.py @@ -162,14 +162,16 @@ class EC2ContainerServiceResponse(BaseResponse): load_balancers = self._get_param("loadBalancers") scheduling_strategy = self._get_param("schedulingStrategy") tags = self._get_param("tags") + deployment_controller = self._get_param("deploymentController") service = self.ecs_backend.create_service( cluster_str, service_name, - task_definition_str, desired_count, + task_definition_str, load_balancers, scheduling_strategy, tags, + deployment_controller, ) return json.dumps({"service": service.response_object}) @@ -189,6 +191,7 @@ class EC2ContainerServiceResponse(BaseResponse): cluster_str = self._get_param("cluster") service_names = self._get_param("services") services = self.ecs_backend.describe_services(cluster_str, service_names) + return json.dumps( { "services": [service.response_object for service in services], @@ -347,3 +350,80 @@ class EC2ContainerServiceResponse(BaseResponse): tag_keys = self._get_param("tagKeys") results = self.ecs_backend.untag_resource(resource_arn, tag_keys) return json.dumps(results) + + def create_task_set(self): + service_str = self._get_param("service") + cluster_str = self._get_param("cluster") + task_definition = self._get_param("taskDefinition") + external_id = self._get_param("externalId") + network_configuration = self._get_param("networkConfiguration") + load_balancers = self._get_param("loadBalancers") + service_registries = self._get_param("serviceRegistries") + launch_type = self._get_param("launchType") + capacity_provider_strategy = self._get_param("capacityProviderStrategy") + platform_version = self._get_param("platformVersion") + scale = self._get_param("scale") + client_token = self._get_param("clientToken") + tags = self._get_param("tags") + task_set = self.ecs_backend.create_task_set( + service_str, + cluster_str, + task_definition, + external_id=external_id, + network_configuration=network_configuration, + load_balancers=load_balancers, + service_registries=service_registries, + launch_type=launch_type, + capacity_provider_strategy=capacity_provider_strategy, + platform_version=platform_version, + scale=scale, + client_token=client_token, + tags=tags, + ) + return json.dumps({"taskSet": task_set.response_object}) + + def describe_task_sets(self): + cluster_str = self._get_param("cluster") + service_str = self._get_param("service") + task_sets = self._get_param("taskSets") + include = self._get_param("include", []) + task_set_objs = self.ecs_backend.describe_task_sets( + cluster_str, service_str, task_sets, include + ) + + response_objs = [t.response_object for t in task_set_objs] + if "TAGS" not in include: + for ro in response_objs: + del ro["tags"] + return json.dumps({"taskSets": response_objs}) + + def delete_task_set(self): + cluster_str = self._get_param("cluster") + service_str = self._get_param("service") + task_set = self._get_param("taskSet") + force = self._get_param("force") + task_set = self.ecs_backend.delete_task_set( + cluster_str, service_str, task_set, force + ) + return json.dumps({"taskSet": task_set.response_object}) + + def update_task_set(self): + cluster_str = self._get_param("cluster") + service_str = self._get_param("service") + task_set = self._get_param("taskSet") + scale = self._get_param("scale") + + task_set = self.ecs_backend.update_task_set( + cluster_str, service_str, task_set, scale + ) + return json.dumps({"taskSet": task_set.response_object}) + + def update_service_primary_task_set(self): + cluster_str = self._get_param("cluster") + service_str = self._get_param("service") + primary_task_set = self._get_param("primaryTaskSet") + + task_set = self.ecs_backend.update_service_primary_task_set( + cluster_str, service_str, primary_task_set + ) + return json.dumps({"taskSet": task_set.response_object}) diff --git a/tests/test_ecs/test_ecs_boto3.py b/tests/test_ecs/test_ecs_boto3.py index f6de59597..3ef62582e 100644 --- a/tests/test_ecs/test_ecs_boto3.py +++ b/tests/test_ecs/test_ecs_boto3.py @@ -2637,3 +2637,332 @@ def test_ecs_task_definition_placement_constraints(): response["taskDefinition"]["placementConstraints"].should.equal( [{"type": "memberOf", "expression": "attribute:ecs.instance-type =~ t2.*"}] ) + + +@mock_ecs +def test_create_task_set(): + cluster_name = "test_ecs_cluster" + service_name = "test_ecs_service" + task_def_name = "test_ecs_task" + + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName=cluster_name) + _ = client.register_task_definition( + family="test_ecs_task", + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + _ = client.create_service( + cluster=cluster_name, + serviceName=service_name, + taskDefinition=task_def_name, + desiredCount=2, + deploymentController={"type": "EXTERNAL"}, + ) + load_balancers = [ + { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-1:01234567890:targetgroup/c26b93c1bc35466ba792d5b08fe6a5bc/ec39113f8831453a", + "containerName": "hello_world", + "containerPort": 8080, + }, + ] + + task_set = client.create_task_set( + cluster=cluster_name, + service=service_name, + taskDefinition=task_def_name, + loadBalancers=load_balancers, + )["taskSet"] + + cluster_arn = client.describe_clusters(clusters=[cluster_name])["clusters"][0][ + "clusterArn" + ] + service_arn = client.describe_services( + cluster=cluster_name, services=[service_name] + )["services"][0]["serviceArn"] + assert task_set["clusterArn"] == cluster_arn + assert task_set["serviceArn"] == service_arn + assert task_set["taskDefinition"].endswith("{0}:1".format(task_def_name)) + assert task_set["scale"] == {"value": 100.0, "unit": "PERCENT"} + assert ( + task_set["loadBalancers"][0]["targetGroupArn"] + == "arn:aws:elasticloadbalancing:us-east-1:01234567890:targetgroup/c26b93c1bc35466ba792d5b08fe6a5bc/ec39113f8831453a" + ) + assert task_set["loadBalancers"][0]["containerPort"] == 8080 + assert task_set["loadBalancers"][0]["containerName"] == "hello_world" + + +@mock_ecs +def test_describe_task_sets(): + cluster_name = "test_ecs_cluster" + service_name = "test_ecs_service" + task_def_name = "test_ecs_task" + + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName=cluster_name) + _ = client.register_task_definition( + family=task_def_name, + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + _ = client.create_service( + cluster=cluster_name, + serviceName=service_name, + taskDefinition=task_def_name, + desiredCount=2, + deploymentController={"type": "EXTERNAL"}, + ) + + load_balancers = [ + { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-1:01234567890:targetgroup/c26b93c1bc35466ba792d5b08fe6a5bc/ec39113f8831453a", + "containerName": "hello_world", + "containerPort": 8080, + } + ] + + _ = client.create_task_set( + cluster=cluster_name, + service=service_name, + taskDefinition=task_def_name, + loadBalancers=load_balancers, + ) + task_sets = client.describe_task_sets(cluster=cluster_name, service=service_name)[ + "taskSets" + ] + assert "tags" not in task_sets[0] + + task_sets = client.describe_task_sets( + cluster=cluster_name, service=service_name, include=["TAGS"], + )["taskSets"] + + cluster_arn = client.describe_clusters(clusters=[cluster_name])["clusters"][0][ + "clusterArn" + ] + + service_arn = client.describe_services( + cluster=cluster_name, services=[service_name] + )["services"][0]["serviceArn"] + + assert "tags" in task_sets[0] + assert len(task_sets) == 1 + assert task_sets[0]["taskDefinition"].endswith("{0}:1".format(task_def_name)) + assert task_sets[0]["clusterArn"] == cluster_arn + assert task_sets[0]["serviceArn"] == service_arn + assert task_sets[0]["serviceArn"].endswith(service_name) + assert task_sets[0]["scale"] == {"value": 100.0, "unit": "PERCENT"} + assert task_sets[0]["taskSetArn"].endswith(task_sets[0]["id"]) + assert ( + task_sets[0]["loadBalancers"][0]["targetGroupArn"] + == "arn:aws:elasticloadbalancing:us-east-1:01234567890:targetgroup/c26b93c1bc35466ba792d5b08fe6a5bc/ec39113f8831453a" + ) + assert task_sets[0]["loadBalancers"][0]["containerPort"] == 8080 + assert task_sets[0]["loadBalancers"][0]["containerName"] == "hello_world" + + +@mock_ecs +def test_delete_task_set(): + cluster_name = "test_ecs_cluster" + service_name = "test_ecs_service" + task_def_name = "test_ecs_task" + + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName=cluster_name) + _ = client.register_task_definition( + family=task_def_name, + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + _ = client.create_service( + cluster=cluster_name, + serviceName=service_name, + taskDefinition=task_def_name, + desiredCount=2, + deploymentController={"type": "EXTERNAL"}, + ) + + task_set = client.create_task_set( + cluster=cluster_name, service=service_name, taskDefinition=task_def_name, + )["taskSet"] + + task_sets = client.describe_task_sets( + cluster=cluster_name, service=service_name, taskSets=[task_set["taskSetArn"]], + )["taskSets"] + + assert len(task_sets) == 1 + + response = client.delete_task_set( + cluster=cluster_name, service=service_name, taskSet=task_set["taskSetArn"], + ) + assert response["taskSet"]["taskSetArn"] == task_set["taskSetArn"] + + task_sets = client.describe_task_sets( + cluster=cluster_name, service=service_name, taskSets=[task_set["taskSetArn"]], + )["taskSets"] + + assert len(task_sets) == 0 + + with assert_raises(ClientError): + _ = client.delete_task_set( + cluster=cluster_name, service=service_name, taskSet=task_set["taskSetArn"], + ) + + +@mock_ecs +def test_update_service_primary_task_set(): + cluster_name = "test_ecs_cluster" + service_name = "test_ecs_service" + task_def_name = "test_ecs_task" + + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName=cluster_name) + _ = client.register_task_definition( + family="test_ecs_task", + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + _ = client.create_service( + cluster=cluster_name, + serviceName=service_name, + desiredCount=2, + deploymentController={"type": "EXTERNAL"}, + ) + + task_set = client.create_task_set( + cluster=cluster_name, service=service_name, taskDefinition=task_def_name, + )["taskSet"] + + service = client.describe_services(cluster=cluster_name, services=[service_name],)[ + "services" + ][0] + + _ = client.update_service_primary_task_set( + cluster=cluster_name, + service=service_name, + primaryTaskSet=task_set["taskSetArn"], + ) + + service = client.describe_services(cluster=cluster_name, services=[service_name],)[ + "services" + ][0] + assert service["taskSets"][0]["status"] == "PRIMARY" + assert service["taskDefinition"] == service["taskSets"][0]["taskDefinition"] + + another_task_set = client.create_task_set( + cluster=cluster_name, service=service_name, taskDefinition=task_def_name, + )["taskSet"] + service = client.describe_services(cluster=cluster_name, services=[service_name],)[ + "services" + ][0] + assert service["taskSets"][1]["status"] == "ACTIVE" + + _ = client.update_service_primary_task_set( + cluster=cluster_name, + service=service_name, + primaryTaskSet=another_task_set["taskSetArn"], + ) + service = client.describe_services(cluster=cluster_name, services=[service_name],)[ + "services" + ][0] + assert service["taskSets"][0]["status"] == "ACTIVE" + assert service["taskSets"][1]["status"] == "PRIMARY" + assert service["taskDefinition"] == service["taskSets"][1]["taskDefinition"] + + +@mock_ecs +def test_update_task_set(): + cluster_name = "test_ecs_cluster" + service_name = "test_ecs_service" + task_def_name = "test_ecs_task" + + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName=cluster_name) + _ = client.register_task_definition( + family=task_def_name, + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + _ = client.create_service( + cluster=cluster_name, + serviceName=service_name, + desiredCount=2, + deploymentController={"type": "EXTERNAL"}, + ) + + task_set = client.create_task_set( + cluster=cluster_name, service=service_name, taskDefinition=task_def_name, + )["taskSet"] + + another_task_set = client.create_task_set( + cluster=cluster_name, service=service_name, taskDefinition=task_def_name, + )["taskSet"] + assert another_task_set["scale"]["unit"] == "PERCENT" + assert another_task_set["scale"]["value"] == 100.0 + + client.update_task_set( + cluster=cluster_name, + service=service_name, + taskSet=task_set["taskSetArn"], + scale={"value": 25.0, "unit": "PERCENT"}, + ) + + updated_task_set = client.describe_task_sets( + cluster=cluster_name, service=service_name, taskSets=[task_set["taskSetArn"]], + )["taskSets"][0] + assert updated_task_set["scale"]["value"] == 25.0 + assert updated_task_set["scale"]["unit"] == "PERCENT"