From dbb9eeeb6051f31766d33b45436ccc3306ee5db9 Mon Sep 17 00:00:00 2001 From: rileyh Date: Thu, 30 May 2024 21:02:18 +0000 Subject: [PATCH] [#134] Add tests for the array feature selection This includes some failing tests which provide 1 or 3 input columns instead of just 2. #134 should make these tests pass. --- hlink/tests/core/transforms_test.py | 92 +++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 hlink/tests/core/transforms_test.py diff --git a/hlink/tests/core/transforms_test.py b/hlink/tests/core/transforms_test.py new file mode 100644 index 0000000..f940dbb --- /dev/null +++ b/hlink/tests/core/transforms_test.py @@ -0,0 +1,92 @@ +from pyspark.sql import Row, SparkSession +import pytest + +from hlink.linking.core.transforms import generate_transforms +from hlink.linking.link_task import LinkTask + + +@pytest.mark.parametrize("is_a", [True, False]) +def test_generate_transforms_array_transform_1_col( + spark: SparkSession, preprocessing: LinkTask, is_a: bool +) -> None: + df = spark.createDataFrame( + [[1, "Leto II", 3508], [2, "Hwi", 26], [3, "Siona", 25]], + schema=["id", "name", "age"], + ) + feature_selections = [ + { + "transform": "array", + "input_columns": ["name"], + "output_column": "array_column", + } + ] + + df_result = generate_transforms( + spark, df, feature_selections, preprocessing, is_a, "id" + ) + array_column = df_result.select("array_column").collect() + assert array_column == [ + Row(array_column=["Leto II"]), + Row(array_column=["Hwi"]), + Row(array_column=["Siona"]), + ] + + +@pytest.mark.parametrize("is_a", [True, False]) +def test_generate_transforms_array_transform_2_cols( + spark: SparkSession, preprocessing: LinkTask, is_a: bool +) -> None: + df = spark.createDataFrame( + [[1, "Leto II", 3508], [2, "Hwi", 26], [3, "Siona", 25]], + schema=["id", "name", "age"], + ) + feature_selections = [ + { + "transform": "array", + "input_columns": ["name", "age"], + "output_column": "array_column", + } + ] + + df_result = generate_transforms( + spark, df, feature_selections, preprocessing, is_a, "id" + ) + array_column = df_result.select("array_column").collect() + assert array_column == [ + Row(array_column=["Leto II", "3508"]), + Row(array_column=["Hwi", "26"]), + Row(array_column=["Siona", "25"]), + ] + + +@pytest.mark.parametrize("is_a", [True, False]) +def test_generate_transforms_array_transform_3_cols( + spark: SparkSession, + preprocessing: LinkTask, + is_a: bool, +) -> None: + df = spark.createDataFrame( + [ + [1, "Leto II", 3508, "Arrakis"], + [2, "Hwi", 26, "Ix"], + [3, "Siona", 25, "Arrakis"], + ], + schema=["id", "name", "age", "home"], + ) + feature_selections = [ + { + "transform": "array", + "input_columns": ["home", "age", "name"], + "output_column": "array_column", + } + ] + + df_result = generate_transforms( + spark, df, feature_selections, preprocessing, is_a, "id" + ) + array_column = df_result.select("array_column").collect() + assert array_column == [ + Row(array_column=["Arrakis", "3508", "Leto II"]), + Row(array_column=["Ix", "26", "Hwi"]), + Row(array_column=["Arrakis", "25", "Siona"]), + ]