Skip to content

Commit

Permalink
[#118] Add unit tests for override_column_<a/b> for column_mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
riley-harper committed Nov 27, 2023
1 parent f9ff38c commit 87090a2
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions hlink/tests/core/column_mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,110 @@ def test_select_column_mapping_column_selects_preserved(spark):
assert set(column_selects) == {"occ_with_underscores", "occupation"}


def test_select_column_mapping_override_column_a(spark):
"""
override_column_a lets the user specify a different column name for
dataset A. override_transforms are applied only to dataset A in
this case.
"""
column_mapping = {
"column_name": "occ",
"override_column_a": "occupation",
"override_transforms": [{"type": "lowercase_strip"}],
}
df_a = spark.createDataFrame(TEST_DF_1)
df_b = spark.createDataFrame(TEST_DF_2)

df_selected_a, column_selects_a = select_column_mapping(
column_mapping,
df_a,
is_a=True,
column_selects=[],
)

df_selected_b, column_selects_b = select_column_mapping(
column_mapping,
df_b,
is_a=False,
column_selects=[],
)

assert column_selects_a == column_selects_b == ["occ"]

occ_a = df_selected_a.select("occ").toPandas()
assert occ_a["occ"].to_list() == [
"farmer",
"computer scientist",
"waitress",
"retired",
"lawyer",
"doctor",
]

occ_b = df_selected_b.select("occ").toPandas()
assert occ_b["occ"].to_list() == [
"RETIRED",
"CHILDCARE",
None,
None,
]


def test_select_column_mapping_override_column_b(spark):
"""
override_column_b lets the user specify a different column name for
dataset B. override_transforms are applied only to dataset B in
this case, and transforms are applied only to dataset A.
"""
column_mapping = {
"column_name": "occupation",
"override_column_b": "occ",
"override_transforms": [
{"type": "concat_two_cols", "column_to_append": "identifier"}
],
"transforms": [
{"type": "lowercase_strip"},
{"type": "concat_two_cols", "column_to_append": "id"},
],
}
df_a = spark.createDataFrame(TEST_DF_1)
df_b = spark.createDataFrame(TEST_DF_2)

df_selected_a, column_selects_a = select_column_mapping(
column_mapping,
df_a,
is_a=True,
column_selects=[],
)

df_selected_b, column_selects_b = select_column_mapping(
column_mapping,
df_b,
is_a=False,
column_selects=[],
)

assert column_selects_a == column_selects_b == ["occupation"]

occ_a = df_selected_a.select("occupation").toPandas()
assert occ_a["occupation"].to_list() == [
"farmer0",
"computer scientist1",
"waitress2",
"retired3",
"lawyer4",
"doctor5",
]

occ_b = df_selected_b.select("occupation").toPandas()
assert occ_b["occupation"].to_list() == [
"RETIRED1000",
"CHILDCARE1002",
None,
None,
]


def test_select_column_mapping_error_missing_column_name(spark):
"""
Without a column_name key in the column_mapping, the function raises
Expand Down

0 comments on commit 87090a2

Please sign in to comment.