Skip to content

Commit

Permalink
[#181] Implement checkpoint_dir behavior for SparkConnection and Spar…
Browse files Browse the repository at this point in the history
…kFactory
  • Loading branch information
riley-harper committed Dec 13, 2024
1 parent e0bf86e commit 3dbc75b
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 3 deletions.
7 changes: 6 additions & 1 deletion hlink/spark/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self):
spark_dir = Path("spark").resolve()
self.derby_dir = spark_dir / "derby"
self.warehouse_dir = spark_dir / "warehouse"
self.checkpoint_dir = spark_dir / "checkpoint"
self.tmp_dir = spark_dir / "tmp"
self.python = sys.executable
self.db_name = "linking"
Expand All @@ -40,6 +41,10 @@ def set_warehouse_dir(self, warehouse_dir):
self.warehouse_dir = warehouse_dir
return self

def set_checkpoint_dir(self, checkpoint_dir):
self.checkpoint_dir = checkpoint_dir
return self

def set_tmp_dir(self, tmp_dir):
self.tmp_dir = tmp_dir
return self
Expand Down Expand Up @@ -78,7 +83,7 @@ def create(self):
spark_conn = SparkConnection(
str(self.derby_dir),
str(self.warehouse_dir),
"checkpoint",
str(self.checkpoint_dir),
str(self.tmp_dir),
self.python,
self.db_name,
Expand Down
3 changes: 2 additions & 1 deletion hlink/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
):
self.derby_dir = derby_dir
self.warehouse_dir = warehouse_dir
self.checkpoint_dir = checkpoint_dir
self.db_name = db_name
self.tmp_dir = tmp_dir
self.python = python
Expand Down Expand Up @@ -122,7 +123,7 @@ def connect(
if self.db_name not in [d.name for d in session.catalog.listDatabases()]:
session.sql(f"CREATE DATABASE IF NOT EXISTS {self.db_name}")
session.catalog.setCurrentDatabase(self.db_name)
session.sparkContext.setCheckpointDir(str(self.tmp_dir))
session.sparkContext.setCheckpointDir(str(self.checkpoint_dir))
self._register_udfs(session)

# If the SynapseML Python package is available, include the Scala
Expand Down
22 changes: 21 additions & 1 deletion hlink/tests/spark_connection_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
import re
import sys

from hlink.spark.session import SparkConnection
Expand All @@ -20,7 +21,7 @@ def test_app_name_defaults_to_linking(tmp_path: Path) -> None:
def test_app_name_argument(tmp_path: Path) -> None:
derby_dir = tmp_path / "derby"
warehouse_dir = tmp_path / "warehouse"
checkpoint_dir = tmp_path / "checkpoint_dir"
checkpoint_dir = tmp_path / "checkpoint"
tmp_dir = tmp_path / "tmp"
connection = SparkConnection(
derby_dir,
Expand All @@ -34,3 +35,22 @@ def test_app_name_argument(tmp_path: Path) -> None:
spark = connection.local(cores=1, executor_memory="1G")
app_name = spark.conf.get("spark.app.name")
assert app_name == "test_app_name"


def test_sets_checkpoint_directory(tmp_path: Path) -> None:
derby_dir = tmp_path / "derby"
warehouse_dir = tmp_path / "warehouse"
checkpoint_dir = tmp_path / "checkpoint"
tmp_dir = tmp_path / "tmp"
connection = SparkConnection(
derby_dir,
warehouse_dir,
checkpoint_dir,
tmp_dir,
sys.executable,
"test",
)
spark = connection.local(cores=1, executor_memory="1G")

spark_checkpoint_dir = spark.sparkContext.getCheckpointDir()
assert re.search(str(checkpoint_dir), spark_checkpoint_dir)
17 changes: 17 additions & 0 deletions hlink/tests/spark_factory_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
import re

from pyspark.sql import Row

Expand Down Expand Up @@ -33,3 +34,19 @@ def test_spark_factory_can_create_spark_session(tmp_path: Path) -> None:
Row(equals_b=True),
Row(equals_b=False),
]


def test_spark_factory_set_checkpoint_dir(tmp_path: Path) -> None:
checkpoint_dir = tmp_path / "checkpoint"

factory = (
SparkFactory()
.set_local()
.set_num_cores(1)
.set_executor_cores(1)
.set_executor_memory("1G")
.set_checkpoint_dir(checkpoint_dir)
)
spark = factory.create()
spark_checkpoint_dir = spark.sparkContext.getCheckpointDir()
assert re.search(str(checkpoint_dir), spark_checkpoint_dir)

0 comments on commit 3dbc75b

Please sign in to comment.