@@ -230,6 +230,7 @@ class GeneratorArgs:
230
230
max_autotune : bool = False
231
231
# (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
232
232
is_torchtune_model : bool = False
233
+ accumulate_tokens : int = 8
233
234
234
235
def __post_init__ (self ):
235
236
if self .compile_prefill and self .sequential_prefill :
@@ -294,6 +295,7 @@ def from_args(cls, args):
294
295
sequential_prefill = sequential_prefill ,
295
296
max_autotune = args .max_autotune ,
296
297
is_torchtune_model = args .model and args .model .endswith ("tune" ),
298
+ accumulate_tokens = getattr (args , "accumulate_tokens" , 8 ),
297
299
)
298
300
299
301
@@ -530,11 +532,13 @@ def decode_n_tokens(
530
532
need_probs : bool ,
531
533
batch = Optional [Dict [str , Any ]], # Inputs for multimodal models
532
534
callback = lambda _ : _ ,
535
+ accumulate_tokens : int = 8 ,
533
536
eos_token_id : int = 2 ,
534
537
eot_id : Optional [int ] = None ,
535
538
attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
536
539
** sampling_kwargs ,
537
540
):
541
+ new_tokens = []
538
542
encountered_eos = False
539
543
for _i in range (
540
544
num_new_tokens - 1
@@ -552,29 +556,52 @@ def decode_n_tokens(
552
556
** sampling_kwargs ,
553
557
)
554
558
input_pos += 1
555
- callback (next_token .clone (), done_generating = _i == num_new_tokens - 2 )
559
+ new_tokens .append (next_token .clone ())
560
+
561
+ done_generating = _i == num_new_tokens - 2
562
+ if need_probs :
563
+ callback (new_tokens [- 1 ], done_generating = done_generating )
556
564
if not need_probs or next_prob is None :
557
565
yield out_token , None
558
566
else :
559
567
yield out_token , next_prob .clone ()
560
568
cur_token = next_token
561
569
562
- # encountered eos
563
- if next_token .item () == eos_token_id or (
564
- eot_id is not None and next_token .item () == eot_id
565
- ):
566
- encountered_eos = True
567
- final_token , next_prob = self .decode_one_token (
568
- model ,
569
- cur_token ,
570
- input_pos ,
571
- need_probs ,
572
- batch = batch ,
573
- ** sampling_kwargs ,
574
- )
575
- input_pos += 1
576
- yield cur_token .clone (), next_prob .clone ()
577
- break
570
+ if need_probs :
571
+ # encountered eos
572
+ if next_token .item () == eos_token_id or (
573
+ eot_id is not None and next_token .item () == eot_id
574
+ ):
575
+ encountered_eos = True
576
+ final_token , next_prob = self .decode_one_token (
577
+ model ,
578
+ cur_token ,
579
+ input_pos ,
580
+ need_probs ,
581
+ batch = batch ,
582
+ ** sampling_kwargs ,
583
+ )
584
+ input_pos += 1
585
+ yield cur_token .clone (), next_prob .clone ()
586
+ break
587
+ else :
588
+ callback_pos = _i % accumulate_tokens + 1
589
+ if done_generating or callback_pos == accumulate_tokens :
590
+ callback_num = min (accumulate_tokens , callback_pos )
591
+ for i in range (callback_num , 0 , - 1 ):
592
+ callback (new_tokens [- i ], done_generating = done_generating )
593
+
594
+ token_item = new_tokens [- i ].item ()
595
+ # encountered eos
596
+ if token_item == eos_token_id or (
597
+ eot_id is not None and token_item == eot_id
598
+ ):
599
+ encountered_eos = True
600
+ input_pos += 1
601
+ yield new_tokens [- i ].clone (), None
602
+ break
603
+ if encountered_eos :
604
+ break
578
605
579
606
if not encountered_eos :
580
607
eos_token = torch .tensor (
@@ -681,6 +708,7 @@ def generate(
681
708
speculate_k : Optional [int ] = 8 ,
682
709
sequential_prefill = True ,
683
710
callback = lambda x : x ,
711
+ accumulate_tokens : int ,
684
712
max_seq_length : int ,
685
713
attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
686
714
seed : Optional [int ] = None ,
@@ -791,6 +819,7 @@ def generate(
791
819
max_new_tokens - 1 ,
792
820
batch = batch ,
793
821
callback = callback ,
822
+ accumulate_tokens = accumulate_tokens ,
794
823
need_probs = False ,
795
824
eos_token_id = self .tokenizer .eos_id () if self .tokenizer else 2 ,
796
825
eot_id = (
@@ -1179,6 +1208,7 @@ def callback(x, *, done_generating=False):
1179
1208
chat_mode = generator_args .chat_mode ,
1180
1209
batch = batch ,
1181
1210
callback = callback ,
1211
+ accumulate_tokens = generator_args .accumulate_tokens ,
1182
1212
temperature = generator_args .temperature ,
1183
1213
top_k = generator_args .top_k ,
1184
1214
sequential_prefill = generator_args .sequential_prefill ,
0 commit comments