Skip to content

Commit 21e9007

Browse files
authored
Merge pull request #143 from ipums/multiple_exploded_columns
Support blocking sections with multiple exploded columns
2 parents 94f0e8b + 1ae9a69 commit 21e9007

File tree

2 files changed

+78
-37
lines changed

2 files changed

+78
-37
lines changed

hlink/linking/matching/link_step_explode.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -124,49 +124,37 @@ def _explode(
124124
if exploding_column.get("expand_length", False):
125125
expand_length = exploding_column["expand_length"]
126126
derived_from_column = exploding_column["derived_from"]
127-
explode_selects = [
128-
(
129-
explode(self._expand(derived_from_column, expand_length)).alias(
130-
exploding_column_name
131-
)
132-
if exploding_column_name == column
133-
else column
134-
)
135-
for column in all_column_names
136-
]
127+
128+
explode_col_expr = explode(
129+
self._expand(derived_from_column, expand_length)
130+
)
137131
else:
138-
explode_selects = [
139-
(
140-
explode(col(exploding_column_name)).alias(exploding_column_name)
141-
if exploding_column_name == c
142-
else c
143-
)
144-
for c in all_column_names
145-
]
132+
explode_col_expr = explode(col(exploding_column_name))
133+
146134
if "dataset" in exploding_column:
147135
derived_from_column = exploding_column["derived_from"]
148-
explode_selects_with_derived_column = [
149-
(
150-
col(derived_from_column).alias(exploding_column_name)
151-
if exploding_column_name == column
152-
else column
153-
)
154-
for column in all_column_names
155-
]
136+
no_explode_col_expr = col(derived_from_column)
137+
156138
if exploding_column["dataset"] == "a":
157-
exploded_df = (
158-
exploded_df.select(explode_selects)
159-
if is_a
160-
else exploded_df.select(explode_selects_with_derived_column)
161-
)
139+
expr = explode_col_expr if is_a else no_explode_col_expr
140+
exploded_df = exploded_df.withColumn(exploding_column_name, expr)
162141
elif exploding_column["dataset"] == "b":
163-
exploded_df = (
164-
exploded_df.select(explode_selects)
165-
if not (is_a)
166-
else exploded_df.select(explode_selects_with_derived_column)
167-
)
142+
expr = explode_col_expr if not is_a else no_explode_col_expr
143+
exploded_df = exploded_df.withColumn(exploding_column_name, expr)
168144
else:
169-
exploded_df = exploded_df.select(explode_selects)
145+
exploded_df = exploded_df.withColumn(
146+
exploding_column_name, explode_col_expr
147+
)
148+
149+
# If there are exploding columns, then select out "all_column_names".
150+
# Otherwise, just let all of the columns through without selecting
151+
# specific ones. I believe this is an artifact of a previous
152+
# implementation, but the tests currently enforce it. It may or may not
153+
# be a breaking change to remove this. We'd have to look into the
154+
# ramifications.
155+
if len(all_exploding_columns) > 0:
156+
exploded_df = exploded_df.select(sorted(all_column_names))
157+
170158
return exploded_df
171159

172160
def _expand(self, column_name: str, expand_length: int) -> Column:

hlink/tests/matching_blocking_explode_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,59 @@ def test_blocking_multi_layer_comparison(
124124
) or (row["namelast_jw_x"] < 0.7)
125125

126126

127+
def test_blocking_multiple_exploded_columns(
128+
spark, blocking_explode_conf, matching_test_input, matching
129+
):
130+
"""
131+
Matching supports multiple exploded blocking columns. Each column is
132+
exploded independently. See GitHub issue #142.
133+
"""
134+
table_a, table_b = matching_test_input
135+
table_a.createOrReplaceTempView("prepped_df_a")
136+
table_b.createOrReplaceTempView("prepped_df_b")
137+
138+
blocking_explode_conf["blocking"] = [
139+
{
140+
"column_name": "birthyr_3",
141+
"dataset": "a",
142+
"derived_from": "birthyr",
143+
"expand_length": 3,
144+
"explode": True,
145+
},
146+
{
147+
"column_name": "birthyr_4",
148+
"dataset": "a",
149+
"derived_from": "birthyr",
150+
"expand_length": 4,
151+
"explode": True,
152+
},
153+
{"column_name": "sex"},
154+
]
155+
156+
matching.run_step(0)
157+
158+
exploded_a = spark.table("exploded_df_a").toPandas()
159+
exploded_b = spark.table("exploded_df_b").toPandas()
160+
161+
input_size_a = spark.table("prepped_df_a").count()
162+
input_size_b = spark.table("prepped_df_b").count()
163+
output_size_a = len(exploded_a)
164+
output_size_b = len(exploded_b)
165+
166+
assert "sex" in exploded_a.columns
167+
assert "birthyr_3" in exploded_a.columns
168+
assert "birthyr_4" in exploded_a.columns
169+
assert "sex" in exploded_b.columns
170+
assert "birthyr_3" in exploded_b.columns
171+
assert "birthyr_4" in exploded_b.columns
172+
173+
# birthyr_3 multiplies the number of columns by 2 * 3 + 1 = 7
174+
# birthyr_4 multiplies the number of columns by 2 * 4 + 1 = 9
175+
assert input_size_a * 63 == output_size_a
176+
# Both columns are only exploded in dataset A
177+
assert input_size_b == output_size_b
178+
179+
127180
def test_blocking_or_groups(
128181
spark, blocking_or_groups_conf, matching_or_groups_test_input, matching
129182
):

0 commit comments

Comments
 (0)