Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug with the override_column_X attributes in conf_validations.py #131

Merged
merged 8 commits into from
Feb 20, 2024
26 changes: 16 additions & 10 deletions hlink/linking/matching/link_step_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,32 @@ def _explode(self, df, comparisons, comparison_features, blocking, id_column, is
expand_length = exploding_column["expand_length"]
derived_from_column = exploding_column["derived_from"]
explode_selects = [
explode(self._expand(derived_from_column, expand_length)).alias(
exploding_column_name
(
explode(self._expand(derived_from_column, expand_length)).alias(
exploding_column_name
)
if exploding_column_name == column
else column
)
if exploding_column_name == column
else column
for column in all_column_names
]
else:
explode_selects = [
explode(col(exploding_column_name)).alias(exploding_column_name)
if exploding_column_name == c
else c
(
explode(col(exploding_column_name)).alias(exploding_column_name)
if exploding_column_name == c
else c
)
for c in all_column_names
]
if "dataset" in exploding_column:
derived_from_column = exploding_column["derived_from"]
explode_selects_with_derived_column = [
col(derived_from_column).alias(exploding_column_name)
if exploding_column_name == column
else column
(
col(derived_from_column).alias(exploding_column_name)
if exploding_column_name == column
else column
)
for column in all_column_names
]
if exploding_column["dataset"] == "a":
Expand Down
60 changes: 48 additions & 12 deletions hlink/scripts/lib/conf_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
# in this project's top-level directory, and also on-line at:
# https://github.com/ipums/hlink

from pyspark.sql.utils import AnalysisException
from os import path
from typing import Any, Literal

import colorama
from pyspark.sql.utils import AnalysisException
from pyspark.sql import DataFrame


def print_checking(section: str):
Expand Down Expand Up @@ -265,7 +268,47 @@ def check_substitution_columns(config, columns_available):
)


def check_column_mappings(config, df_a, df_b):
def check_column_mappings_column_available(
column_mapping: dict[str, Any],
df: DataFrame,
previous_mappings: list[str],
a_or_b: Literal["a", "b"],
) -> None:
"""
Check whether a column in a column mapping is available or not. Raise a
ValueError if it is not available.

previous_mappings is a list of columns mapped by previous column mappings.
"""
column_name = column_mapping["column_name"]
override_column = column_mapping.get(f"override_column_{a_or_b}")
df_columns_lower = [column.lower() for column in df.columns]

if override_column is not None:
if override_column.lower() not in df_columns_lower:
raise ValueError(
f"Within a [[column_mappings]] the override_column_{a_or_b} column "
f"'{override_column}' does not exist in datasource_{a_or_b}.\n"
f"Column mapping: {column_mapping}\n"
f"Available columns: {df.columns}"
)
else:
if (
column_name.lower() not in df_columns_lower
and column_name not in previous_mappings
):
raise ValueError(
f"Within a [[column_mappings]] the column_name '{column_name}' "
f"does not exist in datasource_{a_or_b} and no previous "
"[[column_mapping]] alias exists for it.\n"
f"Column mapping: {column_mapping}.\n"
f"Available columns:\n {df.columns}"
)


def check_column_mappings(
config: dict[str, Any], df_a: DataFrame, df_b: DataFrame
) -> list[str]:
column_mappings = config.get("column_mappings")
if not column_mappings:
raise ValueError("No [[column_mappings]] exist in the conf file.")
Expand All @@ -276,22 +319,15 @@ def check_column_mappings(config, df_a, df_b):
column_name = c.get("column_name")
set_value_column_a = c.get("set_value_column_a")
set_value_column_b = c.get("set_value_column_b")

if not column_name:
raise ValueError(
f"The following [[column_mappings]] has no 'column_name' attribute: {c}"
)
if set_value_column_a is None:
if column_name.lower() not in [c.lower() for c in df_a.columns]:
if column_name not in columns_available:
raise ValueError(
f"Within a [[column_mappings]] the column_name: '{column_name}' does not exist in datasource_a and no previous [[column_mapping]] alias exists for it. \nColumn mapping: {c}. \nAvailable columns: \n {df_a.columns}"
)
check_column_mappings_column_available(c, df_a, columns_available, "a")
if set_value_column_b is None:
if column_name.lower() not in [c.lower() for c in df_b.columns]:
if column_name not in columns_available:
raise ValueError(
f"Within a [[column_mappings]] the column_name: '{column_name}' does not exist in datasource_b and no previous [[column_mapping]] alias exists for it. Column mapping: {c}. Available columns: \n {df_b.columns}"
)
check_column_mappings_column_available(c, df_b, columns_available, "b")
if alias in columns_available:
duplicates.append(alias)
elif not alias and column_name in columns_available:
Expand Down
200 changes: 199 additions & 1 deletion hlink/tests/conf_validations_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import pytest

from pyspark.sql import SparkSession

from hlink.configs.load_config import load_conf_file
from hlink.scripts.lib.conf_validations import analyze_conf
from hlink.scripts.lib.conf_validations import analyze_conf, check_column_mappings
from hlink.linking.link_run import LinkRun


Expand All @@ -25,3 +27,199 @@ def test_invalid_conf(conf_dir_path, spark, conf_name, error_msg):

with pytest.raises(ValueError, match=error_msg):
analyze_conf(link_run)


def test_check_column_mappings_mappings_missing(spark: SparkSession) -> None:
"""
The config must have a column_mappings section.
"""
config = {}
df_a = spark.createDataFrame([[1], [2], [3]], ["a"])
df_b = spark.createDataFrame([[4], [5], [6]], ["b"])

with pytest.raises(
ValueError, match=r"No \[\[column_mappings\]\] exist in the conf file"
):
check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_no_column_name(spark: SparkSession) -> None:
"""
Each column mapping in the config must have a column_name attribute.
"""
config = {
"column_mappings": [{"column_name": "AGE", "alias": "age"}, {"alias": "height"}]
}
df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])

expected_err = (
r"The following \[\[column_mappings\]\] has no 'column_name' attribute:"
)
with pytest.raises(ValueError, match=expected_err):
check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_column_name_not_available_datasource_a(
spark: SparkSession,
) -> None:
"""
Column mappings may only use column_names that appear in datasource A or a
previous column mapping.
"""
config = {"column_mappings": [{"column_name": "HEIGHT"}]}

df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
df_b = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])

expected_err = (
r"Within a \[\[column_mappings\]\] the column_name 'HEIGHT' "
r"does not exist in datasource_a and no previous \[\[column_mapping\]\] "
"alias exists for it"
)

with pytest.raises(ValueError, match=expected_err):
check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_set_value_column_a_does_not_need_column(
spark: SparkSession,
) -> None:
"""
When set_value_column_a is present for a column mapping, that column does not
need to be present in datasource A.
"""
config = {"column_mappings": [{"column_name": "HEIGHT", "set_value_column_a": 125}]}

df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
df_b = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])

check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_column_name_not_available_datasource_b(
spark: SparkSession,
) -> None:
"""
Column mappings may only use column_names that appear in datasource B or a
previous column mapping.
"""
config = {"column_mappings": [{"column_name": "HEIGHT"}]}

df_a = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])
df_b = spark.createDataFrame([[20], [40], [60]], ["AGE"])

expected_err = (
r"Within a \[\[column_mappings\]\] the column_name 'HEIGHT' "
r"does not exist in datasource_b and no previous \[\[column_mapping\]\] "
"alias exists for it"
)

with pytest.raises(ValueError, match=expected_err):
check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_set_value_column_b_does_not_need_column(
spark: SparkSession,
) -> None:
"""
When set_value_column_b is present for a column mapping, that column does not
need to be present in datasource B.
"""
config = {"column_mappings": [{"column_name": "HEIGHT", "set_value_column_b": 125}]}

df_a = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])
df_b = spark.createDataFrame([[20], [40], [60]], ["AGE"])

check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_previous_mappings_are_available(
spark: SparkSession,
) -> None:
"""
Columns created in a previous column mapping can be used in other column
mappings.
"""
config = {
"column_mappings": [
{"column_name": "AGE", "alias": "AGE_HLINK"},
{"column_name": "AGE_HLINK", "alias": "AGE_HLINK2"},
]
}
df_a = spark.createDataFrame([[70], [50], [30]], ["AGE"])
df_b = spark.createDataFrame([[20], [40], [60]], ["AGE"])

check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_override_column_a(spark: SparkSession) -> None:
"""
The override_column_a attribute lets you control which column you read from
in datasource A.
"""
config = {
"column_mappings": [{"column_name": "AGE", "override_column_a": "ageColumn"}]
}
df_a = spark.createDataFrame([[20], [40], [60]], ["ageColumn"])
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])

check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_override_column_b(spark: SparkSession) -> None:
"""
The override_column_b attribute lets you control which column you read from
in datasource B.
"""
config = {
"column_mappings": [{"column_name": "ageColumn", "override_column_b": "AGE"}]
}
df_a = spark.createDataFrame([[20], [40], [60]], ["ageColumn"])
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])

check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_override_column_a_not_present(
spark: SparkSession,
) -> None:
"""
The override_column_a column must be present in datasource A.
"""
config = {
"column_mappings": [
{"column_name": "AGE", "override_column_a": "oops_not_there"}
]
}
df_a = spark.createDataFrame([[20], [40], [60]], ["ageColumn"])
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])

expected_err = (
r"Within a \[\[column_mappings\]\] the override_column_a column "
"'oops_not_there' does not exist in datasource_a"
)
with pytest.raises(ValueError, match=expected_err):
check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_override_column_b_not_present(
spark: SparkSession,
) -> None:
"""
The override_column_b column must be present in datasource B.
"""
config = {
"column_mappings": [
{"column_name": "AGE", "override_column_b": "oops_not_there"}
]
}
df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])

expected_err = (
r"Within a \[\[column_mappings\]\] the override_column_b column "
"'oops_not_there' does not exist in datasource_b"
)
with pytest.raises(ValueError, match=expected_err):
check_column_mappings(config, df_a, df_b)