STS: Support Chinese partitions in ARN's (#7336)

This commit is contained in:
Bert Blommers 2024-02-11 20:06:04 +00:00 committed by GitHub
parent 0ba2561539
commit 2e9b903ab6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 91 additions and 34 deletions

View File

@ -390,3 +390,11 @@ def params_sort_function(item: Tuple[str, Any]) -> Tuple[str, int, str]:
def gzip_decompress(body: bytes) -> bytes: def gzip_decompress(body: bytes) -> bytes:
return decompress(body) return decompress(body)
def get_partition_from_region(region_name: str) -> str:
# Very rough implementation
# In an ideal world we check `boto3.Session.get_partition_for_region`, but that is quite computationally heavy
if region_name.startswith("cn-"):
return "aws-cn"
return "aws"

View File

@ -16,6 +16,7 @@ from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel, CloudFormationModel from moto.core.common_models import BaseModel, CloudFormationModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.core.utils import ( from moto.core.utils import (
get_partition_from_region,
iso_8601_datetime_with_milliseconds, iso_8601_datetime_with_milliseconds,
iso_8601_datetime_without_milliseconds, iso_8601_datetime_without_milliseconds,
unix_time, unix_time,
@ -1271,8 +1272,11 @@ class Group(BaseModel):
class User(CloudFormationModel): class User(CloudFormationModel):
def __init__(self, account_id: str, name: str, path: Optional[str] = None): def __init__(
self, account_id: str, region_name: str, name: str, path: Optional[str] = None
):
self.account_id = account_id self.account_id = account_id
self.region_name = region_name
self.name = name self.name = name
self.id = random_resource_id() self.id = random_resource_id()
self.path = path if path else "/" self.path = path if path else "/"
@ -1291,7 +1295,8 @@ class User(CloudFormationModel):
@property @property
def arn(self) -> str: def arn(self) -> str:
return f"arn:aws:iam::{self.account_id}:user{self.path}{self.name}" partition = get_partition_from_region(self.region_name)
return f"arn:{partition}:iam::{self.account_id}:user{self.path}{self.name}"
@property @property
def created_iso_8601(self) -> str: def created_iso_8601(self) -> str:
@ -1515,7 +1520,9 @@ class User(CloudFormationModel):
) -> "User": ) -> "User":
properties = cloudformation_json.get("Properties", {}) properties = cloudformation_json.get("Properties", {})
path = properties.get("Path") path = properties.get("Path")
user, _ = iam_backends[account_id]["global"].create_user(resource_name, path) user, _ = iam_backends[account_id]["global"].create_user(
region_name=region_name, user_name=resource_name, path=path
)
return user return user
@classmethod @classmethod
@ -2554,6 +2561,7 @@ class IAMBackend(BaseBackend):
def create_user( def create_user(
self, self,
region_name: str,
user_name: str, user_name: str,
path: str = "/", path: str = "/",
tags: Optional[List[Dict[str, str]]] = None, tags: Optional[List[Dict[str, str]]] = None,
@ -2563,7 +2571,7 @@ class IAMBackend(BaseBackend):
"EntityAlreadyExists", f"User {user_name} already exists" "EntityAlreadyExists", f"User {user_name} already exists"
) )
user = User(self.account_id, user_name, path) user = User(self.account_id, region_name, user_name, path)
self.tagger.tag_resource(user.arn, tags or []) self.tagger.tag_resource(user.arn, tags or [])
self.users[user_name] = user self.users[user_name] = user
return user, self.tagger.list_tags_for_resource(user.arn) return user, self.tagger.list_tags_for_resource(user.arn)

View File

@ -520,7 +520,9 @@ class IamResponse(BaseResponse):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
path = self._get_param("Path") path = self._get_param("Path")
tags = self._get_multi_param("Tags.member") tags = self._get_multi_param("Tags.member")
user, user_tags = self.backend.create_user(user_name, path, tags) user, user_tags = self.backend.create_user(
self.region, user_name=user_name, path=path, tags=tags
)
template = self.response_template(USER_TEMPLATE) template = self.response_template(USER_TEMPLATE)
return template.render(action="Create", user=user, tags=user_tags["Tags"]) return template.render(action="Create", user=user, tags=user_tags["Tags"])
@ -530,7 +532,9 @@ class IamResponse(BaseResponse):
access_key_id = self.get_access_key() access_key_id = self.get_access_key()
user = self.backend.get_user_from_access_key_id(access_key_id) user = self.backend.get_user_from_access_key_id(access_key_id)
if user is None: if user is None:
user = User(self.current_account, "default_user") user = User(
self.current_account, region_name=self.region, name="default_user"
)
else: else:
user = self.backend.get_user(user_name) user = self.backend.get_user(user_name)
tags = self.backend.tagger.list_tags_for_resource(user.arn).get("Tags", []) tags = self.backend.tagger.list_tags_for_resource(user.arn).get("Tags", [])

View File

@ -7,7 +7,11 @@ import xmltodict
from moto.core.base_backend import BackendDict, BaseBackend from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel from moto.core.common_models import BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow from moto.core.utils import (
get_partition_from_region,
iso_8601_datetime_with_milliseconds,
utcnow,
)
from moto.iam.models import AccessKey, iam_backends from moto.iam.models import AccessKey, iam_backends
from moto.sts.utils import ( from moto.sts.utils import (
DEFAULT_STS_SESSION_DURATION, DEFAULT_STS_SESSION_DURATION,
@ -32,6 +36,7 @@ class AssumedRole(BaseModel):
def __init__( def __init__(
self, self,
account_id: str, account_id: str,
region_name: str,
access_key: AccessKey, access_key: AccessKey,
role_session_name: str, role_session_name: str,
role_arn: str, role_arn: str,
@ -40,6 +45,7 @@ class AssumedRole(BaseModel):
external_id: str, external_id: str,
): ):
self.account_id = account_id self.account_id = account_id
self.region_name = region_name
self.session_name = role_session_name self.session_name = role_session_name
self.role_arn = role_arn self.role_arn = role_arn
self.policy = policy self.policy = policy
@ -66,7 +72,8 @@ class AssumedRole(BaseModel):
@property @property
def arn(self) -> str: def arn(self) -> str:
return f"arn:aws:sts::{self.account_id}:assumed-role/{self.role_arn.split('/')[-1]}/{self.session_name}" partition = get_partition_from_region(self.region_name)
return f"arn:{partition}:sts::{self.account_id}:assumed-role/{self.role_arn.split('/')[-1]}/{self.session_name}"
class STSBackend(BaseBackend): class STSBackend(BaseBackend):
@ -91,6 +98,7 @@ class STSBackend(BaseBackend):
def assume_role( def assume_role(
self, self,
region_name: str,
role_session_name: str, role_session_name: str,
role_arn: str, role_arn: str,
policy: str, policy: str,
@ -102,13 +110,14 @@ class STSBackend(BaseBackend):
""" """
account_id, access_key = self._create_access_key(role=role_arn) account_id, access_key = self._create_access_key(role=role_arn)
role = AssumedRole( role = AssumedRole(
account_id, account_id=account_id,
access_key, region_name=region_name,
role_session_name, access_key=access_key,
role_arn, role_session_name=role_session_name,
policy, role_arn=role_arn,
duration, policy=policy,
external_id, duration=duration,
external_id=external_id,
) )
access_key.role_arn = role_arn access_key.role_arn = role_arn
account_backend = sts_backends[account_id]["global"] account_backend = sts_backends[account_id]["global"]
@ -166,6 +175,7 @@ class STSBackend(BaseBackend):
account_id, access_key = self._create_access_key(role=target_role) # type: ignore account_id, access_key = self._create_access_key(role=target_role) # type: ignore
kwargs["account_id"] = account_id kwargs["account_id"] = account_id
kwargs["region_name"] = self.region_name
kwargs["access_key"] = access_key kwargs["access_key"] = access_key
kwargs["external_id"] = None kwargs["external_id"] = None
@ -174,7 +184,9 @@ class STSBackend(BaseBackend):
self.assumed_roles.append(role) self.assumed_roles.append(role)
return role return role
def get_caller_identity(self, access_key_id: str) -> Tuple[str, str, str]: def get_caller_identity(
self, access_key_id: str, region: str
) -> Tuple[str, str, str]:
assumed_role = self.get_assumed_role_from_access_key(access_key_id) assumed_role = self.get_assumed_role_from_access_key(access_key_id)
if assumed_role: if assumed_role:
return assumed_role.user_id, assumed_role.arn, assumed_role.account_id return assumed_role.user_id, assumed_role.arn, assumed_role.account_id
@ -185,8 +197,9 @@ class STSBackend(BaseBackend):
return user.id, user.arn, user.account_id return user.id, user.arn, user.account_id
# Default values in case the request does not use valid credentials generated by moto # Default values in case the request does not use valid credentials generated by moto
partition = get_partition_from_region(region)
user_id = "AKIAIOSFODNN7EXAMPLE" user_id = "AKIAIOSFODNN7EXAMPLE"
arn = f"arn:aws:sts::{self.account_id}:user/moto" arn = f"arn:{partition}:sts::{self.account_id}:user/moto"
return user_id, arn, self.account_id return user_id, arn, self.account_id
def _create_access_key(self, role: str) -> Tuple[str, AccessKey]: def _create_access_key(self, role: str) -> Tuple[str, AccessKey]:

View File

@ -51,6 +51,7 @@ class TokenResponse(BaseResponse):
external_id = self.querystring.get("ExternalId", [None])[0] external_id = self.querystring.get("ExternalId", [None])[0]
role = self.backend.assume_role( role = self.backend.assume_role(
region_name=self.region,
role_session_name=role_session_name, role_session_name=role_session_name,
role_arn=role_arn, role_arn=role_arn,
policy=policy, policy=policy,
@ -69,6 +70,7 @@ class TokenResponse(BaseResponse):
external_id = self.querystring.get("ExternalId", [None])[0] external_id = self.querystring.get("ExternalId", [None])[0]
role = self.backend.assume_role_with_web_identity( role = self.backend.assume_role_with_web_identity(
region_name=self.region,
role_session_name=role_session_name, role_session_name=role_session_name,
role_arn=role_arn, role_arn=role_arn,
policy=policy, policy=policy,
@ -95,7 +97,9 @@ class TokenResponse(BaseResponse):
template = self.response_template(GET_CALLER_IDENTITY_RESPONSE) template = self.response_template(GET_CALLER_IDENTITY_RESPONSE)
access_key_id = self.get_access_key() access_key_id = self.get_access_key()
user_id, arn, account_id = self.backend.get_caller_identity(access_key_id) user_id, arn, account_id = self.backend.get_caller_identity(
access_key_id, self.region
)
return template.render(account_id=account_id, user_id=user_id, arn=arn) return template.render(account_id=account_id, user_id=user_id, arn=arn)

View File

@ -66,9 +66,12 @@ def test_get_federation_token_boto3():
@freeze_time("2012-01-01 12:00:00") @freeze_time("2012-01-01 12:00:00")
@mock_aws @mock_aws
def test_assume_role(): @pytest.mark.parametrize(
client = boto3.client("sts", region_name="us-east-1") "region,partition", [("us-east-1", "aws"), ("cn-north-1", "aws-cn")]
iam_client = boto3.client("iam", region_name="us-east-1") )
def test_assume_role(region, partition):
client = boto3.client("sts", region_name=region)
iam_client = boto3.client("iam", region_name=region)
session_name = "session-name" session_name = "session-name"
policy = json.dumps( policy = json.dumps(
@ -114,7 +117,7 @@ def test_assume_role():
assert len(credentials["SecretAccessKey"]) == 40 assert len(credentials["SecretAccessKey"]) == 40
assert assume_role_response["AssumedRoleUser"]["Arn"] == ( assert assume_role_response["AssumedRoleUser"]["Arn"] == (
f"arn:aws:sts::{ACCOUNT_ID}:assumed-role/{role_name}/{session_name}" f"arn:{partition}:sts::{ACCOUNT_ID}:assumed-role/{role_name}/{session_name}"
) )
assert assume_role_response["AssumedRoleUser"]["AssumedRoleId"].startswith("AROA") assert assume_role_response["AssumedRoleUser"]["AssumedRoleId"].startswith("AROA")
assert ( assert (
@ -583,8 +586,11 @@ def test_assume_role_with_saml_when_saml_attribute_not_provided():
@freeze_time("2012-01-01 12:00:00") @freeze_time("2012-01-01 12:00:00")
@mock_aws @mock_aws
def test_assume_role_with_web_identity_boto3(): @pytest.mark.parametrize(
client = boto3.client("sts", region_name="us-east-1") "region,partition", [("us-east-1", "aws"), ("cn-north-1", "aws-cn")]
)
def test_assume_role_with_web_identity_boto3(region, partition):
client = boto3.client("sts", region_name=region)
policy = json.dumps( policy = json.dumps(
{ {
@ -625,30 +631,37 @@ def test_assume_role_with_web_identity_boto3():
assert len(creds["SecretAccessKey"]) == 40 assert len(creds["SecretAccessKey"]) == 40
assert user["Arn"] == ( assert user["Arn"] == (
f"arn:aws:sts::{ACCOUNT_ID}:assumed-role/{role_name}/{session_name}" f"arn:{partition}:sts::{ACCOUNT_ID}:assumed-role/{role_name}/{session_name}"
) )
assert "session-name" in user["AssumedRoleId"] assert "session-name" in user["AssumedRoleId"]
@mock_aws @mock_aws
def test_get_caller_identity_with_default_credentials(): @pytest.mark.parametrize(
identity = boto3.client("sts", region_name="us-east-1").get_caller_identity() "region,partition", [("us-east-1", "aws"), ("cn-north-1", "aws-cn")]
)
def test_get_caller_identity_with_default_credentials(region, partition):
identity = boto3.client("sts", region_name=region).get_caller_identity()
assert identity["Arn"] == f"arn:aws:sts::{ACCOUNT_ID}:user/moto" assert identity["Arn"] == f"arn:{partition}:sts::{ACCOUNT_ID}:user/moto"
assert identity["UserId"] == "AKIAIOSFODNN7EXAMPLE" assert identity["UserId"] == "AKIAIOSFODNN7EXAMPLE"
assert identity["Account"] == str(ACCOUNT_ID) assert identity["Account"] == str(ACCOUNT_ID)
@mock_aws @mock_aws
def test_get_caller_identity_with_iam_user_credentials(): @pytest.mark.parametrize(
iam_client = boto3.client("iam", region_name="us-east-1") "region,partition", [("us-east-1", "aws"), ("cn-north-1", "aws-cn")]
)
def test_get_caller_identity_with_iam_user_credentials(region, partition):
iam_client = boto3.client("iam", region_name=region)
iam_user_name = "new-user" iam_user_name = "new-user"
iam_user = iam_client.create_user(UserName=iam_user_name)["User"] iam_user = iam_client.create_user(UserName=iam_user_name)["User"]
assert iam_user["Arn"] == f"arn:{partition}:iam::123456789012:user/new-user"
access_key = iam_client.create_access_key(UserName=iam_user_name)["AccessKey"] access_key = iam_client.create_access_key(UserName=iam_user_name)["AccessKey"]
identity = boto3.client( identity = boto3.client(
"sts", "sts",
region_name="us-east-1", region_name=region,
aws_access_key_id=access_key["AccessKeyId"], aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"], aws_secret_access_key=access_key["SecretAccessKey"],
).get_caller_identity() ).get_caller_identity()
@ -659,9 +672,12 @@ def test_get_caller_identity_with_iam_user_credentials():
@mock_aws @mock_aws
def test_get_caller_identity_with_assumed_role_credentials(): @pytest.mark.parametrize(
iam_client = boto3.client("iam", region_name="us-east-1") "region,partition", [("us-east-1", "aws"), ("cn-north-1", "aws-cn")]
sts_client = boto3.client("sts", region_name="us-east-1") )
def test_get_caller_identity_with_assumed_role_credentials(region, partition):
iam_client = boto3.client("iam", region_name=region)
sts_client = boto3.client("sts", region_name=region)
iam_role_name = "new-user" iam_role_name = "new-user"
trust_policy_document = { trust_policy_document = {
"Version": "2012-10-17", "Version": "2012-10-17",
@ -679,6 +695,10 @@ def test_get_caller_identity_with_assumed_role_credentials():
assumed_role = sts_client.assume_role( assumed_role = sts_client.assume_role(
RoleArn=iam_role_arn, RoleSessionName=session_name RoleArn=iam_role_arn, RoleSessionName=session_name
) )
assert (
assumed_role["AssumedRoleUser"]["Arn"]
== f"arn:{partition}:sts::123456789012:assumed-role/new-user/new-session"
)
access_key = assumed_role["Credentials"] access_key = assumed_role["Credentials"]
identity = boto3.client( identity = boto3.client(