diff --git a/moto/identitystore/models.py b/moto/identitystore/models.py index 167705355..8ec0d2613 100644 --- a/moto/identitystore/models.py +++ b/moto/identitystore/models.py @@ -82,6 +82,12 @@ class IdentityStoreBackend(BaseBackend): "limit_default": 100, "unique_attribute": "MembershipId", }, + "list_group_memberships_for_member": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "unique_attribute": "MembershipId", + }, "list_groups": { "input_token": "next_token", "limit_key": "max_results", @@ -264,6 +270,19 @@ class IdentityStoreBackend(BaseBackend): if m["GroupId"] == group_id ] + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore + def list_group_memberships_for_member( + self, identity_store_id: str, member_id: Dict[str, str] + ) -> List[Any]: + identity_store = self.__get_identity_store(identity_store_id) + user_id = member_id["UserId"] + + return [ + m + for m in identity_store.group_memberships.values() + if m["MemberId"]["UserId"] == user_id + ] + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore def list_groups( self, identity_store_id: str, filters: List[Dict[str, str]] diff --git a/moto/identitystore/responses.py b/moto/identitystore/responses.py index a54dfd66e..6c76e040b 100644 --- a/moto/identitystore/responses.py +++ b/moto/identitystore/responses.py @@ -137,6 +137,25 @@ class IdentityStoreResponse(BaseResponse): dict(GroupMemberships=group_memberships, NextToken=next_token) ) + def list_group_memberships_for_member(self) -> str: + identity_store_id = self._get_param("IdentityStoreId") + member_id = self._get_param("MemberId") + max_results = self._get_param("MaxResults") + next_token = self._get_param("NextToken") + ( + group_memberships, + next_token, + ) = self.identitystore_backend.list_group_memberships_for_member( + identity_store_id=identity_store_id, + member_id=member_id, + max_results=max_results, + next_token=next_token, + ) + + return json.dumps( + dict(GroupMemberships=group_memberships, NextToken=next_token) + ) + def list_groups(self) -> str: identity_store_id = self._get_param("IdentityStoreId") max_results = self._get_param("MaxResults") diff --git a/tests/test_identitystore/test_identitystore.py b/tests/test_identitystore/test_identitystore.py index b5ec13b32..09d8b518f 100644 --- a/tests/test_identitystore/test_identitystore.py +++ b/tests/test_identitystore/test_identitystore.py @@ -560,6 +560,61 @@ def test_list_group_memberships(): next_token = list_response["NextToken"] +@mock_identitystore +def test_list_group_memberships_for_member(): + client = boto3.client("identitystore", region_name="us-east-2") + identity_store_id = get_identity_store_id() + + start = 0 + end = 5000 + batch_size = 321 + next_token = None + membership_ids = [] + + user_id = __create_and_verify_sparse_user(client, identity_store_id)["UserId"] + for i in range(end): + group_id = client.create_group( + IdentityStoreId=identity_store_id, + DisplayName=f"test_group_{i}", + Description="description", + )["GroupId"] + create_response = client.create_group_membership( + IdentityStoreId=identity_store_id, + GroupId=group_id, + MemberId={"UserId": user_id}, + ) + membership_ids.append((create_response["MembershipId"], user_id)) + + for iteration in range(start, end, batch_size): + last_iteration = end - iteration <= batch_size + expected_size = batch_size if not last_iteration else end - iteration + end_index = iteration + expected_size + + if next_token is not None: + list_response = client.list_group_memberships_for_member( + IdentityStoreId=identity_store_id, + MemberId={"UserId": user_id}, + MaxResults=batch_size, + NextToken=next_token, + ) + else: + list_response = client.list_group_memberships_for_member( + IdentityStoreId=identity_store_id, + MemberId={"UserId": user_id}, + MaxResults=batch_size, + ) + + assert len(list_response["GroupMemberships"]) == expected_size + __check_membership_list_values( + list_response["GroupMemberships"], membership_ids[iteration:end_index] + ) + if last_iteration: + assert "NextToken" not in list_response + else: + assert "NextToken" in list_response + next_token = list_response["NextToken"] + + def __check_membership_list_values(members, expected): assert len(members) == len(expected) for i in range(len(expected)):