diff --git a/moto/core/utils.py b/moto/core/utils.py index 8e817dda2..7735e30a7 100644 --- a/moto/core/utils.py +++ b/moto/core/utils.py @@ -390,3 +390,11 @@ def params_sort_function(item: Tuple[str, Any]) -> Tuple[str, int, str]: def gzip_decompress(body: bytes) -> bytes: return decompress(body) + + +def get_partition_from_region(region_name: str) -> str: + # Very rough implementation + # In an ideal world we check `boto3.Session.get_partition_for_region`, but that is quite computationally heavy + if region_name.startswith("cn-"): + return "aws-cn" + return "aws" diff --git a/moto/iam/models.py b/moto/iam/models.py index f2164fe99..485c7ba5d 100644 --- a/moto/iam/models.py +++ b/moto/iam/models.py @@ -16,6 +16,7 @@ from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import BaseModel, CloudFormationModel from moto.core.exceptions import RESTError from moto.core.utils import ( + get_partition_from_region, iso_8601_datetime_with_milliseconds, iso_8601_datetime_without_milliseconds, unix_time, @@ -1271,8 +1272,11 @@ class Group(BaseModel): class User(CloudFormationModel): - def __init__(self, account_id: str, name: str, path: Optional[str] = None): + def __init__( + self, account_id: str, region_name: str, name: str, path: Optional[str] = None + ): self.account_id = account_id + self.region_name = region_name self.name = name self.id = random_resource_id() self.path = path if path else "/" @@ -1291,7 +1295,8 @@ class User(CloudFormationModel): @property def arn(self) -> str: - return f"arn:aws:iam::{self.account_id}:user{self.path}{self.name}" + partition = get_partition_from_region(self.region_name) + return f"arn:{partition}:iam::{self.account_id}:user{self.path}{self.name}" @property def created_iso_8601(self) -> str: @@ -1515,7 +1520,9 @@ class User(CloudFormationModel): ) -> "User": properties = cloudformation_json.get("Properties", {}) path = properties.get("Path") - user, _ = iam_backends[account_id]["global"].create_user(resource_name, path) + user, _ = iam_backends[account_id]["global"].create_user( + region_name=region_name, user_name=resource_name, path=path + ) return user @classmethod @@ -2554,6 +2561,7 @@ class IAMBackend(BaseBackend): def create_user( self, + region_name: str, user_name: str, path: str = "/", tags: Optional[List[Dict[str, str]]] = None, @@ -2563,7 +2571,7 @@ class IAMBackend(BaseBackend): "EntityAlreadyExists", f"User {user_name} already exists" ) - user = User(self.account_id, user_name, path) + user = User(self.account_id, region_name, user_name, path) self.tagger.tag_resource(user.arn, tags or []) self.users[user_name] = user return user, self.tagger.list_tags_for_resource(user.arn) diff --git a/moto/iam/responses.py b/moto/iam/responses.py index e30cb14eb..fde3c10ae 100644 --- a/moto/iam/responses.py +++ b/moto/iam/responses.py @@ -520,7 +520,9 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") path = self._get_param("Path") tags = self._get_multi_param("Tags.member") - user, user_tags = self.backend.create_user(user_name, path, tags) + user, user_tags = self.backend.create_user( + self.region, user_name=user_name, path=path, tags=tags + ) template = self.response_template(USER_TEMPLATE) return template.render(action="Create", user=user, tags=user_tags["Tags"]) @@ -530,7 +532,9 @@ class IamResponse(BaseResponse): access_key_id = self.get_access_key() user = self.backend.get_user_from_access_key_id(access_key_id) if user is None: - user = User(self.current_account, "default_user") + user = User( + self.current_account, region_name=self.region, name="default_user" + ) else: user = self.backend.get_user(user_name) tags = self.backend.tagger.list_tags_for_resource(user.arn).get("Tags", []) diff --git a/moto/sts/models.py b/moto/sts/models.py index d6a3351ea..4a237551a 100644 --- a/moto/sts/models.py +++ b/moto/sts/models.py @@ -7,7 +7,11 @@ import xmltodict from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import BaseModel -from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow +from moto.core.utils import ( + get_partition_from_region, + iso_8601_datetime_with_milliseconds, + utcnow, +) from moto.iam.models import AccessKey, iam_backends from moto.sts.utils import ( DEFAULT_STS_SESSION_DURATION, @@ -32,6 +36,7 @@ class AssumedRole(BaseModel): def __init__( self, account_id: str, + region_name: str, access_key: AccessKey, role_session_name: str, role_arn: str, @@ -40,6 +45,7 @@ class AssumedRole(BaseModel): external_id: str, ): self.account_id = account_id + self.region_name = region_name self.session_name = role_session_name self.role_arn = role_arn self.policy = policy @@ -66,7 +72,8 @@ class AssumedRole(BaseModel): @property def arn(self) -> str: - return f"arn:aws:sts::{self.account_id}:assumed-role/{self.role_arn.split('/')[-1]}/{self.session_name}" + partition = get_partition_from_region(self.region_name) + return f"arn:{partition}:sts::{self.account_id}:assumed-role/{self.role_arn.split('/')[-1]}/{self.session_name}" class STSBackend(BaseBackend): @@ -91,6 +98,7 @@ class STSBackend(BaseBackend): def assume_role( self, + region_name: str, role_session_name: str, role_arn: str, policy: str, @@ -102,13 +110,14 @@ class STSBackend(BaseBackend): """ account_id, access_key = self._create_access_key(role=role_arn) role = AssumedRole( - account_id, - access_key, - role_session_name, - role_arn, - policy, - duration, - external_id, + account_id=account_id, + region_name=region_name, + access_key=access_key, + role_session_name=role_session_name, + role_arn=role_arn, + policy=policy, + duration=duration, + external_id=external_id, ) access_key.role_arn = role_arn account_backend = sts_backends[account_id]["global"] @@ -166,6 +175,7 @@ class STSBackend(BaseBackend): account_id, access_key = self._create_access_key(role=target_role) # type: ignore kwargs["account_id"] = account_id + kwargs["region_name"] = self.region_name kwargs["access_key"] = access_key kwargs["external_id"] = None @@ -174,7 +184,9 @@ class STSBackend(BaseBackend): self.assumed_roles.append(role) return role - def get_caller_identity(self, access_key_id: str) -> Tuple[str, str, str]: + def get_caller_identity( + self, access_key_id: str, region: str + ) -> Tuple[str, str, str]: assumed_role = self.get_assumed_role_from_access_key(access_key_id) if assumed_role: return assumed_role.user_id, assumed_role.arn, assumed_role.account_id @@ -185,8 +197,9 @@ class STSBackend(BaseBackend): return user.id, user.arn, user.account_id # Default values in case the request does not use valid credentials generated by moto + partition = get_partition_from_region(region) user_id = "AKIAIOSFODNN7EXAMPLE" - arn = f"arn:aws:sts::{self.account_id}:user/moto" + arn = f"arn:{partition}:sts::{self.account_id}:user/moto" return user_id, arn, self.account_id def _create_access_key(self, role: str) -> Tuple[str, AccessKey]: diff --git a/moto/sts/responses.py b/moto/sts/responses.py index f597aeb51..317e0f3a0 100644 --- a/moto/sts/responses.py +++ b/moto/sts/responses.py @@ -51,6 +51,7 @@ class TokenResponse(BaseResponse): external_id = self.querystring.get("ExternalId", [None])[0] role = self.backend.assume_role( + region_name=self.region, role_session_name=role_session_name, role_arn=role_arn, policy=policy, @@ -69,6 +70,7 @@ class TokenResponse(BaseResponse): external_id = self.querystring.get("ExternalId", [None])[0] role = self.backend.assume_role_with_web_identity( + region_name=self.region, role_session_name=role_session_name, role_arn=role_arn, policy=policy, @@ -95,7 +97,9 @@ class TokenResponse(BaseResponse): template = self.response_template(GET_CALLER_IDENTITY_RESPONSE) access_key_id = self.get_access_key() - user_id, arn, account_id = self.backend.get_caller_identity(access_key_id) + user_id, arn, account_id = self.backend.get_caller_identity( + access_key_id, self.region + ) return template.render(account_id=account_id, user_id=user_id, arn=arn) diff --git a/tests/test_sts/test_sts.py b/tests/test_sts/test_sts.py index ec2f6c401..083932c13 100644 --- a/tests/test_sts/test_sts.py +++ b/tests/test_sts/test_sts.py @@ -66,9 +66,12 @@ def test_get_federation_token_boto3(): @freeze_time("2012-01-01 12:00:00") @mock_aws -def test_assume_role(): - client = boto3.client("sts", region_name="us-east-1") - iam_client = boto3.client("iam", region_name="us-east-1") +@pytest.mark.parametrize( + "region,partition", [("us-east-1", "aws"), ("cn-north-1", "aws-cn")] +) +def test_assume_role(region, partition): + client = boto3.client("sts", region_name=region) + iam_client = boto3.client("iam", region_name=region) session_name = "session-name" policy = json.dumps( @@ -114,7 +117,7 @@ def test_assume_role(): assert len(credentials["SecretAccessKey"]) == 40 assert assume_role_response["AssumedRoleUser"]["Arn"] == ( - f"arn:aws:sts::{ACCOUNT_ID}:assumed-role/{role_name}/{session_name}" + f"arn:{partition}:sts::{ACCOUNT_ID}:assumed-role/{role_name}/{session_name}" ) assert assume_role_response["AssumedRoleUser"]["AssumedRoleId"].startswith("AROA") assert ( @@ -583,8 +586,11 @@ def test_assume_role_with_saml_when_saml_attribute_not_provided(): @freeze_time("2012-01-01 12:00:00") @mock_aws -def test_assume_role_with_web_identity_boto3(): - client = boto3.client("sts", region_name="us-east-1") +@pytest.mark.parametrize( + "region,partition", [("us-east-1", "aws"), ("cn-north-1", "aws-cn")] +) +def test_assume_role_with_web_identity_boto3(region, partition): + client = boto3.client("sts", region_name=region) policy = json.dumps( { @@ -625,30 +631,37 @@ def test_assume_role_with_web_identity_boto3(): assert len(creds["SecretAccessKey"]) == 40 assert user["Arn"] == ( - f"arn:aws:sts::{ACCOUNT_ID}:assumed-role/{role_name}/{session_name}" + f"arn:{partition}:sts::{ACCOUNT_ID}:assumed-role/{role_name}/{session_name}" ) assert "session-name" in user["AssumedRoleId"] @mock_aws -def test_get_caller_identity_with_default_credentials(): - identity = boto3.client("sts", region_name="us-east-1").get_caller_identity() +@pytest.mark.parametrize( + "region,partition", [("us-east-1", "aws"), ("cn-north-1", "aws-cn")] +) +def test_get_caller_identity_with_default_credentials(region, partition): + identity = boto3.client("sts", region_name=region).get_caller_identity() - assert identity["Arn"] == f"arn:aws:sts::{ACCOUNT_ID}:user/moto" + assert identity["Arn"] == f"arn:{partition}:sts::{ACCOUNT_ID}:user/moto" assert identity["UserId"] == "AKIAIOSFODNN7EXAMPLE" assert identity["Account"] == str(ACCOUNT_ID) @mock_aws -def test_get_caller_identity_with_iam_user_credentials(): - iam_client = boto3.client("iam", region_name="us-east-1") +@pytest.mark.parametrize( + "region,partition", [("us-east-1", "aws"), ("cn-north-1", "aws-cn")] +) +def test_get_caller_identity_with_iam_user_credentials(region, partition): + iam_client = boto3.client("iam", region_name=region) iam_user_name = "new-user" iam_user = iam_client.create_user(UserName=iam_user_name)["User"] + assert iam_user["Arn"] == f"arn:{partition}:iam::123456789012:user/new-user" access_key = iam_client.create_access_key(UserName=iam_user_name)["AccessKey"] identity = boto3.client( "sts", - region_name="us-east-1", + region_name=region, aws_access_key_id=access_key["AccessKeyId"], aws_secret_access_key=access_key["SecretAccessKey"], ).get_caller_identity() @@ -659,9 +672,12 @@ def test_get_caller_identity_with_iam_user_credentials(): @mock_aws -def test_get_caller_identity_with_assumed_role_credentials(): - iam_client = boto3.client("iam", region_name="us-east-1") - sts_client = boto3.client("sts", region_name="us-east-1") +@pytest.mark.parametrize( + "region,partition", [("us-east-1", "aws"), ("cn-north-1", "aws-cn")] +) +def test_get_caller_identity_with_assumed_role_credentials(region, partition): + iam_client = boto3.client("iam", region_name=region) + sts_client = boto3.client("sts", region_name=region) iam_role_name = "new-user" trust_policy_document = { "Version": "2012-10-17", @@ -679,6 +695,10 @@ def test_get_caller_identity_with_assumed_role_credentials(): assumed_role = sts_client.assume_role( RoleArn=iam_role_arn, RoleSessionName=session_name ) + assert ( + assumed_role["AssumedRoleUser"]["Arn"] + == f"arn:{partition}:sts::123456789012:assumed-role/new-user/new-session" + ) access_key = assumed_role["Credentials"] identity = boto3.client(