diff --git a/moto/dynamodbstreams/models.py b/moto/dynamodbstreams/models.py index 5f524d4b3..a4764fd36 100644 --- a/moto/dynamodbstreams/models.py +++ b/moto/dynamodbstreams/models.py @@ -1,14 +1,20 @@ import os import json import base64 +from typing import Any, Dict, Optional from moto.core import BaseBackend, BackendDict, BaseModel -from moto.dynamodb.models import dynamodb_backends, DynamoJsonEncoder +from moto.dynamodb.models import dynamodb_backends, DynamoJsonEncoder, DynamoDBBackend +from moto.dynamodb.models import Table, StreamShard class ShardIterator(BaseModel): def __init__( - self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None + self, + streams_backend: "DynamoDBStreamsBackend", + stream_shard: StreamShard, + shard_iterator_type: str, + sequence_number: Optional[int] = None, ): self.id = base64.b64encode(os.urandom(472)).decode("utf-8") self.streams_backend = streams_backend @@ -21,19 +27,19 @@ class ShardIterator(BaseModel): stream_shard.items ) elif shard_iterator_type == "AT_SEQUENCE_NUMBER": - self.sequence_number = sequence_number + self.sequence_number = sequence_number # type: ignore[assignment] elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER": - self.sequence_number = sequence_number + 1 + self.sequence_number = sequence_number + 1 # type: ignore[operator] @property - def arn(self): + def arn(self) -> str: return f"{self.stream_shard.table.table_arn}/stream/{self.stream_shard.table.latest_stream_label}|1|{self.id}" - def to_json(self): + def to_json(self) -> Dict[str, str]: return {"ShardIterator": self.arn} - def get(self, limit=1000): - items = self.stream_shard.get(self.sequence_number, limit) + def get(self, limit: int = 1000) -> Dict[str, Any]: + items = self.stream_shard.get(self.sequence_number, limit) # type: ignore[no-untyped-call] try: last_sequence_number = max( int(i["dynamodb"]["SequenceNumber"]) for i in items @@ -59,19 +65,19 @@ class ShardIterator(BaseModel): class DynamoDBStreamsBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.shard_iterators = {} + self.shard_iterators: Dict[str, ShardIterator] = {} @property - def dynamodb(self): + def dynamodb(self) -> DynamoDBBackend: return dynamodb_backends[self.account_id][self.region_name] - def _get_table_from_arn(self, arn): + def _get_table_from_arn(self, arn: str) -> Table: table_name = arn.split(":", 6)[5].split("/")[1] - return self.dynamodb.get_table(table_name) + return self.dynamodb.get_table(table_name) # type: ignore[no-untyped-call] - def describe_stream(self, arn): + def describe_stream(self, arn: str) -> str: table = self._get_table_from_arn(arn) resp = { "StreamDescription": { @@ -81,7 +87,7 @@ class DynamoDBStreamsBackend(BaseBackend): "ENABLED" if table.latest_stream_label else "DISABLED" ), "StreamViewType": table.stream_specification["StreamViewType"], - "CreationRequestDateTime": table.stream_shard.created_on.isoformat(), + "CreationRequestDateTime": table.stream_shard.created_on.isoformat(), # type: ignore[union-attr] "TableName": table.name, "KeySchema": table.schema, "Shards": ( @@ -92,7 +98,7 @@ class DynamoDBStreamsBackend(BaseBackend): return json.dumps(resp) - def list_streams(self, table_name=None): + def list_streams(self, table_name: Optional[str] = None) -> str: streams = [] for table in self.dynamodb.tables.values(): if table_name is not None and table.name != table_name: @@ -110,19 +116,23 @@ class DynamoDBStreamsBackend(BaseBackend): return json.dumps({"Streams": streams}) def get_shard_iterator( - self, arn, shard_id, shard_iterator_type, sequence_number=None - ): + self, + arn: str, + shard_id: str, + shard_iterator_type: str, + sequence_number: Optional[str] = None, + ) -> str: table = self._get_table_from_arn(arn) - assert table.stream_shard.id == shard_id + assert table.stream_shard.id == shard_id # type: ignore[union-attr] shard_iterator = ShardIterator( - self, table.stream_shard, shard_iterator_type, sequence_number + self, table.stream_shard, shard_iterator_type, sequence_number # type: ignore[arg-type] ) self.shard_iterators[shard_iterator.arn] = shard_iterator return json.dumps(shard_iterator.to_json()) - def get_records(self, iterator_arn, limit): + def get_records(self, iterator_arn: str, limit: int) -> str: shard_iterator = self.shard_iterators[iterator_arn] return json.dumps(shard_iterator.get(limit), cls=DynamoJsonEncoder) diff --git a/moto/dynamodbstreams/responses.py b/moto/dynamodbstreams/responses.py index ff8a78650..f9a0dc990 100644 --- a/moto/dynamodbstreams/responses.py +++ b/moto/dynamodbstreams/responses.py @@ -1,25 +1,25 @@ from moto.core.responses import BaseResponse -from .models import dynamodbstreams_backends +from .models import dynamodbstreams_backends, DynamoDBStreamsBackend class DynamoDBStreamsHandler(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="dynamodb-streams") @property - def backend(self): + def backend(self) -> DynamoDBStreamsBackend: return dynamodbstreams_backends[self.current_account][self.region] - def describe_stream(self): + def describe_stream(self) -> str: arn = self._get_param("StreamArn") return self.backend.describe_stream(arn) - def list_streams(self): + def list_streams(self) -> str: table_name = self._get_param("TableName") return self.backend.list_streams(table_name) - def get_shard_iterator(self): + def get_shard_iterator(self) -> str: arn = self._get_param("StreamArn") shard_id = self._get_param("ShardId") shard_iterator_type = self._get_param("ShardIteratorType") @@ -32,7 +32,7 @@ class DynamoDBStreamsHandler(BaseResponse): arn, shard_id, shard_iterator_type, sequence_number ) - def get_records(self): + def get_records(self) -> str: arn = self._get_param("ShardIterator") limit = self._get_param("Limit") if limit is None: diff --git a/setup.cfg b/setup.cfg index 9e857dfdb..df0761ab6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/moto_api +files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/ds,moto/dynamodbstreams,moto/moto_api show_column_numbers=True show_error_codes = True disable_error_code=abstract