diff --git a/moto/cognitoidp/models.py b/moto/cognitoidp/models.py index 78025627a..96b23a404 100644 --- a/moto/cognitoidp/models.py +++ b/moto/cognitoidp/models.py @@ -127,7 +127,10 @@ class CognitoIdpUserPool(BaseModel): return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in def create_id_token(self, client_id, username): - id_token, expires_in = self.create_jwt(client_id, username, "id") + extra_data = self.get_user_extra_data_by_client_id(client_id, username) + id_token, expires_in = self.create_jwt( + client_id, username, "id", extra_data=extra_data + ) self.id_tokens[id_token] = (client_id, username) return id_token, expires_in @@ -137,10 +140,7 @@ class CognitoIdpUserPool(BaseModel): return refresh_token def create_access_token(self, client_id, username): - extra_data = self.get_user_extra_data_by_client_id(client_id, username) - access_token, expires_in = self.create_jwt( - client_id, username, "access", extra_data=extra_data - ) + access_token, expires_in = self.create_jwt(client_id, username, "access") self.access_tokens[access_token] = (client_id, username) return access_token, expires_in diff --git a/tests/test_cognitoidp/test_cognitoidp.py b/tests/test_cognitoidp/test_cognitoidp.py index 79e6dbbb8..6a13683f0 100644 --- a/tests/test_cognitoidp/test_cognitoidp.py +++ b/tests/test_cognitoidp/test_cognitoidp.py @@ -1143,11 +1143,11 @@ def test_token_legitimacy(): id_claims["iss"].should.equal(issuer) id_claims["aud"].should.equal(client_id) id_claims["token_use"].should.equal("id") + for k, v in outputs["additional_fields"].items(): + id_claims[k].should.equal(v) access_claims = json.loads(jws.verify(access_token, json_web_key, "RS256")) access_claims["iss"].should.equal(issuer) access_claims["aud"].should.equal(client_id) - for k, v in outputs["additional_fields"].items(): - access_claims[k].should.equal(v) access_claims["token_use"].should.equal("access")