22# For copyright and licensing information, see the NOTICE and LICENSE files
33# in this project's top-level directory, and also on-line at:
44# https://github.com/ipums/hlink
5- from typing import Any
5+ """
6+ Column mappings for cleaning and preprocessing input data.
7+
8+ This module provides functions for cleaning and preprocessing columns of Spark
9+ data frames. It depends on the idea of a "column mapping", which is a
10+ dictionary which specifies an input column, an optional output column alias,
11+ and a list of zero or more transforms to apply to the input column.
12+
13+ ```python
14+ # An example column mapping. The "column_name" attribute gives the name of the
15+ # input column, and "alias" gives the name of the output column. The alias is
16+ # optional and defaults to the input column name.
17+ {
18+ "column_name": "namefrst",
19+ "alias": "namefrst_std",
20+ "transforms": [
21+ {"type": "lowercase_strip"},
22+ {"type": "rationalize_name_words"},
23+ {"type": "remove_qmark_hyphen"},
24+ {"type": "condense_strip_whitespace"},
25+ {"type": "split"},
26+ {"type": "array_index", "value": 0},
27+ ]
28+ }
29+ ```
30+
31+ Hlink has many built-in column mapping transforms, computed by the
32+ `transform_*` functions in this module. Hlink also has support for custom
33+ column mapping transforms via the `custom_transforms` argument to
34+ `select_column_mapping`. This argument must be a mapping from strings to
35+ functions which compute the column mapping transforms. For example, say that
36+ you wanted to implement a custom column mapping transform named "reverse" which
37+ reverses a string. The first thing to do is to write a function which computes
38+ the transform and satisfies the column mapping transform interface (see the
39+ ColumnMappingTransform type alias below).
40+
41+ ```python
42+ from pyspark.sql import Column
43+ from pyspark.sql.functions import reverse
44+
45+ # input_col is the input Column expression.
46+ # transform is the column mapping transform dictionary, like
47+ # {"type": "reverse"}. This lets the transform accept arbitrary arguments from
48+ # the configuration as needed.
49+ # context is a dictionary with additional context which may be helpful for some
50+ # transforms. In particular, it always contains at least the key "dataset",
51+ # which indicates whether the current dataset is dataset "a" or "b".
52+ def transform_reverse(input_col: Column, transform: Mapping[str, Any], context: Mapping[str, Any]) -> Column:
53+ return reverse(input_col)
54+ ```
55+
56+ Then, when you call `select_column_mapping`, you can pass
57+ `custom_transforms={"reverse": transform_reverse}`, and hlink will
58+ automatically use your custom transform when appropriate. Note that custom
59+ transforms which have the same name as a built-in transform override the
60+ built-in transform.
61+ """
62+
63+ from collections .abc import Mapping
64+ from typing import Any , Callable , TypeAlias
665
766from pyspark .sql import Column , DataFrame
867from pyspark .sql .functions import (
2180from pyspark .sql .types import LongType
2281
2382
83+ ColumnMappingTransform : TypeAlias = Callable [
84+ [Column , Mapping [str , Any ], Mapping [str , Any ]], Column
85+ ]
86+ """
87+ The form of column mapping transform functions. These take an input Column,
88+ the transform mapping from the configuration, and a mapping providing some
89+ additional context. They return a new output Column.
90+ """
91+
92+
2493def select_column_mapping (
25- column_mapping : dict [str , Any ],
94+ column_mapping : Mapping [str , Any ],
2695 df_selected : DataFrame ,
2796 is_a : bool ,
2897 column_selects : list [str ],
98+ custom_transforms : Mapping [str , ColumnMappingTransform ] | None = None ,
2999) -> tuple [DataFrame , list [str ]]:
30100 name = column_mapping ["column_name" ]
31101 if "override_column_a" in column_mapping and is_a :
32102 override_name = column_mapping ["override_column_a" ]
33103 column_select = col (override_name )
34104 if "override_transforms" in column_mapping :
35105 for transform in column_mapping ["override_transforms" ]:
36- column_select = apply_transform (column_select , transform , is_a )
106+ column_select = apply_transform (
107+ column_select , transform , is_a , custom_transforms
108+ )
37109 elif "override_column_b" in column_mapping and not is_a :
38110 override_name = column_mapping ["override_column_b" ]
39111 column_select = col (override_name )
40112 if "override_transforms" in column_mapping :
41113 for transform in column_mapping ["override_transforms" ]:
42- column_select = apply_transform (column_select , transform , is_a )
114+ column_select = apply_transform (
115+ column_select , transform , is_a , custom_transforms
116+ )
43117 elif "set_value_column_a" in column_mapping and is_a :
44118 value_to_set = column_mapping ["set_value_column_a" ]
45119 column_select = lit (value_to_set )
@@ -49,7 +123,9 @@ def select_column_mapping(
49123 elif "transforms" in column_mapping :
50124 column_select = col (name )
51125 for transform in column_mapping ["transforms" ]:
52- column_select = apply_transform (column_select , transform , is_a )
126+ column_select = apply_transform (
127+ column_select , transform , is_a , custom_transforms
128+ )
53129 else :
54130 column_select = col (name )
55131
@@ -59,7 +135,7 @@ def select_column_mapping(
59135 return df_selected .withColumn (alias , column_select ), column_selects
60136
61137
62- def _require_key (transform : dict [str , Any ], key : str ) -> Any :
138+ def _require_key (transform : Mapping [str , Any ], key : str ) -> Any :
63139 """
64140 Extract a key from a transform, or raise a helpful context-aware error if the
65141 key is not present.
@@ -78,7 +154,10 @@ def _require_key(transform: dict[str, Any], key: str) -> Any:
78154
79155# These apply to the column mappings in the current config
80156def apply_transform (
81- column_select : Column , transform : dict [str , Any ], is_a : bool
157+ column_select : Column ,
158+ transform : Mapping [str , Any ],
159+ is_a : bool ,
160+ custom_transforms : Mapping [str , ColumnMappingTransform ] | None = None ,
82161) -> Column :
83162 """Return a new column that is the result of applying the given transform
84163 to the given input column (column_select). The is_a parameter controls the
@@ -93,7 +172,7 @@ def apply_transform(
93172 dataset = "a" if is_a else "b"
94173 context = {"dataset" : dataset }
95174 transform_type = transform ["type" ]
96- transforms = {
175+ builtin_transforms = {
97176 "add_to_a" : transform_add_to_a ,
98177 "concat_to_a" : transform_concat_to_a ,
99178 "concat_to_b" : transform_concat_to_b ,
@@ -123,7 +202,9 @@ def apply_transform(
123202 "get_floor" : transform_get_floor ,
124203 }
125204
126- transform_func = transforms .get (transform_type )
205+ custom_transforms = custom_transforms or {}
206+ builtin_func = builtin_transforms .get (transform_type )
207+ transform_func = custom_transforms .get (transform_type , builtin_func )
127208
128209 if transform_func is None :
129210 raise ValueError (f"Invalid transform type for { transform } " )
@@ -132,7 +213,7 @@ def apply_transform(
132213
133214
134215def transform_add_to_a (
135- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
216+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
136217) -> Column :
137218 is_a = context ["dataset" ] == "a"
138219 if is_a :
@@ -142,7 +223,7 @@ def transform_add_to_a(
142223
143224
144225def transform_concat_to_a (
145- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
226+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
146227) -> Column :
147228 is_a = context ["dataset" ] == "a"
148229 if is_a :
@@ -153,7 +234,7 @@ def transform_concat_to_a(
153234
154235
155236def transform_concat_to_b (
156- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
237+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
157238) -> Column :
158239 is_a = context ["dataset" ] == "a"
159240 if is_a :
@@ -164,50 +245,50 @@ def transform_concat_to_b(
164245
165246
166247def transform_concat_two_cols (
167- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
248+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
168249) -> Column :
169250 column_to_append = _require_key (transform , "column_to_append" )
170251 return concat (input_col , column_to_append )
171252
172253
173254def transform_lowercase_strip (
174- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
255+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
175256) -> Column :
176257 return lower (trim (input_col ))
177258
178259
179260def transform_rationalize_name_words (
180- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
261+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
181262) -> Column :
182263 return regexp_replace (input_col , r"[^a-z?'\*\-]+" , " " )
183264
184265
185266def transform_remove_qmark_hyphen (
186- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
267+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
187268) -> Column :
188269 return regexp_replace (input_col , r"[?\*\-]+" , "" )
189270
190271
191272def transform_remove_punctuation (
192- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
273+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
193274) -> Column :
194275 return regexp_replace (input_col , r"[?\-\\\/\"\':,.\[\]\{\}]+" , "" )
195276
196277
197278def transform_replace_apostrophe (
198- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
279+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
199280) -> Column :
200281 return regexp_replace (input_col , r"'+" , " " )
201282
202283
203284def transform_remove_alternate_names (
204- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
285+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
205286) -> Column :
206287 return regexp_replace (input_col , r"(\w+)( or \w+)+" , "$1" )
207288
208289
209290def transform_remove_suffixes (
210- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
291+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
211292) -> Column :
212293 values = _require_key (transform , "values" )
213294 suffixes = "|" .join (values )
@@ -216,7 +297,7 @@ def transform_remove_suffixes(
216297
217298
218299def transform_remove_stop_words (
219- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
300+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
220301) -> Column :
221302 values = _require_key (transform , "values" )
222303 words = "|" .join (values )
@@ -225,7 +306,7 @@ def transform_remove_stop_words(
225306
226307
227308def transform_remove_prefixes (
228- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
309+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
229310) -> Column :
230311 values = _require_key (transform , "values" )
231312 prefixes = "|" .join (values )
@@ -234,7 +315,7 @@ def transform_remove_prefixes(
234315
235316
236317def transform_condense_prefixes (
237- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
318+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
238319) -> Column :
239320 values = _require_key (transform , "values" )
240321 prefixes = "|" .join (values )
@@ -243,38 +324,38 @@ def transform_condense_prefixes(
243324
244325
245326def transform_condense_strip_whitespace (
246- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
327+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
247328) -> Column :
248329 return regexp_replace (trim (input_col ), r"\s\s+" , " " )
249330
250331
251332def transform_remove_one_letter_names (
252- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
333+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
253334) -> Column :
254335 return regexp_replace (input_col , r"^((?:\w )+)(\w+)" , r"$2" )
255336
256337
257338def transform_split (
258- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
339+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
259340) -> Column :
260341 return split (input_col , " " )
261342
262343
263344def transform_length (
264- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
345+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
265346) -> Column :
266347 return length (input_col )
267348
268349
269350def transform_array_index (
270- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
351+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
271352) -> Column :
272353 value = _require_key (transform , "value" )
273354 return input_col [value ]
274355
275356
276357def transform_mapping (
277- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
358+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
278359) -> Column :
279360 mapped_column = input_col
280361 mappings = _require_key (transform , "mappings" )
@@ -290,7 +371,7 @@ def transform_mapping(
290371
291372
292373def transform_swap_words (
293- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
374+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
294375) -> Column :
295376 mapped_column = input_col
296377 values = _require_key (transform , "values" )
@@ -304,7 +385,7 @@ def transform_swap_words(
304385
305386
306387def transform_substring (
307- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
388+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
308389) -> Column :
309390 values = _require_key (transform , "values" )
310391 if len (values ) == 2 :
@@ -318,27 +399,27 @@ def transform_substring(
318399
319400
320401def transform_expand (
321- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
402+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
322403) -> Column :
323404 expand_length = _require_key (transform , "value" )
324405 return array ([input_col + i for i in range (- expand_length , expand_length + 1 )])
325406
326407
327408def transform_cast_as_int (
328- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
409+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
329410) -> Column :
330411 return input_col .cast ("int" )
331412
332413
333414def transform_divide_by_int (
334- input_col : Column , transform : dict [str , Any ], context
415+ input_col : Column , transform : Mapping [str , Any ], context
335416) -> Column :
336417 divisor = _require_key (transform , "value" )
337418 return input_col .cast ("int" ) / divisor
338419
339420
340421def transform_when_value (
341- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
422+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
342423) -> Column :
343424 threshold = _require_key (transform , "value" )
344425 if_value = _require_key (transform , "if_value" )
@@ -347,6 +428,6 @@ def transform_when_value(
347428
348429
349430def transform_get_floor (
350- input_col : Column , transform : dict [str , Any ], context : dict [str , Any ]
431+ input_col : Column , transform : Mapping [str , Any ], context : Mapping [str , Any ]
351432) -> Column :
352433 return floor (input_col ).cast ("int" )
0 commit comments