From db75c9e25ca49da7c1bb7e330579db695fbeff3b Mon Sep 17 00:00:00 2001 From: Franz See Date: Sun, 5 Jan 2020 23:13:36 +0800 Subject: [PATCH 1/2] moto/issues/2670 | Moved population of user attributes from accessToken to idToken --- moto/cognitoidp/models.py | 6 +++--- tests/test_cognitoidp/test_cognitoidp.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/moto/cognitoidp/models.py b/moto/cognitoidp/models.py index 78025627a..9f39d7a5f 100644 --- a/moto/cognitoidp/models.py +++ b/moto/cognitoidp/models.py @@ -127,7 +127,8 @@ class CognitoIdpUserPool(BaseModel): return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in def create_id_token(self, client_id, username): - id_token, expires_in = self.create_jwt(client_id, username, "id") + extra_data = self.get_user_extra_data_by_client_id(client_id, username) + id_token, expires_in = self.create_jwt(client_id, username, "id", extra_data=extra_data) self.id_tokens[id_token] = (client_id, username) return id_token, expires_in @@ -137,9 +138,8 @@ class CognitoIdpUserPool(BaseModel): return refresh_token def create_access_token(self, client_id, username): - extra_data = self.get_user_extra_data_by_client_id(client_id, username) access_token, expires_in = self.create_jwt( - client_id, username, "access", extra_data=extra_data + client_id, username, "access" ) self.access_tokens[access_token] = (client_id, username) return access_token, expires_in diff --git a/tests/test_cognitoidp/test_cognitoidp.py b/tests/test_cognitoidp/test_cognitoidp.py index 79e6dbbb8..6a13683f0 100644 --- a/tests/test_cognitoidp/test_cognitoidp.py +++ b/tests/test_cognitoidp/test_cognitoidp.py @@ -1143,11 +1143,11 @@ def test_token_legitimacy(): id_claims["iss"].should.equal(issuer) id_claims["aud"].should.equal(client_id) id_claims["token_use"].should.equal("id") + for k, v in outputs["additional_fields"].items(): + id_claims[k].should.equal(v) access_claims = json.loads(jws.verify(access_token, json_web_key, "RS256")) access_claims["iss"].should.equal(issuer) access_claims["aud"].should.equal(client_id) - for k, v in outputs["additional_fields"].items(): - access_claims[k].should.equal(v) access_claims["token_use"].should.equal("access") From 44e92f58ec44250c0701209549104d4545304ae8 Mon Sep 17 00:00:00 2001 From: Franz See Date: Wed, 15 Jan 2020 23:33:26 +0800 Subject: [PATCH 2/2] moto/issues/2670 | Used black to format the code --- moto/cognitoidp/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/moto/cognitoidp/models.py b/moto/cognitoidp/models.py index 9f39d7a5f..96b23a404 100644 --- a/moto/cognitoidp/models.py +++ b/moto/cognitoidp/models.py @@ -128,7 +128,9 @@ class CognitoIdpUserPool(BaseModel): def create_id_token(self, client_id, username): extra_data = self.get_user_extra_data_by_client_id(client_id, username) - id_token, expires_in = self.create_jwt(client_id, username, "id", extra_data=extra_data) + id_token, expires_in = self.create_jwt( + client_id, username, "id", extra_data=extra_data + ) self.id_tokens[id_token] = (client_id, username) return id_token, expires_in @@ -138,9 +140,7 @@ class CognitoIdpUserPool(BaseModel): return refresh_token def create_access_token(self, client_id, username): - access_token, expires_in = self.create_jwt( - client_id, username, "access" - ) + access_token, expires_in = self.create_jwt(client_id, username, "access") self.access_tokens[access_token] = (client_id, username) return access_token, expires_in