diff --git a/moto/kinesis/exceptions.py b/moto/kinesis/exceptions.py index 519dcb35e..93822318a 100644 --- a/moto/kinesis/exceptions.py +++ b/moto/kinesis/exceptions.py @@ -23,6 +23,15 @@ class StreamNotFoundError(ResourceNotFoundError): super().__init__(f"Stream {stream_name} under account {account_id} not found.") +class StreamCannotBeUpdatedError(BadRequest): + def __init__(self, stream_name, account_id): + super().__init__() + message = f"Request is invalid. Stream {stream_name} under account {account_id} is in On-Demand mode." + self.description = json.dumps( + {"message": message, "__type": "ValidationException"} + ) + + class ShardNotFoundError(ResourceNotFoundError): def __init__(self, shard_id, stream, account_id): super().__init__( diff --git a/moto/kinesis/models.py b/moto/kinesis/models.py index f53f10279..c6e206366 100644 --- a/moto/kinesis/models.py +++ b/moto/kinesis/models.py @@ -12,6 +12,7 @@ from moto.utilities.utils import md5_hash from .exceptions import ( ConsumerNotFound, StreamNotFoundError, + StreamCannotBeUpdatedError, ShardNotFoundError, ResourceInUseError, ResourceNotFoundError, @@ -164,7 +165,13 @@ class Shard(BaseModel): class Stream(CloudFormationModel): def __init__( - self, stream_name, shard_count, retention_period_hours, account_id, region_name + self, + stream_name, + shard_count, + stream_mode, + retention_period_hours, + account_id, + region_name, ): self.stream_name = stream_name self.creation_datetime = datetime.datetime.now().strftime( @@ -177,10 +184,11 @@ class Stream(CloudFormationModel): self.tags = {} self.status = "ACTIVE" self.shard_count = None + self.stream_mode = stream_mode or {"StreamMode": "PROVISIONED"} + if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": + shard_count = 4 self.init_shards(shard_count) - self.retention_period_hours = ( - retention_period_hours if retention_period_hours else 24 - ) + self.retention_period_hours = retention_period_hours or 24 self.shard_level_metrics = [] self.encryption_type = "NONE" self.key_id = None @@ -289,6 +297,10 @@ class Stream(CloudFormationModel): ) def update_shard_count(self, target_shard_count): + if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": + raise StreamCannotBeUpdatedError( + stream_name=self.stream_name, account_id=self.account_id + ) current_shard_count = len([s for s in self.shards.values() if s.is_open]) if current_shard_count == target_shard_count: return @@ -393,8 +405,12 @@ class Stream(CloudFormationModel): "StreamARN": self.arn, "StreamName": self.stream_name, "StreamStatus": self.status, + "StreamModeDetails": self.stream_mode, + "RetentionPeriodHours": self.retention_period_hours, "StreamCreationTimestamp": self.creation_datetime, + "EnhancedMonitoring": [{"ShardLevelMetrics": self.shard_level_metrics}], "OpenShardCount": self.shard_count, + "EncryptionType": self.encryption_type, } } @@ -421,7 +437,7 @@ class Stream(CloudFormationModel): backend = kinesis_backends[account_id][region_name] stream = backend.create_stream( - resource_name, shard_count, retention_period_hours + resource_name, shard_count, retention_period_hours=retention_period_hours ) if any(tags): backend.add_tags_to_stream(stream.stream_name, tags) @@ -510,15 +526,18 @@ class KinesisBackend(BaseBackend): service_region, zones, "kinesis", special_service_name="kinesis-streams" ) - def create_stream(self, stream_name, shard_count, retention_period_hours): + def create_stream( + self, stream_name, shard_count, stream_mode=None, retention_period_hours=None + ): if stream_name in self.streams: raise ResourceInUseError(stream_name) stream = Stream( stream_name, shard_count, - retention_period_hours, - self.account_id, - self.region_name, + stream_mode=stream_mode, + retention_period_hours=retention_period_hours, + account_id=self.account_id, + region_name=self.region_name, ) self.streams[stream_name] = stream return stream diff --git a/moto/kinesis/responses.py b/moto/kinesis/responses.py index d27475094..69f00204f 100644 --- a/moto/kinesis/responses.py +++ b/moto/kinesis/responses.py @@ -19,9 +19,9 @@ class KinesisResponse(BaseResponse): def create_stream(self): stream_name = self.parameters.get("StreamName") shard_count = self.parameters.get("ShardCount") - retention_period_hours = self.parameters.get("RetentionPeriodHours") + stream_mode = self.parameters.get("StreamModeDetails") self.kinesis_backend.create_stream( - stream_name, shard_count, retention_period_hours + stream_name, shard_count, stream_mode=stream_mode ) return "" diff --git a/tests/terraformtests/terraform-tests.success.txt b/tests/terraformtests/terraform-tests.success.txt index 8fae5fadf..d5875c21a 100644 --- a/tests/terraformtests/terraform-tests.success.txt +++ b/tests/terraformtests/terraform-tests.success.txt @@ -134,6 +134,9 @@ iam: - TestAccIAMUserSSHKeyDataSource_ iot: - TestAccIoTEndpointDataSource +kinesis: + - TestAccKinesisStream_basic + - TestAccKinesisStream_disappear kms: - TestAccKMSAlias - TestAccKMSGrant_arn diff --git a/tests/test_kinesis/test_kinesis.py b/tests/test_kinesis/test_kinesis.py index a5acfa1b7..aafd2685c 100644 --- a/tests/test_kinesis/test_kinesis.py +++ b/tests/test_kinesis/test_kinesis.py @@ -13,6 +13,29 @@ from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import sure # noqa # pylint: disable=unused-import +@mock_kinesis +def test_stream_creation_on_demand(): + client = boto3.client("kinesis", region_name="eu-west-1") + client.create_stream( + StreamName="my_stream", StreamModeDetails={"StreamMode": "ON_DEMAND"} + ) + + # AWS starts with 4 shards by default + shard_list = client.list_shards(StreamName="my_stream")["Shards"] + shard_list.should.have.length_of(4) + + # Cannot update-shard-count when we're in on-demand mode + with pytest.raises(ClientError) as exc: + client.update_shard_count( + StreamName="my_stream", TargetShardCount=3, ScalingType="UNIFORM_SCALING" + ) + err = exc.value.response["Error"] + err["Code"].should.equal("ValidationException") + err["Message"].should.equal( + f"Request is invalid. Stream my_stream under account {ACCOUNT_ID} is in On-Demand mode." + ) + + @mock_kinesis def test_describe_non_existent_stream_boto3(): client = boto3.client("kinesis", region_name="us-west-2")