Techdebt: Type the Paginator-class (#7330)

This commit is contained in:
Bert Blommers 2024-02-09 22:20:16 +00:00 committed by GitHub
parent ff9dda224f
commit b98c17552d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 91 additions and 85 deletions

View File

@ -118,7 +118,7 @@ class PrometheusServiceBackend(BaseBackend):
"""
self.workspaces.pop(workspace_id, None)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore
@paginate(pagination_model=PAGINATION_MODEL)
def list_workspaces(self, alias: str) -> List[Workspace]:
if alias:
return [w for w in self.workspaces.values() if w.alias == alias]
@ -175,7 +175,7 @@ class PrometheusServiceBackend(BaseBackend):
ns.update(data)
return ns
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore
@paginate(pagination_model=PAGINATION_MODEL)
def list_rule_groups_namespaces(
self, name: str, workspace_id: str
) -> List[RuleGroupNamespace]:

View File

@ -340,7 +340,7 @@ class AthenaBackend(BaseBackend):
self.data_catalogs[name] = data_catalog
return data_catalog
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore
@paginate(pagination_model=PAGINATION_MODEL)
def list_named_queries(self, work_group: str) -> List[str]:
named_query_ids = [
q.id for q in self.named_queries.values() if q.workgroup.name == work_group

View File

@ -973,7 +973,7 @@ class CognitoIdpBackend(BaseBackend):
"MfaConfiguration": user_pool.mfa_config,
}
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_user_pools(self) -> List[CognitoIdpUserPool]:
return list(self.user_pools.values())
@ -1046,7 +1046,7 @@ class CognitoIdpBackend(BaseBackend):
user_pool.clients[user_pool_client.id] = user_pool_client
return user_pool_client
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_user_pool_clients(
self, user_pool_id: str
) -> List[CognitoIdpUserPoolClient]:
@ -1095,7 +1095,7 @@ class CognitoIdpBackend(BaseBackend):
user_pool.identity_providers[name] = identity_provider
return identity_provider
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_identity_providers(
self, user_pool_id: str
) -> List[CognitoIdpIdentityProvider]:
@ -1163,7 +1163,7 @@ class CognitoIdpBackend(BaseBackend):
return user_pool.groups[group_name]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_groups(self, user_pool_id: str) -> List[CognitoIdpGroup]:
user_pool = self.describe_user_pool(user_pool_id)
@ -1204,7 +1204,7 @@ class CognitoIdpBackend(BaseBackend):
group.users.add(user)
user.groups.add(group)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_users_in_group(
self, user_pool_id: str, group_name: str
) -> List[CognitoIdpUser]:
@ -1342,7 +1342,7 @@ class CognitoIdpBackend(BaseBackend):
return user
raise NotAuthorizedError("Invalid token")
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_users(self, user_pool_id: str) -> List[CognitoIdpUser]:
user_pool = self.describe_user_pool(user_pool_id)
@ -1759,7 +1759,7 @@ class CognitoIdpBackend(BaseBackend):
return resource_server
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_resource_servers(self, user_pool_id: str) -> List[CognitoResourceServer]:
user_pool = self.user_pools[user_pool_id]
resource_servers = list(user_pool.resource_servers.values())

View File

@ -131,7 +131,7 @@ class DataBrewBackend(BaseBackend):
recipe = self.recipes[recipe_name]
recipe.update(recipe_description, recipe_steps)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_recipes(
self, recipe_version: Optional[str] = None
) -> List["FakeRecipeVersion"]:
@ -148,7 +148,7 @@ class DataBrewBackend(BaseBackend):
recipes = [getattr(self.recipes[key], version) for key in self.recipes]
return [r for r in recipes if r is not None]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_recipe_versions(self, recipe_name: str) -> List["FakeRecipeVersion"]:
# https://docs.aws.amazon.com/databrew/latest/dg/API_ListRecipeVersions.html
self.validate_length(recipe_name, "name", 255)
@ -252,7 +252,7 @@ class DataBrewBackend(BaseBackend):
raise RulesetNotFoundException(ruleset_name)
return self.rulesets[ruleset_name]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_rulesets(self) -> List["FakeRuleset"]:
return list(self.rulesets.values())
@ -287,7 +287,7 @@ class DataBrewBackend(BaseBackend):
self.datasets[dataset_name] = dataset
return dataset
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_datasets(self) -> List["FakeDataset"]:
return list(self.datasets.values())
@ -404,7 +404,7 @@ class DataBrewBackend(BaseBackend):
# https://docs.aws.amazon.com/databrew/latest/dg/API_UpdateProfileJob.html
return self.update_job(**kwargs)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_jobs(
self, dataset_name: Optional[str] = None, project_name: Optional[str] = None
) -> List["FakeJob"]:

View File

@ -212,7 +212,7 @@ class DAXBackend(BaseBackend):
self.clusters[cluster_name].delete()
return self.clusters[cluster_name]
@paginate(PAGINATION_MODEL) # type: ignore[misc]
@paginate(PAGINATION_MODEL)
def describe_clusters(self, cluster_names: Iterable[str]) -> List[DaxCluster]:
clusters = self.clusters
if not cluster_names:

View File

@ -477,7 +477,7 @@ class DirectoryServiceBackend(BaseBackend):
directory = self.directories[directory_id]
directory.enable_sso(True)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def describe_directories(
self, directory_ids: Optional[List[str]] = None
) -> List[Directory]:
@ -533,7 +533,7 @@ class DirectoryServiceBackend(BaseBackend):
self._validate_directory_id(resource_id)
self.tagger.untag_resource_using_names(resource_id, tag_keys)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_tags_for_resource(self, resource_id: str) -> List[Dict[str, str]]:
"""List all tags on a directory."""
self._validate_directory_id(resource_id)

View File

@ -1183,7 +1183,7 @@ class EventsBackend(BaseBackend):
return False
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_rule_names_by_target(
self, target_arn: str, event_bus_arn: Optional[str]
) -> List[Rule]:
@ -1198,7 +1198,7 @@ class EventsBackend(BaseBackend):
return matching_rules
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_rules(
self, prefix: Optional[str] = None, event_bus_arn: Optional[str] = None
) -> List[Rule]:

View File

@ -331,7 +331,7 @@ class GlueBackend(BaseBackend):
def get_crawlers(self) -> List["FakeCrawler"]:
return [self.crawlers[key] for key in self.crawlers] if self.crawlers else []
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_crawlers(self) -> List["FakeCrawler"]:
return [crawler for _, crawler in self.crawlers.items()]
@ -406,7 +406,7 @@ class GlueBackend(BaseBackend):
except KeyError:
raise JobNotFoundException(name)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def get_jobs(self) -> List["FakeJob"]:
return [job for _, job in self.jobs.items()]
@ -418,7 +418,7 @@ class GlueBackend(BaseBackend):
job = self.get_job(name)
return job.get_job_run(run_id)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_jobs(self) -> List["FakeJob"]:
return [job for _, job in self.jobs.items()]
@ -829,7 +829,7 @@ class GlueBackend(BaseBackend):
except KeyError:
raise SessionNotFoundException(session_id)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_sessions(self) -> List["FakeSession"]:
return [session for _, session in self.sessions.items()]
@ -884,7 +884,7 @@ class GlueBackend(BaseBackend):
trigger = self.get_trigger(name)
trigger.stop_trigger()
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def get_triggers(self, dependent_job_name: str) -> List["FakeTrigger"]:
if dependent_job_name:
triggers = []
@ -898,7 +898,7 @@ class GlueBackend(BaseBackend):
return list(self.triggers.values())
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_triggers(self, dependent_job_name: str) -> List["FakeTrigger"]:
if dependent_job_name:
triggers = []

View File

@ -4,7 +4,14 @@ from typing import Any, Dict, List
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from .models import FakeCrawler, FakeJob, FakeSession, GlueBackend, glue_backends
from .models import (
FakeCrawler,
FakeJob,
FakeSession,
FakeTrigger,
GlueBackend,
glue_backends,
)
class GlueResponse(BaseResponse):
@ -417,7 +424,7 @@ class GlueResponse(BaseResponse):
return [job.get_name() for job in jobs if self.is_tags_match(job.arn, tags)]
def filter_triggers_by_tags(
self, triggers: List[FakeJob], tags: Dict[str, str]
self, triggers: List[FakeTrigger], tags: Dict[str, str]
) -> List[str]:
if not tags:
return [trigger.get_name() for trigger in triggers]

View File

@ -301,7 +301,7 @@ class IdentityStoreBackend(BaseBackend):
return [m._asdict() for m in identity_store.groups.values()]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore
@paginate(pagination_model=PAGINATION_MODEL)
def list_users(
self, identity_store_id: str, filters: List[Dict[str, str]]
) -> List[Dict[str, str]]:

View File

@ -157,9 +157,7 @@ class KinesisResponse(BaseResponse):
limit=max_results,
next_token=next_token,
)
res = {"Shards": shards}
if token:
res["NextToken"] = token
res = {"Shards": shards, "NextToken": token}
return json.dumps(res)
def update_shard_count(self) -> str:

View File

@ -537,7 +537,7 @@ class OrganizationsBackend(BaseBackend):
next_token = str(len(accounts_resp))
return dict(CreateAccountStatuses=accounts_resp, NextToken=next_token)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_accounts(self) -> List[FakeAccount]:
accounts = [account.describe() for account in self.accounts]
return sorted(accounts, key=lambda x: x["JoinedTimestamp"]) # type: ignore

View File

@ -251,7 +251,7 @@ class PanoramaBackend(BaseBackend):
raise ValidationError(f"Device {device_id} not found")
return device
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_devices(
self,
device_aggregated_status_filter: str,

View File

@ -898,9 +898,8 @@ class Route53Backend(BaseBackend):
from moto.logs import logs_backends # pylint: disable=import-outside-toplevel
response = logs_backends[self.account_id][region].describe_log_groups()
log_groups = response[0] if response else []
for entry in log_groups: # type: ignore
log_groups = logs_backends[self.account_id][region].describe_log_groups()
for entry in log_groups[0] if log_groups else []:
if log_group_arn == entry["arn"]:
break
else:
@ -935,7 +934,7 @@ class Route53Backend(BaseBackend):
raise NoSuchQueryLoggingConfig()
return self.query_logging_configs[query_logging_config_id]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_query_logging_configs(
self, hosted_zone_id: Optional[str] = None
) -> List[QueryLoggingConfig]:

View File

@ -814,7 +814,7 @@ class Route53ResolverBackend(BaseBackend):
return True
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_endpoints(self, filters: Any) -> List[ResolverEndpoint]:
if not filters:
filters = []
@ -828,7 +828,7 @@ class Route53ResolverBackend(BaseBackend):
endpoints.append(endpoint)
return endpoints
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_rules(self, filters: Any) -> List[ResolverRule]:
if not filters:
filters = []
@ -842,7 +842,7 @@ class Route53ResolverBackend(BaseBackend):
rules.append(rule)
return rules
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_rule_associations(
self, filters: Any
) -> List[ResolverRuleAssociation]:
@ -872,12 +872,10 @@ class Route53ResolverBackend(BaseBackend):
f"Resolver endpoint with ID '{resource_arn}' does not exist"
)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
def list_tags_for_resource(
self, resource_arn: str
) -> Optional[List[Dict[str, str]]]:
@paginate(pagination_model=PAGINATION_MODEL)
def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]:
self._matched_arn(resource_arn)
return self.tagger.list_tags_for_resource(resource_arn).get("Tags")
return self.tagger.list_tags_for_resource(resource_arn)["Tags"]
def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
self._matched_arn(resource_arn)

View File

@ -1971,7 +1971,7 @@ class SageMakerModelBackend(BaseBackend):
resource.tags.extend(tags)
return resource.tags
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_tags(self, arn: str) -> List[Dict[str, str]]:
resource = self._get_resource_from_arn(arn)
return resource.tags
@ -1980,7 +1980,7 @@ class SageMakerModelBackend(BaseBackend):
resource = self._get_resource_from_arn(arn)
resource.tags = [tag for tag in resource.tags if tag["Key"] not in tag_keys]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_experiments(self) -> List["FakeExperiment"]:
return list(self.experiments.values())
@ -2150,7 +2150,7 @@ class SageMakerModelBackend(BaseBackend):
message=f"Could not find trial configuration '{arn}'."
)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_trials(
self,
experiment_name: Optional[str] = None,
@ -2213,7 +2213,7 @@ class SageMakerModelBackend(BaseBackend):
) -> None:
self.trial_components[trial_component_name].update(details_json)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_trial_components(
self, trial_name: Optional[str] = None
) -> List["FakeTrialComponent"]:
@ -2337,14 +2337,14 @@ class SageMakerModelBackend(BaseBackend):
raise ValidationError(message=message)
del self.notebook_instances[notebook_instance_name]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_notebook_instances(
self,
sort_by: str,
sort_order: str,
name_contains: Optional[str],
status: Optional[str],
) -> Iterable[FakeSagemakerNotebookInstance]:
) -> List[FakeSagemakerNotebookInstance]:
"""
The following parameters are not yet implemented:
CreationTimeBefore, CreationTimeAfter, LastModifiedTimeBefore, LastModifiedTimeAfter, NotebookInstanceLifecycleConfigNameContains, DefaultCodeRepositoryContains, AdditionalCodeRepositoryEquals
@ -3283,8 +3283,8 @@ class SageMakerModelBackend(BaseBackend):
return True
raise ValueError(f"Invalid model package type: {model_package_type}")
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
def list_model_package_groups( # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_model_package_groups(
self,
creation_time_after: Optional[int],
creation_time_before: Optional[int],
@ -3346,8 +3346,8 @@ class SageMakerModelBackend(BaseBackend):
)
return model_package_group
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
def list_model_packages( # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_model_packages(
self,
creation_time_after: Optional[int],
creation_time_before: Optional[int],

View File

@ -322,11 +322,9 @@ class SageMakerResponse(BaseResponse):
response = {
"ExperimentSummaries": experiment_summaries,
"NextToken": next_token,
}
if next_token:
response["NextToken"] = next_token
return 200, {}, json.dumps(response)
def delete_experiment(self) -> TYPE_RESPONSE:
@ -368,12 +366,7 @@ class SageMakerResponse(BaseResponse):
for trial_data in paged_results
]
response = {
"TrialSummaries": trial_summaries,
}
if next_token:
response["NextToken"] = next_token
response = {"TrialSummaries": trial_summaries, "NextToken": next_token}
return 200, {}, json.dumps(response)
@ -406,11 +399,9 @@ class SageMakerResponse(BaseResponse):
response = {
"TrialComponentSummaries": trial_component_summaries,
"NextToken": next_token,
}
if next_token:
response["NextToken"] = next_token
return 200, {}, json.dumps(response)
def create_trial_component(self) -> TYPE_RESPONSE:

View File

@ -337,7 +337,7 @@ class SSOAdminBackend(BaseBackend):
message=f"Could not find PermissionSet with id {ps_id}"
)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_permission_sets(self, instance_arn: str) -> List[PermissionSet]:
permission_sets = []
for permission_set in self.permission_sets:
@ -399,7 +399,7 @@ class SSOAdminBackend(BaseBackend):
permissionset.managed_policies.append(managed_policy)
permissionset.total_managed_policies_attached += 1
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_managed_policies_in_permission_set(
self,
instance_arn: str,
@ -467,7 +467,7 @@ class SSOAdminBackend(BaseBackend):
permissionset.customer_managed_policies.append(customer_managed_policy)
permissionset.total_managed_policies_attached += 1
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_customer_managed_policy_references_in_permission_set(
self, instance_arn: str, permission_set_arn: str
) -> List[CustomerManagedPolicy]:

View File

@ -1,7 +1,7 @@
import json
import re
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Pattern
from typing import Any, Dict, List, Optional, Pattern
from dateutil.tz import tzlocal
@ -504,8 +504,8 @@ class StepFunctionBackend(BaseBackend):
self.state_machines.append(state_machine)
return state_machine
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
def list_state_machines(self) -> Iterable[StateMachine]:
@paginate(pagination_model=PAGINATION_MODEL)
def list_state_machines(self) -> List[StateMachine]:
return sorted(self.state_machines, key=lambda x: x.creation_date)
def describe_state_machine(self, arn: str) -> StateMachine:
@ -552,10 +552,10 @@ class StepFunctionBackend(BaseBackend):
state_machine = self._get_state_machine_for_execution(execution_arn)
return state_machine.stop_execution(execution_arn)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
@paginate(pagination_model=PAGINATION_MODEL)
def list_executions(
self, state_machine_arn: str, status_filter: Optional[str] = None
) -> Iterable[Execution]:
) -> List[Execution]:
"""
The status of every execution is set to 'RUNNING' by default.
Set the following environment variable if you want to get a FAILED status back:
@ -569,8 +569,7 @@ class StepFunctionBackend(BaseBackend):
if status_filter:
executions = list(filter(lambda e: e.status == status_filter, executions))
executions = sorted(executions, key=lambda x: x.start_date, reverse=True)
return executions
return sorted(executions, key=lambda x: x.start_date, reverse=True)
def describe_execution(self, execution_arn: str) -> Execution:
"""

View File

@ -1,20 +1,34 @@
import inspect
from copy import deepcopy
from functools import wraps
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar
from botocore.paginate import TokenDecoder, TokenEncoder
from moto.core.exceptions import InvalidToken
# This should be typed using ParamSpec
# https://stackoverflow.com/a/70591060/13245310
# This currently does not work for our usecase
# I believe this could be fixed after https://github.com/python/mypy/pull/14903 is accepted
if TYPE_CHECKING:
from typing_extensions import ParamSpec, Protocol
P1 = ParamSpec("P1")
P2 = ParamSpec("P2")
else:
Protocol = object
T = TypeVar("T")
def paginate(pagination_model: Dict[str, Any]) -> Any:
def pagination_decorator(func: Any) -> Any:
class GenericFunction(Protocol):
def __call__(
self, func: "Callable[P1, List[T]]"
) -> "Callable[P2, Tuple[List[T], Optional[str]]]":
...
def paginate(pagination_model: Dict[str, Any]) -> GenericFunction:
def pagination_decorator(
func: Callable[..., List[T]]
) -> Callable[..., Tuple[List[T], Optional[str]]]:
@wraps(func)
def pagination_wrapper(*args: Any, **kwargs: Any) -> Any: # type: ignore
method = func.__name__