From 87090a257c6874deaeafde62a53d1520848d6025 Mon Sep 17 00:00:00 2001 From: rileyh Date: Mon, 27 Nov 2023 17:05:08 +0000 Subject: [PATCH] [#118] Add unit tests for override_column_ for column_mappings --- hlink/tests/core/column_mapping_test.py | 104 ++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/hlink/tests/core/column_mapping_test.py b/hlink/tests/core/column_mapping_test.py index 4db1639..eb694e6 100644 --- a/hlink/tests/core/column_mapping_test.py +++ b/hlink/tests/core/column_mapping_test.py @@ -194,6 +194,110 @@ def test_select_column_mapping_column_selects_preserved(spark): assert set(column_selects) == {"occ_with_underscores", "occupation"} +def test_select_column_mapping_override_column_a(spark): + """ + override_column_a lets the user specify a different column name for + dataset A. override_transforms are applied only to dataset A in + this case. + """ + column_mapping = { + "column_name": "occ", + "override_column_a": "occupation", + "override_transforms": [{"type": "lowercase_strip"}], + } + df_a = spark.createDataFrame(TEST_DF_1) + df_b = spark.createDataFrame(TEST_DF_2) + + df_selected_a, column_selects_a = select_column_mapping( + column_mapping, + df_a, + is_a=True, + column_selects=[], + ) + + df_selected_b, column_selects_b = select_column_mapping( + column_mapping, + df_b, + is_a=False, + column_selects=[], + ) + + assert column_selects_a == column_selects_b == ["occ"] + + occ_a = df_selected_a.select("occ").toPandas() + assert occ_a["occ"].to_list() == [ + "farmer", + "computer scientist", + "waitress", + "retired", + "lawyer", + "doctor", + ] + + occ_b = df_selected_b.select("occ").toPandas() + assert occ_b["occ"].to_list() == [ + "RETIRED", + "CHILDCARE", + None, + None, + ] + + +def test_select_column_mapping_override_column_b(spark): + """ + override_column_b lets the user specify a different column name for + dataset B. override_transforms are applied only to dataset B in + this case, and transforms are applied only to dataset A. + """ + column_mapping = { + "column_name": "occupation", + "override_column_b": "occ", + "override_transforms": [ + {"type": "concat_two_cols", "column_to_append": "identifier"} + ], + "transforms": [ + {"type": "lowercase_strip"}, + {"type": "concat_two_cols", "column_to_append": "id"}, + ], + } + df_a = spark.createDataFrame(TEST_DF_1) + df_b = spark.createDataFrame(TEST_DF_2) + + df_selected_a, column_selects_a = select_column_mapping( + column_mapping, + df_a, + is_a=True, + column_selects=[], + ) + + df_selected_b, column_selects_b = select_column_mapping( + column_mapping, + df_b, + is_a=False, + column_selects=[], + ) + + assert column_selects_a == column_selects_b == ["occupation"] + + occ_a = df_selected_a.select("occupation").toPandas() + assert occ_a["occupation"].to_list() == [ + "farmer0", + "computer scientist1", + "waitress2", + "retired3", + "lawyer4", + "doctor5", + ] + + occ_b = df_selected_b.select("occupation").toPandas() + assert occ_b["occupation"].to_list() == [ + "RETIRED1000", + "CHILDCARE1002", + None, + None, + ] + + def test_select_column_mapping_error_missing_column_name(spark): """ Without a column_name key in the column_mapping, the function raises