diff --git a/hlink/spark/factory.py b/hlink/spark/factory.py index 8c4781d..e7d320d 100644 --- a/hlink/spark/factory.py +++ b/hlink/spark/factory.py @@ -24,6 +24,7 @@ def __init__(self): spark_dir = Path("spark").resolve() self.derby_dir = spark_dir / "derby" self.warehouse_dir = spark_dir / "warehouse" + self.checkpoint_dir = spark_dir / "checkpoint" self.tmp_dir = spark_dir / "tmp" self.python = sys.executable self.db_name = "linking" @@ -40,6 +41,10 @@ def set_warehouse_dir(self, warehouse_dir): self.warehouse_dir = warehouse_dir return self + def set_checkpoint_dir(self, checkpoint_dir): + self.checkpoint_dir = checkpoint_dir + return self + def set_tmp_dir(self, tmp_dir): self.tmp_dir = tmp_dir return self @@ -78,7 +83,7 @@ def create(self): spark_conn = SparkConnection( str(self.derby_dir), str(self.warehouse_dir), - "checkpoint", + str(self.checkpoint_dir), str(self.tmp_dir), self.python, self.db_name, diff --git a/hlink/spark/session.py b/hlink/spark/session.py index a0d7841..54723df 100644 --- a/hlink/spark/session.py +++ b/hlink/spark/session.py @@ -44,6 +44,7 @@ def __init__( ): self.derby_dir = derby_dir self.warehouse_dir = warehouse_dir + self.checkpoint_dir = checkpoint_dir self.db_name = db_name self.tmp_dir = tmp_dir self.python = python @@ -122,7 +123,7 @@ def connect( if self.db_name not in [d.name for d in session.catalog.listDatabases()]: session.sql(f"CREATE DATABASE IF NOT EXISTS {self.db_name}") session.catalog.setCurrentDatabase(self.db_name) - session.sparkContext.setCheckpointDir(str(self.tmp_dir)) + session.sparkContext.setCheckpointDir(str(self.checkpoint_dir)) self._register_udfs(session) # If the SynapseML Python package is available, include the Scala diff --git a/hlink/tests/spark_connection_test.py b/hlink/tests/spark_connection_test.py index 707fb22..c45831a 100644 --- a/hlink/tests/spark_connection_test.py +++ b/hlink/tests/spark_connection_test.py @@ -1,4 +1,5 @@ from pathlib import Path +import re import sys from hlink.spark.session import SparkConnection @@ -20,7 +21,7 @@ def test_app_name_defaults_to_linking(tmp_path: Path) -> None: def test_app_name_argument(tmp_path: Path) -> None: derby_dir = tmp_path / "derby" warehouse_dir = tmp_path / "warehouse" - checkpoint_dir = tmp_path / "checkpoint_dir" + checkpoint_dir = tmp_path / "checkpoint" tmp_dir = tmp_path / "tmp" connection = SparkConnection( derby_dir, @@ -34,3 +35,22 @@ def test_app_name_argument(tmp_path: Path) -> None: spark = connection.local(cores=1, executor_memory="1G") app_name = spark.conf.get("spark.app.name") assert app_name == "test_app_name" + + +def test_sets_checkpoint_directory(tmp_path: Path) -> None: + derby_dir = tmp_path / "derby" + warehouse_dir = tmp_path / "warehouse" + checkpoint_dir = tmp_path / "checkpoint" + tmp_dir = tmp_path / "tmp" + connection = SparkConnection( + derby_dir, + warehouse_dir, + checkpoint_dir, + tmp_dir, + sys.executable, + "test", + ) + spark = connection.local(cores=1, executor_memory="1G") + + spark_checkpoint_dir = spark.sparkContext.getCheckpointDir() + assert re.search(str(checkpoint_dir), spark_checkpoint_dir) diff --git a/hlink/tests/spark_factory_test.py b/hlink/tests/spark_factory_test.py index 895131c..803bf30 100644 --- a/hlink/tests/spark_factory_test.py +++ b/hlink/tests/spark_factory_test.py @@ -1,4 +1,5 @@ from pathlib import Path +import re from pyspark.sql import Row @@ -33,3 +34,19 @@ def test_spark_factory_can_create_spark_session(tmp_path: Path) -> None: Row(equals_b=True), Row(equals_b=False), ] + + +def test_spark_factory_set_checkpoint_dir(tmp_path: Path) -> None: + checkpoint_dir = tmp_path / "checkpoint" + + factory = ( + SparkFactory() + .set_local() + .set_num_cores(1) + .set_executor_cores(1) + .set_executor_memory("1G") + .set_checkpoint_dir(checkpoint_dir) + ) + spark = factory.create() + spark_checkpoint_dir = spark.sparkContext.getCheckpointDir() + assert re.search(str(checkpoint_dir), spark_checkpoint_dir)