Skip to content

Commit

Permalink
[#181] Return a tuple (path, config) from load_conf_file
Browse files Browse the repository at this point in the history
This eliminates the need to set a new "conf_path" attribute on the
configuration dictionary before returning it.
  • Loading branch information
riley-harper committed Dec 13, 2024
1 parent 7f8b49d commit 4c6e602
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 46 deletions.
14 changes: 4 additions & 10 deletions hlink/configs/load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from hlink.errors import UsageError


def load_conf_file(conf_name: str) -> dict[str, Any]:
def load_conf_file(conf_name: str) -> tuple[Path, dict[str, Any]]:
"""Flexibly load a config file.
Given a path `conf_name`, look for a file at that path. If that file
Expand All @@ -20,15 +20,11 @@ def load_conf_file(conf_name: str) -> dict[str, Any]:
name with a '.toml' extension added and load it if it exists. Then do the
same for a file with a '.json' extension added.
After successfully loading a config file, store the absolute path where the
config file was found as the value of the "conf_path" key in the returned
config dictionary.
Args:
conf_name: the file to look for
Returns:
the contents of the config file
a tuple (absolute path to the config file, contents of the config file)
Raises:
FileNotFoundError: if none of the three checked files exist
Expand All @@ -46,14 +42,12 @@ def load_conf_file(conf_name: str) -> dict[str, Any]:
if file.suffix == ".toml":
with open(file) as f:
conf = toml.load(f)
conf["conf_path"] = str(file.resolve())
return conf
return file.absolute(), conf

if file.suffix == ".json":
with open(file) as f:
conf = json.load(f)
conf["conf_path"] = str(file.resolve())
return conf
return file.absolute(), conf

raise UsageError(
f"The file {file} exists, but it doesn't have a '.toml' or '.json' extension."
Expand Down
23 changes: 12 additions & 11 deletions hlink/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import importlib.metadata
import readline
import sys
from timeit import default_timer as timer
import traceback
from typing import Any
import uuid
from timeit import default_timer as timer

from hlink.spark.session import SparkConnection
from hlink.configs.load_config import load_conf_file
Expand All @@ -28,7 +29,7 @@
logger = logging.getLogger(__name__)


def load_conf(conf_name, user):
def load_conf(conf_name: str, user: str) -> tuple[Path, dict[str, Any]]:
"""Load and return the hlink config dictionary.
Add the following attributes to the config dictionary:
Expand All @@ -50,7 +51,7 @@ def load_conf(conf_name, user):
base_derby_dir = hlink_dir / "derby"
base_warehouse_dir = hlink_dir / "warehouse"
base_spark_tmp_dir = hlink_dir / "spark_tmp_dir"
conf = load_conf_file(conf_name)
path, conf = load_conf_file(conf_name)

conf["derby_dir"] = base_derby_dir / run_name
conf["warehouse_dir"] = base_warehouse_dir / run_name
Expand All @@ -62,7 +63,7 @@ def load_conf(conf_name, user):
user_dir_fast = Path(global_conf["users_dir_fast"]) / user
conf_dir = user_dir / "confs"
conf_path = conf_dir / conf_name
conf = load_conf_file(str(conf_path))
path, conf = load_conf_file(str(conf_path))

conf["derby_dir"] = user_dir / "derby" / run_name
conf["warehouse_dir"] = user_dir_fast / "warehouse" / run_name
Expand All @@ -71,8 +72,8 @@ def load_conf(conf_name, user):
conf["python"] = global_conf["python"]

conf["run_name"] = run_name
print(f"*** Using config file {conf['conf_path']}")
return conf
print(f"*** Using config file {path}")
return path, conf


def cli():
Expand All @@ -85,7 +86,7 @@ def cli():

try:
if args.conf:
run_conf = load_conf(args.conf, args.user)
conf_path, run_conf = load_conf(args.conf, args.user)
else:
raise Exception(
"ERROR: You must specify a config file to use by including either the --run or --conf flag in your program call."
Expand All @@ -103,7 +104,7 @@ def cli():
traceback.print_exception("", err, None)
sys.exit(1)

_setup_logging(run_conf)
_setup_logging(conf_path, run_conf)

logger.info("Initializing Spark")
spark_init_start = timer()
Expand Down Expand Up @@ -235,14 +236,14 @@ def _cli_loop(spark, args, run_conf, run_name):
main.cmdloop()
if main.lastcmd == "reload":
logger.info("Reloading config file")
run_conf = load_conf(args.conf, args.user)
conf_path, run_conf = load_conf(args.conf, args.user)
else:
break
except Exception as err:
report_and_log_error("", err)


def _setup_logging(conf):
def _setup_logging(conf_path, conf):
log_dir = Path(conf["log_dir"])
log_dir.mkdir(exist_ok=True, parents=True)

Expand All @@ -260,7 +261,7 @@ def _setup_logging(conf):
logging.basicConfig(filename=log_file, level=logging.INFO, format=format_string)

logger.info(f"New session {session_id} by user {user}")
logger.info(f"Configured with {conf['conf_path']}")
logger.info(f"Configured with {conf_path}")
logger.info(f"Using hlink version {hlink_version}")
logger.info(
"-------------------------------------------------------------------------------------"
Expand Down
2 changes: 1 addition & 1 deletion hlink/tests/conf_validations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
def test_invalid_conf(conf_dir_path, spark, conf_name, error_msg):
conf_file = os.path.join(conf_dir_path, conf_name)
config = load_conf_file(conf_file)
_path, config = load_conf_file(conf_file)
link_run = LinkRun(spark, config)

with pytest.raises(ValueError, match=error_msg):
Expand Down
6 changes: 3 additions & 3 deletions hlink/tests/config_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@

def test_load_conf_file_json(conf_dir_path):
conf_file = os.path.join(conf_dir_path, "test")
conf = load_conf_file(conf_file)
_path, conf = load_conf_file(conf_file)
assert conf["id_column"] == "id"


def test_load_conf_file_toml(conf_dir_path):
conf_file = os.path.join(conf_dir_path, "test1")
conf = load_conf_file(conf_file)
_path, conf = load_conf_file(conf_file)
assert conf["id_column"] == "id-toml"


def test_load_conf_file_json2(conf_dir_path):
conf_file = os.path.join(conf_dir_path, "test_conf_flag_run")
conf = load_conf_file(conf_file)
_path, conf = load_conf_file(conf_file)
assert conf["id_column"] == "id_conf_flag"
2 changes: 1 addition & 1 deletion hlink/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def conf(conf_dir_path):
@pytest.fixture(scope="function")
def integration_conf(input_data_dir_path, conf_dir_path):
conf_file = os.path.join(conf_dir_path, "integration")
conf = load_conf_file(conf_file)
_conf_path, conf = load_conf_file(conf_file)

datasource_a = conf["datasource_a"]
datasource_b = conf["datasource_b"]
Expand Down
38 changes: 18 additions & 20 deletions hlink/tests/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_load_conf_json_exists_no_env(monkeypatch, tmp_path, conf_file, user):
with open(filename, "w") as f:
json.dump(contents, f)

conf = load_conf(filename, user)
assert conf["conf_path"] == filename
path, _conf = load_conf(filename, user)
assert str(path) == filename


@pytest.mark.parametrize("conf_name", ("my_conf", "my_conf.json", "my_conf.toml"))
Expand All @@ -85,8 +85,8 @@ def test_load_conf_json_exists_ext_added_no_env(monkeypatch, tmp_path, conf_name
with open(filename, "w") as f:
json.dump(contents, f)

conf = load_conf(str(tmp_path / conf_name), user)
assert conf["conf_path"] == filename
path, _conf = load_conf(str(tmp_path / conf_name), user)
assert str(path) == filename


@pytest.mark.parametrize("conf_file", ("my_conf.toml",))
Expand All @@ -100,8 +100,8 @@ def test_load_conf_toml_exists_no_env(monkeypatch, tmp_path, conf_file, user):
with open(filename, "w") as f:
toml.dump(contents, f)

conf = load_conf(filename, user)
assert conf["conf_path"] == filename
path, _conf = load_conf(filename, user)
assert str(path) == filename


@pytest.mark.parametrize("conf_name", ("my_conf", "my_conf.json", "my_conf.toml"))
Expand All @@ -115,8 +115,8 @@ def test_load_conf_toml_exists_ext_added_no_env(monkeypatch, tmp_path, conf_name
with open(filename, "w") as f:
toml.dump(contents, f)

conf = load_conf(str(tmp_path / conf_name), user)
assert conf["conf_path"] == filename
path, _conf = load_conf(str(tmp_path / conf_name), user)
assert str(path) == filename


@pytest.mark.parametrize("conf_name", ("my_conf", "testing.txt", "what.yaml"))
Expand Down Expand Up @@ -147,13 +147,12 @@ def test_load_conf_keys_set_no_env(monkeypatch, tmp_path):
with open(filename, "w") as f:
json.dump(contents, f)

conf = load_conf(filename, "test")
_path, conf = load_conf(filename, "test")

for key, value in contents.items():
assert conf[key] == value

# Check for extra keys added by load_conf()
assert "conf_path" in conf
assert "derby_dir" in conf
assert "warehouse_dir" in conf
assert "spark_tmp_dir" in conf
Expand Down Expand Up @@ -202,8 +201,8 @@ def test_load_conf_json_exists_in_conf_dir_env(
with open(file, "w") as f:
json.dump(contents, f)

conf = load_conf(conf_file, user)
assert conf["conf_path"] == str(file)
path, _conf = load_conf(conf_file, user)
assert path == file


@pytest.mark.parametrize("conf_file", ("my_conf.toml",))
Expand All @@ -221,8 +220,8 @@ def test_load_conf_toml_exists_in_conf_dir_env(
with open(file, "w") as f:
toml.dump(contents, f)

conf = load_conf(conf_file, user)
assert conf["conf_path"] == str(file)
path, _conf = load_conf(conf_file, user)
assert path == file


@pytest.mark.parametrize("conf_name", ("my_conf", "test", "testingtesting123.txt"))
Expand All @@ -241,8 +240,8 @@ def test_load_conf_json_exists_in_conf_dir_ext_added_env(
with open(file, "w") as f:
json.dump(contents, f)

conf = load_conf(conf_name, user)
assert conf["conf_path"] == str(file)
path, _conf = load_conf(conf_name, user)
assert path == file


@pytest.mark.parametrize("conf_name", ("my_conf", "test", "testingtesting123.txt"))
Expand All @@ -261,8 +260,8 @@ def test_load_conf_toml_exists_in_conf_dir_ext_added_env(
with open(file, "w") as f:
toml.dump(contents, f)

conf = load_conf(conf_name, user)
assert conf["conf_path"] == str(file)
path, _conf = load_conf(conf_name, user)
assert path == file


@pytest.mark.parametrize("conf_name", ("my_conf", "testing.txt", "what.yaml"))
Expand Down Expand Up @@ -294,13 +293,12 @@ def test_load_conf_keys_set_env(
with open(file, "w") as f:
json.dump(contents, f)

conf = load_conf(filename, user)
_path, conf = load_conf(filename, user)

for key, value in contents.items():
assert conf[key] == value

# Check for extra keys added by load_conf()
assert "conf_path" in conf
assert "derby_dir" in conf
assert "warehouse_dir" in conf
assert "spark_tmp_dir" in conf
Expand Down

0 comments on commit 4c6e602

Please sign in to comment.