Skip to content

Commit

Permalink
Use NeMo Curator Utils
Browse files Browse the repository at this point in the history
Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa committed Jun 4, 2024
1 parent 115598f commit 9e8c2cd
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 81 deletions.
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
import argparse
import os
import warnings
import time
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path

import cudf
import dask_cudf
import torch
import torch.nn as nn
from crossfit import op
from crossfit.backend.torch.hf.model import HFModel
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer

from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk
from nemo_curator.utils.script_utils import (
parse_client_args,
parse_distributed_classifier_args,
)


@dataclass
class TranslationConfig:
pretrained_model_name_or_path: str
max_length: int = 256
num_beams: int = 5
autocast: bool = False


class CustomModel(nn.Module):
def __init__(self, pretrained_model_name_or_path: str):
def __init__(self, pretrained_model_name_or_path: str, autocast: bool = False):
super().__init__()
self.model = AutoModelForSeq2SeqLM.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
Expand All @@ -33,16 +36,25 @@ def __init__(self, pretrained_model_name_or_path: str):
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
)
self.autocast = autocast

def forward(self, batch):
outputs = self.model.generate(
def _forward(self, batch):
return self.model.generate(
**batch,
use_cache=True,
min_length=0,
max_length=self.config.max_length,
num_beams=self.config.num_beams,
num_return_sequences=1,
)

@torch.no_grad()
def forward(self, batch):
if self.autocast:
with torch.autocast(device_type="cuda"):
outputs = self._forward(batch)
else:
outputs = self._forward(batch)
return outputs


Expand All @@ -52,7 +64,12 @@ def __init__(self, config):
super().__init__(config.pretrained_model_name_or_path)

def load_model(self, device="cuda"):
return load_model(config=self.config, device=device)
model = CustomModel(
self.config.pretrained_model_name_or_path, self.config.autocast
)
model = model.to(device)
model.eval()
return model

def load_config(self):
return AutoConfig.from_pretrained(
Expand All @@ -78,13 +95,6 @@ def load_cfg(self):
)


def load_model(config, device):
model = CustomModel(config.pretrained_model_name_or_path)
model = model.to(device)
model.eval()
return model


def translate_tokens(df, model):
tokenizer = model.load_tokenizer()
generated_tokens = df["translation"].to_arrow().to_pylist()
Expand All @@ -99,78 +109,39 @@ def translate_tokens(df, model):


def parse_arguments():
parser = argparse.ArgumentParser(
description="PyTorch Model Predictions using Crossfit"
)
parser.add_argument(
"--input-jsonl-path", help="Input JSONL file path", required=True
)
parser.add_argument(
"--output-parquet-path", help="Output Parquet file path", required=True
)
parser = parse_distributed_classifier_args()
parser.add_argument(
"--input-column",
type=str,
required=False,
default="text",
help="Column name in input dataframe",
help="The column name in the input data that contains the text to be translated",
)
parser.add_argument("--pool-size", type=str, default="1GB", help="RMM pool size")
parser.add_argument("--num-workers", type=int, default=1, help="Number of GPUs")
parser.add_argument(
"--pretrained-model-name-or-path",
type=str,
default="ai4bharat/indictrans2-en-indic-1B",
help="Model name",
)
parser.add_argument("--batch-size", type=int, default=256, help="Batch size")
return parser.parse_args()


def single_partition_write_with_filename(df, output_file_dir):
assert "path" in df.columns

if len(df) > 0:
empty_partition = True
else:
warnings.warn("Empty partition found")
empty_partition = False

success_ser = cudf.Series([empty_partition])
if empty_partition:
filename = df.path.iloc[0]
num_filenames = len(df.path.unique())
if num_filenames > 1:
raise ValueError(
f"More than one filename found in partition: {num_filenames}"
)
filename = Path(filename).stem
output_file_path = os.path.join(output_file_dir, f"{filename}.parquet")
df["path"] = df["path"].astype(str)
df.to_parquet(output_file_path)

return success_ser


def main():
args = parse_arguments()

cluster = LocalCUDACluster(
rmm_pool_size=args.pool_size, n_workers=args.num_workers, rmm_async=True
)
client = Client(cluster)
print(f"Arguments parsed = {args}")
client = get_client(**parse_client_args(args))
print(client.dashboard_link)

st = time.time()
translation_config = TranslationConfig(
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
max_length=256,
num_beams=5,
autocast=args.autocast,
)
input_files = [
os.path.join(args.input_jsonl_path, x)
for x in os.listdir(args.input_jsonl_path)
os.path.join(args.input_data_dir, x) for x in os.listdir(args.input_data_dir)
]
# ddf = dask_cudf.read_json(input_files, lines=True, include_path_column=True)
ddf = dask_cudf.read_parquet(input_files, include_path_column=True)
ddf = read_data(
input_files,
file_type=args.input_file_type,
backend="cudf",
files_per_partition=1,
add_filename=True,
)
columns = ddf.columns.tolist()
model = ModelForSeq2SeqModel(translation_config)
pipe = op.Sequential(
Expand All @@ -188,13 +159,14 @@ def main():
translated_meta["translation"] = "DUMMY_STRING"
ddf = ddf.map_partitions(translate_tokens, model=model, meta=translated_meta)

# Create output directory if it does not exist
os.makedirs(args.output_parquet_path, exist_ok=True)
ddf.map_partitions(
single_partition_write_with_filename,
output_file_dir=args.output_parquet_path,
meta=cudf.Series(dtype=bool),
).compute()
write_to_disk(
ddf,
output_file_dir=args.output_data_dir,
write_to_filename=True,
output_type=args.output_file_type,
)
print("Total time taken for translation: ".format(time.time() - st))
client.close()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/scripts/domain_classifier_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main():
add_filename = True

domain_classifier = DomainClassifier(
model_path=args.model_path,
model_path=args.pretrained_model_name_or_path,
labels=labels,
max_chars=max_chars,
batch_size=args.batch_size,
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/scripts/quality_classifier_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def main():
add_filename = True

classifier = QualityClassifier(
model_path=args.model_path,
model_path=args.pretrained_model_name_or_path,
max_chars=max_chars,
labels=labels,
batch_size=args.batch_size,
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/utils/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def parse_distributed_classifier_args(
required=True,
)
parser.add_argument(
"--model-path",
"--pretrained-model-name-or-path",
type=str,
help="The path to the model file",
required=True,
Expand Down

0 comments on commit 9e8c2cd

Please sign in to comment.