diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index a108361d3..9ea4330fa 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -3518,34 +3518,34 @@ ## forecast
-0% implemented +19% implemented - [ ] create_dataset -- [ ] create_dataset_group +- [X] create_dataset_group - [ ] create_dataset_import_job - [ ] create_forecast - [ ] create_forecast_export_job - [ ] create_predictor - [ ] delete_dataset -- [ ] delete_dataset_group +- [X] delete_dataset_group - [ ] delete_dataset_import_job - [ ] delete_forecast - [ ] delete_forecast_export_job - [ ] delete_predictor - [ ] describe_dataset -- [ ] describe_dataset_group +- [X] describe_dataset_group - [ ] describe_dataset_import_job - [ ] describe_forecast - [ ] describe_forecast_export_job - [ ] describe_predictor - [ ] get_accuracy_metrics -- [ ] list_dataset_groups +- [X] list_dataset_groups - [ ] list_dataset_import_jobs - [ ] list_datasets - [ ] list_forecast_export_jobs - [ ] list_forecasts - [ ] list_predictors -- [ ] update_dataset_group +- [X] update_dataset_group
## forecastquery @@ -8125,9 +8125,9 @@ - [ ] send_task_success - [X] start_execution - [X] stop_execution -- [ ] tag_resource -- [ ] untag_resource -- [ ] update_state_machine +- [X] tag_resource +- [X] untag_resource +- [X] update_state_machine ## storagegateway diff --git a/README.md b/README.md index 3915a85cd..784976a4a 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,7 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L | ELB | @mock_elb | core endpoints done | | | ELBv2 | @mock_elbv2 | all endpoints done | | | EMR | @mock_emr | core endpoints done | | +| Forecast | @mock_forecast | some core endpoints done | | | Glacier | @mock_glacier | core endpoints done | | | IAM | @mock_iam | core endpoints done | | | IoT | @mock_iot | core endpoints done | | diff --git a/docs/index.rst b/docs/index.rst index 22ac97228..4f2d7e090 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -60,6 +60,8 @@ Currently implemented Services: +---------------------------+-----------------------+------------------------------------+ | EMR | @mock_emr | core endpoints done | +---------------------------+-----------------------+------------------------------------+ +| Forecast | @mock_forecast | basic endpoints done | ++---------------------------+-----------------------+------------------------------------+ | Glacier | @mock_glacier | core endpoints done | +---------------------------+-----------------------+------------------------------------+ | IAM | @mock_iam | core endpoints done | diff --git a/moto/__init__.py b/moto/__init__.py index c73e111a0..fd467cbf8 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -63,6 +63,7 @@ mock_elbv2 = lazy_load(".elbv2", "mock_elbv2") mock_emr = lazy_load(".emr", "mock_emr") mock_emr_deprecated = lazy_load(".emr", "mock_emr_deprecated") mock_events = lazy_load(".events", "mock_events") +mock_forecast = lazy_load(".forecast", "mock_forecast") mock_glacier = lazy_load(".glacier", "mock_glacier") mock_glacier_deprecated = lazy_load(".glacier", "mock_glacier_deprecated") mock_glue = lazy_load(".glue", "mock_glue") diff --git a/moto/backends.py b/moto/backends.py index e76a89ccb..c8bac72fc 100644 --- a/moto/backends.py +++ b/moto/backends.py @@ -75,6 +75,7 @@ BACKENDS = { "kinesisvideoarchivedmedia", "kinesisvideoarchivedmedia_backends", ), + "forecast": ("forecast", "forecast_backends"), } diff --git a/moto/ecr/models.py b/moto/ecr/models.py index 33a0201fd..299ed48a7 100644 --- a/moto/ecr/models.py +++ b/moto/ecr/models.py @@ -164,7 +164,7 @@ class Image(BaseObject): def response_list_object(self): response_object = self.gen_response_object() response_object["imageTag"] = self.image_tag - response_object["imageDigest"] = "i don't know" + response_object["imageDigest"] = self.get_image_digest() return { k: v for k, v in response_object.items() if v is not None and v != [None] } diff --git a/moto/forecast/__init__.py b/moto/forecast/__init__.py new file mode 100644 index 000000000..75b23b94a --- /dev/null +++ b/moto/forecast/__init__.py @@ -0,0 +1,7 @@ +from __future__ import unicode_literals + +from .models import forecast_backends +from ..core.models import base_decorator + +forecast_backend = forecast_backends["us-east-1"] +mock_forecast = base_decorator(forecast_backends) diff --git a/moto/forecast/exceptions.py b/moto/forecast/exceptions.py new file mode 100644 index 000000000..ad86e90fc --- /dev/null +++ b/moto/forecast/exceptions.py @@ -0,0 +1,43 @@ +from __future__ import unicode_literals + +import json + + +class AWSError(Exception): + TYPE = None + STATUS = 400 + + def __init__(self, message, type=None, status=None): + self.message = message + self.type = type if type is not None else self.TYPE + self.status = status if status is not None else self.STATUS + + def response(self): + return ( + json.dumps({"__type": self.type, "message": self.message}), + dict(status=self.status), + ) + + +class InvalidInputException(AWSError): + TYPE = "InvalidInputException" + + +class ResourceAlreadyExistsException(AWSError): + TYPE = "ResourceAlreadyExistsException" + + +class ResourceNotFoundException(AWSError): + TYPE = "ResourceNotFoundException" + + +class ResourceInUseException(AWSError): + TYPE = "ResourceInUseException" + + +class LimitExceededException(AWSError): + TYPE = "LimitExceededException" + + +class ValidationException(AWSError): + TYPE = "ValidationException" diff --git a/moto/forecast/models.py b/moto/forecast/models.py new file mode 100644 index 000000000..c7b18618c --- /dev/null +++ b/moto/forecast/models.py @@ -0,0 +1,173 @@ +import re +from datetime import datetime + +from boto3 import Session +from future.utils import iteritems + +from moto.core import ACCOUNT_ID, BaseBackend +from moto.core.utils import iso_8601_datetime_without_milliseconds +from .exceptions import ( + InvalidInputException, + ResourceAlreadyExistsException, + ResourceNotFoundException, + ValidationException, +) + + +class DatasetGroup: + accepted_dataset_group_name_format = re.compile(r"^[a-zA-Z][a-z-A-Z0-9_]*") + accepted_dataset_group_arn_format = re.compile(r"^[a-zA-Z0-9\-\_\.\/\:]+$") + accepted_dataset_types = [ + "INVENTORY_PLANNING", + "METRICS", + "RETAIL", + "EC2_CAPACITY", + "CUSTOM", + "WEB_TRAFFIC", + "WORK_FORCE", + ] + + def __init__( + self, region_name, dataset_arns, dataset_group_name, domain, tags=None + ): + self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now()) + self.modified_date = self.creation_date + + self.arn = ( + "arn:aws:forecast:" + + region_name + + ":" + + str(ACCOUNT_ID) + + ":dataset-group/" + + dataset_group_name + ) + self.dataset_arns = dataset_arns if dataset_arns else [] + self.dataset_group_name = dataset_group_name + self.domain = domain + self.tags = tags + self._validate() + + def update(self, dataset_arns): + self.dataset_arns = dataset_arns + self.last_modified_date = iso_8601_datetime_without_milliseconds(datetime.now()) + + def _validate(self): + errors = [] + + errors.extend(self._validate_dataset_group_name()) + errors.extend(self._validate_dataset_group_name_len()) + errors.extend(self._validate_dataset_group_domain()) + + if errors: + err_count = len(errors) + message = str(err_count) + " validation error" + message += "s" if err_count > 1 else "" + message += " detected: " + message += "; ".join(errors) + raise ValidationException(message) + + def _validate_dataset_group_name(self): + errors = [] + if not re.match( + self.accepted_dataset_group_name_format, self.dataset_group_name + ): + errors.append( + "Value '" + + self.dataset_group_name + + "' at 'datasetGroupName' failed to satisfy constraint: Member must satisfy regular expression pattern " + + self.accepted_dataset_group_name_format.pattern + ) + return errors + + def _validate_dataset_group_name_len(self): + errors = [] + if len(self.dataset_group_name) >= 64: + errors.append( + "Value '" + + self.dataset_group_name + + "' at 'datasetGroupName' failed to satisfy constraint: Member must have length less than or equal to 63" + ) + return errors + + def _validate_dataset_group_domain(self): + errors = [] + if self.domain not in self.accepted_dataset_types: + errors.append( + "Value '" + + self.domain + + "' at 'domain' failed to satisfy constraint: Member must satisfy enum value set " + + str(self.accepted_dataset_types) + ) + return errors + + +class ForecastBackend(BaseBackend): + def __init__(self, region_name): + super(ForecastBackend, self).__init__() + self.dataset_groups = {} + self.datasets = {} + self.region_name = region_name + + def create_dataset_group(self, dataset_group_name, domain, dataset_arns, tags): + dataset_group = DatasetGroup( + region_name=self.region_name, + dataset_group_name=dataset_group_name, + domain=domain, + dataset_arns=dataset_arns, + tags=tags, + ) + + if dataset_arns: + for dataset_arn in dataset_arns: + if dataset_arn not in self.datasets: + raise InvalidInputException( + "Dataset arns: [" + dataset_arn + "] are not found" + ) + + if self.dataset_groups.get(dataset_group.arn): + raise ResourceAlreadyExistsException( + "A dataset group already exists with the arn: " + dataset_group.arn + ) + + self.dataset_groups[dataset_group.arn] = dataset_group + return dataset_group + + def describe_dataset_group(self, dataset_group_arn): + try: + dataset_group = self.dataset_groups[dataset_group_arn] + except KeyError: + raise ResourceNotFoundException("No resource found " + dataset_group_arn) + return dataset_group + + def delete_dataset_group(self, dataset_group_arn): + try: + del self.dataset_groups[dataset_group_arn] + except KeyError: + raise ResourceNotFoundException("No resource found " + dataset_group_arn) + + def update_dataset_group(self, dataset_group_arn, dataset_arns): + try: + dsg = self.dataset_groups[dataset_group_arn] + except KeyError: + raise ResourceNotFoundException("No resource found " + dataset_group_arn) + + for dataset_arn in dataset_arns: + if dataset_arn not in dsg.dataset_arns: + raise InvalidInputException( + "Dataset arns: [" + dataset_arn + "] are not found" + ) + + dsg.update(dataset_arns) + + def list_dataset_groups(self): + return [v for (_, v) in iteritems(self.dataset_groups)] + + def reset(self): + region_name = self.region_name + self.__dict__ = {} + self.__init__(region_name) + + +forecast_backends = {} +for region in Session().get_available_regions("forecast"): + forecast_backends[region] = ForecastBackend(region) diff --git a/moto/forecast/responses.py b/moto/forecast/responses.py new file mode 100644 index 000000000..09d55b0d8 --- /dev/null +++ b/moto/forecast/responses.py @@ -0,0 +1,92 @@ +from __future__ import unicode_literals + +import json + +from moto.core.responses import BaseResponse +from moto.core.utils import amzn_request_id +from .exceptions import AWSError +from .models import forecast_backends + + +class ForecastResponse(BaseResponse): + @property + def forecast_backend(self): + return forecast_backends[self.region] + + @amzn_request_id + def create_dataset_group(self): + dataset_group_name = self._get_param("DatasetGroupName") + domain = self._get_param("Domain") + dataset_arns = self._get_param("DatasetArns") + tags = self._get_param("Tags") + + try: + dataset_group = self.forecast_backend.create_dataset_group( + dataset_group_name=dataset_group_name, + domain=domain, + dataset_arns=dataset_arns, + tags=tags, + ) + response = {"DatasetGroupArn": dataset_group.arn} + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_dataset_group(self): + dataset_group_arn = self._get_param("DatasetGroupArn") + + try: + dataset_group = self.forecast_backend.describe_dataset_group( + dataset_group_arn=dataset_group_arn + ) + response = { + "CreationTime": dataset_group.creation_date, + "DatasetArns": dataset_group.dataset_arns, + "DatasetGroupArn": dataset_group.arn, + "DatasetGroupName": dataset_group.dataset_group_name, + "Domain": dataset_group.domain, + "LastModificationTime": dataset_group.modified_date, + "Status": "ACTIVE", + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def delete_dataset_group(self): + dataset_group_arn = self._get_param("DatasetGroupArn") + try: + self.forecast_backend.delete_dataset_group(dataset_group_arn) + return 200, {}, None + except AWSError as err: + return err.response() + + @amzn_request_id + def update_dataset_group(self): + dataset_group_arn = self._get_param("DatasetGroupArn") + dataset_arns = self._get_param("DatasetArns") + try: + self.forecast_backend.update_dataset_group(dataset_group_arn, dataset_arns) + return 200, {}, None + except AWSError as err: + return err.response() + + @amzn_request_id + def list_dataset_groups(self): + list_all = self.forecast_backend.list_dataset_groups() + list_all = sorted( + [ + { + "DatasetGroupArn": dsg.arn, + "DatasetGroupName": dsg.dataset_group_name, + "CreationTime": dsg.creation_date, + "LastModificationTime": dsg.creation_date, + } + for dsg in list_all + ], + key=lambda x: x["LastModificationTime"], + reverse=True, + ) + response = {"DatasetGroups": list_all} + return 200, {}, json.dumps(response) diff --git a/moto/forecast/urls.py b/moto/forecast/urls.py new file mode 100644 index 000000000..221659e6f --- /dev/null +++ b/moto/forecast/urls.py @@ -0,0 +1,7 @@ +from __future__ import unicode_literals + +from .responses import ForecastResponse + +url_bases = ["https?://forecast.(.+).amazonaws.com"] + +url_paths = {"{0}/$": ForecastResponse.dispatch} diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py index a24c15008..9598c65f9 100644 --- a/moto/stepfunctions/exceptions.py +++ b/moto/stepfunctions/exceptions.py @@ -38,3 +38,11 @@ class InvalidToken(AWSError): def __init__(self, message="Invalid token"): super(InvalidToken, self).__init__("Invalid Token: {}".format(message)) + + +class ResourceNotFound(AWSError): + TYPE = "ResourceNotFound" + STATUS = 400 + + def __init__(self, arn): + super(ResourceNotFound, self).__init__("Resource not found: '{}'".format(arn)) diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index 9dfa33ba8..86c76c98a 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -13,6 +13,7 @@ from .exceptions import ( InvalidArn, InvalidExecutionInput, InvalidName, + ResourceNotFound, StateMachineDoesNotExist, ) from .utils import paginate @@ -21,11 +22,41 @@ from .utils import paginate class StateMachine(CloudFormationModel): def __init__(self, arn, name, definition, roleArn, tags=None): self.creation_date = iso_8601_datetime_with_milliseconds(datetime.now()) + self.update_date = self.creation_date self.arn = arn self.name = name self.definition = definition self.roleArn = roleArn - self.tags = tags + self.tags = [] + if tags: + self.add_tags(tags) + + def update(self, **kwargs): + for key, value in kwargs.items(): + if value is not None: + setattr(self, key, value) + self.update_date = iso_8601_datetime_with_milliseconds(datetime.now()) + + def add_tags(self, tags): + merged_tags = [] + for tag in self.tags: + replacement_index = next( + (index for (index, d) in enumerate(tags) if d["key"] == tag["key"]), + None, + ) + if replacement_index is not None: + replacement = tags.pop(replacement_index) + merged_tags.append(replacement) + else: + merged_tags.append(tag) + for tag in tags: + merged_tags.append(tag) + self.tags = merged_tags + return self.tags + + def remove_tags(self, tag_keys): + self.tags = [tag_set for tag_set in self.tags if tag_set["key"] not in tag_keys] + return self.tags @property def physical_resource_id(self): @@ -249,6 +280,15 @@ class StepFunctionBackend(BaseBackend): if sm: self.state_machines.remove(sm) + def update_state_machine(self, arn, definition=None, role_arn=None): + sm = self.describe_state_machine(arn) + updates = { + "definition": definition, + "roleArn": role_arn, + } + sm.update(**updates) + return sm + def start_execution(self, state_machine_arn, name=None, execution_input=None): state_machine_name = self.describe_state_machine(state_machine_arn).name self._ensure_execution_name_doesnt_exist(name) @@ -296,6 +336,20 @@ class StepFunctionBackend(BaseBackend): raise ExecutionDoesNotExist("Execution Does Not Exist: '" + arn + "'") return exctn + def tag_resource(self, resource_arn, tags): + try: + state_machine = self.describe_state_machine(resource_arn) + state_machine.add_tags(tags) + except StateMachineDoesNotExist: + raise ResourceNotFound(resource_arn) + + def untag_resource(self, resource_arn, tag_keys): + try: + state_machine = self.describe_state_machine(resource_arn) + state_machine.remove_tags(tag_keys) + except StateMachineDoesNotExist: + raise ResourceNotFound(resource_arn) + def reset(self): region_name = self.region_name self.__dict__ = {} diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py index 7106d81d0..7eae8091b 100644 --- a/moto/stepfunctions/responses.py +++ b/moto/stepfunctions/responses.py @@ -83,6 +83,22 @@ class StepFunctionResponse(BaseResponse): except AWSError as err: return err.response() + @amzn_request_id + def update_state_machine(self): + arn = self._get_param("stateMachineArn") + definition = self._get_param("definition") + role_arn = self._get_param("roleArn") + try: + state_machine = self.stepfunction_backend.update_state_machine( + arn=arn, definition=definition, role_arn=role_arn + ) + response = { + "updateDate": state_machine.update_date, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + @amzn_request_id def list_tags_for_resource(self): arn = self._get_param("resourceArn") @@ -94,6 +110,26 @@ class StepFunctionResponse(BaseResponse): response = {"tags": tags} return 200, {}, json.dumps(response) + @amzn_request_id + def tag_resource(self): + arn = self._get_param("resourceArn") + tags = self._get_param("tags", []) + try: + self.stepfunction_backend.tag_resource(arn, tags) + except AWSError as err: + return err.response() + return 200, {}, json.dumps({}) + + @amzn_request_id + def untag_resource(self): + arn = self._get_param("resourceArn") + tag_keys = self._get_param("tagKeys", []) + try: + self.stepfunction_backend.untag_resource(arn, tag_keys) + except AWSError as err: + return err.response() + return 200, {}, json.dumps({}) + @amzn_request_id def start_execution(self): arn = self._get_param("stateMachineArn") diff --git a/tests/test_ecr/test_ecr_boto3.py b/tests/test_ecr/test_ecr_boto3.py index 6c6840a7e..fd678f661 100644 --- a/tests/test_ecr/test_ecr_boto3.py +++ b/tests/test_ecr/test_ecr_boto3.py @@ -318,6 +318,9 @@ def test_list_images(): type(response["imageIds"]).should.be(list) len(response["imageIds"]).should.be(3) + for image in response["imageIds"]: + image["imageDigest"].should.contain("sha") + image_tags = ["latest", "v1", "v2"] set( [ @@ -331,6 +334,7 @@ def test_list_images(): type(response["imageIds"]).should.be(list) len(response["imageIds"]).should.be(1) response["imageIds"][0]["imageTag"].should.equal("oldest") + response["imageIds"][0]["imageDigest"].should.contain("sha") @mock_ecr diff --git a/tests/test_forecast/__init__.py b/tests/test_forecast/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_forecast/test_forecast.py b/tests/test_forecast/test_forecast.py new file mode 100644 index 000000000..32af519c7 --- /dev/null +++ b/tests/test_forecast/test_forecast.py @@ -0,0 +1,222 @@ +from __future__ import unicode_literals + +import boto3 +import sure # noqa +from botocore.exceptions import ClientError +from nose.tools import assert_raises +from parameterized import parameterized + +from moto import mock_forecast +from moto.core import ACCOUNT_ID + +region = "us-east-1" +account_id = None +valid_domains = [ + "RETAIL", + "CUSTOM", + "INVENTORY_PLANNING", + "EC2_CAPACITY", + "WORK_FORCE", + "WEB_TRAFFIC", + "METRICS", +] + + +@parameterized(valid_domains) +@mock_forecast +def test_forecast_dataset_group_create(domain): + name = "example_dataset_group" + client = boto3.client("forecast", region_name=region) + response = client.create_dataset_group(DatasetGroupName=name, Domain=domain) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["DatasetGroupArn"].should.equal( + "arn:aws:forecast:" + region + ":" + ACCOUNT_ID + ":dataset-group/" + name + ) + + +@mock_forecast +def test_forecast_dataset_group_create_invalid_domain(): + name = "example_dataset_group" + client = boto3.client("forecast", region_name=region) + invalid_domain = "INVALID" + + with assert_raises(ClientError) as exc: + client.create_dataset_group(DatasetGroupName=name, Domain=invalid_domain) + exc.exception.response["Error"]["Code"].should.equal("ValidationException") + exc.exception.response["Error"]["Message"].should.equal( + "1 validation error detected: Value '" + + invalid_domain + + "' at 'domain' failed to satisfy constraint: Member must satisfy enum value set ['INVENTORY_PLANNING', 'METRICS', 'RETAIL', 'EC2_CAPACITY', 'CUSTOM', 'WEB_TRAFFIC', 'WORK_FORCE']" + ) + + +@parameterized([" ", "a" * 64]) +@mock_forecast +def test_forecast_dataset_group_create_invalid_name(name): + client = boto3.client("forecast", region_name=region) + + with assert_raises(ClientError) as exc: + client.create_dataset_group(DatasetGroupName=name, Domain="CUSTOM") + exc.exception.response["Error"]["Code"].should.equal("ValidationException") + exc.exception.response["Error"]["Message"].should.contain( + "1 validation error detected: Value '" + + name + + "' at 'datasetGroupName' failed to satisfy constraint: Member must" + ) + + +@mock_forecast +def test_forecast_dataset_group_create_duplicate_fails(): + client = boto3.client("forecast", region_name=region) + client.create_dataset_group(DatasetGroupName="name", Domain="RETAIL") + + with assert_raises(ClientError) as exc: + client.create_dataset_group(DatasetGroupName="name", Domain="RETAIL") + + exc.exception.response["Error"]["Code"].should.equal( + "ResourceAlreadyExistsException" + ) + + +@mock_forecast +def test_forecast_dataset_group_list_default_empty(): + client = boto3.client("forecast", region_name=region) + + list = client.list_dataset_groups() + list["DatasetGroups"].should.be.empty + + +@mock_forecast +def test_forecast_dataset_group_list_some(): + client = boto3.client("forecast", region_name=region) + + client.create_dataset_group(DatasetGroupName="hello", Domain="CUSTOM") + result = client.list_dataset_groups() + + assert len(result["DatasetGroups"]) == 1 + result["DatasetGroups"][0]["DatasetGroupArn"].should.equal( + "arn:aws:forecast:" + region + ":" + ACCOUNT_ID + ":dataset-group/hello" + ) + + +@mock_forecast +def test_forecast_delete_dataset_group(): + dataset_group_name = "name" + dataset_group_arn = ( + "arn:aws:forecast:" + + region + + ":" + + ACCOUNT_ID + + ":dataset-group/" + + dataset_group_name + ) + client = boto3.client("forecast", region_name=region) + client.create_dataset_group(DatasetGroupName=dataset_group_name, Domain="CUSTOM") + client.delete_dataset_group(DatasetGroupArn=dataset_group_arn) + + +@mock_forecast +def test_forecast_delete_dataset_group_missing(): + client = boto3.client("forecast", region_name=region) + missing_dsg_arn = ( + "arn:aws:forecast:" + region + ":" + ACCOUNT_ID + ":dataset-group/missing" + ) + + with assert_raises(ClientError) as exc: + client.delete_dataset_group(DatasetGroupArn=missing_dsg_arn) + exc.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") + exc.exception.response["Error"]["Message"].should.equal( + "No resource found " + missing_dsg_arn + ) + + +@mock_forecast +def test_forecast_update_dataset_arns_empty(): + dataset_group_name = "name" + dataset_group_arn = ( + "arn:aws:forecast:" + + region + + ":" + + ACCOUNT_ID + + ":dataset-group/" + + dataset_group_name + ) + client = boto3.client("forecast", region_name=region) + client.create_dataset_group(DatasetGroupName=dataset_group_name, Domain="CUSTOM") + client.update_dataset_group(DatasetGroupArn=dataset_group_arn, DatasetArns=[]) + + +@mock_forecast +def test_forecast_update_dataset_group_not_found(): + client = boto3.client("forecast", region_name=region) + dataset_group_arn = ( + "arn:aws:forecast:" + region + ":" + ACCOUNT_ID + ":dataset-group/" + "test" + ) + with assert_raises(ClientError) as exc: + client.update_dataset_group(DatasetGroupArn=dataset_group_arn, DatasetArns=[]) + exc.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") + exc.exception.response["Error"]["Message"].should.equal( + "No resource found " + dataset_group_arn + ) + + +@mock_forecast +def test_describe_dataset_group(): + name = "test" + client = boto3.client("forecast", region_name=region) + dataset_group_arn = ( + "arn:aws:forecast:" + region + ":" + ACCOUNT_ID + ":dataset-group/" + name + ) + client.create_dataset_group(DatasetGroupName=name, Domain="CUSTOM") + result = client.describe_dataset_group(DatasetGroupArn=dataset_group_arn) + assert result.get("DatasetGroupArn") == dataset_group_arn + assert result.get("Domain") == "CUSTOM" + assert result.get("DatasetArns") == [] + + +@mock_forecast +def test_describe_dataset_group_missing(): + client = boto3.client("forecast", region_name=region) + dataset_group_arn = ( + "arn:aws:forecast:" + region + ":" + ACCOUNT_ID + ":dataset-group/name" + ) + with assert_raises(ClientError) as exc: + client.describe_dataset_group(DatasetGroupArn=dataset_group_arn) + exc.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") + exc.exception.response["Error"]["Message"].should.equal( + "No resource found " + dataset_group_arn + ) + + +@mock_forecast +def test_create_dataset_group_missing_datasets(): + client = boto3.client("forecast", region_name=region) + dataset_arn = "arn:aws:forecast:" + region + ":" + ACCOUNT_ID + ":dataset/name" + with assert_raises(ClientError) as exc: + client.create_dataset_group( + DatasetGroupName="name", Domain="CUSTOM", DatasetArns=[dataset_arn] + ) + exc.exception.response["Error"]["Code"].should.equal("InvalidInputException") + exc.exception.response["Error"]["Message"].should.equal( + "Dataset arns: [" + dataset_arn + "] are not found" + ) + + +@mock_forecast +def test_update_dataset_group_missing_datasets(): + name = "test" + client = boto3.client("forecast", region_name=region) + dataset_group_arn = ( + "arn:aws:forecast:" + region + ":" + ACCOUNT_ID + ":dataset-group/" + name + ) + client.create_dataset_group(DatasetGroupName=name, Domain="CUSTOM") + dataset_arn = "arn:aws:forecast:" + region + ":" + ACCOUNT_ID + ":dataset/name" + + with assert_raises(ClientError) as exc: + client.update_dataset_group( + DatasetGroupArn=dataset_group_arn, DatasetArns=[dataset_arn] + ) + exc.exception.response["Error"]["Code"].should.equal("InvalidInputException") + exc.exception.response["Error"]["Message"].should.equal( + "Dataset arns: [" + dataset_arn + "] are not found" + ) diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py index 1c961b882..0bea43084 100644 --- a/tests/test_stepfunctions/test_stepfunctions.py +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -155,6 +155,33 @@ def test_state_machine_creation_requires_valid_role_arn(): ) +@mock_stepfunctions +@mock_sts +def test_update_state_machine(): + client = boto3.client("stepfunctions", region_name=region) + + resp = client.create_state_machine( + name="test", definition=str(simple_definition), roleArn=_get_default_role() + ) + state_machine_arn = resp["stateMachineArn"] + + updated_role = _get_default_role() + "-updated" + updated_definition = str(simple_definition).replace( + "DefaultState", "DefaultStateUpdated" + ) + resp = client.update_state_machine( + stateMachineArn=state_machine_arn, + definition=updated_definition, + roleArn=updated_role, + ) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + resp["updateDate"].should.be.a(datetime) + + desc = client.describe_state_machine(stateMachineArn=state_machine_arn) + desc["definition"].should.equal(updated_definition) + desc["roleArn"].should.equal(updated_role) + + @mock_stepfunctions def test_state_machine_list_returns_empty_list_by_default(): client = boto3.client("stepfunctions", region_name=region) @@ -326,6 +353,85 @@ def test_state_machine_can_deleted_nonexisting_machine(): sm_list["stateMachines"].should.have.length_of(0) +@mock_stepfunctions +def test_state_machine_tagging_non_existent_resource_fails(): + client = boto3.client("stepfunctions", region_name=region) + non_existent_arn = "arn:aws:states:{region}:{account}:stateMachine:non-existent".format( + region=region, account=ACCOUNT_ID + ) + with assert_raises(ClientError) as ex: + client.tag_resource(resourceArn=non_existent_arn, tags=[]) + ex.exception.response["Error"]["Code"].should.equal("ResourceNotFound") + ex.exception.response["Error"]["Message"].should.contain(non_existent_arn) + + +@mock_stepfunctions +def test_state_machine_untagging_non_existent_resource_fails(): + client = boto3.client("stepfunctions", region_name=region) + non_existent_arn = "arn:aws:states:{region}:{account}:stateMachine:non-existent".format( + region=region, account=ACCOUNT_ID + ) + with assert_raises(ClientError) as ex: + client.untag_resource(resourceArn=non_existent_arn, tagKeys=[]) + ex.exception.response["Error"]["Code"].should.equal("ResourceNotFound") + ex.exception.response["Error"]["Message"].should.contain(non_existent_arn) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_tagging(): + client = boto3.client("stepfunctions", region_name=region) + tags = [ + {"key": "tag_key1", "value": "tag_value1"}, + {"key": "tag_key2", "value": "tag_value2"}, + ] + machine = client.create_state_machine( + name="test", definition=str(simple_definition), roleArn=_get_default_role(), + ) + client.tag_resource(resourceArn=machine["stateMachineArn"], tags=tags) + resp = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + resp["tags"].should.equal(tags) + + tags_update = [ + {"key": "tag_key1", "value": "tag_value1_new"}, + {"key": "tag_key3", "value": "tag_value3"}, + ] + client.tag_resource(resourceArn=machine["stateMachineArn"], tags=tags_update) + resp = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + tags_expected = [ + tags_update[0], + tags[1], + tags_update[1], + ] + resp["tags"].should.equal(tags_expected) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_untagging(): + client = boto3.client("stepfunctions", region_name=region) + tags = [ + {"key": "tag_key1", "value": "tag_value1"}, + {"key": "tag_key2", "value": "tag_value2"}, + {"key": "tag_key3", "value": "tag_value3"}, + ] + machine = client.create_state_machine( + name="test", + definition=str(simple_definition), + roleArn=_get_default_role(), + tags=tags, + ) + resp = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + resp["tags"].should.equal(tags) + tags_to_delete = ["tag_key1", "tag_key2"] + client.untag_resource( + resourceArn=machine["stateMachineArn"], tagKeys=tags_to_delete + ) + resp = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + expected_tags = [tag for tag in tags if tag["key"] not in tags_to_delete] + resp["tags"].should.equal(expected_tags) + + @mock_stepfunctions @mock_sts def test_state_machine_list_tags_for_created_machine():