Skip to content

Commit daf4b67

Browse files
committed
Address Reviews and docstrings
1 parent bd43d5d commit daf4b67

File tree

6 files changed

+80
-24
lines changed

6 files changed

+80
-24
lines changed

nemo_curator/modules/config.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,10 @@ class SemDedupConfig(BaseConfig):
111111
id_col_name (str): Column name for ID.
112112
id_col_type (str): Column type for ID.
113113
input_column (str): Input column for embeddings.
114-
input_file_type (str): File type for input embeddings.
115114
embeddings_save_loc (str): Location to save embeddings.
116-
model_name_or_path (str): Model name or path for embeddings.
117-
batch_size (int): Inital Batch size for processing embeddings.
118-
max_mem_gb (int): Maximum memory in GB for embeddings.
115+
embedding_model_name_or_path (str): Model name or path for embeddings.
116+
embedding_batch_size (int): Inital Batch size for processing embeddings.
117+
embedding_max_mem_gb (int): Maximum memory in GB for embeddings.
119118
clustering_save_loc (str): Location to save clustering results.
120119
n_clusters (int): Number of clusters.
121120
seed (int): Seed for clustering.
@@ -124,7 +123,8 @@ class SemDedupConfig(BaseConfig):
124123
which_to_keep (str): Which duplicates to keep.
125124
largest_cluster_size_to_process (int): Largest cluster size to process.
126125
sim_metric (str): Similarity metric for deduplication.
127-
eps (str): Epsilon values to calculate if semantically similar or not
126+
eps_thresholds (str): Epsilon thresholds to calculate if semantically similar or not
127+
eps_to_extract (float): Epsilon value to extract deduplicated data.
128128
"""
129129

130130
cache_dir: str
@@ -134,7 +134,6 @@ class SemDedupConfig(BaseConfig):
134134
input_column: str = "text"
135135

136136
# Embeddings
137-
input_file_type: str = "json"
138137
embeddings_save_loc: str = "embeddings"
139138
embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2"
140139
embedding_batch_size: int = 128
@@ -154,11 +153,16 @@ class SemDedupConfig(BaseConfig):
154153

155154
# Extract dedup config
156155
eps_thresholds: str = "0.01 0.001"
157-
eps_to_extract: str = "0.01"
156+
eps_to_extract: float = 0.01
158157

159158
def __post_init__(self):
160159
self.eps_thresholds = [float(x) for x in self.eps_thresholds.split()]
161160
if self.cache_dir is None:
162161
raise ValueError(
163162
"Finding sem-dedup requires a cache directory accessible via all workers to store intermediates"
164163
)
164+
165+
if self.eps_to_extract not in self.eps_thresholds:
166+
raise ValueError(
167+
f"Epsilon to extract {self.eps_to_extract} must be in eps_thresholds {self.eps_thresholds}"
168+
)

nemo_curator/modules/semantic_dedup.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def load_tokenizer(self):
118118
class EmbeddingCreator:
119119
def __init__(
120120
self,
121-
model_name_or_path: str,
122-
max_memory: str,
123-
batch_size: int,
121+
embeddings_model_name_or_path: str,
122+
embedding_max_mem_gb: str,
123+
embedding_batch_size: int,
124124
embedding_output_dir: str,
125125
input_column: str = "text",
126126
write_embeddings_to_disk: bool = True,
@@ -131,9 +131,9 @@ def __init__(
131131
Initializes an EmbeddingCreator for generating embeddings using the specified model configurations.
132132
133133
Args:
134-
model_name_or_path (str): The path or identifier for the model used to generate embeddings.
135-
max_memory (str): Maximum memory usage for the embedding process.
136-
batch_size (int): Number of samples to process in each batch.
134+
embeddings_model_name_or_path (str): The path or identifier for the model used to generate embeddings.
135+
embedding_max_mem_gb (str): Maximum memory usage for the embedding process.
136+
embedding_batch_size (int): Number of samples to process in each batch.
137137
embedding_output_dir (str): Directory path where embeddings will be saved.
138138
input_column (str): Column name from the data to be used for embedding generation, defaults to "text".
139139
write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True.
@@ -153,9 +153,10 @@ def __init__(
153153
"""
154154

155155
self.embeddings_config = EmbeddingConfig(
156-
model_name_or_path=model_name_or_path, max_mem_gb=max_memory
156+
model_name_or_path=embeddings_model_name_or_path,
157+
max_mem_gb=embedding_max_mem_gb,
157158
)
158-
self.batch_size = batch_size
159+
self.batch_size = embedding_batch_size
159160
self.logger = self._setup_logger(logger)
160161
self.embedding_output_dir = embedding_output_dir
161162
self.input_column = input_column

nemo_curator/scripts/semdedup/clustering.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,23 @@ def main(args):
8080

8181

8282
def attach_args():
83-
parser = ArgumentHelper.parse_semdedup_args(add_input_args=False)
83+
parser = ArgumentHelper.parse_semdedup_args(
84+
description=(
85+
"Performs clustering on the computed embeddings of a collection of documents. "
86+
"This script requires that the embeddings have been created beforehand using: "
87+
"semdedup_extract_embeddings"
88+
"Input arguments include: "
89+
"--config-file for the path to the semdedup config file. "
90+
"Important configuration parameters include: "
91+
" cache_dir for the directory to store cache,"
92+
" clustering_save_loc for the location to save clustering results,"
93+
" n_clusters for the number of clusters,"
94+
" seed for the seed for clustering,"
95+
" max_iter for the maximum iterations for clustering,"
96+
" Kmeans_with_cos_dist for using KMeans with cosine distance,"
97+
),
98+
add_input_args=False,
99+
)
84100
return parser
85101

86102

nemo_curator/scripts/semdedup/compute_embeddings.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def main(args):
6767
# Can repartition here if needed
6868
# ddf = ddf.repartition(partition_size="64MB")
6969
embedding_creator = EmbeddingCreator(
70-
model_name_or_path=semdedup_config.embedding_model_name_or_path,
71-
max_memory=semdedup_config.embedding_max_mem_gb,
72-
batch_size=semdedup_config.embedding_batch_size,
70+
embedding_model_name_or_path=semdedup_config.embedding_model_name_or_path,
71+
embedding_max_mem_gb=semdedup_config.embedding_max_mem_gb,
72+
embedding_batch_size=semdedup_config.embedding_batch_size,
7373
embedding_output_dir=os.path.join(
7474
semdedup_config.cache_dir, semdedup_config.embeddings_save_loc
7575
),
@@ -85,7 +85,28 @@ def main(args):
8585

8686

8787
def attach_args():
88-
parser = ArgumentHelper.parse_semdedup_args(add_input_args=True)
88+
parser = ArgumentHelper.parse_semdedup_args(
89+
description=(
90+
"Computes the embeddings of a collection of documents using the specified model. "
91+
"The model is specified in the config file using embedding_model_name_or_path (e.g. 'sentence-transformers/paraphrase-MiniLM-L6-v2'). "
92+
"The embeddings are saved in the specified cache directory under the embeddings_save_loc directory. "
93+
"Input arguments include: "
94+
"--input_data_dir for the directory containing input data files, "
95+
"--input_file_extension for specifying the file extension of input files (e.g., .jsonl), "
96+
"--input_file_type for the type of input files (e.g., json, csv), "
97+
"--input_text_field for the field in the input files containing the text data to be embedded. "
98+
"Additional configuration can be provided via the --config-file argument. "
99+
"Important configuration parameters include: "
100+
" cache_dir for the directory to store cache"
101+
" num_files for the number of files to process (default is -1, meaning all files),"
102+
" input_column for specifying the input column for embeddings,"
103+
" embeddings_save_loc for the location to save embeddings,"
104+
" embedding_model_name_or_path for the model name or path for embeddings,"
105+
" embedding_batch_size for the batch size for processing embeddings,"
106+
" embedding_max_mem_gb for the maximum memory in GB for embeddings"
107+
),
108+
add_input_args=True,
109+
)
89110
return parser
90111

91112

nemo_curator/scripts/semdedup/extract_dedup_data.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,26 @@ def main(args):
5757

5858
client.cancel(client.futures, force=True)
5959
client.close()
60-
return
6160

6261

6362
def attach_args():
64-
parser = ArgumentHelper.parse_semdedup_args(add_input_args=False)
63+
parser = ArgumentHelper.parse_semdedup_args(
64+
description=(
65+
"Extracts deduplicated data from the clustered embeddings of a collection of documents. "
66+
"This script requires that embeddings and clustering have been performed beforehand using the specified configurations. "
67+
"earlier using semdedup_extract_embeddings and semdedup_cluster_embeddings."
68+
"Input arguments include: "
69+
"--config-file for the path to the semdedup config file. "
70+
"Important configuration parameters include:"
71+
"- cache_dir for the directory to store cache"
72+
"which_to_keep for specifying which duplicates to keep,"
73+
"largest_cluster_size_to_process for the largest cluster size to process,"
74+
"sim_metric for the similarity metric for deduplication,"
75+
"eps_thresholds for epsilon thresholds to calculate if semantically similar or not"
76+
"and eps_to_extract for the epsilon value to extract deduplicated data."
77+
),
78+
add_input_args=False,
79+
)
6580
return parser
6681

6782

nemo_curator/utils/semdedup_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _assign_and_sort_clusters(
5252
keep_hard (bool): When True, sorts cluster items in descending order by similarity to the cluster centroid. Defaults to True.
5353
kmeans_with_cos_dist (bool): Whether to use cosine distance for K-means clustering. Defaults to True.
5454
sorted_clusters_file_loc (str): The location to save the sorted clusters file. Defaults to an empty string.
55-
cluster_ids (list): The range of cluster IDs to sort. Defaults to range(5000).
55+
cluster_ids (list): The range of cluster IDs to sort.
5656
logger (logging.Logger): A logger object to log messages. Defaults to None.
5757
5858
Returns:
@@ -268,7 +268,6 @@ def get_semantic_matches_per_cluster(
268268
points_to_remove_df[f"eps={eps}"] = eps_points_to_remove
269269

270270
points_to_remove_df.to_parquet(output_df_file_path)
271-
return None
272271

273272

274273
def get_num_records(file_path):

0 commit comments

Comments
 (0)