Skip to content

Commit

Permalink
[#162] Fix a bug in RenameVectorAttributes
Browse files Browse the repository at this point in the history
The bug was that we didn't propagate the metadata changes into Java, so they
weren't persistent in something like a Pipeline. By calling withMetadata(), we
should now be persisting our changes correctly.
  • Loading branch information
riley-harper committed Nov 20, 2024
1 parent a4f3534 commit 2e58078
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
7 changes: 4 additions & 3 deletions hlink/linking/transformers/rename_vector_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ def setParams(
return self._set(**kwargs)

def _transform(self, dataset: DataFrame) -> DataFrame:
metadata = dataset.schema[self.getInputCol()].metadata
attributes_by_type = metadata["ml_attr"]["attrs"]
input_col = self.getInputCol()
to_replace = self.getOrDefault("strsToReplace")
replacement_str = self.getOrDefault("replaceWith")
metadata = dataset.schema[input_col].metadata
attributes_by_type = metadata["ml_attr"]["attrs"]

# The attributes are grouped by type, which may be numeric, binary, or
# nominal. We don't care about the type here; we'll just rename all of
Expand All @@ -70,4 +71,4 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
substring, replacement_str
)

return dataset
return dataset.withMetadata(input_col, metadata)
12 changes: 10 additions & 2 deletions hlink/tests/transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def test_rename_vector_attributes(spark: SparkSession) -> None:
)
transformed = remove_colons.transform(assembler.transform(df))

attrs = transformed.schema["vectorized"].metadata["ml_attr"]["attrs"]["numeric"]
# Save to Java, then reload to confirm that the metadata changes are persistent
transformed.write.mode("overwrite").saveAsTable("transformed")
df = spark.table("transformed")

attrs = df.schema["vectorized"].metadata["ml_attr"]["attrs"]["numeric"]
attr_names = [attr["name"] for attr in attrs]
assert attr_names == ["A", "regionf_0_namelast_jw"]

Expand All @@ -35,6 +39,10 @@ def test_rename_vector_attributes_multiple_replacements(spark: SparkSession) ->
)
transformed = rename_attrs.transform(assembler.transform(df))

attrs = transformed.schema["vector"].metadata["ml_attr"]["attrs"]["numeric"]
# Save to Java, then reload to confirm that the metadata changes are persistent
transformed.write.mode("overwrite").saveAsTable("transformed")
df = spark.table("transformed")

attrs = df.schema["vector"].metadata["ml_attr"]["attrs"]["numeric"]
attr_names = [attr["name"] for attr in attrs]
assert attr_names == ["column1hasstars", "column2multiplesymbols"]

0 comments on commit 2e58078

Please sign in to comment.