Skip to content

Commit 1f99c93

Browse files
committed
[#181] Remove the scripts.main.load_conf() function
Instead of using this function to get the config and add attributes to it, we now separately get the config with load_conf_file() and pass attributes to Spark. I've translated some of the tests for load_conf() to tests for load_conf_file().
1 parent 46f79e3 commit 1f99c93

File tree

3 files changed

+37
-362
lines changed

3 files changed

+37
-362
lines changed

hlink/scripts/main.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -32,52 +32,6 @@
3232
logger = logging.getLogger(__name__)
3333

3434

35-
def load_conf(conf_name: str, user: str) -> tuple[Path, dict[str, Any]]:
36-
"""Load and return the hlink config dictionary.
37-
38-
Add the following attributes to the config dictionary:
39-
"derby_dir", "warehouse_dir", "spark_tmp_dir", "log_dir", "python",
40-
"conf_path", "run_name"
41-
"""
42-
if "HLINK_CONF" not in os.environ:
43-
global_conf = None
44-
else:
45-
global_conf_file = os.environ["HLINK_CONF"]
46-
with open(global_conf_file) as f:
47-
global_conf = json.load(f)
48-
49-
run_name = Path(conf_name).stem
50-
51-
if global_conf is None:
52-
current_dir = Path.cwd()
53-
hlink_dir = current_dir / "hlink_config"
54-
base_derby_dir = hlink_dir / "derby"
55-
base_warehouse_dir = hlink_dir / "warehouse"
56-
base_spark_tmp_dir = hlink_dir / "spark_tmp_dir"
57-
path, conf = load_conf_file(conf_name)
58-
59-
conf["derby_dir"] = base_derby_dir / run_name
60-
conf["warehouse_dir"] = base_warehouse_dir / run_name
61-
conf["spark_tmp_dir"] = base_spark_tmp_dir / run_name
62-
conf["log_dir"] = hlink_dir / "logs"
63-
conf["python"] = sys.executable
64-
else:
65-
user_dir = Path(global_conf["users_dir"]) / user
66-
user_dir_fast = Path(global_conf["users_dir_fast"]) / user
67-
conf_dir = user_dir / "confs"
68-
conf_path = conf_dir / conf_name
69-
path, conf = load_conf_file(str(conf_path))
70-
71-
conf["derby_dir"] = user_dir / "derby" / run_name
72-
conf["warehouse_dir"] = user_dir_fast / "warehouse" / run_name
73-
conf["spark_tmp_dir"] = user_dir_fast / "tmp" / run_name
74-
conf["log_dir"] = user_dir / "logs"
75-
conf["python"] = global_conf["python"]
76-
77-
conf["run_name"] = run_name
78-
return path, conf
79-
80-
8135
def cli():
8236
"""Called by the hlink script."""
8337
if "--version" in sys.argv:

hlink/tests/config_loader_test.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,50 @@
33
# in this project's top-level directory, and also on-line at:
44
# https://github.com/ipums/hlink
55

6+
from pathlib import Path
7+
8+
import pytest
9+
610
from hlink.configs.load_config import load_conf_file
7-
import os.path
11+
from hlink.errors import UsageError
812

913

10-
def test_load_conf_file_json(conf_dir_path):
11-
conf_file = os.path.join(conf_dir_path, "test")
12-
_path, conf = load_conf_file(conf_file)
14+
@pytest.mark.parametrize("file_name", ["test", "test.json"])
15+
def test_load_conf_file_json(conf_dir_path: str, file_name: str) -> None:
16+
conf_file = Path(conf_dir_path) / file_name
17+
path, conf = load_conf_file(str(conf_file))
1318
assert conf["id_column"] == "id"
19+
assert path == conf_file.with_suffix(".json")
1420

1521

16-
def test_load_conf_file_toml(conf_dir_path):
17-
conf_file = os.path.join(conf_dir_path, "test1")
18-
_path, conf = load_conf_file(conf_file)
22+
@pytest.mark.parametrize("file_name", ["test1", "test1.toml"])
23+
def test_load_conf_file_toml(conf_dir_path: str, file_name: str) -> None:
24+
conf_file = Path(conf_dir_path) / file_name
25+
path, conf = load_conf_file(str(conf_file))
1926
assert conf["id_column"] == "id-toml"
27+
assert path == conf_file.with_suffix(".toml")
2028

2129

22-
def test_load_conf_file_json2(conf_dir_path):
23-
conf_file = os.path.join(conf_dir_path, "test_conf_flag_run")
24-
_path, conf = load_conf_file(conf_file)
30+
def test_load_conf_file_json2(conf_dir_path: str) -> None:
31+
conf_file = Path(conf_dir_path) / "test_conf_flag_run"
32+
path, conf = load_conf_file(str(conf_file))
2533
assert conf["id_column"] == "id_conf_flag"
34+
assert path == conf_file.with_suffix(".json")
35+
36+
37+
def test_load_conf_file_does_not_exist(tmp_path: Path) -> None:
38+
conf_file = tmp_path / "notthere"
39+
with pytest.raises(
40+
FileNotFoundError, match="Couldn't find any of these three files:"
41+
):
42+
load_conf_file(str(conf_file))
43+
44+
45+
def test_load_conf_file_unrecognized_extension(tmp_path: Path) -> None:
46+
conf_file = tmp_path / "test.yaml"
47+
conf_file.touch()
48+
with pytest.raises(
49+
UsageError,
50+
match="The file .+ exists, but it doesn't have a '.toml' or '.json' extension",
51+
):
52+
load_conf_file(str(conf_file))

0 commit comments

Comments
 (0)