diff --git a/moto/ec2/models/spot_requests.py b/moto/ec2/models/spot_requests.py index 1b2c9f8d9..bb1e92dbf 100644 --- a/moto/ec2/models/spot_requests.py +++ b/moto/ec2/models/spot_requests.py @@ -466,8 +466,10 @@ class SpotFleetBackend: if terminate_instances: spot_fleet.target_capacity = 0 spot_fleet.terminate_instances() + del self.spot_fleet_requests[spot_fleet_request_id] + else: + spot_fleet.state = "cancelled_running" spot_requests.append(spot_fleet) - del self.spot_fleet_requests[spot_fleet_request_id] return spot_requests def modify_spot_fleet_request( diff --git a/moto/ec2/responses/spot_fleets.py b/moto/ec2/responses/spot_fleets.py index c32ad9e19..504611054 100644 --- a/moto/ec2/responses/spot_fleets.py +++ b/moto/ec2/responses/spot_fleets.py @@ -4,7 +4,7 @@ from moto.core.responses import BaseResponse class SpotFleets(BaseResponse): def cancel_spot_fleet_requests(self): spot_fleet_request_ids = self._get_multi_param("SpotFleetRequestId.") - terminate_instances = self._get_param("TerminateInstances") + terminate_instances = self._get_bool_param("TerminateInstances") spot_fleets = self.ec2_backend.cancel_spot_fleet_requests( spot_fleet_request_ids, terminate_instances ) diff --git a/tests/test_ec2/test_spot_fleet.py b/tests/test_ec2/test_spot_fleet.py index b6483972f..5c2c32b1d 100644 --- a/tests/test_ec2/test_spot_fleet.py +++ b/tests/test_ec2/test_spot_fleet.py @@ -118,8 +118,7 @@ def test_create_spot_fleet_with_lowest_price(): launch_spec["UserData"].should.equal("some user data") launch_spec["WeightedCapacity"].should.equal(2.0) - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) len(instances).should.equal(3) @@ -132,8 +131,7 @@ def test_create_diversified_spot_fleet(): spot_fleet_res = conn.request_spot_fleet(SpotFleetRequestConfig=diversified_config) spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) len(instances).should.equal(2) instance_types = set([instance["InstanceType"] for instance in instances]) instance_types.should.equal(set(["t2.small", "t2.large"])) @@ -180,8 +178,7 @@ def test_request_spot_fleet_using_launch_template_config__name(allocation_strate spot_fleet_res = conn.request_spot_fleet(SpotFleetRequestConfig=template_config) spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) len(instances).should.equal(1) instance_types = set([instance["InstanceType"] for instance in instances]) instance_types.should.equal(set(["t2.medium"])) @@ -223,8 +220,7 @@ def test_request_spot_fleet_using_launch_template_config__id(): spot_fleet_res = conn.request_spot_fleet(SpotFleetRequestConfig=template_config) spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) len(instances).should.equal(1) instance_types = set([instance["InstanceType"] for instance in instances]) instance_types.should.equal(set(["t2.medium"])) @@ -277,8 +273,7 @@ def test_request_spot_fleet_using_launch_template_config__overrides(): spot_fleet_res = conn.request_spot_fleet(SpotFleetRequestConfig=template_config) spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) instances.should.have.length_of(1) instances[0].should.have.key("InstanceType").equals("t2.nano") @@ -347,6 +342,42 @@ def test_cancel_spot_fleet_request(): len(spot_fleet_requests).should.equal(0) +@mock_ec2 +def test_cancel_spot_fleet_request__but_dont_terminate_instances(): + conn = boto3.client("ec2", region_name="us-west-2") + subnet_id = get_subnet_id(conn) + + spot_fleet_res = conn.request_spot_fleet( + SpotFleetRequestConfig=spot_config(subnet_id) + ) + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] + + get_active_instances(conn, spot_fleet_id).should.have.length_of(3) + + conn.cancel_spot_fleet_requests( + SpotFleetRequestIds=[spot_fleet_id], TerminateInstances=False + ) + + spot_fleet_requests = conn.describe_spot_fleet_requests( + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] + spot_fleet_requests.should.have.length_of(1) + spot_fleet_requests[0]["SpotFleetRequestState"].should.equal("cancelled_running") + + get_active_instances(conn, spot_fleet_id).should.have.length_of(3) + + # Cancel again and terminate instances + conn.cancel_spot_fleet_requests( + SpotFleetRequestIds=[spot_fleet_id], TerminateInstances=True + ) + + get_active_instances(conn, spot_fleet_id).should.have.length_of(0) + spot_fleet_requests = conn.describe_spot_fleet_requests( + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] + spot_fleet_requests.should.have.length_of(0) + + @mock_ec2 def test_modify_spot_fleet_request_up(): conn = boto3.client("ec2", region_name="us-west-2") @@ -359,8 +390,7 @@ def test_modify_spot_fleet_request_up(): conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=20) - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) len(instances).should.equal(10) spot_fleet_config = conn.describe_spot_fleet_requests( @@ -382,8 +412,7 @@ def test_modify_spot_fleet_request_up_diversified(): conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=19) - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) len(instances).should.equal(7) spot_fleet_config = conn.describe_spot_fleet_requests( @@ -409,8 +438,7 @@ def test_modify_spot_fleet_request_down_no_terminate(): ExcessCapacityTerminationPolicy="noTermination", ) - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) len(instances).should.equal(3) spot_fleet_config = conn.describe_spot_fleet_requests( @@ -433,8 +461,7 @@ def test_modify_spot_fleet_request_down_odd(): conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=7) conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=5) - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) len(instances).should.equal(3) spot_fleet_config = conn.describe_spot_fleet_requests( @@ -456,8 +483,7 @@ def test_modify_spot_fleet_request_down(): conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=1) - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) len(instances).should.equal(1) spot_fleet_config = conn.describe_spot_fleet_requests( @@ -477,8 +503,7 @@ def test_modify_spot_fleet_request_down_no_terminate_after_custom_terminate(): ) spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) conn.terminate_instances(InstanceIds=[i["InstanceId"] for i in instances[1:]]) conn.modify_spot_fleet_request( @@ -487,8 +512,7 @@ def test_modify_spot_fleet_request_down_no_terminate_after_custom_terminate(): ExcessCapacityTerminationPolicy="noTermination", ) - instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) - instances = instance_res["ActiveInstances"] + instances = get_active_instances(conn, spot_fleet_id) len(instances).should.equal(1) spot_fleet_config = conn.describe_spot_fleet_requests( @@ -526,3 +550,8 @@ def test_create_spot_fleet_without_spot_price(): # AWS will figure out the price assert "SpotPrice" not in launch_spec1 assert "SpotPrice" not in launch_spec2 + + +def get_active_instances(conn, spot_fleet_id): + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + return instance_res["ActiveInstances"]