From 20d935988b35372dfa4a7b8075838ebb2f646fba Mon Sep 17 00:00:00 2001 From: Jeff Sternberg Date: Mon, 5 Aug 2019 16:37:24 -0400 Subject: [PATCH 1/3] add configurable username claim --- oauthenticator/azuread.py | 13 ++++++++++++- oauthenticator/tests/test_azuread.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/oauthenticator/azuread.py b/oauthenticator/azuread.py index b77b986d..5eb1ae0b 100644 --- a/oauthenticator/azuread.py +++ b/oauthenticator/azuread.py @@ -51,6 +51,7 @@ class AzureAdOAuthenticator(OAuthenticator): login_handler = AzureAdLoginHandler tenant_id = Unicode(config=True) + username_claim = Unicode(config=True) def get_tenant(self): if hasattr(self, 'tenant_id') and self.tenant_id: @@ -64,6 +65,16 @@ def get_tenant(self): app_log.info('ID4: {0}'.format(tenant_id)) return tenant_id + def get_username_claim(self): + """ + The claim to map to the jupyter username, such as `upn` or `unique_name` + See https://docs.microsoft.com/en-gb/azure/active-directory/develop/id-tokens + """ + if hasattr(self, 'username_claim') and self.username_claim: + app_log.info('ID5: {0}'.format(self.username_claim)) + return self.username_claim + return 'oid' + async def authenticate(self, handler, data=None): code = handler.get_argument("code") http_client = AsyncHTTPClient() @@ -101,7 +112,7 @@ async def authenticate(self, handler, data=None): id_token = resp_json['id_token'] decoded = jwt.decode(id_token, verify=False) - userdict = {"name": decoded['name']} + userdict = {"name": decoded[self.get_username_claim()]} userdict["auth_state"] = auth_state = {} auth_state['access_token'] = access_token # results in a decoded JWT for the user data diff --git a/oauthenticator/tests/test_azuread.py b/oauthenticator/tests/test_azuread.py index f6cb2602..cd9621e9 100644 --- a/oauthenticator/tests/test_azuread.py +++ b/oauthenticator/tests/test_azuread.py @@ -2,10 +2,12 @@ from ..azuread import AzureAdOAuthenticator _t_id = 'XXX-XXX-XXXX' +_t_username_claim = 'upn' class Config(object): tenant_id = _t_id + username_claim = _t_username_claim def test_gettenant_with_tenant_id(): @@ -20,3 +22,17 @@ def test_gettenant_with_tenant_id(): def test_gettenant_from_env(): t_id = AzureAdOAuthenticator.get_tenant(object) assert t_id.default_value == "some_random_id" + + +def test_username_claim_config(): + t_username_claim = AzureAdOAuthenticator.get_username_claim(Config()) + assert t_username_claim == _t_username_claim + + +def test_username_claim_default(): + + class Config(object): + tenant_id = _t_id + + t_username_claim = AzureAdOAuthenticator.get_username_claim(Config()) + assert t_username_claim == 'oid' \ No newline at end of file From 622277ea985f74462ca8f4692935b8cbcc889c7d Mon Sep 17 00:00:00 2001 From: Jeff Sternberg Date: Tue, 6 Aug 2019 08:21:25 -0400 Subject: [PATCH 2/3] default to 'name' --- oauthenticator/azuread.py | 2 +- oauthenticator/tests/test_azuread.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/oauthenticator/azuread.py b/oauthenticator/azuread.py index 5eb1ae0b..fb42492f 100644 --- a/oauthenticator/azuread.py +++ b/oauthenticator/azuread.py @@ -73,7 +73,7 @@ def get_username_claim(self): if hasattr(self, 'username_claim') and self.username_claim: app_log.info('ID5: {0}'.format(self.username_claim)) return self.username_claim - return 'oid' + return 'name' async def authenticate(self, handler, data=None): code = handler.get_argument("code") diff --git a/oauthenticator/tests/test_azuread.py b/oauthenticator/tests/test_azuread.py index cd9621e9..e798e4e4 100644 --- a/oauthenticator/tests/test_azuread.py +++ b/oauthenticator/tests/test_azuread.py @@ -35,4 +35,4 @@ class Config(object): tenant_id = _t_id t_username_claim = AzureAdOAuthenticator.get_username_claim(Config()) - assert t_username_claim == 'oid' \ No newline at end of file + assert t_username_claim == 'name' \ No newline at end of file From 2898b746a044119705f1d3b3e230b2e741a6d330 Mon Sep 17 00:00:00 2001 From: Jeff Sternberg Date: Tue, 6 Aug 2019 08:34:09 -0400 Subject: [PATCH 3/3] use traitlet.default for username_claim --- oauthenticator/azuread.py | 21 ++++----------------- oauthenticator/tests/test_azuread.py | 18 +----------------- 2 files changed, 5 insertions(+), 34 deletions(-) diff --git a/oauthenticator/azuread.py b/oauthenticator/azuread.py index fb42492f..dc1fc4ca 100644 --- a/oauthenticator/azuread.py +++ b/oauthenticator/azuread.py @@ -6,23 +6,16 @@ import json import jwt import os -import re -import string import urllib -import sys from tornado.auth import OAuth2Mixin from tornado.log import app_log -from tornado import web - -from tornado.httputil import url_concat from tornado.httpclient import HTTPRequest, AsyncHTTPClient from jupyterhub.auth import LocalAuthenticator -from traitlets import List, Set, Unicode +from traitlets import Unicode, default -from .common import next_page_from_links from .oauth2 import OAuthLoginHandler, OAuthenticator @@ -65,14 +58,8 @@ def get_tenant(self): app_log.info('ID4: {0}'.format(tenant_id)) return tenant_id - def get_username_claim(self): - """ - The claim to map to the jupyter username, such as `upn` or `unique_name` - See https://docs.microsoft.com/en-gb/azure/active-directory/develop/id-tokens - """ - if hasattr(self, 'username_claim') and self.username_claim: - app_log.info('ID5: {0}'.format(self.username_claim)) - return self.username_claim + @default('username_claim') + def _username_claim_default(self): return 'name' async def authenticate(self, handler, data=None): @@ -112,7 +99,7 @@ async def authenticate(self, handler, data=None): id_token = resp_json['id_token'] decoded = jwt.decode(id_token, verify=False) - userdict = {"name": decoded[self.get_username_claim()]} + userdict = {"name": decoded[self.username_claim]} userdict["auth_state"] = auth_state = {} auth_state['access_token'] = access_token # results in a decoded JWT for the user data diff --git a/oauthenticator/tests/test_azuread.py b/oauthenticator/tests/test_azuread.py index e798e4e4..11b375a8 100644 --- a/oauthenticator/tests/test_azuread.py +++ b/oauthenticator/tests/test_azuread.py @@ -2,12 +2,10 @@ from ..azuread import AzureAdOAuthenticator _t_id = 'XXX-XXX-XXXX' -_t_username_claim = 'upn' class Config(object): tenant_id = _t_id - username_claim = _t_username_claim def test_gettenant_with_tenant_id(): @@ -21,18 +19,4 @@ def test_gettenant_with_tenant_id(): def test_gettenant_from_env(): t_id = AzureAdOAuthenticator.get_tenant(object) - assert t_id.default_value == "some_random_id" - - -def test_username_claim_config(): - t_username_claim = AzureAdOAuthenticator.get_username_claim(Config()) - assert t_username_claim == _t_username_claim - - -def test_username_claim_default(): - - class Config(object): - tenant_id = _t_id - - t_username_claim = AzureAdOAuthenticator.get_username_claim(Config()) - assert t_username_claim == 'name' \ No newline at end of file + assert t_id.default_value == "some_random_id" \ No newline at end of file