Skip to content

Commit 7b15b7b

Browse files
authored
Merge pull request #213 from ipums/custom-col-mappings
Support custom column mapping transforms
2 parents ffa4d0d + 02fb8c2 commit 7b15b7b

File tree

6 files changed

+275
-38
lines changed

6 files changed

+275
-38
lines changed

hlink/linking/core/column_mapping.py

Lines changed: 117 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,66 @@
22
# For copyright and licensing information, see the NOTICE and LICENSE files
33
# in this project's top-level directory, and also on-line at:
44
# https://github.com/ipums/hlink
5-
from typing import Any
5+
"""
6+
Column mappings for cleaning and preprocessing input data.
7+
8+
This module provides functions for cleaning and preprocessing columns of Spark
9+
data frames. It depends on the idea of a "column mapping", which is a
10+
dictionary which specifies an input column, an optional output column alias,
11+
and a list of zero or more transforms to apply to the input column.
12+
13+
```python
14+
# An example column mapping. The "column_name" attribute gives the name of the
15+
# input column, and "alias" gives the name of the output column. The alias is
16+
# optional and defaults to the input column name.
17+
{
18+
"column_name": "namefrst",
19+
"alias": "namefrst_std",
20+
"transforms": [
21+
{"type": "lowercase_strip"},
22+
{"type": "rationalize_name_words"},
23+
{"type": "remove_qmark_hyphen"},
24+
{"type": "condense_strip_whitespace"},
25+
{"type": "split"},
26+
{"type": "array_index", "value": 0},
27+
]
28+
}
29+
```
30+
31+
Hlink has many built-in column mapping transforms, computed by the
32+
`transform_*` functions in this module. Hlink also has support for custom
33+
column mapping transforms via the `custom_transforms` argument to
34+
`select_column_mapping`. This argument must be a mapping from strings to
35+
functions which compute the column mapping transforms. For example, say that
36+
you wanted to implement a custom column mapping transform named "reverse" which
37+
reverses a string. The first thing to do is to write a function which computes
38+
the transform and satisfies the column mapping transform interface (see the
39+
ColumnMappingTransform type alias below).
40+
41+
```python
42+
from pyspark.sql import Column
43+
from pyspark.sql.functions import reverse
44+
45+
# input_col is the input Column expression.
46+
# transform is the column mapping transform dictionary, like
47+
# {"type": "reverse"}. This lets the transform accept arbitrary arguments from
48+
# the configuration as needed.
49+
# context is a dictionary with additional context which may be helpful for some
50+
# transforms. In particular, it always contains at least the key "dataset",
51+
# which indicates whether the current dataset is dataset "a" or "b".
52+
def transform_reverse(input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]) -> Column:
53+
return reverse(input_col)
54+
```
55+
56+
Then, when you call `select_column_mapping`, you can pass
57+
`custom_transforms={"reverse": transform_reverse}`, and hlink will
58+
automatically use your custom transform when appropriate. Note that custom
59+
transforms which have the same name as a built-in transform override the
60+
built-in transform.
61+
"""
62+
63+
from collections.abc import Mapping
64+
from typing import Any, Callable, TypeAlias
665

766
from pyspark.sql import Column, DataFrame
867
from pyspark.sql.functions import (
@@ -21,25 +80,40 @@
2180
from pyspark.sql.types import LongType
2281

2382

83+
ColumnMappingTransform: TypeAlias = Callable[
84+
[Column, Mapping[str, Any], Mapping[str, Any]], Column
85+
]
86+
"""
87+
The form of column mapping transform functions. These take an input Column,
88+
the transform mapping from the configuration, and a mapping providing some
89+
additional context. They return a new output Column.
90+
"""
91+
92+
2493
def select_column_mapping(
25-
column_mapping: dict[str, Any],
94+
column_mapping: Mapping[str, Any],
2695
df_selected: DataFrame,
2796
is_a: bool,
2897
column_selects: list[str],
98+
custom_transforms: Mapping[str, ColumnMappingTransform] | None = None,
2999
) -> tuple[DataFrame, list[str]]:
30100
name = column_mapping["column_name"]
31101
if "override_column_a" in column_mapping and is_a:
32102
override_name = column_mapping["override_column_a"]
33103
column_select = col(override_name)
34104
if "override_transforms" in column_mapping:
35105
for transform in column_mapping["override_transforms"]:
36-
column_select = apply_transform(column_select, transform, is_a)
106+
column_select = apply_transform(
107+
column_select, transform, is_a, custom_transforms
108+
)
37109
elif "override_column_b" in column_mapping and not is_a:
38110
override_name = column_mapping["override_column_b"]
39111
column_select = col(override_name)
40112
if "override_transforms" in column_mapping:
41113
for transform in column_mapping["override_transforms"]:
42-
column_select = apply_transform(column_select, transform, is_a)
114+
column_select = apply_transform(
115+
column_select, transform, is_a, custom_transforms
116+
)
43117
elif "set_value_column_a" in column_mapping and is_a:
44118
value_to_set = column_mapping["set_value_column_a"]
45119
column_select = lit(value_to_set)
@@ -49,7 +123,9 @@ def select_column_mapping(
49123
elif "transforms" in column_mapping:
50124
column_select = col(name)
51125
for transform in column_mapping["transforms"]:
52-
column_select = apply_transform(column_select, transform, is_a)
126+
column_select = apply_transform(
127+
column_select, transform, is_a, custom_transforms
128+
)
53129
else:
54130
column_select = col(name)
55131

@@ -59,7 +135,7 @@ def select_column_mapping(
59135
return df_selected.withColumn(alias, column_select), column_selects
60136

61137

62-
def _require_key(transform: dict[str, Any], key: str) -> Any:
138+
def _require_key(transform: Mapping[str, Any], key: str) -> Any:
63139
"""
64140
Extract a key from a transform, or raise a helpful context-aware error if the
65141
key is not present.
@@ -78,7 +154,10 @@ def _require_key(transform: dict[str, Any], key: str) -> Any:
78154

79155
# These apply to the column mappings in the current config
80156
def apply_transform(
81-
column_select: Column, transform: dict[str, Any], is_a: bool
157+
column_select: Column,
158+
transform: Mapping[str, Any],
159+
is_a: bool,
160+
custom_transforms: Mapping[str, ColumnMappingTransform] | None = None,
82161
) -> Column:
83162
"""Return a new column that is the result of applying the given transform
84163
to the given input column (column_select). The is_a parameter controls the
@@ -93,7 +172,7 @@ def apply_transform(
93172
dataset = "a" if is_a else "b"
94173
context = {"dataset": dataset}
95174
transform_type = transform["type"]
96-
transforms = {
175+
builtin_transforms = {
97176
"add_to_a": transform_add_to_a,
98177
"concat_to_a": transform_concat_to_a,
99178
"concat_to_b": transform_concat_to_b,
@@ -123,7 +202,9 @@ def apply_transform(
123202
"get_floor": transform_get_floor,
124203
}
125204

126-
transform_func = transforms.get(transform_type)
205+
custom_transforms = custom_transforms or {}
206+
builtin_func = builtin_transforms.get(transform_type)
207+
transform_func = custom_transforms.get(transform_type, builtin_func)
127208

128209
if transform_func is None:
129210
raise ValueError(f"Invalid transform type for {transform}")
@@ -132,7 +213,7 @@ def apply_transform(
132213

133214

134215
def transform_add_to_a(
135-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
216+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
136217
) -> Column:
137218
is_a = context["dataset"] == "a"
138219
if is_a:
@@ -142,7 +223,7 @@ def transform_add_to_a(
142223

143224

144225
def transform_concat_to_a(
145-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
226+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
146227
) -> Column:
147228
is_a = context["dataset"] == "a"
148229
if is_a:
@@ -153,7 +234,7 @@ def transform_concat_to_a(
153234

154235

155236
def transform_concat_to_b(
156-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
237+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
157238
) -> Column:
158239
is_a = context["dataset"] == "a"
159240
if is_a:
@@ -164,50 +245,50 @@ def transform_concat_to_b(
164245

165246

166247
def transform_concat_two_cols(
167-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
248+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
168249
) -> Column:
169250
column_to_append = _require_key(transform, "column_to_append")
170251
return concat(input_col, column_to_append)
171252

172253

173254
def transform_lowercase_strip(
174-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
255+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
175256
) -> Column:
176257
return lower(trim(input_col))
177258

178259

179260
def transform_rationalize_name_words(
180-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
261+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
181262
) -> Column:
182263
return regexp_replace(input_col, r"[^a-z?'\*\-]+", " ")
183264

184265

185266
def transform_remove_qmark_hyphen(
186-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
267+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
187268
) -> Column:
188269
return regexp_replace(input_col, r"[?\*\-]+", "")
189270

190271

191272
def transform_remove_punctuation(
192-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
273+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
193274
) -> Column:
194275
return regexp_replace(input_col, r"[?\-\\\/\"\':,.\[\]\{\}]+", "")
195276

196277

197278
def transform_replace_apostrophe(
198-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
279+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
199280
) -> Column:
200281
return regexp_replace(input_col, r"'+", " ")
201282

202283

203284
def transform_remove_alternate_names(
204-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
285+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
205286
) -> Column:
206287
return regexp_replace(input_col, r"(\w+)( or \w+)+", "$1")
207288

208289

209290
def transform_remove_suffixes(
210-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
291+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
211292
) -> Column:
212293
values = _require_key(transform, "values")
213294
suffixes = "|".join(values)
@@ -216,7 +297,7 @@ def transform_remove_suffixes(
216297

217298

218299
def transform_remove_stop_words(
219-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
300+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
220301
) -> Column:
221302
values = _require_key(transform, "values")
222303
words = "|".join(values)
@@ -225,7 +306,7 @@ def transform_remove_stop_words(
225306

226307

227308
def transform_remove_prefixes(
228-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
309+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
229310
) -> Column:
230311
values = _require_key(transform, "values")
231312
prefixes = "|".join(values)
@@ -234,7 +315,7 @@ def transform_remove_prefixes(
234315

235316

236317
def transform_condense_prefixes(
237-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
318+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
238319
) -> Column:
239320
values = _require_key(transform, "values")
240321
prefixes = "|".join(values)
@@ -243,38 +324,38 @@ def transform_condense_prefixes(
243324

244325

245326
def transform_condense_strip_whitespace(
246-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
327+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
247328
) -> Column:
248329
return regexp_replace(trim(input_col), r"\s\s+", " ")
249330

250331

251332
def transform_remove_one_letter_names(
252-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
333+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
253334
) -> Column:
254335
return regexp_replace(input_col, r"^((?:\w )+)(\w+)", r"$2")
255336

256337

257338
def transform_split(
258-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
339+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
259340
) -> Column:
260341
return split(input_col, " ")
261342

262343

263344
def transform_length(
264-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
345+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
265346
) -> Column:
266347
return length(input_col)
267348

268349

269350
def transform_array_index(
270-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
351+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
271352
) -> Column:
272353
value = _require_key(transform, "value")
273354
return input_col[value]
274355

275356

276357
def transform_mapping(
277-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
358+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
278359
) -> Column:
279360
mapped_column = input_col
280361
mappings = _require_key(transform, "mappings")
@@ -290,7 +371,7 @@ def transform_mapping(
290371

291372

292373
def transform_swap_words(
293-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
374+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
294375
) -> Column:
295376
mapped_column = input_col
296377
values = _require_key(transform, "values")
@@ -304,7 +385,7 @@ def transform_swap_words(
304385

305386

306387
def transform_substring(
307-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
388+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
308389
) -> Column:
309390
values = _require_key(transform, "values")
310391
if len(values) == 2:
@@ -318,27 +399,27 @@ def transform_substring(
318399

319400

320401
def transform_expand(
321-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
402+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
322403
) -> Column:
323404
expand_length = _require_key(transform, "value")
324405
return array([input_col + i for i in range(-expand_length, expand_length + 1)])
325406

326407

327408
def transform_cast_as_int(
328-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
409+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
329410
) -> Column:
330411
return input_col.cast("int")
331412

332413

333414
def transform_divide_by_int(
334-
input_col: Column, transform: dict[str, Any], context
415+
input_col: Column, transform: Mapping[str, Any], context
335416
) -> Column:
336417
divisor = _require_key(transform, "value")
337418
return input_col.cast("int") / divisor
338419

339420

340421
def transform_when_value(
341-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
422+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
342423
) -> Column:
343424
threshold = _require_key(transform, "value")
344425
if_value = _require_key(transform, "if_value")
@@ -347,6 +428,6 @@ def transform_when_value(
347428

348429

349430
def transform_get_floor(
350-
input_col: Column, transform: dict[str, Any], context: dict[str, Any]
431+
input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]
351432
) -> Column:
352433
return floor(input_col).cast("int")

0 commit comments

Comments
 (0)