Skip to content

Commit

Permalink
[#181] Add a new checkpoint_dir argument to SparkConnection()
Browse files Browse the repository at this point in the history
Previously we always set the checkpoint directory to be the same as
spark.local.dir, which we call "tmp_dir". However, this doesn't make sense
because tmp_dir should be on a disk local to each executor, and the checkpoint
directory has to be on shared storage to work correctly.
  • Loading branch information
riley-harper committed Dec 13, 2024
1 parent 1f99c93 commit e0bf86e
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 2 deletions.
2 changes: 2 additions & 0 deletions hlink/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,14 @@ def _parse_args():
def _get_spark(run_name: str, args: argparse.Namespace) -> SparkSession:
derby_dir = HLINK_DIR / "derby" / run_name
warehouse_dir = HLINK_DIR / "warehouse" / run_name
checkpoint_dir = HLINK_DIR / "checkpoint" / run_name
tmp_dir = HLINK_DIR / "tmp" / run_name
python = sys.executable

spark_connection = SparkConnection(
derby_dir=derby_dir,
warehouse_dir=warehouse_dir,
checkpoint_dir=checkpoint_dir,
tmp_dir=tmp_dir,
python=python,
db_name="linking",
Expand Down
1 change: 1 addition & 0 deletions hlink/spark/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def create(self):
spark_conn = SparkConnection(
str(self.derby_dir),
str(self.warehouse_dir),
"checkpoint",
str(self.tmp_dir),
self.python,
self.db_name,
Expand Down
9 changes: 8 additions & 1 deletion hlink/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@ class SparkConnection:
"""Handles initialization of spark session and connection to local cluster."""

def __init__(
self, derby_dir, warehouse_dir, tmp_dir, python, db_name, app_name="linking"
self,
derby_dir,
warehouse_dir,
checkpoint_dir,
tmp_dir,
python,
db_name,
app_name="linking",
):
self.derby_dir = derby_dir
self.warehouse_dir = warehouse_dir
Expand Down
1 change: 1 addition & 0 deletions hlink/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def spark(tmpdir_factory):
spark_connection = SparkConnection(
tmpdir_factory.mktemp("derby"),
tmpdir_factory.mktemp("warehouse"),
tmpdir_factory.mktemp("checkpoint"),
tmpdir_factory.mktemp("spark_tmp_dir"),
sys.executable,
"linking",
Expand Down
5 changes: 4 additions & 1 deletion hlink/tests/spark_connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
def test_app_name_defaults_to_linking(tmp_path: Path) -> None:
derby_dir = tmp_path / "derby"
warehouse_dir = tmp_path / "warehouse"
checkpoint_dir = tmp_path / "checkpoint"
tmp_dir = tmp_path / "tmp"
connection = SparkConnection(
derby_dir, warehouse_dir, tmp_dir, sys.executable, "test"
derby_dir, warehouse_dir, checkpoint_dir, tmp_dir, sys.executable, "test"
)
spark = connection.local(cores=1, executor_memory="1G")
app_name = spark.conf.get("spark.app.name")
Expand All @@ -19,10 +20,12 @@ def test_app_name_defaults_to_linking(tmp_path: Path) -> None:
def test_app_name_argument(tmp_path: Path) -> None:
derby_dir = tmp_path / "derby"
warehouse_dir = tmp_path / "warehouse"
checkpoint_dir = tmp_path / "checkpoint_dir"
tmp_dir = tmp_path / "tmp"
connection = SparkConnection(
derby_dir,
warehouse_dir,
checkpoint_dir,
tmp_dir,
sys.executable,
"test",
Expand Down

0 comments on commit e0bf86e

Please sign in to comment.