Skip to content

Commit

Permalink
[#181] Remove the scripts.main.load_conf() function
Browse files Browse the repository at this point in the history
Instead of using this function to get the config and add attributes to it, we
now separately get the config with load_conf_file() and pass attributes to
Spark. I've translated some of the tests for load_conf() to tests for
load_conf_file().
  • Loading branch information
riley-harper committed Dec 13, 2024
1 parent 46f79e3 commit 1f99c93
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 362 deletions.
46 changes: 0 additions & 46 deletions hlink/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,52 +32,6 @@
logger = logging.getLogger(__name__)


def load_conf(conf_name: str, user: str) -> tuple[Path, dict[str, Any]]:
"""Load and return the hlink config dictionary.
Add the following attributes to the config dictionary:
"derby_dir", "warehouse_dir", "spark_tmp_dir", "log_dir", "python",
"conf_path", "run_name"
"""
if "HLINK_CONF" not in os.environ:
global_conf = None
else:
global_conf_file = os.environ["HLINK_CONF"]
with open(global_conf_file) as f:
global_conf = json.load(f)

run_name = Path(conf_name).stem

if global_conf is None:
current_dir = Path.cwd()
hlink_dir = current_dir / "hlink_config"
base_derby_dir = hlink_dir / "derby"
base_warehouse_dir = hlink_dir / "warehouse"
base_spark_tmp_dir = hlink_dir / "spark_tmp_dir"
path, conf = load_conf_file(conf_name)

conf["derby_dir"] = base_derby_dir / run_name
conf["warehouse_dir"] = base_warehouse_dir / run_name
conf["spark_tmp_dir"] = base_spark_tmp_dir / run_name
conf["log_dir"] = hlink_dir / "logs"
conf["python"] = sys.executable
else:
user_dir = Path(global_conf["users_dir"]) / user
user_dir_fast = Path(global_conf["users_dir_fast"]) / user
conf_dir = user_dir / "confs"
conf_path = conf_dir / conf_name
path, conf = load_conf_file(str(conf_path))

conf["derby_dir"] = user_dir / "derby" / run_name
conf["warehouse_dir"] = user_dir_fast / "warehouse" / run_name
conf["spark_tmp_dir"] = user_dir_fast / "tmp" / run_name
conf["log_dir"] = user_dir / "logs"
conf["python"] = global_conf["python"]

conf["run_name"] = run_name
return path, conf


def cli():
"""Called by the hlink script."""
if "--version" in sys.argv:
Expand Down
47 changes: 37 additions & 10 deletions hlink/tests/config_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,50 @@
# in this project's top-level directory, and also on-line at:
# https://github.com/ipums/hlink

from pathlib import Path

import pytest

from hlink.configs.load_config import load_conf_file
import os.path
from hlink.errors import UsageError


def test_load_conf_file_json(conf_dir_path):
conf_file = os.path.join(conf_dir_path, "test")
_path, conf = load_conf_file(conf_file)
@pytest.mark.parametrize("file_name", ["test", "test.json"])
def test_load_conf_file_json(conf_dir_path: str, file_name: str) -> None:
conf_file = Path(conf_dir_path) / file_name
path, conf = load_conf_file(str(conf_file))
assert conf["id_column"] == "id"
assert path == conf_file.with_suffix(".json")


def test_load_conf_file_toml(conf_dir_path):
conf_file = os.path.join(conf_dir_path, "test1")
_path, conf = load_conf_file(conf_file)
@pytest.mark.parametrize("file_name", ["test1", "test1.toml"])
def test_load_conf_file_toml(conf_dir_path: str, file_name: str) -> None:
conf_file = Path(conf_dir_path) / file_name
path, conf = load_conf_file(str(conf_file))
assert conf["id_column"] == "id-toml"
assert path == conf_file.with_suffix(".toml")


def test_load_conf_file_json2(conf_dir_path):
conf_file = os.path.join(conf_dir_path, "test_conf_flag_run")
_path, conf = load_conf_file(conf_file)
def test_load_conf_file_json2(conf_dir_path: str) -> None:
conf_file = Path(conf_dir_path) / "test_conf_flag_run"
path, conf = load_conf_file(str(conf_file))
assert conf["id_column"] == "id_conf_flag"
assert path == conf_file.with_suffix(".json")


def test_load_conf_file_does_not_exist(tmp_path: Path) -> None:
conf_file = tmp_path / "notthere"
with pytest.raises(
FileNotFoundError, match="Couldn't find any of these three files:"
):
load_conf_file(str(conf_file))


def test_load_conf_file_unrecognized_extension(tmp_path: Path) -> None:
conf_file = tmp_path / "test.yaml"
conf_file.touch()
with pytest.raises(
UsageError,
match="The file .+ exists, but it doesn't have a '.toml' or '.json' extension",
):
load_conf_file(str(conf_file))
Loading

0 comments on commit 1f99c93

Please sign in to comment.