From 6ce49cff3b09328fcabced5199f5996c1818a794 Mon Sep 17 00:00:00 2001 From: Matus Kosut Date: Thu, 12 Sep 2024 11:39:45 +0200 Subject: [PATCH] optimize checking of supported args --- jupyter_rsession_proxy/__init__.py | 36 +++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/jupyter_rsession_proxy/__init__.py b/jupyter_rsession_proxy/__init__.py index f3c6734..b42aed4 100644 --- a/jupyter_rsession_proxy/__init__.py +++ b/jupyter_rsession_proxy/__init__.py @@ -1,6 +1,5 @@ import getpass import os -import pathlib import shutil import subprocess import tempfile @@ -46,7 +45,7 @@ def rewrite_netloc(response, request): def get_system_user(): try: user = pwd.getpwuid(os.getuid())[0] - except: + except Exception: user = os.environ.get('NB_USER', getpass.getuser()) return(user) @@ -73,9 +72,16 @@ def db_config(db_dir): f.close() return db_config_name - def _support_arg(arg): + def _support_args(args): ret = subprocess.check_output([get_rstudio_executable('rserver'), '--help']) - return ret.decode().find(arg) != -1 + help_output = ret.decode() + return {arg: (help_output.find(arg) != -1) for arg in args} + + def _get_www_frame_origin(default="same"): + try: + return os.getenv('JUPYTER_RSESSION_PROXY_WWW_FRAME_ORIGIN', default) + except Exception: + return default def _get_cmd(port): ntf = tempfile.NamedTemporaryFile() @@ -88,7 +94,7 @@ def _get_cmd(port): cmd = [ get_rstudio_executable('rserver'), '--auth-none=1', - '--www-frame-origin=same', + '--www-frame-origin=' + _get_www_frame_origin(), '--www-port=' + str(port), '--www-verify-user-agent=0', '--secure-cookie-key-file=' + ntf.name, @@ -96,13 +102,27 @@ def _get_cmd(port): ] # Support at least v1.2.1335 and up - if _support_arg('www-root-path'): + supported_args = _support_args([ + 'www-root-path', + 'server-data-dir', + 'database-config-file', + 'www-thread-pool-size', + ]) + if supported_args['www-root-path']: cmd.append('--www-root-path={base_url}rstudio/') - if _support_arg('server-data-dir'): + if supported_args['server-data-dir']: cmd.append(f'--server-data-dir={server_data_dir}') - if _support_arg('database-config-file'): + if supported_args['database-config-file']: cmd.append(f'--database-config-file={database_config_file}') + if supported_args['www-thread-pool-size']: + try: + thread_pool_size = int(os.getenv('RSERVER_THREAD_POOL_SIZE', "")) + if thread_pool_size > 0: + cmd.append('--www-thread-pool-size=' + str(thread_pool_size)) + except Exception: + pass + return cmd def _get_timeout(default=15):