From d9f7be32c86c7792d67be1bee10d124c86f543d4 Mon Sep 17 00:00:00 2001 From: Vinay Raman Date: Mon, 10 Feb 2025 16:07:52 -0800 Subject: [PATCH] fixed filters, Signed-off by viraman@nvidia.com Signed-off-by: Vinay Raman --- nemo_curator/filters/synthetic.py | 51 +++++++++++++------ .../main.py | 14 +++-- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/nemo_curator/filters/synthetic.py b/nemo_curator/filters/synthetic.py index c54a32a5b..166c729e7 100644 --- a/nemo_curator/filters/synthetic.py +++ b/nemo_curator/filters/synthetic.py @@ -24,6 +24,15 @@ from nemo_curator.filters.doc_filter import DocumentFilter from nemo_curator.utils.decorators import batched +from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker + + +def create_client(base_url, api_key): + openai_client = OpenAI( + base_url=base_url, + api_key=api_key, + ) + return openai_client # ----------------------------------------------------------------------------80 @@ -52,16 +61,21 @@ def __init__( self.percentile = percentile if truncate: self.truncate = truncate - try: - self.client = OpenAI(base_url=self.base_url, api_key=self.api_key) - except Exception as e: - print(f"Error accessing NIM model: {e}") self.batch_size = batch_size self.text_fields = text_fields @batched def score_document(self, df: pd.DataFrame): + try: + self.client = load_object_on_worker( + attr="openai_client_easiness", + load_object_function=create_client, + load_object_kwargs={"base_url": self.base_url, "api_key": self.api_key}, + ) + except NoWorkerError: + return pd.Series(np.ones(len(df)), dtype=float) + document_score = self._calc_similarity_nim( df[self.text_fields[0]].to_list(), df[self.text_fields[1]].to_list() ) @@ -116,8 +130,8 @@ def _calc_similarity_nim(self, context, question): return sim - def __dask_tokenize__(self): - return normalize_token(EasinessFilter) + # def __dask_tokenize__(self): + # return normalize_token(EasinessFilter) # ----------------------------------------------------------------------------80 @@ -149,19 +163,24 @@ def __init__( self.system_prompt = answerability_system_prompt self.user_prompt_template = answerability_user_prompt_template self.num_criteria = num_criteria - - try: - self.client = OpenAI(base_url=self.base_url, api_key=self.api_key) - except Exception as e: - print(f"Error accessing NIM model: {e}") - self.text_fields = text_fields @batched def score_document(self, df: pd.DataFrame): - return df.apply( + + try: + self.client = load_object_on_worker( + attr="openai_client_answerability", + load_object_function=create_client, + load_object_kwargs={"base_url": self.base_url, "api_key": self.api_key}, + ) + except NoWorkerError: + return pd.Series(["string"] * len(df)) + + return df.progress_apply( lambda row: self._llm_as_judge( - row[self.text_fields[0]], row[self.text_fields[1]] + row[self.text_fields[0]], + row[self.text_fields[1]], ), axis=1, ) @@ -212,8 +231,8 @@ def _llm_as_judge(self, context: str, question: str): return generation - def __dask_tokenize__(self): - return normalize_token(AnswerabilityFilter) + # def __dask_tokenize__(self): + # return normalize_token(AnswerabilityFilter) # ----------------------------------------------------------------------------80 diff --git a/tutorials/nemo-retriever-synthetic-data-generation/main.py b/tutorials/nemo-retriever-synthetic-data-generation/main.py index ca2acd56b..573d8113d 100644 --- a/tutorials/nemo-retriever-synthetic-data-generation/main.py +++ b/tutorials/nemo-retriever-synthetic-data-generation/main.py @@ -26,7 +26,11 @@ from config.config import RetrieverEvalSDGConfig from nemo_curator import AsyncOpenAIClient, ScoreFilter, Sequential, get_client from nemo_curator.datasets import DocumentDataset -from nemo_curator.filters import AnswerabilityFilter, EasinessFilter +from nemo_curator.filters import ( + AnswerabilityFilter, + EasinessFilter, + NonAlphaNumericFilter, +) from nemo_curator.modules.filter import Score, ScoreFilter # from tqdm.dask import TqdmCallback @@ -44,6 +48,7 @@ def get_pipeline(args: Any) -> Any: ] ) filters = [] + if cfg.easiness_filter: filters.append( ScoreFilter( @@ -180,25 +185,28 @@ def main(): print("Generating data ...") st_time = time.time() generated_dataset = sdg_pipeline(input_dataset) - generated_dataset.persist() + # generated_dataset.persist() print("Writing all generated data to disk ...") # saving in jsonl format all_save_dir = os.path.join(args.output_dir, "jsonl", "all") os.makedirs(all_save_dir) generated_dataset.to_json(all_save_dir) + generated_dataset = DocumentDataset.read_json(all_save_dir) print("Time taken to generate data = {:.2f} s".format(time.time() - st_time)) # saving in beir format # write_to_beir(args, generated_dataset, filtered=False) print("Filtering data ...") + st_time = time.time() filtered_dataset = filtering_pipeline(generated_dataset) filtered_dataset.persist() print("Writing filtered data to disk ...") all_save_dir = os.path.join(args.output_dir, "jsonl", "filtered") os.makedirs(all_save_dir) - generated_dataset.to_json(all_save_dir) + filtered_dataset.to_json(all_save_dir) + print("Time taken to generate data = {:.2f} s".format(time.time() - st_time)) # saving in beir format # write_to_beir(args, filtered_dataset, filtered=True)