Skip to content

Commit 5f8f35d

Browse files
Accumulate tokens in generate mode (#1534)
* batch callbacks in generate mode * command line argument
1 parent 371eb8b commit 5f8f35d

File tree

2 files changed

+53
-17
lines changed

2 files changed

+53
-17
lines changed

torchchat/cli/cli.py

+6
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,12 @@ def _add_generation_args(parser, verb: str) -> None:
359359
default=1,
360360
help="Number of samples",
361361
)
362+
generator_parser.add_argument(
363+
"--accumulate-tokens",
364+
type=int,
365+
default=8,
366+
help="Number of generated tokens to accumulate before calling the callback on each one of them.",
367+
)
362368

363369
generator_parser.add_argument(
364370
"--image-prompts",

torchchat/generate.py

+47-17
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ class GeneratorArgs:
230230
max_autotune: bool = False
231231
# (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
232232
is_torchtune_model: bool = False
233+
accumulate_tokens: int = 8
233234

234235
def __post_init__(self):
235236
if self.compile_prefill and self.sequential_prefill:
@@ -294,6 +295,7 @@ def from_args(cls, args):
294295
sequential_prefill=sequential_prefill,
295296
max_autotune=args.max_autotune,
296297
is_torchtune_model=args.model and args.model.endswith("tune"),
298+
accumulate_tokens=getattr(args, "accumulate_tokens", 8),
297299
)
298300

299301

@@ -530,11 +532,13 @@ def decode_n_tokens(
530532
need_probs: bool,
531533
batch=Optional[Dict[str, Any]], # Inputs for multimodal models
532534
callback=lambda _: _,
535+
accumulate_tokens: int = 8,
533536
eos_token_id: int = 2,
534537
eot_id: Optional[int] = None,
535538
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
536539
**sampling_kwargs,
537540
):
541+
new_tokens = []
538542
encountered_eos = False
539543
for _i in range(
540544
num_new_tokens - 1
@@ -552,29 +556,52 @@ def decode_n_tokens(
552556
**sampling_kwargs,
553557
)
554558
input_pos += 1
555-
callback(next_token.clone(), done_generating=_i == num_new_tokens - 2)
559+
new_tokens.append(next_token.clone())
560+
561+
done_generating = _i == num_new_tokens - 2
562+
if need_probs:
563+
callback(new_tokens[-1], done_generating=done_generating)
556564
if not need_probs or next_prob is None:
557565
yield out_token, None
558566
else:
559567
yield out_token, next_prob.clone()
560568
cur_token = next_token
561569

562-
# encountered eos
563-
if next_token.item() == eos_token_id or (
564-
eot_id is not None and next_token.item() == eot_id
565-
):
566-
encountered_eos = True
567-
final_token, next_prob = self.decode_one_token(
568-
model,
569-
cur_token,
570-
input_pos,
571-
need_probs,
572-
batch=batch,
573-
**sampling_kwargs,
574-
)
575-
input_pos += 1
576-
yield cur_token.clone(), next_prob.clone()
577-
break
570+
if need_probs:
571+
# encountered eos
572+
if next_token.item() == eos_token_id or (
573+
eot_id is not None and next_token.item() == eot_id
574+
):
575+
encountered_eos = True
576+
final_token, next_prob = self.decode_one_token(
577+
model,
578+
cur_token,
579+
input_pos,
580+
need_probs,
581+
batch=batch,
582+
**sampling_kwargs,
583+
)
584+
input_pos += 1
585+
yield cur_token.clone(), next_prob.clone()
586+
break
587+
else:
588+
callback_pos = _i % accumulate_tokens + 1
589+
if done_generating or callback_pos == accumulate_tokens:
590+
callback_num = min(accumulate_tokens, callback_pos)
591+
for i in range(callback_num, 0, -1):
592+
callback(new_tokens[-i], done_generating=done_generating)
593+
594+
token_item = new_tokens[-i].item()
595+
# encountered eos
596+
if token_item == eos_token_id or (
597+
eot_id is not None and token_item == eot_id
598+
):
599+
encountered_eos = True
600+
input_pos += 1
601+
yield new_tokens[-i].clone(), None
602+
break
603+
if encountered_eos:
604+
break
578605

579606
if not encountered_eos:
580607
eos_token = torch.tensor(
@@ -681,6 +708,7 @@ def generate(
681708
speculate_k: Optional[int] = 8,
682709
sequential_prefill=True,
683710
callback=lambda x: x,
711+
accumulate_tokens: int,
684712
max_seq_length: int,
685713
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
686714
seed: Optional[int] = None,
@@ -791,6 +819,7 @@ def generate(
791819
max_new_tokens - 1,
792820
batch=batch,
793821
callback=callback,
822+
accumulate_tokens=accumulate_tokens,
794823
need_probs=False,
795824
eos_token_id=self.tokenizer.eos_id() if self.tokenizer else 2,
796825
eot_id=(
@@ -1179,6 +1208,7 @@ def callback(x, *, done_generating=False):
11791208
chat_mode=generator_args.chat_mode,
11801209
batch=batch,
11811210
callback=callback,
1211+
accumulate_tokens=generator_args.accumulate_tokens,
11821212
temperature=generator_args.temperature,
11831213
top_k=generator_args.top_k,
11841214
sequential_prefill=generator_args.sequential_prefill,

0 commit comments

Comments
 (0)