Skip to content

Commit

Permalink
[BUGFIX] Fix log duplication when using specific super call (#168)
Browse files Browse the repository at this point in the history
Addressed an issue of duplicate logs when using a `super()` call in the execute method of a Step class under the specific condition where `super()` was not pointing to the direct ancestor of the called`Step`. The problem was caused by the `_is_called_through_super` method of the `StepMetaClass` not being able to inspect beyond the immediate ancestor of the called object.

This fix involves updating the `_is_called_through_super` method to traverse the entire method resolution order (MRO) and correctly identify if the `execute`-method is called through `super()` in any parent class of its direct ancestry.

Additionally, the `_execute_wrapper` method was updated to ensure logging is only triggered once per execute call.

While fixing this issue, I came across a few more problems that needed to be addressed. A summary below.
All relevant tests have been updated / addressed also.

## Snowflake
Switched Snowflake classes to use `params` over `options` to stay in line with the rest of the Koheesio classes.

1. src/koheesio/integrations/snowflake/__init__.py:
       - Introduced `SF_DEFAULT_PARAMS` with default Snowflake parameters.
       - Renamed `options` to `params` to accommodate the switch to `ExtraParamsMixin` and updated the class to use `default_factory=partial(dict, **SF_DEFAULT_PARAMS)` (this was to make mypy and pytorch happy)
       - Added a property named `options` for backwards compatibility.

## JDBC switch to `ExtraParamsMixin`

1. `spark/readers/jdbc.py`:
       - Introduced ExtraParamsMixin to handle additional parameters natively.
       - Renamed `options` Field to `params` to accommodate the switch to `ExtraParamsMixin` and added alias="options".
       - Added a property named `options` for backwards compatibility.
       - `dbtable` and `query` validation are now handled upon `__init__` rather than at runtime (this is more in line with how Koheesio's other classes work and how it is intended to be used)
       - by default, either `dbtable` or `query` need to be submitted to use JDBC (as was always intended)

2. `spark/readers/hana.py`: (depends on jdbc)
       - Renamed `options` Field to `params` to accommodate the switch to `ExtraParamsMixin` and added alias="options".

3. `spark/readers/teradata.py`: (depends on jdbc)
       - Renamed `options` Field to `params` to accommodate the switch to `ExtraParamsMixin` and added alias="options".

## Hash Transformation
A new error popped up (only while using Spark Connect) that uncovered some bugs with how missing columns are being handled.

1. `src/koheesio/spark/transformations/hash.py`:
       - Updated the `sha2` function call to use named parameters.
       - Added a check for missing columns in the `Sha2Hash` class.
       - Improved the` Sha2Hash` class to handle cases when no columns are provided.


## Easier debugging and dev improvements
To make debugging easier, I changed the `pyproject.toml` to allow for easier running `spark connect` in your local dev environment:
- Added extra dependencies for `pyspark[connect]==3.5.4`.
- Added environment variables for Spark Connect in the development environment.

Additionally, changed to verbose mode logging in the pytest output.
- Changed pytest options from `-q --color=yes --order-scope=module` to `-vv --color=yes --order-scope=module` (which makes test log output in CICD more readable).

## Related Issue
#167 

## Motivation and Context
This change is required to prevent duplicate logs when using `super()` in nested Step classes. The updated logic ensures that the logging mechanism correctly identifies and handles `super()` calls, providing accurate and non-redundant log entries. The `_is_called_through_super` method was not just used for logs, but also for `Output` validation - although I did not witness any direct issues with this, this fix ensure that we call this only once also.

---------

Co-authored-by: Danny Meijer <[email protected]>
  • Loading branch information
dannymeijer and dannymeijer authored Feb 26, 2025
1 parent cd13c81 commit 522fd70
Show file tree
Hide file tree
Showing 12 changed files with 234 additions and 80 deletions.
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ name.".*(pyspark35r).*".env-vars = [


[tool.pytest.ini_options]
addopts = "-q --color=yes --order-scope=module"
addopts = "-vv --color=yes --order-scope=module"
log_level = "CRITICAL"
testpaths = ["tests"]
asyncio_default_fixture_loop_scope = "scope"
Expand Down Expand Up @@ -428,6 +428,14 @@ features = [
"test",
"docs",
]
extra-dependencies = [
"pyspark[connect]==3.5.4",
]

# Enable this if you want Spark Connect in your dev environment
[tool.hatch.envs.dev.env-vars]
SPARK_REMOTE = "local"
SPARK_TESTING = "True"


### ~~~~~~~~~~~~~~~~~~ ###
Expand Down
22 changes: 16 additions & 6 deletions src/koheesio/integrations/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from typing import Any, Dict, Generator, List, Optional, Set, Union
from abc import ABC
from contextlib import contextmanager
from functools import partial
import os
import tempfile
from types import ModuleType
Expand Down Expand Up @@ -145,6 +146,7 @@ def safe_import_snowflake_connector() -> Optional[ModuleType]:
)
return None

SF_DEFAULT_PARAMS = {"sfCompress": "on", "continue_on_error": "off"}

class SnowflakeBaseModel(BaseModel, ExtraParamsMixin, ABC): # type: ignore[misc]
"""
Expand Down Expand Up @@ -215,11 +217,19 @@ class SnowflakeBaseModel(BaseModel, ExtraParamsMixin, ABC): # type: ignore[misc
sfSchema: Optional[str] = Field(
default=..., alias="schema", description="The schema to use for the session after connecting"
)
options: Optional[Dict[str, Any]] = Field(
default={"sfCompress": "on", "continue_on_error": "off"},
description="Extra options to pass to the Snowflake connector",
params: Optional[Dict[str, Any]] = Field(
default_factory=partial(dict, **SF_DEFAULT_PARAMS),
description="Extra options to pass to the Snowflake connector, by default it includes "
"'sfCompress': 'on' and 'continue_on_error': 'off'",
alias="options",
examples=[{"sfCompress": "on", "continue_on_error": "off"}]
)

@property
def options(self) -> Dict[str, Any]:
"""Shorthand for accessing self.params provided for backwards compatibility"""
return self

def get_options(self, by_alias: bool = True, include: Optional[Set[str]] = None) -> Dict[str, Any]:
"""Get the sfOptions as a dictionary.
Expand All @@ -242,7 +252,7 @@ def get_options(self, by_alias: bool = True, include: Optional[Set[str]] = None)
# Exclude koheesio specific fields
"name",
"description",
# options and params are separately implemented
# params are separately implemented
"params",
"options",
# schema and password have to be handled separately
Expand Down Expand Up @@ -270,11 +280,11 @@ def get_options(self, by_alias: bool = True, include: Optional[Set[str]] = None)
fields = {key: value for key, value in fields.items() if key in include}
else:
# default filter
include = {"options", "params"}
include = {"params"}

# handle options
if "options" in include:
options = fields.pop("options", self.options)
options = fields.pop("params", self.params)
fields.update(**options)

# handle params
Expand Down
3 changes: 2 additions & 1 deletion src/koheesio/spark/readers/hana.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class HanaReader(JdbcReader):
default="com.sap.db.jdbc.Driver",
description="Make sure that the necessary JARs are available in the cluster: ngdbc-2-x.x.x.x",
)
options: Optional[Dict[str, Any]] = Field(
params: Optional[Dict[str, Any]] = Field(
default={"fetchsize": 2000, "numPartitions": 10},
description="Extra options to pass to the SAP HANA JDBC driver",
alias="options",
)
50 changes: 35 additions & 15 deletions src/koheesio/spark/readers/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

from typing import Any, Dict, Optional

from koheesio.models import Field, SecretStr
from koheesio import ExtraParamsMixin
from koheesio.models import Field, SecretStr, model_validator
from koheesio.spark.readers import Reader


class JdbcReader(Reader):
class JdbcReader(Reader, ExtraParamsMixin):
"""
Reader for JDBC tables.
Expand Down Expand Up @@ -49,10 +50,16 @@ class JdbcReader(Reader):
user="YOUR_USERNAME",
password="***",
dbtable="schema_name.table_name",
options={"fetchsize": 100},
options={"fetchsize": 100}, # you can also use 'params' instead of 'options'
)
df = jdbc_mssql.read()
```
### ExtraParamsMixin
The `ExtraParamsMixin` is a mixin class that provides a way to pass extra parameters to the reader. The extra
parameters are stored in the `params` (or `options`) attribute. Any key-value pairs passed to the reader will be
stored in the `params` attribute.
"""

format: str = Field(default="jdbc", description="The type of format to load. Defaults to 'jdbc'.")
Expand All @@ -71,41 +78,54 @@ class JdbcReader(Reader):
default=None, description="Database table name, also include schema name", alias="table"
)
query: Optional[str] = Field(default=None, description="Query")
options: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Extra options to pass to spark reader")
params: Dict[str, Any] = Field(default_factory=dict, description="Extra options to pass to spark reader", alias="options")

def get_options(self) -> Dict[str, Any]:
"""
Dictionary of options required for the specific JDBC driver.
Note: override this method if driver requires custom names, e.g. Snowflake: `sfUrl`, `sfUser`, etc.
"""
return {
_options = {
"driver": self.driver,
"url": self.url,
"user": self.user,
"password": self.password,
**self.options, # type: ignore
**self.params,
}
if query := self.query:
_options["query"] = query

def execute(self) -> Reader.Output:
"""Wrapper around Spark's jdbc read format"""
# Check that only one of them is filled in
if query and self.dbtable:
self.log.warning("Query is filled in, dbtable will be ignored!")
else:
_options["dbtable"] = self.dbtable

# Can't have both dbtable and query empty
return _options

@model_validator(mode="after")
def check_dbtable_or_query(self) -> "JdbcReader":
"""Check that dbtable or query is filled in and that only one of them is filled in (query has precedence)"""
# Check that dbtable or query is filled in
if not self.dbtable and not self.query:
raise ValueError("Please do not leave dbtable and query both empty!")

if self.query and self.dbtable:
self.log.info("Both 'query' and 'dbtable' are filled in, 'dbtable' will be ignored!")
return self

@property
def options(self) -> Dict[str, Any]:
"""Shorthand for accessing self.params provided for backwards compatibility"""
return self.params

def execute(self) -> "JdbcReader.Output":
"""Wrapper around Spark's jdbc read format"""
options = self.get_options()

if pw := self.password:
options["password"] = pw.get_secret_value()

if query := self.query:
options["query"] = query
self.log.info(f"Executing query: {self.query}")
else:
options["dbtable"] = self.dbtable
self.log.info(f"Executing query: {query}")

self.output.df = self.spark.read.format(self.format).options(**options).load()
5 changes: 3 additions & 2 deletions src/koheesio/spark/readers/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class TeradataReader(JdbcReader):
"com.teradata.jdbc.TeraDriver",
description="Make sure that the necessary JARs are available in the cluster: terajdbc4-x.x.x.x",
)
options: Optional[Dict[str, Any]] = Field(
{"fetchsize": 2000, "numPartitions": 10}, description="Extra options to pass to the Teradata JDBC driver"
params: Optional[Dict[str, Any]] = Field(
{"fetchsize": 2000, "numPartitions": 10}, description="Extra options to pass to the Teradata JDBC driver",
alias="options",
)
18 changes: 13 additions & 5 deletions src/koheesio/spark/transformations/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ def sha2_hash(columns: List[str], delimiter: Optional[str] = "|", num_bits: Opti
else:
column = _columns[0]

return sha2(column, num_bits) # type: ignore
return sha2(col=column, numBits=num_bits) # type: ignore


# TODO: convert this class to a ColumnsTransformationWithTarget
class Sha2Hash(ColumnsTransformation):
"""
hash the value of 1 or more columns using SHA-2 family of hash functions
Expand Down Expand Up @@ -92,12 +93,19 @@ class Sha2Hash(ColumnsTransformation):
default=..., description="The generated hash will be written to the column name specified here"
)

def execute(self) -> ColumnsTransformation.Output:
columns = list(self.get_columns())
def execute(self) -> "Sha2Hash.Output":
if not (columns := list(self.get_columns())):
self.output.df = self.df
return self.output

# check if columns exist in the dataframe
missing_columns = set(columns) - set(self.df.columns)
if missing_columns:
raise ValueError(f"Columns {missing_columns} not found in dataframe")

self.output.df = (
self.df.withColumn(
self.target_column, sha2_hash(columns=columns, delimiter=self.delimiter, num_bits=self.num_bits)
)
if columns
else self.df
)

19 changes: 13 additions & 6 deletions src/koheesio/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __new__(
@staticmethod
def _is_called_through_super(caller_self: Any, caller_name: str, *_args, **_kwargs) -> bool: # type: ignore[no-untyped-def]
"""
Check if the method is called through super() in the immediate parent class.
Check if the method is called through super() using MRO (Method Resolution Order).
Parameters
----------
Expand All @@ -197,9 +197,10 @@ def _is_called_through_super(caller_self: Any, caller_name: str, *_args, **_kwar
True if the method is called through super(), False otherwise.
"""

base_class = caller_self.__class__.__bases__[0]
return caller_name in base_class.__dict__
for base_class in caller_self.__class__.__mro__:
if caller_name in base_class.__dict__:
return True
return False

@classmethod
def _partialmethod_impl(mcs, cls: type, execute_method: Callable) -> partialmethod:
Expand Down Expand Up @@ -273,8 +274,14 @@ def _execute_wrapper(cls, step: Step, execute_method: Callable, *args, **kwargs)
"""

# check if the method is called through super() in the immediate parent class
caller_name = inspect.currentframe().f_back.f_back.f_code.co_name
# Check if the method is called through super() in the immediate parent class
caller_name = (
inspect.currentframe() # Current stack frame
.f_back # Previous stack frame (caller of the current function)
.f_back # Parent stack frame (caller of the caller function)
.f_code # Code object of that frame
.co_name # Name of the function from the code object
)
is_called_through_super_ = cls._is_called_through_super(step, caller_name)

cls._log_start_message(step=step, skip_logging=is_called_through_super_)
Expand Down
10 changes: 5 additions & 5 deletions tests/spark/integrations/snowflake/test_sync_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ def mock_drop_table(table):
task.execute()
chispa.assert_df_equality(task.output.target_df, df)

@mock.patch.object(SnowflakeRunQueryPython, "execute")
def test_merge(
self,
mocked_sf_query_execute,
spark,
foreach_batch_stream_local,
snowflake_staging_file,
mocker
):
# Arrange - Prepare Delta requirements
mocker.patch("koheesio.integrations.spark.snowflake.SnowflakeRunQueryPython.execute")
source_table = DeltaTableStep(database="klettern", table="test_merge")
spark.sql(
dedent(
Expand Down Expand Up @@ -167,9 +167,9 @@ def test_merge(

# Act - Run code
# Note: We are using the foreach_batch_stream_local fixture to simulate writing to a live environment
with mock.patch.object(SynchronizeDeltaToSnowflakeTask, "writer", new=foreach_batch_stream_local):
task.execute()
task.writer.await_termination()
mocker.patch.object(SynchronizeDeltaToSnowflakeTask, "writer", new=foreach_batch_stream_local)
task.execute()
task.writer.await_termination()

# Assert - Validate result
df = spark.read.parquet(snowflake_staging_file).select("Country", "NumVaccinated", "AvailableDoses")
Expand Down
12 changes: 7 additions & 5 deletions tests/spark/readers/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class TestJdbcReader:
}

def test_get_options_wo_extra_options(self):
jr = JdbcReader(**self.common_options)
jr = JdbcReader(**self.common_options, dbtable="table")
actual = jr.get_options()
del actual["password"] # we don't need to test for this

expected = {**self.common_options}
expected = {**self.common_options, "dbtable": "table"}
del expected["password"] # we don't need to test for this

assert actual == expected
Expand All @@ -29,6 +29,8 @@ def test_get_options_w_extra_options(self):
"foo": "foo",
"bar": "bar",
},
query = "unit test",
dbtable = "table",
**self.common_options,
)

Expand All @@ -37,17 +39,17 @@ def test_get_options_w_extra_options(self):

expected = {
**self.common_options,
"query": "unit test",
"foo": "foo",
"bar": "bar",
}
del expected["password"] # we don't need to test for this

assert actual == expected
assert sorted(actual) == sorted(expected)

def test_execute_wo_dbtable_and_query(self):
jr = JdbcReader(**self.common_options)
with pytest.raises(ValueError) as e:
jr.execute()
_ = JdbcReader(**self.common_options)
assert e.type is ValueError

def test_execute_w_dbtable_and_query(self, dummy_spark):
Expand Down
Loading

0 comments on commit 522fd70

Please sign in to comment.