Merge pull request #1756 from ferruvich/add_extra_attributes_in_token_payload
Add extra attributes in token payload
This commit is contained in:
commit
7b9bb15d28
@ -84,7 +84,11 @@ class CognitoIdpUserPool(BaseModel):
|
|||||||
return refresh_token
|
return refresh_token
|
||||||
|
|
||||||
def create_access_token(self, client_id, username):
|
def create_access_token(self, client_id, username):
|
||||||
access_token, expires_in = self.create_jwt(client_id, username)
|
extra_data = self.get_user_extra_data_by_client_id(
|
||||||
|
client_id, username
|
||||||
|
)
|
||||||
|
access_token, expires_in = self.create_jwt(client_id, username,
|
||||||
|
extra_data=extra_data)
|
||||||
self.access_tokens[access_token] = (client_id, username)
|
self.access_tokens[access_token] = (client_id, username)
|
||||||
return access_token, expires_in
|
return access_token, expires_in
|
||||||
|
|
||||||
@ -97,6 +101,21 @@ class CognitoIdpUserPool(BaseModel):
|
|||||||
id_token, _ = self.create_id_token(client_id, username)
|
id_token, _ = self.create_id_token(client_id, username)
|
||||||
return access_token, id_token, expires_in
|
return access_token, id_token, expires_in
|
||||||
|
|
||||||
|
def get_user_extra_data_by_client_id(self, client_id, username):
|
||||||
|
extra_data = {}
|
||||||
|
current_client = self.clients.get(client_id, None)
|
||||||
|
if current_client:
|
||||||
|
for readable_field in current_client.get_readable_fields():
|
||||||
|
attribute = list(filter(
|
||||||
|
lambda f: f['Name'] == readable_field,
|
||||||
|
self.users.get(username).attributes
|
||||||
|
))
|
||||||
|
if len(attribute) > 0:
|
||||||
|
extra_data.update({
|
||||||
|
attribute[0]['Name']: attribute[0]['Value']
|
||||||
|
})
|
||||||
|
return extra_data
|
||||||
|
|
||||||
|
|
||||||
class CognitoIdpUserPoolDomain(BaseModel):
|
class CognitoIdpUserPoolDomain(BaseModel):
|
||||||
|
|
||||||
@ -138,6 +157,9 @@ class CognitoIdpUserPoolClient(BaseModel):
|
|||||||
|
|
||||||
return user_pool_client_json
|
return user_pool_client_json
|
||||||
|
|
||||||
|
def get_readable_fields(self):
|
||||||
|
return self.extended_config.get('ReadAttributes', [])
|
||||||
|
|
||||||
|
|
||||||
class CognitoIdpIdentityProvider(BaseModel):
|
class CognitoIdpIdentityProvider(BaseModel):
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from jose import jws
|
from jose import jws
|
||||||
|
|
||||||
from moto import mock_cognitoidp
|
from moto import mock_cognitoidp
|
||||||
import sure # noqa
|
import sure # noqa
|
||||||
|
|
||||||
@ -400,15 +401,22 @@ def authentication_flow(conn):
|
|||||||
username = str(uuid.uuid4())
|
username = str(uuid.uuid4())
|
||||||
temporary_password = str(uuid.uuid4())
|
temporary_password = str(uuid.uuid4())
|
||||||
user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"]
|
user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"]
|
||||||
|
user_attribute_name = str(uuid.uuid4())
|
||||||
|
user_attribute_value = str(uuid.uuid4())
|
||||||
client_id = conn.create_user_pool_client(
|
client_id = conn.create_user_pool_client(
|
||||||
UserPoolId=user_pool_id,
|
UserPoolId=user_pool_id,
|
||||||
ClientName=str(uuid.uuid4()),
|
ClientName=str(uuid.uuid4()),
|
||||||
|
ReadAttributes=[user_attribute_name]
|
||||||
)["UserPoolClient"]["ClientId"]
|
)["UserPoolClient"]["ClientId"]
|
||||||
|
|
||||||
conn.admin_create_user(
|
conn.admin_create_user(
|
||||||
UserPoolId=user_pool_id,
|
UserPoolId=user_pool_id,
|
||||||
Username=username,
|
Username=username,
|
||||||
TemporaryPassword=temporary_password,
|
TemporaryPassword=temporary_password,
|
||||||
|
UserAttributes=[{
|
||||||
|
'Name': user_attribute_name,
|
||||||
|
'Value': user_attribute_value
|
||||||
|
}]
|
||||||
)
|
)
|
||||||
|
|
||||||
result = conn.admin_initiate_auth(
|
result = conn.admin_initiate_auth(
|
||||||
@ -447,6 +455,9 @@ def authentication_flow(conn):
|
|||||||
"access_token": result["AuthenticationResult"]["AccessToken"],
|
"access_token": result["AuthenticationResult"]["AccessToken"],
|
||||||
"username": username,
|
"username": username,
|
||||||
"password": new_password,
|
"password": new_password,
|
||||||
|
"additional_fields": {
|
||||||
|
user_attribute_name: user_attribute_value
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -476,6 +487,8 @@ def test_token_legitimacy():
|
|||||||
access_claims = json.loads(jws.verify(access_token, json_web_key, "RS256"))
|
access_claims = json.loads(jws.verify(access_token, json_web_key, "RS256"))
|
||||||
access_claims["iss"].should.equal(issuer)
|
access_claims["iss"].should.equal(issuer)
|
||||||
access_claims["aud"].should.equal(client_id)
|
access_claims["aud"].should.equal(client_id)
|
||||||
|
for k, v in outputs["additional_fields"].items():
|
||||||
|
access_claims[k].should.equal(v)
|
||||||
|
|
||||||
|
|
||||||
@mock_cognitoidp
|
@mock_cognitoidp
|
||||||
|
Loading…
Reference in New Issue
Block a user