Skip to content

Commit

Permalink
Fix #63. Add --input-meta parameter to explicitly specify the jsonl f…
Browse files Browse the repository at this point in the history
…ield dtypes (#75)

* Add dtype support (optional) when reading jsonl files

Signed-off-by: Miguel Martínez <[email protected]>
Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Change input_meta type hint

Signed-off-by: Miguel Martínez <[email protected]>

* Change input_meta type hint

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Resolve merge conflit

Signed-off-by: Miguel Martínez <[email protected]>

* Assign input_meta to the right variable

Signed-off-by: Miguel Martínez <[email protected]>

* Add warning when input_meta is used with non jsonl files.

Signed-off-by: Miguel Martínez <[email protected]>

* Explicitly check for None when validating input_meta

Signed-off-by: Miguel Martínez <[email protected]>

* Add input_meta test

Signed-off-by: Miguel Martínez <[email protected]>

* Add description to function

Signed-off-by: Miguel Martínez <[email protected]>

* Add test_meta_str

Signed-off-by: Miguel Martínez <[email protected]>

---------

Signed-off-by: Miguel Martínez <[email protected]>
Signed-off-by: Miguel Martínez <[email protected]>
Co-authored-by: Miguel Martínez <[email protected]>
  • Loading branch information
miguelusque and miguelusque authored May 30, 2024
1 parent 30416e0 commit 757b389
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 42 deletions.
42 changes: 25 additions & 17 deletions nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

import dask.dataframe as dd

from nemo_curator.utils.distributed_utils import read_data, write_to_disk
Expand All @@ -36,10 +38,11 @@ def persist(self):
@classmethod
def read_json(
cls,
input_files,
backend="pandas",
files_per_partition=1,
add_filename=False,
input_files: Union[str, List[str]],
backend: str = "pandas",
files_per_partition: int = 1,
add_filename: bool = False,
input_meta: Union[str, dict] = None,
):
return cls(
_read_json_or_parquet(
Expand All @@ -48,6 +51,7 @@ def read_json(
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
input_meta=input_meta,
)
)

Expand Down Expand Up @@ -77,16 +81,16 @@ def read_pickle(
files_per_partition=1,
add_filename=False,
):
raw_data = read_data(
input_files=input_files,
file_type="pickle",
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
return cls(
read_data(
input_files=input_files,
file_type="pickle",
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
)
)

return cls(raw_data)

def to_json(
self,
output_file_dir,
Expand Down Expand Up @@ -128,11 +132,12 @@ def to_pickle(


def _read_json_or_parquet(
input_files,
file_type,
backend,
files_per_partition,
add_filename,
input_files: Union[str, List[str]],
file_type: str,
backend: str,
files_per_partition: int,
add_filename: bool,
input_meta: Union[str, dict] = None,
):
"""
`input_files` may be a list or a string type.
Expand Down Expand Up @@ -162,6 +167,7 @@ def _read_json_or_parquet(
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
input_meta=input_meta,
)

# List of directories
Expand All @@ -178,6 +184,7 @@ def _read_json_or_parquet(
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
input_meta=input_meta,
)
dfs.append(df)

Expand All @@ -200,6 +207,7 @@ def _read_json_or_parquet(
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
input_meta=input_meta,
)

else:
Expand Down
7 changes: 6 additions & 1 deletion nemo_curator/download/doc_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import importlib
import os
from abc import ABC, abstractmethod
from typing import List, Tuple
from typing import List, Tuple, Union

import dask.dataframe as dd
import pandas as pd
Expand Down Expand Up @@ -111,6 +111,7 @@ def _download_and_extract_single_partition(
output_type: str,
keep_raw_download: bool,
force_download: bool,
input_meta: Union[str, dict] = None,
) -> pd.DataFrame:
url, output_path = paths

Expand Down Expand Up @@ -158,6 +159,7 @@ def download_and_extract(
output_type: str = "jsonl",
keep_raw_download=False,
force_download=False,
input_meta: Union[str, dict] = None,
) -> DocumentDataset:
"""
Downloads and extracts a dataset into a format accepted by the NeMo Curator
Expand All @@ -174,6 +176,8 @@ def download_and_extract(
keep_raw_download: Whether to keep the pre-extracted download file.
force_download: If False, will skip processing all files in output_paths that already exist and
directly read from them instead.
input_meta: A dictionary or a string formatted as a dictionary, which outlines
the field names and their respective data types within the JSONL input file.
Returns:
A DocumentDataset of the downloaded data
Expand All @@ -192,6 +196,7 @@ def download_and_extract(
keep_raw_download=keep_raw_download,
force_download=force_download,
enforce_metadata=False,
input_meta=input_meta,
meta=output_format,
)

Expand Down
10 changes: 9 additions & 1 deletion nemo_curator/scripts/download_and_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import argparse
import os

from nemo_curator.download import batch_download, download_and_extract
from nemo_curator.download.doc_builder import batch_download, download_and_extract
from nemo_curator.utils.config_utils import build_downloader
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.file_utils import (
Expand Down Expand Up @@ -77,6 +77,7 @@ def main(args):
output_format,
keep_raw_download=args.keep_downloaded_files,
force_download=args.overwrite_existing_json,
input_meta=args.input_meta,
)

# Sample to trigger the dask computation
Expand Down Expand Up @@ -120,6 +121,13 @@ def attach_args(
required=False,
help="Path to input data directory",
)
parser.add_argument(
"--input-meta",
type=str,
default=None,
help="A string formatted as a dictionary, which outlines the field names and "
"their respective data types within the JSONL input files.",
)
parser.add_argument(
"--output-json-dir",
type=str,
Expand Down
9 changes: 9 additions & 0 deletions nemo_curator/scripts/fuzzy_deduplication/jaccard_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def main(args):
blocksize=args.text_ddf_blocksize,
id_column=args.input_json_id_field,
text_column=args.input_json_text_field,
input_meta=args.input_meta,
)
print(
"Graph creation for get_text_ddf_from_json_path_with_blocksize" " complete.",
Expand Down Expand Up @@ -86,6 +87,13 @@ def attach_args(parser=None):
type=str,
help="The directory containing anchor docs with bk files",
)
parser.add_argument(
"--input-meta",
type=str,
default=None,
help="A string formatted as a dictionary, which outlines the field names and "
"their respective data types within the JSONL input files.",
)
parser.add_argument(
"--text-ddf-blocksize",
type=int,
Expand Down Expand Up @@ -115,6 +123,7 @@ def attach_args(parser=None):
type=int,
help="The number of bucket parts to process per worker per batch",
)

return parser


Expand Down
12 changes: 12 additions & 0 deletions nemo_curator/scripts/fuzzy_deduplication/map_buckets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_anchor_and_output_map_info(
input_bucket_field,
input_id_field,
input_text_field,
input_meta,
):
"""
Get anchor docs with bucket info
Expand All @@ -53,6 +54,7 @@ def get_anchor_and_output_map_info(
blocksize=text_ddf_blocksize,
id_column=input_id_field,
text_column=input_text_field,
input_meta=input_meta,
)
ddf_bk = get_bucket_ddf_from_parquet_path(
input_bucket_path=input_bucket_path, num_workers=num_workers
Expand All @@ -79,6 +81,13 @@ def attach_args(parser=None):
type=str,
help="The directory containing bucket information files",
)
parser.add_argument(
"--input-meta",
type=str,
default=None,
help="A string formatted as a dictionary, which outlines the field names and "
"their respective data types within the JSONL input files.",
)
parser.add_argument(
"--text-ddf-blocksize",
type=int,
Expand Down Expand Up @@ -116,6 +125,7 @@ def jaccard_get_output_map_workflow(
input_bucket_field,
input_id_field,
input_text_field,
input_meta,
):
"""
Workflow for jaccard shuffle
Expand All @@ -140,6 +150,7 @@ def jaccard_get_output_map_workflow(
input_bucket_field,
input_id_field,
input_text_field,
input_meta=input_meta,
)
ddf_anchor_docs_with_bk.to_parquet(
output_anchor_docs_with_bk_path,
Expand Down Expand Up @@ -171,6 +182,7 @@ def main(args):
args.input_bucket_field,
args.input_json_id_field,
args.input_json_text_field,
args.input_meta,
)
et = time.time()
print(f"Bucket Mapping time taken = {et-st} s")
Expand Down
41 changes: 31 additions & 10 deletions nemo_curator/scripts/verify_classification_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import argparse
import ast
import os
from typing import Union

import pandas as pd

Expand All @@ -27,30 +29,39 @@ def parse_args():
"""
parser = argparse.ArgumentParser(description="Run verification")

parser.add_argument(
"--results_file_path",
type=str,
help="The path of the input files",
required=True,
help="The path of the input files",
)
parser.add_argument(
"--expected_results_file_path",
type=str,
help="The path of the expected_result file",
required=True,
help="The path of the expected_result file",
)
parser.add_argument(
"--results_pred_column",
type=str,
help="The prediction column name for the input files",
default="pred",
help="The prediction column name for the input files",
)
parser.add_argument(
"--expected_pred_column",
type=str,
help="The prediction column name for the expected_result file",
default="pred",
help="The prediction column name for the expected_result file",
)
parser.add_argument(
"--input-meta",
type=str,
default=None,
help="A string formatted as a dictionary, which outlines the field names and "
"their respective data types within the JSONL input files.",
)

return parser.parse_args()


Expand Down Expand Up @@ -122,10 +133,11 @@ def verify_same_dataframe(


def verify_results(
results_file_path,
expected_results_file_path,
results_pred_column,
expected_pred_column,
results_file_path: str,
expected_results_file_path: str,
results_pred_column: str,
expected_pred_column: str,
input_meta: Union[str, dict] = None,
):
"""
This function compares an input file with its expected result file.
Expand All @@ -136,9 +148,14 @@ def verify_results(
expected_results_file_path: The path of the expected_result file.
results_pred_column: The prediction column name for the input files.
expected_pred_column: The prediction column name for the expected_result file.
input_meta: A dictionary or a string formatted as a dictionary, which outlines
the field names and their respective data types within the JSONL input file.
"""
expected_df = pd.read_json(expected_results_file_path, lines=True)
if type(input_meta) == str:
input_meta = ast.literal_eval(input_meta)

expected_df = pd.read_json(expected_results_file_path, lines=True, dtype=input_meta)
expected_df = expected_df.sort_values(by=["text"]).reset_index(drop=True)
expected_counts = expected_df[expected_pred_column].value_counts().to_dict()

Expand All @@ -150,7 +167,10 @@ def verify_results(
]

got_paths = [p for p in os.scandir(results_file_path)]
got_df = [pd.read_json(path, lines=True)[expected_columns] for path in got_paths]
got_df = [
pd.read_json(path, lines=True, dtype=input_meta)[expected_columns]
for path in got_paths
]
got_df = pd.concat(got_df, ignore_index=True)
got_df = got_df.sort_values(by=["text"]).reset_index(drop=True)
got_counts = got_df[results_pred_column].value_counts().to_dict()
Expand All @@ -172,6 +192,7 @@ def main():
args.expected_results_file_path,
args.results_pred_column,
args.expected_pred_column,
args.input_meta,
)


Expand Down
Loading

0 comments on commit 757b389

Please sign in to comment.