Skip to content

Commit

Permalink
fixed filters, Signed-off by [email protected]
Browse files Browse the repository at this point in the history
Signed-off-by: Vinay Raman <[email protected]>
  • Loading branch information
vinay-raman committed Feb 11, 2025
1 parent cf0ec14 commit d9f7be3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 19 deletions.
51 changes: 35 additions & 16 deletions nemo_curator/filters/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@

from nemo_curator.filters.doc_filter import DocumentFilter
from nemo_curator.utils.decorators import batched
from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker


def create_client(base_url, api_key):
openai_client = OpenAI(
base_url=base_url,
api_key=api_key,
)
return openai_client


# ----------------------------------------------------------------------------80
Expand Down Expand Up @@ -52,16 +61,21 @@ def __init__(
self.percentile = percentile
if truncate:
self.truncate = truncate
try:
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
except Exception as e:
print(f"Error accessing NIM model: {e}")
self.batch_size = batch_size
self.text_fields = text_fields

@batched
def score_document(self, df: pd.DataFrame):

try:
self.client = load_object_on_worker(
attr="openai_client_easiness",
load_object_function=create_client,
load_object_kwargs={"base_url": self.base_url, "api_key": self.api_key},
)
except NoWorkerError:
return pd.Series(np.ones(len(df)), dtype=float)

document_score = self._calc_similarity_nim(
df[self.text_fields[0]].to_list(), df[self.text_fields[1]].to_list()
)
Expand Down Expand Up @@ -116,8 +130,8 @@ def _calc_similarity_nim(self, context, question):

return sim

def __dask_tokenize__(self):
return normalize_token(EasinessFilter)
# def __dask_tokenize__(self):
# return normalize_token(EasinessFilter)


# ----------------------------------------------------------------------------80
Expand Down Expand Up @@ -149,19 +163,24 @@ def __init__(
self.system_prompt = answerability_system_prompt
self.user_prompt_template = answerability_user_prompt_template
self.num_criteria = num_criteria

try:
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
except Exception as e:
print(f"Error accessing NIM model: {e}")

self.text_fields = text_fields

@batched
def score_document(self, df: pd.DataFrame):
return df.apply(

try:
self.client = load_object_on_worker(
attr="openai_client_answerability",
load_object_function=create_client,
load_object_kwargs={"base_url": self.base_url, "api_key": self.api_key},
)
except NoWorkerError:
return pd.Series(["string"] * len(df))

return df.progress_apply(
lambda row: self._llm_as_judge(
row[self.text_fields[0]], row[self.text_fields[1]]
row[self.text_fields[0]],
row[self.text_fields[1]],
),
axis=1,
)
Expand Down Expand Up @@ -212,8 +231,8 @@ def _llm_as_judge(self, context: str, question: str):

return generation

def __dask_tokenize__(self):
return normalize_token(AnswerabilityFilter)
# def __dask_tokenize__(self):
# return normalize_token(AnswerabilityFilter)


# ----------------------------------------------------------------------------80
14 changes: 11 additions & 3 deletions tutorials/nemo-retriever-synthetic-data-generation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from config.config import RetrieverEvalSDGConfig
from nemo_curator import AsyncOpenAIClient, ScoreFilter, Sequential, get_client
from nemo_curator.datasets import DocumentDataset
from nemo_curator.filters import AnswerabilityFilter, EasinessFilter
from nemo_curator.filters import (
AnswerabilityFilter,
EasinessFilter,
NonAlphaNumericFilter,
)
from nemo_curator.modules.filter import Score, ScoreFilter

# from tqdm.dask import TqdmCallback
Expand All @@ -44,6 +48,7 @@ def get_pipeline(args: Any) -> Any:
]
)
filters = []

if cfg.easiness_filter:
filters.append(
ScoreFilter(
Expand Down Expand Up @@ -180,25 +185,28 @@ def main():
print("Generating data ...")
st_time = time.time()
generated_dataset = sdg_pipeline(input_dataset)
generated_dataset.persist()
# generated_dataset.persist()

print("Writing all generated data to disk ...")
# saving in jsonl format
all_save_dir = os.path.join(args.output_dir, "jsonl", "all")
os.makedirs(all_save_dir)
generated_dataset.to_json(all_save_dir)
generated_dataset = DocumentDataset.read_json(all_save_dir)
print("Time taken to generate data = {:.2f} s".format(time.time() - st_time))

# saving in beir format
# write_to_beir(args, generated_dataset, filtered=False)

print("Filtering data ...")
st_time = time.time()
filtered_dataset = filtering_pipeline(generated_dataset)
filtered_dataset.persist()
print("Writing filtered data to disk ...")
all_save_dir = os.path.join(args.output_dir, "jsonl", "filtered")
os.makedirs(all_save_dir)
generated_dataset.to_json(all_save_dir)
filtered_dataset.to_json(all_save_dir)
print("Time taken to generate data = {:.2f} s".format(time.time() - st_time))

# saving in beir format
# write_to_beir(args, filtered_dataset, filtered=True)
Expand Down

0 comments on commit d9f7be3

Please sign in to comment.