Techdebt: Type the Paginator-class (#7330)

This commit is contained in:
Bert Blommers 2024-02-09 22:20:16 +00:00 committed by GitHub
parent ff9dda224f
commit b98c17552d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 91 additions and 85 deletions

View File

@ -118,7 +118,7 @@ class PrometheusServiceBackend(BaseBackend):
""" """
self.workspaces.pop(workspace_id, None) self.workspaces.pop(workspace_id, None)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore @paginate(pagination_model=PAGINATION_MODEL)
def list_workspaces(self, alias: str) -> List[Workspace]: def list_workspaces(self, alias: str) -> List[Workspace]:
if alias: if alias:
return [w for w in self.workspaces.values() if w.alias == alias] return [w for w in self.workspaces.values() if w.alias == alias]
@ -175,7 +175,7 @@ class PrometheusServiceBackend(BaseBackend):
ns.update(data) ns.update(data)
return ns return ns
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore @paginate(pagination_model=PAGINATION_MODEL)
def list_rule_groups_namespaces( def list_rule_groups_namespaces(
self, name: str, workspace_id: str self, name: str, workspace_id: str
) -> List[RuleGroupNamespace]: ) -> List[RuleGroupNamespace]:

View File

@ -340,7 +340,7 @@ class AthenaBackend(BaseBackend):
self.data_catalogs[name] = data_catalog self.data_catalogs[name] = data_catalog
return data_catalog return data_catalog
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore @paginate(pagination_model=PAGINATION_MODEL)
def list_named_queries(self, work_group: str) -> List[str]: def list_named_queries(self, work_group: str) -> List[str]:
named_query_ids = [ named_query_ids = [
q.id for q in self.named_queries.values() if q.workgroup.name == work_group q.id for q in self.named_queries.values() if q.workgroup.name == work_group

View File

@ -973,7 +973,7 @@ class CognitoIdpBackend(BaseBackend):
"MfaConfiguration": user_pool.mfa_config, "MfaConfiguration": user_pool.mfa_config,
} }
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_user_pools(self) -> List[CognitoIdpUserPool]: def list_user_pools(self) -> List[CognitoIdpUserPool]:
return list(self.user_pools.values()) return list(self.user_pools.values())
@ -1046,7 +1046,7 @@ class CognitoIdpBackend(BaseBackend):
user_pool.clients[user_pool_client.id] = user_pool_client user_pool.clients[user_pool_client.id] = user_pool_client
return user_pool_client return user_pool_client
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_user_pool_clients( def list_user_pool_clients(
self, user_pool_id: str self, user_pool_id: str
) -> List[CognitoIdpUserPoolClient]: ) -> List[CognitoIdpUserPoolClient]:
@ -1095,7 +1095,7 @@ class CognitoIdpBackend(BaseBackend):
user_pool.identity_providers[name] = identity_provider user_pool.identity_providers[name] = identity_provider
return identity_provider return identity_provider
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_identity_providers( def list_identity_providers(
self, user_pool_id: str self, user_pool_id: str
) -> List[CognitoIdpIdentityProvider]: ) -> List[CognitoIdpIdentityProvider]:
@ -1163,7 +1163,7 @@ class CognitoIdpBackend(BaseBackend):
return user_pool.groups[group_name] return user_pool.groups[group_name]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_groups(self, user_pool_id: str) -> List[CognitoIdpGroup]: def list_groups(self, user_pool_id: str) -> List[CognitoIdpGroup]:
user_pool = self.describe_user_pool(user_pool_id) user_pool = self.describe_user_pool(user_pool_id)
@ -1204,7 +1204,7 @@ class CognitoIdpBackend(BaseBackend):
group.users.add(user) group.users.add(user)
user.groups.add(group) user.groups.add(group)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_users_in_group( def list_users_in_group(
self, user_pool_id: str, group_name: str self, user_pool_id: str, group_name: str
) -> List[CognitoIdpUser]: ) -> List[CognitoIdpUser]:
@ -1342,7 +1342,7 @@ class CognitoIdpBackend(BaseBackend):
return user return user
raise NotAuthorizedError("Invalid token") raise NotAuthorizedError("Invalid token")
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_users(self, user_pool_id: str) -> List[CognitoIdpUser]: def list_users(self, user_pool_id: str) -> List[CognitoIdpUser]:
user_pool = self.describe_user_pool(user_pool_id) user_pool = self.describe_user_pool(user_pool_id)
@ -1759,7 +1759,7 @@ class CognitoIdpBackend(BaseBackend):
return resource_server return resource_server
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_resource_servers(self, user_pool_id: str) -> List[CognitoResourceServer]: def list_resource_servers(self, user_pool_id: str) -> List[CognitoResourceServer]:
user_pool = self.user_pools[user_pool_id] user_pool = self.user_pools[user_pool_id]
resource_servers = list(user_pool.resource_servers.values()) resource_servers = list(user_pool.resource_servers.values())

View File

@ -131,7 +131,7 @@ class DataBrewBackend(BaseBackend):
recipe = self.recipes[recipe_name] recipe = self.recipes[recipe_name]
recipe.update(recipe_description, recipe_steps) recipe.update(recipe_description, recipe_steps)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_recipes( def list_recipes(
self, recipe_version: Optional[str] = None self, recipe_version: Optional[str] = None
) -> List["FakeRecipeVersion"]: ) -> List["FakeRecipeVersion"]:
@ -148,7 +148,7 @@ class DataBrewBackend(BaseBackend):
recipes = [getattr(self.recipes[key], version) for key in self.recipes] recipes = [getattr(self.recipes[key], version) for key in self.recipes]
return [r for r in recipes if r is not None] return [r for r in recipes if r is not None]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_recipe_versions(self, recipe_name: str) -> List["FakeRecipeVersion"]: def list_recipe_versions(self, recipe_name: str) -> List["FakeRecipeVersion"]:
# https://docs.aws.amazon.com/databrew/latest/dg/API_ListRecipeVersions.html # https://docs.aws.amazon.com/databrew/latest/dg/API_ListRecipeVersions.html
self.validate_length(recipe_name, "name", 255) self.validate_length(recipe_name, "name", 255)
@ -252,7 +252,7 @@ class DataBrewBackend(BaseBackend):
raise RulesetNotFoundException(ruleset_name) raise RulesetNotFoundException(ruleset_name)
return self.rulesets[ruleset_name] return self.rulesets[ruleset_name]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_rulesets(self) -> List["FakeRuleset"]: def list_rulesets(self) -> List["FakeRuleset"]:
return list(self.rulesets.values()) return list(self.rulesets.values())
@ -287,7 +287,7 @@ class DataBrewBackend(BaseBackend):
self.datasets[dataset_name] = dataset self.datasets[dataset_name] = dataset
return dataset return dataset
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_datasets(self) -> List["FakeDataset"]: def list_datasets(self) -> List["FakeDataset"]:
return list(self.datasets.values()) return list(self.datasets.values())
@ -404,7 +404,7 @@ class DataBrewBackend(BaseBackend):
# https://docs.aws.amazon.com/databrew/latest/dg/API_UpdateProfileJob.html # https://docs.aws.amazon.com/databrew/latest/dg/API_UpdateProfileJob.html
return self.update_job(**kwargs) return self.update_job(**kwargs)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_jobs( def list_jobs(
self, dataset_name: Optional[str] = None, project_name: Optional[str] = None self, dataset_name: Optional[str] = None, project_name: Optional[str] = None
) -> List["FakeJob"]: ) -> List["FakeJob"]:

View File

@ -212,7 +212,7 @@ class DAXBackend(BaseBackend):
self.clusters[cluster_name].delete() self.clusters[cluster_name].delete()
return self.clusters[cluster_name] return self.clusters[cluster_name]
@paginate(PAGINATION_MODEL) # type: ignore[misc] @paginate(PAGINATION_MODEL)
def describe_clusters(self, cluster_names: Iterable[str]) -> List[DaxCluster]: def describe_clusters(self, cluster_names: Iterable[str]) -> List[DaxCluster]:
clusters = self.clusters clusters = self.clusters
if not cluster_names: if not cluster_names:

View File

@ -477,7 +477,7 @@ class DirectoryServiceBackend(BaseBackend):
directory = self.directories[directory_id] directory = self.directories[directory_id]
directory.enable_sso(True) directory.enable_sso(True)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def describe_directories( def describe_directories(
self, directory_ids: Optional[List[str]] = None self, directory_ids: Optional[List[str]] = None
) -> List[Directory]: ) -> List[Directory]:
@ -533,7 +533,7 @@ class DirectoryServiceBackend(BaseBackend):
self._validate_directory_id(resource_id) self._validate_directory_id(resource_id)
self.tagger.untag_resource_using_names(resource_id, tag_keys) self.tagger.untag_resource_using_names(resource_id, tag_keys)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_tags_for_resource(self, resource_id: str) -> List[Dict[str, str]]: def list_tags_for_resource(self, resource_id: str) -> List[Dict[str, str]]:
"""List all tags on a directory.""" """List all tags on a directory."""
self._validate_directory_id(resource_id) self._validate_directory_id(resource_id)

View File

@ -1183,7 +1183,7 @@ class EventsBackend(BaseBackend):
return False return False
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_rule_names_by_target( def list_rule_names_by_target(
self, target_arn: str, event_bus_arn: Optional[str] self, target_arn: str, event_bus_arn: Optional[str]
) -> List[Rule]: ) -> List[Rule]:
@ -1198,7 +1198,7 @@ class EventsBackend(BaseBackend):
return matching_rules return matching_rules
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_rules( def list_rules(
self, prefix: Optional[str] = None, event_bus_arn: Optional[str] = None self, prefix: Optional[str] = None, event_bus_arn: Optional[str] = None
) -> List[Rule]: ) -> List[Rule]:

View File

@ -331,7 +331,7 @@ class GlueBackend(BaseBackend):
def get_crawlers(self) -> List["FakeCrawler"]: def get_crawlers(self) -> List["FakeCrawler"]:
return [self.crawlers[key] for key in self.crawlers] if self.crawlers else [] return [self.crawlers[key] for key in self.crawlers] if self.crawlers else []
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_crawlers(self) -> List["FakeCrawler"]: def list_crawlers(self) -> List["FakeCrawler"]:
return [crawler for _, crawler in self.crawlers.items()] return [crawler for _, crawler in self.crawlers.items()]
@ -406,7 +406,7 @@ class GlueBackend(BaseBackend):
except KeyError: except KeyError:
raise JobNotFoundException(name) raise JobNotFoundException(name)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def get_jobs(self) -> List["FakeJob"]: def get_jobs(self) -> List["FakeJob"]:
return [job for _, job in self.jobs.items()] return [job for _, job in self.jobs.items()]
@ -418,7 +418,7 @@ class GlueBackend(BaseBackend):
job = self.get_job(name) job = self.get_job(name)
return job.get_job_run(run_id) return job.get_job_run(run_id)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_jobs(self) -> List["FakeJob"]: def list_jobs(self) -> List["FakeJob"]:
return [job for _, job in self.jobs.items()] return [job for _, job in self.jobs.items()]
@ -829,7 +829,7 @@ class GlueBackend(BaseBackend):
except KeyError: except KeyError:
raise SessionNotFoundException(session_id) raise SessionNotFoundException(session_id)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_sessions(self) -> List["FakeSession"]: def list_sessions(self) -> List["FakeSession"]:
return [session for _, session in self.sessions.items()] return [session for _, session in self.sessions.items()]
@ -884,7 +884,7 @@ class GlueBackend(BaseBackend):
trigger = self.get_trigger(name) trigger = self.get_trigger(name)
trigger.stop_trigger() trigger.stop_trigger()
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def get_triggers(self, dependent_job_name: str) -> List["FakeTrigger"]: def get_triggers(self, dependent_job_name: str) -> List["FakeTrigger"]:
if dependent_job_name: if dependent_job_name:
triggers = [] triggers = []
@ -898,7 +898,7 @@ class GlueBackend(BaseBackend):
return list(self.triggers.values()) return list(self.triggers.values())
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_triggers(self, dependent_job_name: str) -> List["FakeTrigger"]: def list_triggers(self, dependent_job_name: str) -> List["FakeTrigger"]:
if dependent_job_name: if dependent_job_name:
triggers = [] triggers = []

View File

@ -4,7 +4,14 @@ from typing import Any, Dict, List
from moto.core.common_types import TYPE_RESPONSE from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import FakeCrawler, FakeJob, FakeSession, GlueBackend, glue_backends from .models import (
FakeCrawler,
FakeJob,
FakeSession,
FakeTrigger,
GlueBackend,
glue_backends,
)
class GlueResponse(BaseResponse): class GlueResponse(BaseResponse):
@ -417,7 +424,7 @@ class GlueResponse(BaseResponse):
return [job.get_name() for job in jobs if self.is_tags_match(job.arn, tags)] return [job.get_name() for job in jobs if self.is_tags_match(job.arn, tags)]
def filter_triggers_by_tags( def filter_triggers_by_tags(
self, triggers: List[FakeJob], tags: Dict[str, str] self, triggers: List[FakeTrigger], tags: Dict[str, str]
) -> List[str]: ) -> List[str]:
if not tags: if not tags:
return [trigger.get_name() for trigger in triggers] return [trigger.get_name() for trigger in triggers]

View File

@ -301,7 +301,7 @@ class IdentityStoreBackend(BaseBackend):
return [m._asdict() for m in identity_store.groups.values()] return [m._asdict() for m in identity_store.groups.values()]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore @paginate(pagination_model=PAGINATION_MODEL)
def list_users( def list_users(
self, identity_store_id: str, filters: List[Dict[str, str]] self, identity_store_id: str, filters: List[Dict[str, str]]
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:

View File

@ -157,9 +157,7 @@ class KinesisResponse(BaseResponse):
limit=max_results, limit=max_results,
next_token=next_token, next_token=next_token,
) )
res = {"Shards": shards} res = {"Shards": shards, "NextToken": token}
if token:
res["NextToken"] = token
return json.dumps(res) return json.dumps(res)
def update_shard_count(self) -> str: def update_shard_count(self) -> str:

View File

@ -537,7 +537,7 @@ class OrganizationsBackend(BaseBackend):
next_token = str(len(accounts_resp)) next_token = str(len(accounts_resp))
return dict(CreateAccountStatuses=accounts_resp, NextToken=next_token) return dict(CreateAccountStatuses=accounts_resp, NextToken=next_token)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_accounts(self) -> List[FakeAccount]: def list_accounts(self) -> List[FakeAccount]:
accounts = [account.describe() for account in self.accounts] accounts = [account.describe() for account in self.accounts]
return sorted(accounts, key=lambda x: x["JoinedTimestamp"]) # type: ignore return sorted(accounts, key=lambda x: x["JoinedTimestamp"]) # type: ignore

View File

@ -251,7 +251,7 @@ class PanoramaBackend(BaseBackend):
raise ValidationError(f"Device {device_id} not found") raise ValidationError(f"Device {device_id} not found")
return device return device
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_devices( def list_devices(
self, self,
device_aggregated_status_filter: str, device_aggregated_status_filter: str,

View File

@ -898,9 +898,8 @@ class Route53Backend(BaseBackend):
from moto.logs import logs_backends # pylint: disable=import-outside-toplevel from moto.logs import logs_backends # pylint: disable=import-outside-toplevel
response = logs_backends[self.account_id][region].describe_log_groups() log_groups = logs_backends[self.account_id][region].describe_log_groups()
log_groups = response[0] if response else [] for entry in log_groups[0] if log_groups else []:
for entry in log_groups: # type: ignore
if log_group_arn == entry["arn"]: if log_group_arn == entry["arn"]:
break break
else: else:
@ -935,7 +934,7 @@ class Route53Backend(BaseBackend):
raise NoSuchQueryLoggingConfig() raise NoSuchQueryLoggingConfig()
return self.query_logging_configs[query_logging_config_id] return self.query_logging_configs[query_logging_config_id]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_query_logging_configs( def list_query_logging_configs(
self, hosted_zone_id: Optional[str] = None self, hosted_zone_id: Optional[str] = None
) -> List[QueryLoggingConfig]: ) -> List[QueryLoggingConfig]:

View File

@ -814,7 +814,7 @@ class Route53ResolverBackend(BaseBackend):
return True return True
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_endpoints(self, filters: Any) -> List[ResolverEndpoint]: def list_resolver_endpoints(self, filters: Any) -> List[ResolverEndpoint]:
if not filters: if not filters:
filters = [] filters = []
@ -828,7 +828,7 @@ class Route53ResolverBackend(BaseBackend):
endpoints.append(endpoint) endpoints.append(endpoint)
return endpoints return endpoints
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_rules(self, filters: Any) -> List[ResolverRule]: def list_resolver_rules(self, filters: Any) -> List[ResolverRule]:
if not filters: if not filters:
filters = [] filters = []
@ -842,7 +842,7 @@ class Route53ResolverBackend(BaseBackend):
rules.append(rule) rules.append(rule)
return rules return rules
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_rule_associations( def list_resolver_rule_associations(
self, filters: Any self, filters: Any
) -> List[ResolverRuleAssociation]: ) -> List[ResolverRuleAssociation]:
@ -872,12 +872,10 @@ class Route53ResolverBackend(BaseBackend):
f"Resolver endpoint with ID '{resource_arn}' does not exist" f"Resolver endpoint with ID '{resource_arn}' does not exist"
) )
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_tags_for_resource( def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]:
self, resource_arn: str
) -> Optional[List[Dict[str, str]]]:
self._matched_arn(resource_arn) self._matched_arn(resource_arn)
return self.tagger.list_tags_for_resource(resource_arn).get("Tags") return self.tagger.list_tags_for_resource(resource_arn)["Tags"]
def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None: def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
self._matched_arn(resource_arn) self._matched_arn(resource_arn)

View File

@ -1971,7 +1971,7 @@ class SageMakerModelBackend(BaseBackend):
resource.tags.extend(tags) resource.tags.extend(tags)
return resource.tags return resource.tags
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_tags(self, arn: str) -> List[Dict[str, str]]: def list_tags(self, arn: str) -> List[Dict[str, str]]:
resource = self._get_resource_from_arn(arn) resource = self._get_resource_from_arn(arn)
return resource.tags return resource.tags
@ -1980,7 +1980,7 @@ class SageMakerModelBackend(BaseBackend):
resource = self._get_resource_from_arn(arn) resource = self._get_resource_from_arn(arn)
resource.tags = [tag for tag in resource.tags if tag["Key"] not in tag_keys] resource.tags = [tag for tag in resource.tags if tag["Key"] not in tag_keys]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_experiments(self) -> List["FakeExperiment"]: def list_experiments(self) -> List["FakeExperiment"]:
return list(self.experiments.values()) return list(self.experiments.values())
@ -2150,7 +2150,7 @@ class SageMakerModelBackend(BaseBackend):
message=f"Could not find trial configuration '{arn}'." message=f"Could not find trial configuration '{arn}'."
) )
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_trials( def list_trials(
self, self,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
@ -2213,7 +2213,7 @@ class SageMakerModelBackend(BaseBackend):
) -> None: ) -> None:
self.trial_components[trial_component_name].update(details_json) self.trial_components[trial_component_name].update(details_json)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_trial_components( def list_trial_components(
self, trial_name: Optional[str] = None self, trial_name: Optional[str] = None
) -> List["FakeTrialComponent"]: ) -> List["FakeTrialComponent"]:
@ -2337,14 +2337,14 @@ class SageMakerModelBackend(BaseBackend):
raise ValidationError(message=message) raise ValidationError(message=message)
del self.notebook_instances[notebook_instance_name] del self.notebook_instances[notebook_instance_name]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_notebook_instances( def list_notebook_instances(
self, self,
sort_by: str, sort_by: str,
sort_order: str, sort_order: str,
name_contains: Optional[str], name_contains: Optional[str],
status: Optional[str], status: Optional[str],
) -> Iterable[FakeSagemakerNotebookInstance]: ) -> List[FakeSagemakerNotebookInstance]:
""" """
The following parameters are not yet implemented: The following parameters are not yet implemented:
CreationTimeBefore, CreationTimeAfter, LastModifiedTimeBefore, LastModifiedTimeAfter, NotebookInstanceLifecycleConfigNameContains, DefaultCodeRepositoryContains, AdditionalCodeRepositoryEquals CreationTimeBefore, CreationTimeAfter, LastModifiedTimeBefore, LastModifiedTimeAfter, NotebookInstanceLifecycleConfigNameContains, DefaultCodeRepositoryContains, AdditionalCodeRepositoryEquals
@ -3283,8 +3283,8 @@ class SageMakerModelBackend(BaseBackend):
return True return True
raise ValueError(f"Invalid model package type: {model_package_type}") raise ValueError(f"Invalid model package type: {model_package_type}")
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_model_package_groups( # type: ignore[misc] def list_model_package_groups(
self, self,
creation_time_after: Optional[int], creation_time_after: Optional[int],
creation_time_before: Optional[int], creation_time_before: Optional[int],
@ -3346,8 +3346,8 @@ class SageMakerModelBackend(BaseBackend):
) )
return model_package_group return model_package_group
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_model_packages( # type: ignore[misc] def list_model_packages(
self, self,
creation_time_after: Optional[int], creation_time_after: Optional[int],
creation_time_before: Optional[int], creation_time_before: Optional[int],

View File

@ -322,11 +322,9 @@ class SageMakerResponse(BaseResponse):
response = { response = {
"ExperimentSummaries": experiment_summaries, "ExperimentSummaries": experiment_summaries,
"NextToken": next_token,
} }
if next_token:
response["NextToken"] = next_token
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
def delete_experiment(self) -> TYPE_RESPONSE: def delete_experiment(self) -> TYPE_RESPONSE:
@ -368,12 +366,7 @@ class SageMakerResponse(BaseResponse):
for trial_data in paged_results for trial_data in paged_results
] ]
response = { response = {"TrialSummaries": trial_summaries, "NextToken": next_token}
"TrialSummaries": trial_summaries,
}
if next_token:
response["NextToken"] = next_token
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@ -406,11 +399,9 @@ class SageMakerResponse(BaseResponse):
response = { response = {
"TrialComponentSummaries": trial_component_summaries, "TrialComponentSummaries": trial_component_summaries,
"NextToken": next_token,
} }
if next_token:
response["NextToken"] = next_token
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
def create_trial_component(self) -> TYPE_RESPONSE: def create_trial_component(self) -> TYPE_RESPONSE:

View File

@ -337,7 +337,7 @@ class SSOAdminBackend(BaseBackend):
message=f"Could not find PermissionSet with id {ps_id}" message=f"Could not find PermissionSet with id {ps_id}"
) )
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_permission_sets(self, instance_arn: str) -> List[PermissionSet]: def list_permission_sets(self, instance_arn: str) -> List[PermissionSet]:
permission_sets = [] permission_sets = []
for permission_set in self.permission_sets: for permission_set in self.permission_sets:
@ -399,7 +399,7 @@ class SSOAdminBackend(BaseBackend):
permissionset.managed_policies.append(managed_policy) permissionset.managed_policies.append(managed_policy)
permissionset.total_managed_policies_attached += 1 permissionset.total_managed_policies_attached += 1
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_managed_policies_in_permission_set( def list_managed_policies_in_permission_set(
self, self,
instance_arn: str, instance_arn: str,
@ -467,7 +467,7 @@ class SSOAdminBackend(BaseBackend):
permissionset.customer_managed_policies.append(customer_managed_policy) permissionset.customer_managed_policies.append(customer_managed_policy)
permissionset.total_managed_policies_attached += 1 permissionset.total_managed_policies_attached += 1
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_customer_managed_policy_references_in_permission_set( def list_customer_managed_policy_references_in_permission_set(
self, instance_arn: str, permission_set_arn: str self, instance_arn: str, permission_set_arn: str
) -> List[CustomerManagedPolicy]: ) -> List[CustomerManagedPolicy]:

View File

@ -1,7 +1,7 @@
import json import json
import re import re
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Pattern from typing import Any, Dict, List, Optional, Pattern
from dateutil.tz import tzlocal from dateutil.tz import tzlocal
@ -504,8 +504,8 @@ class StepFunctionBackend(BaseBackend):
self.state_machines.append(state_machine) self.state_machines.append(state_machine)
return state_machine return state_machine
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_state_machines(self) -> Iterable[StateMachine]: def list_state_machines(self) -> List[StateMachine]:
return sorted(self.state_machines, key=lambda x: x.creation_date) return sorted(self.state_machines, key=lambda x: x.creation_date)
def describe_state_machine(self, arn: str) -> StateMachine: def describe_state_machine(self, arn: str) -> StateMachine:
@ -552,10 +552,10 @@ class StepFunctionBackend(BaseBackend):
state_machine = self._get_state_machine_for_execution(execution_arn) state_machine = self._get_state_machine_for_execution(execution_arn)
return state_machine.stop_execution(execution_arn) return state_machine.stop_execution(execution_arn)
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL)
def list_executions( def list_executions(
self, state_machine_arn: str, status_filter: Optional[str] = None self, state_machine_arn: str, status_filter: Optional[str] = None
) -> Iterable[Execution]: ) -> List[Execution]:
""" """
The status of every execution is set to 'RUNNING' by default. The status of every execution is set to 'RUNNING' by default.
Set the following environment variable if you want to get a FAILED status back: Set the following environment variable if you want to get a FAILED status back:
@ -569,8 +569,7 @@ class StepFunctionBackend(BaseBackend):
if status_filter: if status_filter:
executions = list(filter(lambda e: e.status == status_filter, executions)) executions = list(filter(lambda e: e.status == status_filter, executions))
executions = sorted(executions, key=lambda x: x.start_date, reverse=True) return sorted(executions, key=lambda x: x.start_date, reverse=True)
return executions
def describe_execution(self, execution_arn: str) -> Execution: def describe_execution(self, execution_arn: str) -> Execution:
""" """

View File

@ -1,20 +1,34 @@
import inspect import inspect
from copy import deepcopy from copy import deepcopy
from functools import wraps from functools import wraps
from typing import Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar
from botocore.paginate import TokenDecoder, TokenEncoder from botocore.paginate import TokenDecoder, TokenEncoder
from moto.core.exceptions import InvalidToken from moto.core.exceptions import InvalidToken
# This should be typed using ParamSpec if TYPE_CHECKING:
# https://stackoverflow.com/a/70591060/13245310 from typing_extensions import ParamSpec, Protocol
# This currently does not work for our usecase
# I believe this could be fixed after https://github.com/python/mypy/pull/14903 is accepted P1 = ParamSpec("P1")
P2 = ParamSpec("P2")
else:
Protocol = object
T = TypeVar("T")
def paginate(pagination_model: Dict[str, Any]) -> Any: class GenericFunction(Protocol):
def pagination_decorator(func: Any) -> Any: def __call__(
self, func: "Callable[P1, List[T]]"
) -> "Callable[P2, Tuple[List[T], Optional[str]]]":
...
def paginate(pagination_model: Dict[str, Any]) -> GenericFunction:
def pagination_decorator(
func: Callable[..., List[T]]
) -> Callable[..., Tuple[List[T], Optional[str]]]:
@wraps(func) @wraps(func)
def pagination_wrapper(*args: Any, **kwargs: Any) -> Any: # type: ignore def pagination_wrapper(*args: Any, **kwargs: Any) -> Any: # type: ignore
method = func.__name__ method = func.__name__