Skip to content

Commit 12ff643

Browse files
authored
Merge pull request #131 from ipums/bug_fix_column_mappings_validation
Fix a bug with the override_column_X attributes in conf_validations.py
2 parents e8db991 + 96d8c0a commit 12ff643

File tree

3 files changed

+263
-23
lines changed

3 files changed

+263
-23
lines changed

hlink/linking/matching/link_step_explode.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,26 +115,32 @@ def _explode(self, df, comparisons, comparison_features, blocking, id_column, is
115115
expand_length = exploding_column["expand_length"]
116116
derived_from_column = exploding_column["derived_from"]
117117
explode_selects = [
118-
explode(self._expand(derived_from_column, expand_length)).alias(
119-
exploding_column_name
118+
(
119+
explode(self._expand(derived_from_column, expand_length)).alias(
120+
exploding_column_name
121+
)
122+
if exploding_column_name == column
123+
else column
120124
)
121-
if exploding_column_name == column
122-
else column
123125
for column in all_column_names
124126
]
125127
else:
126128
explode_selects = [
127-
explode(col(exploding_column_name)).alias(exploding_column_name)
128-
if exploding_column_name == c
129-
else c
129+
(
130+
explode(col(exploding_column_name)).alias(exploding_column_name)
131+
if exploding_column_name == c
132+
else c
133+
)
130134
for c in all_column_names
131135
]
132136
if "dataset" in exploding_column:
133137
derived_from_column = exploding_column["derived_from"]
134138
explode_selects_with_derived_column = [
135-
col(derived_from_column).alias(exploding_column_name)
136-
if exploding_column_name == column
137-
else column
139+
(
140+
col(derived_from_column).alias(exploding_column_name)
141+
if exploding_column_name == column
142+
else column
143+
)
138144
for column in all_column_names
139145
]
140146
if exploding_column["dataset"] == "a":

hlink/scripts/lib/conf_validations.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
# in this project's top-level directory, and also on-line at:
44
# https://github.com/ipums/hlink
55

6-
from pyspark.sql.utils import AnalysisException
76
from os import path
7+
from typing import Any, Literal
8+
89
import colorama
10+
from pyspark.sql.utils import AnalysisException
11+
from pyspark.sql import DataFrame
912

1013

1114
def print_checking(section: str):
@@ -265,7 +268,47 @@ def check_substitution_columns(config, columns_available):
265268
)
266269

267270

268-
def check_column_mappings(config, df_a, df_b):
271+
def check_column_mappings_column_available(
272+
column_mapping: dict[str, Any],
273+
df: DataFrame,
274+
previous_mappings: list[str],
275+
a_or_b: Literal["a", "b"],
276+
) -> None:
277+
"""
278+
Check whether a column in a column mapping is available or not. Raise a
279+
ValueError if it is not available.
280+
281+
previous_mappings is a list of columns mapped by previous column mappings.
282+
"""
283+
column_name = column_mapping["column_name"]
284+
override_column = column_mapping.get(f"override_column_{a_or_b}")
285+
df_columns_lower = [column.lower() for column in df.columns]
286+
287+
if override_column is not None:
288+
if override_column.lower() not in df_columns_lower:
289+
raise ValueError(
290+
f"Within a [[column_mappings]] the override_column_{a_or_b} column "
291+
f"'{override_column}' does not exist in datasource_{a_or_b}.\n"
292+
f"Column mapping: {column_mapping}\n"
293+
f"Available columns: {df.columns}"
294+
)
295+
else:
296+
if (
297+
column_name.lower() not in df_columns_lower
298+
and column_name not in previous_mappings
299+
):
300+
raise ValueError(
301+
f"Within a [[column_mappings]] the column_name '{column_name}' "
302+
f"does not exist in datasource_{a_or_b} and no previous "
303+
"[[column_mapping]] alias exists for it.\n"
304+
f"Column mapping: {column_mapping}.\n"
305+
f"Available columns:\n {df.columns}"
306+
)
307+
308+
309+
def check_column_mappings(
310+
config: dict[str, Any], df_a: DataFrame, df_b: DataFrame
311+
) -> list[str]:
269312
column_mappings = config.get("column_mappings")
270313
if not column_mappings:
271314
raise ValueError("No [[column_mappings]] exist in the conf file.")
@@ -276,22 +319,15 @@ def check_column_mappings(config, df_a, df_b):
276319
column_name = c.get("column_name")
277320
set_value_column_a = c.get("set_value_column_a")
278321
set_value_column_b = c.get("set_value_column_b")
322+
279323
if not column_name:
280324
raise ValueError(
281325
f"The following [[column_mappings]] has no 'column_name' attribute: {c}"
282326
)
283327
if set_value_column_a is None:
284-
if column_name.lower() not in [c.lower() for c in df_a.columns]:
285-
if column_name not in columns_available:
286-
raise ValueError(
287-
f"Within a [[column_mappings]] the column_name: '{column_name}' does not exist in datasource_a and no previous [[column_mapping]] alias exists for it. \nColumn mapping: {c}. \nAvailable columns: \n {df_a.columns}"
288-
)
328+
check_column_mappings_column_available(c, df_a, columns_available, "a")
289329
if set_value_column_b is None:
290-
if column_name.lower() not in [c.lower() for c in df_b.columns]:
291-
if column_name not in columns_available:
292-
raise ValueError(
293-
f"Within a [[column_mappings]] the column_name: '{column_name}' does not exist in datasource_b and no previous [[column_mapping]] alias exists for it. Column mapping: {c}. Available columns: \n {df_b.columns}"
294-
)
330+
check_column_mappings_column_available(c, df_b, columns_available, "b")
295331
if alias in columns_available:
296332
duplicates.append(alias)
297333
elif not alias and column_name in columns_available:

hlink/tests/conf_validations_test.py

Lines changed: 199 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
import pytest
33

4+
from pyspark.sql import SparkSession
5+
46
from hlink.configs.load_config import load_conf_file
5-
from hlink.scripts.lib.conf_validations import analyze_conf
7+
from hlink.scripts.lib.conf_validations import analyze_conf, check_column_mappings
68
from hlink.linking.link_run import LinkRun
79

810

@@ -25,3 +27,199 @@ def test_invalid_conf(conf_dir_path, spark, conf_name, error_msg):
2527

2628
with pytest.raises(ValueError, match=error_msg):
2729
analyze_conf(link_run)
30+
31+
32+
def test_check_column_mappings_mappings_missing(spark: SparkSession) -> None:
33+
"""
34+
The config must have a column_mappings section.
35+
"""
36+
config = {}
37+
df_a = spark.createDataFrame([[1], [2], [3]], ["a"])
38+
df_b = spark.createDataFrame([[4], [5], [6]], ["b"])
39+
40+
with pytest.raises(
41+
ValueError, match=r"No \[\[column_mappings\]\] exist in the conf file"
42+
):
43+
check_column_mappings(config, df_a, df_b)
44+
45+
46+
def test_check_column_mappings_no_column_name(spark: SparkSession) -> None:
47+
"""
48+
Each column mapping in the config must have a column_name attribute.
49+
"""
50+
config = {
51+
"column_mappings": [{"column_name": "AGE", "alias": "age"}, {"alias": "height"}]
52+
}
53+
df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
54+
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])
55+
56+
expected_err = (
57+
r"The following \[\[column_mappings\]\] has no 'column_name' attribute:"
58+
)
59+
with pytest.raises(ValueError, match=expected_err):
60+
check_column_mappings(config, df_a, df_b)
61+
62+
63+
def test_check_column_mappings_column_name_not_available_datasource_a(
64+
spark: SparkSession,
65+
) -> None:
66+
"""
67+
Column mappings may only use column_names that appear in datasource A or a
68+
previous column mapping.
69+
"""
70+
config = {"column_mappings": [{"column_name": "HEIGHT"}]}
71+
72+
df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
73+
df_b = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])
74+
75+
expected_err = (
76+
r"Within a \[\[column_mappings\]\] the column_name 'HEIGHT' "
77+
r"does not exist in datasource_a and no previous \[\[column_mapping\]\] "
78+
"alias exists for it"
79+
)
80+
81+
with pytest.raises(ValueError, match=expected_err):
82+
check_column_mappings(config, df_a, df_b)
83+
84+
85+
def test_check_column_mappings_set_value_column_a_does_not_need_column(
86+
spark: SparkSession,
87+
) -> None:
88+
"""
89+
When set_value_column_a is present for a column mapping, that column does not
90+
need to be present in datasource A.
91+
"""
92+
config = {"column_mappings": [{"column_name": "HEIGHT", "set_value_column_a": 125}]}
93+
94+
df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
95+
df_b = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])
96+
97+
check_column_mappings(config, df_a, df_b)
98+
99+
100+
def test_check_column_mappings_column_name_not_available_datasource_b(
101+
spark: SparkSession,
102+
) -> None:
103+
"""
104+
Column mappings may only use column_names that appear in datasource B or a
105+
previous column mapping.
106+
"""
107+
config = {"column_mappings": [{"column_name": "HEIGHT"}]}
108+
109+
df_a = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])
110+
df_b = spark.createDataFrame([[20], [40], [60]], ["AGE"])
111+
112+
expected_err = (
113+
r"Within a \[\[column_mappings\]\] the column_name 'HEIGHT' "
114+
r"does not exist in datasource_b and no previous \[\[column_mapping\]\] "
115+
"alias exists for it"
116+
)
117+
118+
with pytest.raises(ValueError, match=expected_err):
119+
check_column_mappings(config, df_a, df_b)
120+
121+
122+
def test_check_column_mappings_set_value_column_b_does_not_need_column(
123+
spark: SparkSession,
124+
) -> None:
125+
"""
126+
When set_value_column_b is present for a column mapping, that column does not
127+
need to be present in datasource B.
128+
"""
129+
config = {"column_mappings": [{"column_name": "HEIGHT", "set_value_column_b": 125}]}
130+
131+
df_a = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])
132+
df_b = spark.createDataFrame([[20], [40], [60]], ["AGE"])
133+
134+
check_column_mappings(config, df_a, df_b)
135+
136+
137+
def test_check_column_mappings_previous_mappings_are_available(
138+
spark: SparkSession,
139+
) -> None:
140+
"""
141+
Columns created in a previous column mapping can be used in other column
142+
mappings.
143+
"""
144+
config = {
145+
"column_mappings": [
146+
{"column_name": "AGE", "alias": "AGE_HLINK"},
147+
{"column_name": "AGE_HLINK", "alias": "AGE_HLINK2"},
148+
]
149+
}
150+
df_a = spark.createDataFrame([[70], [50], [30]], ["AGE"])
151+
df_b = spark.createDataFrame([[20], [40], [60]], ["AGE"])
152+
153+
check_column_mappings(config, df_a, df_b)
154+
155+
156+
def test_check_column_mappings_override_column_a(spark: SparkSession) -> None:
157+
"""
158+
The override_column_a attribute lets you control which column you read from
159+
in datasource A.
160+
"""
161+
config = {
162+
"column_mappings": [{"column_name": "AGE", "override_column_a": "ageColumn"}]
163+
}
164+
df_a = spark.createDataFrame([[20], [40], [60]], ["ageColumn"])
165+
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])
166+
167+
check_column_mappings(config, df_a, df_b)
168+
169+
170+
def test_check_column_mappings_override_column_b(spark: SparkSession) -> None:
171+
"""
172+
The override_column_b attribute lets you control which column you read from
173+
in datasource B.
174+
"""
175+
config = {
176+
"column_mappings": [{"column_name": "ageColumn", "override_column_b": "AGE"}]
177+
}
178+
df_a = spark.createDataFrame([[20], [40], [60]], ["ageColumn"])
179+
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])
180+
181+
check_column_mappings(config, df_a, df_b)
182+
183+
184+
def test_check_column_mappings_override_column_a_not_present(
185+
spark: SparkSession,
186+
) -> None:
187+
"""
188+
The override_column_a column must be present in datasource A.
189+
"""
190+
config = {
191+
"column_mappings": [
192+
{"column_name": "AGE", "override_column_a": "oops_not_there"}
193+
]
194+
}
195+
df_a = spark.createDataFrame([[20], [40], [60]], ["ageColumn"])
196+
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])
197+
198+
expected_err = (
199+
r"Within a \[\[column_mappings\]\] the override_column_a column "
200+
"'oops_not_there' does not exist in datasource_a"
201+
)
202+
with pytest.raises(ValueError, match=expected_err):
203+
check_column_mappings(config, df_a, df_b)
204+
205+
206+
def test_check_column_mappings_override_column_b_not_present(
207+
spark: SparkSession,
208+
) -> None:
209+
"""
210+
The override_column_b column must be present in datasource B.
211+
"""
212+
config = {
213+
"column_mappings": [
214+
{"column_name": "AGE", "override_column_b": "oops_not_there"}
215+
]
216+
}
217+
df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
218+
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])
219+
220+
expected_err = (
221+
r"Within a \[\[column_mappings\]\] the override_column_b column "
222+
"'oops_not_there' does not exist in datasource_b"
223+
)
224+
with pytest.raises(ValueError, match=expected_err):
225+
check_column_mappings(config, df_a, df_b)

0 commit comments

Comments
 (0)