Skip to content

Commit

Permalink
Create check_dask_cwd function (NVIDIA#484)
Browse files Browse the repository at this point in the history
* Create `check_dask_cwd` function

Signed-off-by: Sarah Yurick <[email protected]>

* add Praateek's suggestions

Signed-off-by: Sarah Yurick <[email protected]>

---------

Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick authored Jan 21, 2025
1 parent 2a39616 commit 57f0e3c
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions nemo_curator/utils/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ast
import os
import shutil
import subprocess

import dask

Expand Down Expand Up @@ -565,6 +566,9 @@ def read_data(
"""
if isinstance(input_files, str):
input_files = [input_files]

check_dask_cwd(input_files)

if file_type == "pickle":
df = read_pandas_pickle(
input_files[0], add_filename=add_filename, columns=columns, **kwargs
Expand Down Expand Up @@ -1013,6 +1017,24 @@ def get_current_client():
return None


def check_dask_cwd(file_list: List[str]):
if any(not os.path.isabs(file_path) for file_path in file_list):
dask_cwd_list = list(get_current_client().run(os.getcwd).values())
if len(set(dask_cwd_list)) <= 1:
dask_cwd = dask_cwd_list[0]
os_pwd = subprocess.check_output("pwd", shell=True, text=True).strip()
if dask_cwd != os_pwd:
raise RuntimeError(
"Mismatch between Dask client and worker working directories. "
"Use absolute file paths to ensure the correct files are read as intended."
)
else:
raise RuntimeError(
"Mismatch between at least 2 Dask workers' working directories. "
"Use absolute file paths to ensure the correct files are read as intended."
)


def performance_report_if(
path: Optional[str] = None, report_name: str = "dask-profile.html"
):
Expand Down

0 comments on commit 57f0e3c

Please sign in to comment.