diff --git a/oauthenticator/azuread.py b/oauthenticator/azuread.py index b77b986d..dc1fc4ca 100644 --- a/oauthenticator/azuread.py +++ b/oauthenticator/azuread.py @@ -6,23 +6,16 @@ import json import jwt import os -import re -import string import urllib -import sys from tornado.auth import OAuth2Mixin from tornado.log import app_log -from tornado import web - -from tornado.httputil import url_concat from tornado.httpclient import HTTPRequest, AsyncHTTPClient from jupyterhub.auth import LocalAuthenticator -from traitlets import List, Set, Unicode +from traitlets import Unicode, default -from .common import next_page_from_links from .oauth2 import OAuthLoginHandler, OAuthenticator @@ -51,6 +44,7 @@ class AzureAdOAuthenticator(OAuthenticator): login_handler = AzureAdLoginHandler tenant_id = Unicode(config=True) + username_claim = Unicode(config=True) def get_tenant(self): if hasattr(self, 'tenant_id') and self.tenant_id: @@ -64,6 +58,10 @@ def get_tenant(self): app_log.info('ID4: {0}'.format(tenant_id)) return tenant_id + @default('username_claim') + def _username_claim_default(self): + return 'name' + async def authenticate(self, handler, data=None): code = handler.get_argument("code") http_client = AsyncHTTPClient() @@ -101,7 +99,7 @@ async def authenticate(self, handler, data=None): id_token = resp_json['id_token'] decoded = jwt.decode(id_token, verify=False) - userdict = {"name": decoded['name']} + userdict = {"name": decoded[self.username_claim]} userdict["auth_state"] = auth_state = {} auth_state['access_token'] = access_token # results in a decoded JWT for the user data diff --git a/oauthenticator/tests/test_azuread.py b/oauthenticator/tests/test_azuread.py index f6cb2602..11b375a8 100644 --- a/oauthenticator/tests/test_azuread.py +++ b/oauthenticator/tests/test_azuread.py @@ -19,4 +19,4 @@ def test_gettenant_with_tenant_id(): def test_gettenant_from_env(): t_id = AzureAdOAuthenticator.get_tenant(object) - assert t_id.default_value == "some_random_id" + assert t_id.default_value == "some_random_id" \ No newline at end of file