@@ -28,7 +28,7 @@ def _transform_func(tokenizer,
28
28
29
29
# Triton is not thread safe AFAICT so using naive DataParallel fails
30
30
class EncoderWorker (mp .Process ):
31
- def __init__ (self , rank , world_size , input_queue , output_queue , model_name , tokenizer_name , batch_size , master_port = 12345 ):
31
+ def __init__ (self , rank , world_size , input_queue , output_queue , model_name , tokenizer_name , batch_size , master_port = 12344 ):
32
32
super ().__init__ ()
33
33
self .rank = rank
34
34
self .world_size = world_size
@@ -99,7 +99,7 @@ def run(self):
99
99
100
100
local_embeds = []
101
101
with torch .no_grad ():
102
- for batch_dict in tqdm (loader , desc = f"Rank { self .rank } " ):
102
+ for batch_dict in tqdm (loader , desc = f"Rank { self .rank } " , disable = True ):
103
103
batch_dict = {k : v .cuda (self .rank ) for k , v in batch_dict .items ()}
104
104
with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
105
105
outputs = encoder (** batch_dict )
@@ -215,11 +215,15 @@ def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
215
215
def encode_single_device (
216
216
self ,
217
217
sentences : Union [List [str ], str ],
218
- batch_size : int = 256 ,
218
+ batch_size : int = 512 ,
219
219
max_length : int = 512 ,
220
220
convert_to_numpy : bool = True ,
221
221
device : Optional [str ] = None ,
222
222
):
223
+ if isinstance (sentences , str ):
224
+ sentences = [sentences ]
225
+
226
+ # Initialize workers if not already initialized
223
227
if len (self .workers ) == 0 :
224
228
for rank in range (self .world_size ):
225
229
worker = EncoderWorker (
@@ -234,17 +238,41 @@ def encode_single_device(
234
238
worker .start ()
235
239
self .workers .append (worker )
236
240
237
- if isinstance (sentences , str ):
238
- sentences = [sentences ]
239
-
240
- for _ in range (self .world_size ):
241
- self .input_queue .put (sentences )
242
- result = self .output_queue .get ()
243
-
244
- if isinstance (result , Exception ):
245
- raise result
246
-
247
- return result
241
+ # Calculate number of batches
242
+ total_samples = len (sentences )
243
+ batch_size = 65536
244
+ num_batches = (total_samples + batch_size - 1 ) // batch_size
245
+
246
+ all_results = []
247
+
248
+ # Process sentences in batches
249
+ for batch_idx in tqdm (range (num_batches )):
250
+ start_idx = batch_idx * batch_size
251
+ end_idx = min ((batch_idx + 1 ) * batch_size , total_samples )
252
+ batch_sentences = sentences [start_idx :end_idx ]
253
+
254
+ # Distribute batch to workers
255
+ for _ in range (self .world_size ):
256
+ self .input_queue .put (batch_sentences )
257
+
258
+ # Get results for this batch
259
+ batch_result = self .output_queue .get ()
260
+
261
+ if isinstance (batch_result , Exception ):
262
+ raise batch_result
263
+
264
+ all_results .append (batch_result )
265
+
266
+ # Concatenate results from all batches
267
+ if len (all_results ) > 1 :
268
+ if isinstance (all_results [0 ], np .ndarray ):
269
+ final_result = np .concatenate (all_results , axis = 0 )
270
+ else : # Assuming torch.Tensor
271
+ final_result = torch .cat (all_results , dim = 0 )
272
+ else :
273
+ final_result = all_results [0 ]
274
+
275
+ return final_result
248
276
249
277
def __del__ (self ):
250
278
# Send poison pills to workers
0 commit comments