diff --git a/binderhub/base.py b/binderhub/base.py index eeb079b9d..ba83bce87 100644 --- a/binderhub/base.py +++ b/binderhub/base.py @@ -1,12 +1,15 @@ """Base classes for request handlers""" import json +import os import urllib.parse import jwt from http.client import responses from tornado import web from tornado.log import app_log +from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPError + from jupyterhub.services.auth import HubOAuthenticated, HubOAuth from . import __version__ as binder_version @@ -21,6 +24,7 @@ def initialize(self): super().initialize() if self.settings['auth_enabled']: self.hub_auth = HubOAuth.instance(config=self.settings['traitlets_config']) + self.current_user_model = None def prepare(self): super().prepare() @@ -156,14 +160,46 @@ def get_spec_from_request(self, prefix): spec = self.request.path[idx + len(prefix) + 1:] return spec - def get_provider(self, provider_prefix, spec): + async def get_provider(self, provider_prefix, spec): """Construct a provider object""" providers = self.settings['repo_providers'] if provider_prefix not in providers: raise web.HTTPError(404, "No provider found for prefix %s" % provider_prefix) + async def api_request(url, *args, **kwargs): + headers = kwargs.setdefault('headers', {}) + headers.update({'Authorization': 'token %s' % self.hub_auth.api_token}) + hub_api_url = os.getenv('JUPYTERHUB_API_URL', '') or self.hub_auth.api_url + request_url = hub_api_url + url + req = HTTPRequest(request_url, *args, **kwargs) + + try: + return await AsyncHTTPClient().fetch(req) + except HTTPError as e: + app_log.error("Error accessing Hub API (using %s): %s", request_url, e) + + async def get_current_user_model(): + """Get the current user model. + The user auth_state is only accessible to admin users. + """ + if not self.settings['auth_enabled']: + return None + + if self.current_user_model is None: + username = self.get_current_user()['name'] + resp = await api_request( + f'/users/{username}', + method='GET', + ) + self.current_user_model = json.loads(resp.body.decode('utf-8')) + + return self.current_user_model + return providers[provider_prefix]( - config=self.settings['traitlets_config'], spec=spec) + config=self.settings['traitlets_config'], + spec=spec, + user_model=await get_current_user_model() + ) def get_badge_base_url(self): badge_base_url = self.settings['badge_base_url'] diff --git a/binderhub/builder.py b/binderhub/builder.py index fcd25b4e7..6f53a6c56 100644 --- a/binderhub/builder.py +++ b/binderhub/builder.py @@ -259,7 +259,7 @@ async def get(self, provider_prefix, _unescaped_spec): # get a provider object that encapsulates the provider and the spec try: - provider = self.get_provider(provider_prefix, spec=spec) + provider = await self.get_provider(provider_prefix, spec=spec) except Exception as e: app_log.exception("Failed to get provider for %s", key) await self.fail(str(e)) diff --git a/binderhub/main.py b/binderhub/main.py index 976cd12e4..dbf1787c8 100755 --- a/binderhub/main.py +++ b/binderhub/main.py @@ -50,7 +50,7 @@ async def get(self, provider_prefix, _unescaped_spec): spec = self.get_spec_from_request(prefix) spec = spec.rstrip("/") try: - self.get_provider(provider_prefix, spec=spec) + await self.get_provider(provider_prefix, spec=spec) except HTTPError: raise except Exception as e: