| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | import boto3 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | from moto import mock_sagemaker | 
					
						
							| 
									
										
										
										
											2022-08-13 09:49:43 +00:00
										 |  |  | from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | TEST_REGION_NAME = "us-east-1" | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | TEST_EXPERIMENT_NAME = "MyExperimentName" | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-04 16:28:30 +00:00
										 |  |  | @pytest.fixture(name="sagemaker_client") | 
					
						
							|  |  |  | def fixture_sagemaker_client(): | 
					
						
							|  |  |  |     with mock_sagemaker(): | 
					
						
							|  |  |  |         yield boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_create_experiment(sagemaker_client): | 
					
						
							|  |  |  |     resp = sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.list_experiments() | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert len(resp["ExperimentSummaries"]) == 1 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     assert resp["ExperimentSummaries"][0]["ExperimentName"] == TEST_EXPERIMENT_NAME | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  |     assert ( | 
					
						
							|  |  |  |         resp["ExperimentSummaries"][0]["ExperimentArn"] | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment/{TEST_EXPERIMENT_NAME}" | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_list_experiments(sagemaker_client): | 
					
						
							| 
									
										
										
										
											2021-12-15 11:32:19 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     experiment_names = [f"some-experiment-name-{i}" for i in range(10)] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for experiment_name in experiment_names: | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         resp = sagemaker_client.create_experiment(ExperimentName=experiment_name) | 
					
						
							| 
									
										
										
										
											2021-12-15 11:32:19 +01:00
										 |  |  |         assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.list_experiments(MaxResults=1) | 
					
						
							| 
									
										
										
										
											2021-12-15 11:32:19 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert len(resp["ExperimentSummaries"]) == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     next_token = resp["NextToken"] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.list_experiments(MaxResults=2, NextToken=next_token) | 
					
						
							| 
									
										
										
										
											2021-12-15 11:32:19 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert len(resp["ExperimentSummaries"]) == 2 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     next_token = resp["NextToken"] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.list_experiments(NextToken=next_token) | 
					
						
							| 
									
										
										
										
											2021-12-15 11:32:19 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert len(resp["ExperimentSummaries"]) == 7 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert resp.get("NextToken") is None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_delete_experiment(sagemaker_client): | 
					
						
							|  |  |  |     sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.delete_experiment(ExperimentName=TEST_EXPERIMENT_NAME) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.list_experiments() | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert len(resp["ExperimentSummaries"]) == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_add_tags_to_experiment(sagemaker_client): | 
					
						
							|  |  |  |     sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.describe_experiment(ExperimentName=TEST_EXPERIMENT_NAME) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     arn = resp["ExperimentArn"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tags = [{"Key": "name", "Value": "value"}] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     sagemaker_client.add_tags(ResourceArn=arn, Tags=tags) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.list_tags(ResourceArn=arn) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert resp["Tags"] == tags | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_delete_tags_to_experiment(sagemaker_client): | 
					
						
							|  |  |  |     sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.describe_experiment(ExperimentName=TEST_EXPERIMENT_NAME) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     arn = resp["ExperimentArn"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tags = [{"Key": "name", "Value": "value"}] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     sagemaker_client.add_tags(ResourceArn=arn, Tags=tags) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     sagemaker_client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags]) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.list_tags(ResourceArn=arn) | 
					
						
							| 
									
										
										
										
											2021-10-28 22:21:20 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     assert resp["Tags"] == [] |