diff --git a/hlink/configs/load_config.py b/hlink/configs/load_config.py index 73e048a..46b565a 100755 --- a/hlink/configs/load_config.py +++ b/hlink/configs/load_config.py @@ -11,7 +11,7 @@ from hlink.errors import UsageError -def load_conf_file(conf_name: str) -> dict[str, Any]: +def load_conf_file(conf_name: str) -> tuple[Path, dict[str, Any]]: """Flexibly load a config file. Given a path `conf_name`, look for a file at that path. If that file @@ -20,15 +20,11 @@ def load_conf_file(conf_name: str) -> dict[str, Any]: name with a '.toml' extension added and load it if it exists. Then do the same for a file with a '.json' extension added. - After successfully loading a config file, store the absolute path where the - config file was found as the value of the "conf_path" key in the returned - config dictionary. - Args: conf_name: the file to look for Returns: - the contents of the config file + a tuple (absolute path to the config file, contents of the config file) Raises: FileNotFoundError: if none of the three checked files exist @@ -46,14 +42,12 @@ def load_conf_file(conf_name: str) -> dict[str, Any]: if file.suffix == ".toml": with open(file) as f: conf = toml.load(f) - conf["conf_path"] = str(file.resolve()) - return conf + return file.absolute(), conf if file.suffix == ".json": with open(file) as f: conf = json.load(f) - conf["conf_path"] = str(file.resolve()) - return conf + return file.absolute(), conf raise UsageError( f"The file {file} exists, but it doesn't have a '.toml' or '.json' extension." diff --git a/hlink/scripts/main.py b/hlink/scripts/main.py index 2cea838..d4f59e3 100755 --- a/hlink/scripts/main.py +++ b/hlink/scripts/main.py @@ -12,9 +12,10 @@ import importlib.metadata import readline import sys +from timeit import default_timer as timer import traceback +from typing import Any import uuid -from timeit import default_timer as timer from hlink.spark.session import SparkConnection from hlink.configs.load_config import load_conf_file @@ -28,7 +29,7 @@ logger = logging.getLogger(__name__) -def load_conf(conf_name, user): +def load_conf(conf_name: str, user: str) -> tuple[Path, dict[str, Any]]: """Load and return the hlink config dictionary. Add the following attributes to the config dictionary: @@ -50,7 +51,7 @@ def load_conf(conf_name, user): base_derby_dir = hlink_dir / "derby" base_warehouse_dir = hlink_dir / "warehouse" base_spark_tmp_dir = hlink_dir / "spark_tmp_dir" - conf = load_conf_file(conf_name) + path, conf = load_conf_file(conf_name) conf["derby_dir"] = base_derby_dir / run_name conf["warehouse_dir"] = base_warehouse_dir / run_name @@ -62,7 +63,7 @@ def load_conf(conf_name, user): user_dir_fast = Path(global_conf["users_dir_fast"]) / user conf_dir = user_dir / "confs" conf_path = conf_dir / conf_name - conf = load_conf_file(str(conf_path)) + path, conf = load_conf_file(str(conf_path)) conf["derby_dir"] = user_dir / "derby" / run_name conf["warehouse_dir"] = user_dir_fast / "warehouse" / run_name @@ -71,8 +72,8 @@ def load_conf(conf_name, user): conf["python"] = global_conf["python"] conf["run_name"] = run_name - print(f"*** Using config file {conf['conf_path']}") - return conf + print(f"*** Using config file {path}") + return path, conf def cli(): @@ -85,7 +86,7 @@ def cli(): try: if args.conf: - run_conf = load_conf(args.conf, args.user) + conf_path, run_conf = load_conf(args.conf, args.user) else: raise Exception( "ERROR: You must specify a config file to use by including either the --run or --conf flag in your program call." @@ -103,7 +104,7 @@ def cli(): traceback.print_exception("", err, None) sys.exit(1) - _setup_logging(run_conf) + _setup_logging(conf_path, run_conf) logger.info("Initializing Spark") spark_init_start = timer() @@ -235,14 +236,14 @@ def _cli_loop(spark, args, run_conf, run_name): main.cmdloop() if main.lastcmd == "reload": logger.info("Reloading config file") - run_conf = load_conf(args.conf, args.user) + conf_path, run_conf = load_conf(args.conf, args.user) else: break except Exception as err: report_and_log_error("", err) -def _setup_logging(conf): +def _setup_logging(conf_path, conf): log_dir = Path(conf["log_dir"]) log_dir.mkdir(exist_ok=True, parents=True) @@ -260,7 +261,7 @@ def _setup_logging(conf): logging.basicConfig(filename=log_file, level=logging.INFO, format=format_string) logger.info(f"New session {session_id} by user {user}") - logger.info(f"Configured with {conf['conf_path']}") + logger.info(f"Configured with {conf_path}") logger.info(f"Using hlink version {hlink_version}") logger.info( "-------------------------------------------------------------------------------------" diff --git a/hlink/tests/conf_validations_test.py b/hlink/tests/conf_validations_test.py index 9cf896c..387c447 100644 --- a/hlink/tests/conf_validations_test.py +++ b/hlink/tests/conf_validations_test.py @@ -22,7 +22,7 @@ ) def test_invalid_conf(conf_dir_path, spark, conf_name, error_msg): conf_file = os.path.join(conf_dir_path, conf_name) - config = load_conf_file(conf_file) + _path, config = load_conf_file(conf_file) link_run = LinkRun(spark, config) with pytest.raises(ValueError, match=error_msg): diff --git a/hlink/tests/config_loader_test.py b/hlink/tests/config_loader_test.py index 4fd4827..58c497e 100644 --- a/hlink/tests/config_loader_test.py +++ b/hlink/tests/config_loader_test.py @@ -9,17 +9,17 @@ def test_load_conf_file_json(conf_dir_path): conf_file = os.path.join(conf_dir_path, "test") - conf = load_conf_file(conf_file) + _path, conf = load_conf_file(conf_file) assert conf["id_column"] == "id" def test_load_conf_file_toml(conf_dir_path): conf_file = os.path.join(conf_dir_path, "test1") - conf = load_conf_file(conf_file) + _path, conf = load_conf_file(conf_file) assert conf["id_column"] == "id-toml" def test_load_conf_file_json2(conf_dir_path): conf_file = os.path.join(conf_dir_path, "test_conf_flag_run") - conf = load_conf_file(conf_file) + _path, conf = load_conf_file(conf_file) assert conf["id_column"] == "id_conf_flag" diff --git a/hlink/tests/conftest.py b/hlink/tests/conftest.py index 48db85e..88c99af 100755 --- a/hlink/tests/conftest.py +++ b/hlink/tests/conftest.py @@ -158,7 +158,7 @@ def conf(conf_dir_path): @pytest.fixture(scope="function") def integration_conf(input_data_dir_path, conf_dir_path): conf_file = os.path.join(conf_dir_path, "integration") - conf = load_conf_file(conf_file) + _conf_path, conf = load_conf_file(conf_file) datasource_a = conf["datasource_a"] datasource_b = conf["datasource_b"] diff --git a/hlink/tests/main_test.py b/hlink/tests/main_test.py index 2938458..c236a3f 100644 --- a/hlink/tests/main_test.py +++ b/hlink/tests/main_test.py @@ -70,8 +70,8 @@ def test_load_conf_json_exists_no_env(monkeypatch, tmp_path, conf_file, user): with open(filename, "w") as f: json.dump(contents, f) - conf = load_conf(filename, user) - assert conf["conf_path"] == filename + path, _conf = load_conf(filename, user) + assert str(path) == filename @pytest.mark.parametrize("conf_name", ("my_conf", "my_conf.json", "my_conf.toml")) @@ -85,8 +85,8 @@ def test_load_conf_json_exists_ext_added_no_env(monkeypatch, tmp_path, conf_name with open(filename, "w") as f: json.dump(contents, f) - conf = load_conf(str(tmp_path / conf_name), user) - assert conf["conf_path"] == filename + path, _conf = load_conf(str(tmp_path / conf_name), user) + assert str(path) == filename @pytest.mark.parametrize("conf_file", ("my_conf.toml",)) @@ -100,8 +100,8 @@ def test_load_conf_toml_exists_no_env(monkeypatch, tmp_path, conf_file, user): with open(filename, "w") as f: toml.dump(contents, f) - conf = load_conf(filename, user) - assert conf["conf_path"] == filename + path, _conf = load_conf(filename, user) + assert str(path) == filename @pytest.mark.parametrize("conf_name", ("my_conf", "my_conf.json", "my_conf.toml")) @@ -115,8 +115,8 @@ def test_load_conf_toml_exists_ext_added_no_env(monkeypatch, tmp_path, conf_name with open(filename, "w") as f: toml.dump(contents, f) - conf = load_conf(str(tmp_path / conf_name), user) - assert conf["conf_path"] == filename + path, _conf = load_conf(str(tmp_path / conf_name), user) + assert str(path) == filename @pytest.mark.parametrize("conf_name", ("my_conf", "testing.txt", "what.yaml")) @@ -147,13 +147,12 @@ def test_load_conf_keys_set_no_env(monkeypatch, tmp_path): with open(filename, "w") as f: json.dump(contents, f) - conf = load_conf(filename, "test") + _path, conf = load_conf(filename, "test") for key, value in contents.items(): assert conf[key] == value # Check for extra keys added by load_conf() - assert "conf_path" in conf assert "derby_dir" in conf assert "warehouse_dir" in conf assert "spark_tmp_dir" in conf @@ -202,8 +201,8 @@ def test_load_conf_json_exists_in_conf_dir_env( with open(file, "w") as f: json.dump(contents, f) - conf = load_conf(conf_file, user) - assert conf["conf_path"] == str(file) + path, _conf = load_conf(conf_file, user) + assert path == file @pytest.mark.parametrize("conf_file", ("my_conf.toml",)) @@ -221,8 +220,8 @@ def test_load_conf_toml_exists_in_conf_dir_env( with open(file, "w") as f: toml.dump(contents, f) - conf = load_conf(conf_file, user) - assert conf["conf_path"] == str(file) + path, _conf = load_conf(conf_file, user) + assert path == file @pytest.mark.parametrize("conf_name", ("my_conf", "test", "testingtesting123.txt")) @@ -241,8 +240,8 @@ def test_load_conf_json_exists_in_conf_dir_ext_added_env( with open(file, "w") as f: json.dump(contents, f) - conf = load_conf(conf_name, user) - assert conf["conf_path"] == str(file) + path, _conf = load_conf(conf_name, user) + assert path == file @pytest.mark.parametrize("conf_name", ("my_conf", "test", "testingtesting123.txt")) @@ -261,8 +260,8 @@ def test_load_conf_toml_exists_in_conf_dir_ext_added_env( with open(file, "w") as f: toml.dump(contents, f) - conf = load_conf(conf_name, user) - assert conf["conf_path"] == str(file) + path, _conf = load_conf(conf_name, user) + assert path == file @pytest.mark.parametrize("conf_name", ("my_conf", "testing.txt", "what.yaml")) @@ -294,13 +293,12 @@ def test_load_conf_keys_set_env( with open(file, "w") as f: json.dump(contents, f) - conf = load_conf(filename, user) + _path, conf = load_conf(filename, user) for key, value in contents.items(): assert conf[key] == value # Check for extra keys added by load_conf() - assert "conf_path" in conf assert "derby_dir" in conf assert "warehouse_dir" in conf assert "spark_tmp_dir" in conf