Skip to content

Commit 21af732

Browse files
ryantwolfsarahyurick
authored andcommitted
Add Synthetic Data Generation Module (NVIDIA-NeMo#136)
* Begin implementation on OpenAI client Signed-off-by: Ryan Wolf <[email protected]> * Fix relative import Signed-off-by: Ryan Wolf <[email protected]> * Add temperature Signed-off-by: Ryan Wolf <[email protected]> * Modify client interface and begin ultrachat Signed-off-by: Ryan Wolf <[email protected]> * Change type annotation in openai client Signed-off-by: Ryan Wolf <[email protected]> * Make imports easier Signed-off-by: Ryan Wolf <[email protected]> * Reformat to match nemotron report Signed-off-by: Ryan Wolf <[email protected]> * Add yaml conversion Signed-off-by: Ryan Wolf <[email protected]> * Fix index error Signed-off-by: Ryan Wolf <[email protected]> * Add error handling for yaml parsing Signed-off-by: Ryan Wolf <[email protected]> * Fix error Signed-off-by: Ryan Wolf <[email protected]> * Add additional yaml parsing check Signed-off-by: Ryan Wolf <[email protected]> * Add more yaml error handling Signed-off-by: Ryan Wolf <[email protected]> * Export conversion error Signed-off-by: Ryan Wolf <[email protected]> * Change variable naming Signed-off-by: Ryan Wolf <[email protected]> * Make error catching more general Signed-off-by: Ryan Wolf <[email protected]> * Refactor list out of nemotron Signed-off-by: Ryan Wolf <[email protected]> * Add prompt helper function Signed-off-by: Ryan Wolf <[email protected]> * Add revisions and writing prompts Signed-off-by: Ryan Wolf <[email protected]> * Fix default prompt templates Signed-off-by: Ryan Wolf <[email protected]> * Add closed qa Signed-off-by: Ryan Wolf <[email protected]> * Fix prompt Signed-off-by: Ryan Wolf <[email protected]> * Add math and coding Signed-off-by: Ryan Wolf <[email protected]> * Add problem generation Signed-off-by: Ryan Wolf <[email protected]> * Rename function Signed-off-by: Ryan Wolf <[email protected]> * Add dialogue support Signed-off-by: Ryan Wolf <[email protected]> * Fix mispell Signed-off-by: Ryan Wolf <[email protected]> * Add two turn generation Signed-off-by: Ryan Wolf <[email protected]> * Add reward model as judge Signed-off-by: Ryan Wolf <[email protected]> * Refactor reward query Signed-off-by: Ryan Wolf <[email protected]> * Add error handling for non-reward models Signed-off-by: Ryan Wolf <[email protected]> * Add error handling to sync client Signed-off-by: Ryan Wolf <[email protected]> * Add open qa pipeline Signed-off-by: Ryan Wolf <[email protected]> * Improve docs and add writing pipeline Signed-off-by: Ryan Wolf <[email protected]> * Add closed qa pipeline Signed-off-by: Ryan Wolf <[email protected]> * Add math pipeline Signed-off-by: Ryan Wolf <[email protected]> * Add python pipeline Signed-off-by: Ryan Wolf <[email protected]> * Add async nemotron generator Signed-off-by: Ryan Wolf <[email protected]> * Fix await with index Signed-off-by: Ryan Wolf <[email protected]> * Add seed parameter Signed-off-by: Ryan Wolf <[email protected]> * Add missing await Signed-off-by: Ryan Wolf <[email protected]> * Fix parameter names Signed-off-by: Ryan Wolf <[email protected]> * Fix subscript await issues Signed-off-by: Ryan Wolf <[email protected]> * Switch parsing method for reward model Signed-off-by: Ryan Wolf <[email protected]> * Add initial docs Signed-off-by: Ryan Wolf <[email protected]> * Add nemo deploy client Signed-off-by: Ryan Wolf <[email protected]> * Add easy import Signed-off-by: Ryan Wolf <[email protected]> * Move conversation formatter Signed-off-by: Ryan Wolf <[email protected]> * Add other file Signed-off-by: Ryan Wolf <[email protected]> * Update nemotron import Signed-off-by: Ryan Wolf <[email protected]> * Update model client import Signed-off-by: Ryan Wolf <[email protected]> * Remove model in query call Signed-off-by: Ryan Wolf <[email protected]> * Add extra index Signed-off-by: Ryan Wolf <[email protected]> * Fix response indexing Signed-off-by: Ryan Wolf <[email protected]> * Add top k Signed-off-by: Ryan Wolf <[email protected]> * Remove extras Signed-off-by: Ryan Wolf <[email protected]> * Add safe import for nemo deploy Signed-off-by: Ryan Wolf <[email protected]> * Add pandas conversions Signed-off-by: Ryan Wolf <[email protected]> * Add partition default Signed-off-by: Ryan Wolf <[email protected]> * Add no format Signed-off-by: Ryan Wolf <[email protected]> * Move no format location Signed-off-by: Ryan Wolf <[email protected]> * Use top_k in nemo client Signed-off-by: Ryan Wolf <[email protected]> * Address vibhu's review Signed-off-by: Ryan Wolf <[email protected]> * Add logging import Signed-off-by: Ryan Wolf <[email protected]> * Fix import Signed-off-by: Ryan Wolf <[email protected]> * Fix tqdm Signed-off-by: Ryan Wolf <[email protected]> * Add missing awaits Signed-off-by: Ryan Wolf <[email protected]> * Standardize names Signed-off-by: Ryan Wolf <[email protected]> * Address Ayush nit Signed-off-by: Ryan Wolf <[email protected]> --------- Signed-off-by: Ryan Wolf <[email protected]>
1 parent c87233d commit 21af732

18 files changed

+3883
-2
lines changed

docs/user-guide/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
:ref:`GPU Accelerated Exact and Fuzzy Deduplication <data-curator-gpu-deduplication>`
1919
Both exact and fuzzy deduplication functionalities are supported in NeMo Curator and accelerated using RAPIDS cuDF.
2020

21+
:ref:`Synthetic Data Generation <data-curator-syntheticdata>`
22+
Synthetic data generation tools and example piplines are available within NeMo Curator.
23+
2124
:ref:`Downstream Task Decontamination <data-curator-downstream>`
2225
After training, large language models are usually evaluated by their performance on downstream tasks consisting of unseen test data. When dealing with large datasets, there is a potential for leakage of this test data into the model’s training dataset. NeMo Curator allows you to remove sections of documents in your dataset that are present in downstream tasks.
2326

docs/user-guide/syntheticdata.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
2+
.. _data-curator-syntheticdata:
3+
4+
======================================
5+
Synthetic Data Generation
6+
======================================
7+
--------------------------------------
8+
Background
9+
--------------------------------------
10+
Synthetic data generation has become increasing useful in large language model training.
11+
It is used in pretraining, fine-tuning, and evalutation.
12+
Synthetically generated data can be useful for adapting an LLM to low resource languages/domains, or performing knowledge distillation from other models among other purposes.
13+
There are a variety of ways to construct synthetic data generation pipelines, with numerous LLM and classical filters.
14+
15+
NeMo Curator has a simple, easy-to-use set of tools that allow you to use prebuilt synthetic generation pipelines or build your own.
16+
Any model inference service that uses the OpenAI API is compatible with the synthetic data generation module, allowing you to generate your data from any model.
17+
NeMo Curator has prebuilt synthetic data generation pipelines for supervised fine-tuning (SFT) and preference data that were used to generate data for the training of `Nemotron-4 340B <https://research.nvidia.com/publication/2024-06_nemotron-4-340b>`_.
18+
And, you can easily interweave filtering and deduplication steps in your synthetic data pipeline with the other modules in NeMo Curator.

nemo_curator/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@
3434

3535

3636
from .modules import *
37+
from .services import (
38+
AsyncLLMClient,
39+
AsyncOpenAIClient,
40+
LLMClient,
41+
NemoDeployClient,
42+
OpenAIClient,
43+
)
3744
from .utils.distributed_utils import get_client
3845

3946
# Dask will automatically convert the list score type

nemo_curator/datasets/doc_dataset.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Union
15+
from typing import List, Optional, Union
1616

1717
import dask.dataframe as dd
1818

@@ -130,6 +130,44 @@ def to_pickle(
130130
):
131131
raise NotImplementedError("DocumentDataset does not support to_pickle yet")
132132

133+
@classmethod
134+
def from_pandas(
135+
cls,
136+
data,
137+
npartitions: Optional[int] = 1,
138+
chunksize: Optional[int] = None,
139+
sort: Optional[bool] = True,
140+
name: Optional[str] = None,
141+
):
142+
"""
143+
Creates a document dataset from a pandas data frame.
144+
For more information on the arguments see Dask's from_pandas documentation
145+
https://docs.dask.org/en/stable/generated/dask.dataframe.from_pandas.html
146+
147+
Args:
148+
data: A pandas dataframe
149+
Returns:
150+
A document dataset with a pandas backend (on the CPU).
151+
"""
152+
return cls(
153+
dd.from_pandas(
154+
data=data,
155+
npartitions=npartitions,
156+
chunksize=chunksize,
157+
sort=sort,
158+
name=name,
159+
)
160+
)
161+
162+
def to_pandas(self):
163+
"""
164+
Creates a pandas dataframe from a DocumentDataset
165+
166+
Returns:
167+
A pandas dataframe (on the CPU)
168+
"""
169+
return self.df.to_backend("pandas").compute()
170+
133171

134172
def _read_json_or_parquet(
135173
input_files: Union[str, List[str]],

nemo_curator/services/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from .conversation_formatter import ConversationFormatter
15+
from .model_client import AsyncLLMClient, LLMClient
16+
from .nemo_client import NemoDeployClient
17+
from .openai_client import AsyncOpenAIClient, OpenAIClient
18+
19+
__all__ = [
20+
"AsyncLLMClient",
21+
"LLMClient",
22+
"AsyncOpenAIClient",
23+
"OpenAIClient",
24+
"NemoDeployClient",
25+
"ConversationFormatter",
26+
]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from abc import ABC, abstractmethod
15+
from typing import List
16+
17+
18+
class ConversationFormatter(ABC):
19+
"""
20+
Represents a way of formatting a conversation with an LLM
21+
such that it can response appropriately
22+
"""
23+
24+
@abstractmethod
25+
def format_conversation(self, conv: List[dict]) -> str:
26+
raise NotImplementedError(
27+
"format_converstaion must be implemented by subclasses"
28+
)

nemo_curator/services/model_client.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from abc import ABC, abstractmethod
15+
from typing import Iterable, List, Optional, Union
16+
17+
from nemo_curator.services.conversation_formatter import ConversationFormatter
18+
19+
20+
class LLMClient(ABC):
21+
"""
22+
Interface representing a client connecting to an LLM inference server
23+
and making requests synchronously
24+
"""
25+
26+
@abstractmethod
27+
def query_model(
28+
self,
29+
*,
30+
messages: Iterable,
31+
model: str,
32+
conversation_formatter: Optional[ConversationFormatter] = None,
33+
max_tokens: Optional[int] = None,
34+
n: Optional[int] = 1,
35+
seed: Optional[int] = None,
36+
stop: Union[Optional[str], List[str]] = None,
37+
stream: bool = False,
38+
temperature: Optional[float] = None,
39+
top_k: Optional[int] = None,
40+
top_p: Optional[float] = None,
41+
) -> List[str]:
42+
raise NotImplementedError("Subclass of LLMClient must implement 'query_model'")
43+
44+
@abstractmethod
45+
def query_reward_model(
46+
self,
47+
*,
48+
messages: Iterable,
49+
model: str,
50+
conversation_formatter: Optional[ConversationFormatter] = None,
51+
) -> dict:
52+
raise NotImplementedError(
53+
"Subclass of LLMClient must implement 'query_reward_model'"
54+
)
55+
56+
57+
class AsyncLLMClient(ABC):
58+
"""
59+
Interface representing a client connecting to an LLM inference server
60+
and making requests asynchronously
61+
"""
62+
63+
@abstractmethod
64+
async def query_model(
65+
self,
66+
*,
67+
messages: Iterable,
68+
model: str,
69+
conversation_formatter: Optional[ConversationFormatter] = None,
70+
max_tokens: Optional[int] = None,
71+
n: Optional[int] = 1,
72+
seed: Optional[int] = None,
73+
stop: Union[Optional[str], List[str]] = None,
74+
stream: bool = False,
75+
temperature: Optional[float] = None,
76+
top_k: Optional[int] = None,
77+
top_p: Optional[float] = None,
78+
) -> List[str]:
79+
raise NotImplementedError(
80+
"Subclass of AsyncLLMClient must implement 'query_model'"
81+
)
82+
83+
@abstractmethod
84+
async def query_reward_model(
85+
self,
86+
*,
87+
messages: Iterable,
88+
model: str,
89+
conversation_formatter: Optional[ConversationFormatter] = None,
90+
) -> dict:
91+
raise NotImplementedError(
92+
"Subclass of LLMClient must implement 'query_reward_model'"
93+
)

nemo_curator/services/nemo_client.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import warnings
15+
from typing import Iterable, List, Optional, Union
16+
17+
from nemo_curator.services.conversation_formatter import ConversationFormatter
18+
from nemo_curator.utils.import_utils import safe_import_from
19+
20+
from .model_client import AsyncLLMClient, LLMClient
21+
22+
NemoQueryLLM = safe_import_from("nemo.deploy.nlp", "NemoQueryLLM")
23+
24+
25+
class NemoDeployClient(LLMClient):
26+
"""
27+
A wrapper around NemoQueryLLM for querying models in synthetic data generation
28+
"""
29+
30+
def __init__(self, nemo_deploy: NemoQueryLLM) -> None:
31+
self.client = nemo_deploy
32+
33+
def query_model(
34+
self,
35+
*,
36+
messages: Iterable,
37+
model: str,
38+
conversation_formatter: Optional[ConversationFormatter] = None,
39+
max_tokens: Optional[int] = None,
40+
n: Optional[int] = None,
41+
seed: Optional[int] = None,
42+
stop: Union[Optional[str], List[str]] = None,
43+
stream: bool = False,
44+
temperature: Optional[float] = None,
45+
top_k: Optional[int] = None,
46+
top_p: Optional[float] = None,
47+
) -> List[str]:
48+
if conversation_formatter is None:
49+
raise ValueError(
50+
"NemoDeployClient's query_model requires a conversation_formatter"
51+
)
52+
53+
prompt = conversation_formatter.format_conversation(messages)
54+
self.client.model_name = model
55+
56+
if n is not None:
57+
warnings.warn("n is not supported in NemoDeployClient")
58+
if stream:
59+
warnings.warn("streamming is not supported in NeMoDeployClient")
60+
61+
if isinstance(stop, str):
62+
stop = [stop]
63+
64+
response = self.client.query_llm(
65+
prompts=[prompt],
66+
max_output_len=max_tokens,
67+
random_seed=seed,
68+
stop_words_list=stop,
69+
temperature=temperature,
70+
top_p=top_p,
71+
top_k=top_k,
72+
)[0]
73+
74+
return self._postprocess_response(response, stop)
75+
76+
@staticmethod
77+
def _postprocess_response(responses: List[str], stop_words: List[str]) -> List[str]:
78+
processed_responses = []
79+
for response in responses:
80+
for stop in stop_words:
81+
if response.endswith(stop):
82+
response = response[: -len(stop)]
83+
processed_responses.append(response.strip())
84+
return processed_responses
85+
86+
def query_reward_model(self, *, messages: Iterable, model: str) -> dict:
87+
"""
88+
Prompts an LLM Reward model to score a conversation between a user and assistant
89+
Args:
90+
messages: The conversation to calculate a score for.
91+
Should be formatted like:
92+
[{"role": "user", "content": "Write a sentence"}, {"role": "assistant", "content": "This is a sentence"}, ...]
93+
model: The name of the model that should be used to calculate the reward.
94+
Must be a reward model, cannot be a regular LLM.
95+
Returns:
96+
A mapping of score_name -> score
97+
"""
98+
raise NotImplementedError(
99+
"Reward model inference is not supported in NeMo Deploy Clients"
100+
)

0 commit comments

Comments
 (0)