Improve typing for IAM (#7091)

This commit is contained in:
tungol 2023-12-05 12:55:04 -08:00 committed by GitHub
parent 16b9f319c5
commit ff5256d8e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 133 additions and 108 deletions

View File

@ -263,7 +263,7 @@ __version__ = "4.2.12.dev"
try:
# Need to monkey-patch botocore requests back to underlying urllib3 classes
from botocore.awsrequest import (
from botocore.awsrequest import ( # type: ignore[attr-defined]
HTTPConnection,
HTTPConnectionPool,
HTTPSConnectionPool,

View File

@ -70,7 +70,7 @@ DEFAULT_PAGE_SIZE = 100
CONFIG_RULE_PAGE_SIZE = 25
# Map the Config resource type to a backend:
RESOURCE_MAP: Dict[str, ConfigQueryModel] = {
RESOURCE_MAP: Dict[str, ConfigQueryModel[Any]] = {
"AWS::S3::Bucket": s3_config_query,
"AWS::S3::AccountPublicAccessBlock": s3_account_public_access_block_query,
"AWS::IAM::Role": role_config_query,

View File

@ -1,6 +1,6 @@
from collections import defaultdict
from io import BytesIO
from typing import Any, Callable, Dict, List, Pattern, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
from botocore.awsrequest import AWSResponse
@ -38,7 +38,9 @@ class BotocoreStubber:
matchers = self.methods[method]
matchers.append((pattern, response))
def __call__(self, event_name: str, request: Any, **kwargs: Any) -> AWSResponse:
def __call__(
self, event_name: str, request: Any, **kwargs: Any
) -> Optional[AWSResponse]:
if not self.enabled:
return None
@ -70,6 +72,6 @@ class BotocoreStubber:
headers = e.get_headers() # type: ignore[assignment]
body = e.get_body()
raw_response = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, raw_response)
response = AWSResponse(request.url, status, headers, raw_response) # type: ignore[arg-type]
return response

View File

@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Generic, List, Optional, Tuple
from .base_backend import InstanceTrackerMeta
from .base_backend import SERVICE_BACKEND, BackendDict, InstanceTrackerMeta
class BaseModel(metaclass=InstanceTrackerMeta):
@ -94,8 +94,8 @@ class CloudFormationModel(BaseModel):
return True
class ConfigQueryModel:
def __init__(self, backends: Any):
class ConfigQueryModel(Generic[SERVICE_BACKEND]):
def __init__(self, backends: BackendDict[SERVICE_BACKEND]):
"""Inits based on the resource type's backends (1 for each region if applicable)"""
self.backends = backends

View File

@ -293,7 +293,7 @@ def patch_client(client: botocore.client.BaseClient) -> None:
if isinstance(client, botocore.client.BaseClient):
# Check if our event handler was already registered
try:
event_emitter = client._ruleset_resolver._event_emitter._emitter
event_emitter = client._ruleset_resolver._event_emitter._emitter # type: ignore[attr-defined]
all_handlers = event_emitter._handlers._root["children"]
handler_trie = list(all_handlers["before-send"].values())[1]
handlers_list = handler_trie.first + handler_trie.middle + handler_trie.last

View File

@ -257,7 +257,7 @@ class IAMRequestBase(object, metaclass=ABCMeta):
raise NotImplementedError()
@abstractmethod
def _create_auth(self, credentials: Credentials) -> SigV4Auth: # type: ignore[misc]
def _create_auth(self, credentials: Credentials) -> SigV4Auth:
raise NotImplementedError()
@staticmethod

View File

@ -5,10 +5,10 @@ import boto3
from moto.core.common_models import ConfigQueryModel
from moto.core.exceptions import InvalidNextTokenException
from moto.iam import iam_backends
from moto.iam.models import IAMBackend, iam_backends
class RoleConfigQuery(ConfigQueryModel):
class RoleConfigQuery(ConfigQueryModel[IAMBackend]):
def list_config_service_resources(
self,
account_id: str,
@ -32,26 +32,27 @@ class RoleConfigQuery(ConfigQueryModel):
return [], None
# Filter by resource name or ids
if resource_name or resource_ids:
filtered_roles = []
# resource_name takes precedence over resource_ids
if resource_name:
for role in role_list:
if role.name == resource_name:
filtered_roles = [role]
break
# but if both are passed, it must be a subset
if filtered_roles and resource_ids:
if filtered_roles[0].id not in resource_ids:
return [], None
else:
for role in role_list:
if role.id in resource_ids: # type: ignore[operator]
filtered_roles.append(role)
# resource_name takes precedence over resource_ids
filtered_roles = []
if resource_name:
for role in role_list:
if role.name == resource_name:
filtered_roles = [role]
break
# but if both are passed, it must be a subset
if filtered_roles and resource_ids:
if filtered_roles[0].id not in resource_ids:
return [], None
# Filtered roles are now the subject for the listing
role_list = filtered_roles
elif resource_ids:
for role in role_list:
if role.id in resource_ids:
filtered_roles.append(role)
role_list = filtered_roles
if aggregator:
# IAM is a little special; Roles are created in us-east-1 (which AWS calls the "global" region)
# However, the resource will return in the aggregator (in duplicate) for each region in the aggregator
@ -77,7 +78,6 @@ class RoleConfigQuery(ConfigQueryModel):
duplicate_role_list.append(
{
"_id": f"{role.id}{region}", # this is only for sorting, isn't returned outside of this function
"type": "AWS::IAM::Role",
"id": role.id,
"name": role.name,
"region": region,
@ -89,7 +89,10 @@ class RoleConfigQuery(ConfigQueryModel):
else:
# Non-aggregated queries are in the else block, and we can treat these like a normal config resource
# Pagination logic, sort by role id
sorted_roles = sorted(role_list, key=lambda role: role.id) # type: ignore[attr-defined]
sorted_roles = [
{"_id": role.id, "id": role.id, "name": role.name, "region": "global"}
for role in sorted(role_list, key=lambda role: role.id)
]
new_token = None
@ -102,27 +105,27 @@ class RoleConfigQuery(ConfigQueryModel):
start = next(
index
for (index, r) in enumerate(sorted_roles)
if next_token == (r["_id"] if aggregator else r.id) # type: ignore[attr-defined]
if next_token == r["_id"]
)
except StopIteration:
raise InvalidNextTokenException()
# Get the list of items to collect:
role_list = sorted_roles[start : (start + limit)]
collected_role_list = sorted_roles[start : (start + limit)]
if len(sorted_roles) > (start + limit):
record = sorted_roles[start + limit]
new_token = record["_id"] if aggregator else record.id # type: ignore[attr-defined]
new_token = record["_id"]
return (
[
{
"type": "AWS::IAM::Role",
"id": role["id"] if aggregator else role.id, # type: ignore[attr-defined]
"name": role["name"] if aggregator else role.name, # type: ignore[attr-defined]
"region": role["region"] if aggregator else "global",
"id": role["id"],
"name": role["name"],
"region": role["region"],
}
for role in role_list
for role in collected_role_list
],
new_token,
)
@ -136,7 +139,7 @@ class RoleConfigQuery(ConfigQueryModel):
resource_region: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
role = self.backends[account_id]["global"].roles.get(resource_id, {})
role = self.backends[account_id]["global"].roles.get(resource_id)
if not role:
return None
@ -158,7 +161,7 @@ class RoleConfigQuery(ConfigQueryModel):
return config_data
class PolicyConfigQuery(ConfigQueryModel):
class PolicyConfigQuery(ConfigQueryModel[IAMBackend]):
def list_config_service_resources(
self,
account_id: str,
@ -194,27 +197,27 @@ class PolicyConfigQuery(ConfigQueryModel):
return [], None
# Filter by resource name or ids
if resource_name or resource_ids:
filtered_policies = []
# resource_name takes precedence over resource_ids
if resource_name:
for policy in policy_list:
if policy.name == resource_name:
filtered_policies = [policy]
break
# but if both are passed, it must be a subset
if filtered_policies and resource_ids:
if filtered_policies[0].id not in resource_ids:
return [], None
else:
for policy in policy_list:
if policy.id in resource_ids: # type: ignore[operator]
filtered_policies.append(policy)
# resource_name takes precedence over resource_ids
filtered_policies = []
if resource_name:
for policy in policy_list:
if policy.name == resource_name:
filtered_policies = [policy]
break
# but if both are passed, it must be a subset
if filtered_policies and resource_ids:
if filtered_policies[0].id not in resource_ids:
return [], None
# Filtered roles are now the subject for the listing
policy_list = filtered_policies
elif resource_ids:
for policy in policy_list:
if policy.id in resource_ids:
filtered_policies.append(policy)
policy_list = filtered_policies
if aggregator:
# IAM is a little special; Policies are created in us-east-1 (which AWS calls the "global" region)
# However, the resource will return in the aggregator (in duplicate) for each region in the aggregator
@ -240,7 +243,6 @@ class PolicyConfigQuery(ConfigQueryModel):
duplicate_policy_list.append(
{
"_id": f"{policy.id}{region}", # this is only for sorting, isn't returned outside of this function
"type": "AWS::IAM::Policy",
"id": policy.id,
"name": policy.name,
"region": region,
@ -255,7 +257,15 @@ class PolicyConfigQuery(ConfigQueryModel):
else:
# Non-aggregated queries are in the else block, and we can treat these like a normal config resource
# Pagination logic, sort by role id
sorted_policies = sorted(policy_list, key=lambda role: role.id) # type: ignore[attr-defined]
sorted_policies = [
{
"_id": policy.id,
"id": policy.id,
"name": policy.name,
"region": "global",
}
for policy in sorted(policy_list, key=lambda role: role.id)
]
new_token = None
@ -268,27 +278,27 @@ class PolicyConfigQuery(ConfigQueryModel):
start = next(
index
for (index, p) in enumerate(sorted_policies)
if next_token == (p["_id"] if aggregator else p.id) # type: ignore[attr-defined]
if next_token == p["_id"]
)
except StopIteration:
raise InvalidNextTokenException()
# Get the list of items to collect:
policy_list = sorted_policies[start : (start + limit)]
collected_policy_list = sorted_policies[start : (start + limit)]
if len(sorted_policies) > (start + limit):
record = sorted_policies[start + limit]
new_token = record["_id"] if aggregator else record.id # type: ignore[attr-defined]
new_token = record["_id"]
return (
[
{
"type": "AWS::IAM::Policy",
"id": policy["id"] if aggregator else policy.id, # type: ignore[attr-defined]
"name": policy["name"] if aggregator else policy.name, # type: ignore[attr-defined]
"region": policy["region"] if aggregator else "global",
"id": policy["id"],
"name": policy["name"],
"region": policy["region"],
}
for policy in policy_list
for policy in collected_policy_list
],
new_token,
)

View File

@ -114,7 +114,7 @@ class MFADevice:
@property
def enabled_iso_8601(self) -> str:
return iso_8601_datetime_without_milliseconds(self.enable_date) # type: ignore[return-value]
return iso_8601_datetime_without_milliseconds(self.enable_date)
class VirtualMfaDevice:
@ -177,7 +177,11 @@ class Policy(CloudFormationModel):
self.next_version_num = 2
self.versions = [
PolicyVersion(
self.arn, document, True, self.default_version_id, update_date # type: ignore
self.arn, # type: ignore[attr-defined]
document,
True,
self.default_version_id,
update_date,
)
]
@ -243,7 +247,7 @@ class OpenIDConnectProvider(BaseModel):
@property
def created_iso_8601(self) -> str:
return iso_8601_datetime_without_milliseconds(self.create_date) # type: ignore[return-value]
return iso_8601_datetime_without_milliseconds(self.create_date)
def _validate(
self, url: str, thumbprint_list: List[str], client_id_list: List[str]
@ -315,7 +319,7 @@ class PolicyVersion:
def __init__(
self,
policy_arn: str,
document: str,
document: Optional[str],
is_default: bool = False,
version_id: str = "v1",
create_date: Optional[datetime] = None,
@ -343,7 +347,7 @@ class ManagedPolicy(Policy, CloudFormationModel):
def attach_to(self, obj: Union["Role", "Group", "User"]) -> None:
self.attachment_count += 1
obj.managed_policies[self.arn] = self # type: ignore[assignment]
obj.managed_policies[self.arn] = self
def detach_from(self, obj: Union["Role", "Group", "User"]) -> None:
self.attachment_count -= 1
@ -412,7 +416,7 @@ class ManagedPolicy(Policy, CloudFormationModel):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
@ -449,7 +453,7 @@ class ManagedPolicy(Policy, CloudFormationModel):
return policy
def __eq__(self, other: Any) -> bool:
return self.arn == other.arn
return self.arn == other.arn # type: ignore[no-any-return]
def __hash__(self) -> int:
return self.arn.__hash__()
@ -532,7 +536,7 @@ class InlinePolicy(CloudFormationModel):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
@ -556,9 +560,9 @@ class InlinePolicy(CloudFormationModel):
@classmethod
def update_from_cloudformation_json( # type: ignore[misc]
cls,
original_resource: Any,
original_resource: "InlinePolicy",
new_resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
) -> "InlinePolicy":
@ -601,7 +605,7 @@ class InlinePolicy(CloudFormationModel):
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
) -> None:
@ -705,7 +709,7 @@ class Role(CloudFormationModel):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
@ -736,7 +740,7 @@ class Role(CloudFormationModel):
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
) -> None:
@ -850,8 +854,8 @@ class Role(CloudFormationModel):
return self.arn
raise UnformattedGetAttTemplateException()
def get_tags(self) -> List[str]:
return [self.tags[tag] for tag in self.tags] # type: ignore
def get_tags(self) -> List[Dict[str, str]]:
return [self.tags[tag] for tag in self.tags]
@property
def description_escaped(self) -> str:
@ -938,7 +942,7 @@ class InstanceProfile(CloudFormationModel):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
@ -956,7 +960,7 @@ class InstanceProfile(CloudFormationModel):
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
) -> None:
@ -1060,7 +1064,7 @@ class SigningCertificate(BaseModel):
@property
def uploaded_iso_8601(self) -> str:
return iso_8601_datetime_without_milliseconds(self.upload_date) # type: ignore
return iso_8601_datetime_without_milliseconds(self.upload_date)
class AccessKeyLastUsed:
@ -1071,7 +1075,7 @@ class AccessKeyLastUsed:
@property
def timestamp(self) -> str:
return iso_8601_datetime_without_milliseconds(self._timestamp) # type: ignore
return iso_8601_datetime_without_milliseconds(self._timestamp)
def strftime(self, date_format: str) -> str:
return self._timestamp.strftime(date_format)
@ -1105,7 +1109,7 @@ class AccessKey(CloudFormationModel):
@property
def created_iso_8601(self) -> str:
return iso_8601_datetime_without_milliseconds(self.create_date) # type: ignore
return iso_8601_datetime_without_milliseconds(self.create_date)
@classmethod
def has_cfn_attr(cls, attr: str) -> bool:
@ -1130,7 +1134,7 @@ class AccessKey(CloudFormationModel):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
@ -1146,9 +1150,9 @@ class AccessKey(CloudFormationModel):
@classmethod
def update_from_cloudformation_json( # type: ignore[misc]
cls,
original_resource: Any,
original_resource: "AccessKey",
new_resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
) -> "AccessKey":
@ -1170,14 +1174,16 @@ class AccessKey(CloudFormationModel):
properties = cloudformation_json.get("Properties", {})
status = properties.get("Status")
return iam_backends[account_id]["global"].update_access_key(
original_resource.user_name, original_resource.access_key_id, status
original_resource.user_name, # type: ignore[arg-type]
original_resource.access_key_id,
status,
)
@classmethod
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
) -> None:
@ -1209,7 +1215,7 @@ class SshPublicKey(BaseModel):
@property
def uploaded_iso_8601(self) -> str:
return iso_8601_datetime_without_milliseconds(self.upload_date) # type: ignore
return iso_8601_datetime_without_milliseconds(self.upload_date)
class Group(BaseModel):
@ -1221,7 +1227,7 @@ class Group(BaseModel):
self.create_date = utcnow()
self.users: List[User] = []
self.managed_policies: Dict[str, str] = {}
self.managed_policies: Dict[str, ManagedPolicy] = {}
self.policies: Dict[str, str] = {}
@property
@ -1281,7 +1287,7 @@ class User(CloudFormationModel):
self.create_date = utcnow()
self.mfa_devices: Dict[str, MFADevice] = {}
self.policies: Dict[str, str] = {}
self.managed_policies: Dict[str, Dict[str, str]] = {}
self.managed_policies: Dict[str, ManagedPolicy] = {}
self.access_keys: List[AccessKey] = []
self.ssh_public_keys: List[SshPublicKey] = []
self.password: Optional[str] = None
@ -1510,7 +1516,7 @@ class User(CloudFormationModel):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
@ -1523,9 +1529,9 @@ class User(CloudFormationModel):
@classmethod
def update_from_cloudformation_json( # type: ignore[misc]
cls,
original_resource: Any,
original_resource: "User",
new_resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
) -> "User":
@ -1556,7 +1562,7 @@ class User(CloudFormationModel):
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
) -> None:
@ -1796,10 +1802,10 @@ class IAMBackend(BaseBackend):
def __init__(
self,
region_name: str,
account_id: Optional[str] = None,
account_id: str,
aws_policies: Optional[List[ManagedPolicy]] = None,
):
super().__init__(region_name=region_name, account_id=account_id) # type: ignore
super().__init__(region_name=region_name, account_id=account_id)
self.instance_profiles: Dict[str, InstanceProfile] = {}
self.roles: Dict[str, Role] = {}
self.certificates: Dict[str, Certificate] = {}
@ -1840,7 +1846,7 @@ class IAMBackend(BaseBackend):
# Do not reset these policies, as they take a long time to load
aws_policies = self.aws_managed_policies
self.__dict__ = {}
self.__init__(region_name, account_id, aws_policies) # type: ignore[misc]
IAMBackend.__init__(self, region_name, account_id, aws_policies)
def initialize_service_roles(self) -> None:
pass
@ -2828,8 +2834,9 @@ class IAMBackend(BaseBackend):
def delete_access_key_by_name(self, name: str) -> None:
key = self.access_keys[name]
try: # User may have been deleted before their access key...
user = self.get_user(key.user_name) # type: ignore
user.delete_access_key(key.access_key_id)
if key.user_name is not None:
user = self.get_user(key.user_name)
user.delete_access_key(key.access_key_id)
except NoSuchEntity:
pass
del self.access_keys[name]

View File

@ -70,7 +70,7 @@ class MotoRequestHandler:
request = AWSPreparedRequest(
method, full_url, headers, body, stream_output=False
)
request.form_data = form_data
request.form_data = form_data # type: ignore[attr-defined]
return handler(request, full_url, headers)
@ -160,7 +160,7 @@ class ProxyRequestHandler(BaseHTTPRequestHandler):
host=host,
path=path,
headers=req.headers,
body=req_body,
body=req_body, # type: ignore[arg-type]
form_data=form_data,
)
debug("\t=====RESPONSE========")

View File

@ -4,9 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple
from moto.core.common_models import ConfigQueryModel
from moto.core.exceptions import InvalidNextTokenException
from moto.s3 import s3_backends
from moto.s3.models import S3Backend
class S3ConfigQuery(ConfigQueryModel):
class S3ConfigQuery(ConfigQueryModel[S3Backend]):
def list_config_service_resources(
self,
account_id: str,
@ -103,7 +104,7 @@ class S3ConfigQuery(ConfigQueryModel):
resource_region: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
# Get the bucket:
bucket = self.backends[account_id]["global"].buckets.get(resource_id, {})
bucket = self.backends[account_id]["global"].buckets.get(resource_id)
if not bucket:
return None

View File

@ -7,9 +7,10 @@ from moto.core.common_models import ConfigQueryModel
from moto.core.exceptions import InvalidNextTokenException
from moto.core.utils import unix_time, utcnow
from moto.s3control import s3control_backends
from moto.s3control.models import S3ControlBackend
class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel):
class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel[S3ControlBackend]):
def list_config_service_resources(
self,
account_id: str,

View File

@ -12,6 +12,10 @@ packaging
build
prompt_toolkit
# type stubs that mypy doesn't install automatically
botocore-stubs
# typing_extensions is currently used for:
# Protocol (3.8+)
# ParamSpec (3.10+)