@@ -335,6 +335,7 @@ def generate(
335
335
return_attn = False ,
336
336
return_hidden = False ,
337
337
stream = False ,
338
+ show_tqdm = True ,
338
339
context = Context (),
339
340
):
340
341
@@ -368,160 +369,165 @@ def generate(
368
369
attention_mask_cache .narrow (1 , 0 , attention_mask .shape [1 ]).copy_ (
369
370
attention_mask
370
371
)
372
+
373
+ pbar : Optional [tqdm ] = None
374
+
375
+ if show_tqdm :
376
+ pbar = tqdm (
377
+ total = max_new_token ,
378
+ desc = "text" if infer_text else "code" ,
379
+ bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]" ,
380
+ )
371
381
372
- with tqdm (
373
- total = max_new_token ,
374
- desc = "text" if infer_text else "code" ,
375
- bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]" ,
376
- ) as pbar :
377
-
378
- past_key_values = None
382
+ past_key_values = None
379
383
380
- for i in range (max_new_token ):
381
- model_input = self ._prepare_generation_inputs (
382
- inputs_ids ,
383
- past_key_values ,
384
- attention_mask_cache .narrow (1 , 0 , inputs_ids .shape [1 ]),
385
- use_cache = True ,
386
- )
384
+ for i in range (max_new_token ):
385
+ model_input = self ._prepare_generation_inputs (
386
+ inputs_ids ,
387
+ past_key_values ,
388
+ attention_mask_cache .narrow (1 , 0 , inputs_ids .shape [1 ]),
389
+ use_cache = True ,
390
+ )
387
391
388
- if i > 0 :
389
- del emb
390
- inputs_ids_emb = model_input .input_ids .to (self .device_gpt )
391
- if infer_text :
392
- emb : torch .Tensor = self .emb_text (inputs_ids_emb [:, :, 0 ])
393
- else :
394
- code_emb = [
395
- self .emb_code [i ](inputs_ids_emb [:, :, i ])
396
- for i in range (self .num_vq )
397
- ]
398
- emb = torch .stack (code_emb , 3 ).sum (3 )
399
- del inputs_ids_emb , model_input .input_ids
400
- model_input .inputs_embeds = emb
401
-
402
- model_input .to (self .device_gpt )
403
-
404
- outputs : BaseModelOutputWithPast = self .gpt (
405
- attention_mask = model_input .attention_mask ,
406
- position_ids = model_input .position_ids ,
407
- past_key_values = model_input .past_key_values ,
408
- inputs_embeds = model_input .inputs_embeds ,
409
- use_cache = model_input .use_cache ,
410
- output_attentions = return_attn ,
411
- cache_position = model_input .cache_position ,
412
- )
413
- del_all (model_input )
414
- attentions .append (outputs .attentions )
415
- hidden_states = outputs .last_hidden_state .to (self .device ) # 🐻
416
- past_key_values = outputs .past_key_values
417
- del_all (outputs )
418
- if return_hidden :
419
- hiddens .append (hidden_states .narrow (1 , - 1 , 1 ).squeeze_ (1 ))
420
-
421
- with P .cached ():
422
- if infer_text :
423
- logits : torch .Tensor = self .head_text (hidden_states )
424
- else :
425
- # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
426
- logits = torch .empty (
427
- hidden_states .size (0 ),
428
- hidden_states .size (1 ),
429
- self .num_audio_tokens ,
430
- self .num_vq ,
431
- dtype = torch .float ,
432
- device = self .device ,
433
- )
434
- for i in range (self .num_vq ):
435
- x : torch .Tensor = self .head_code [i ](hidden_states )
436
- logits [..., i ] = x
437
- del x
438
-
439
- # logits = logits[:, -1].float()
440
- logits = logits .narrow (1 , - 1 , 1 ).squeeze_ (1 ).float ()
441
-
442
- if not infer_text :
443
- # logits = rearrange(logits, "b c n -> (b n) c")
444
- logits = logits .permute (0 , 2 , 1 )
445
- logits = logits .reshape (- 1 , logits .size (2 ))
446
- # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
447
- inputs_ids_sliced = inputs_ids [:, start_idx :].permute (0 , 2 , 1 )
448
- logits_token = inputs_ids_sliced .reshape (
449
- inputs_ids_sliced .size (0 ) * inputs_ids_sliced .size (1 ),
450
- - 1 ,
451
- ).to (self .device )
392
+ if i > 0 :
393
+ del emb
394
+ inputs_ids_emb = model_input .input_ids .to (self .device_gpt )
395
+ if infer_text :
396
+ emb : torch .Tensor = self .emb_text (inputs_ids_emb [:, :, 0 ])
397
+ else :
398
+ code_emb = [
399
+ self .emb_code [i ](inputs_ids_emb [:, :, i ])
400
+ for i in range (self .num_vq )
401
+ ]
402
+ emb = torch .stack (code_emb , 3 ).sum (3 )
403
+ del inputs_ids_emb , model_input .input_ids
404
+ model_input .inputs_embeds = emb
405
+
406
+ model_input .to (self .device_gpt )
407
+
408
+ outputs : BaseModelOutputWithPast = self .gpt (
409
+ attention_mask = model_input .attention_mask ,
410
+ position_ids = model_input .position_ids ,
411
+ past_key_values = model_input .past_key_values ,
412
+ inputs_embeds = model_input .inputs_embeds ,
413
+ use_cache = model_input .use_cache ,
414
+ output_attentions = return_attn ,
415
+ cache_position = model_input .cache_position ,
416
+ )
417
+ del_all (model_input )
418
+ attentions .append (outputs .attentions )
419
+ hidden_states = outputs .last_hidden_state .to (self .device ) # 🐻
420
+ past_key_values = outputs .past_key_values
421
+ del_all (outputs )
422
+ if return_hidden :
423
+ hiddens .append (hidden_states .narrow (1 , - 1 , 1 ).squeeze_ (1 ))
424
+
425
+ with P .cached ():
426
+ if infer_text :
427
+ logits : torch .Tensor = self .head_text (hidden_states )
452
428
else :
453
- logits_token = inputs_ids [:, start_idx :, 0 ].to (self .device )
429
+ # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
430
+ logits = torch .empty (
431
+ hidden_states .size (0 ),
432
+ hidden_states .size (1 ),
433
+ self .num_audio_tokens ,
434
+ self .num_vq ,
435
+ dtype = torch .float ,
436
+ device = self .device ,
437
+ )
438
+ for i in range (self .num_vq ):
439
+ x : torch .Tensor = self .head_code [i ](hidden_states )
440
+ logits [..., i ] = x
441
+ del x
442
+
443
+ # logits = logits[:, -1].float()
444
+ logits = logits .narrow (1 , - 1 , 1 ).squeeze_ (1 ).float ()
445
+
446
+ if not infer_text :
447
+ # logits = rearrange(logits, "b c n -> (b n) c")
448
+ logits = logits .permute (0 , 2 , 1 )
449
+ logits = logits .reshape (- 1 , logits .size (2 ))
450
+ # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
451
+ inputs_ids_sliced = inputs_ids [:, start_idx :].permute (0 , 2 , 1 )
452
+ logits_token = inputs_ids_sliced .reshape (
453
+ inputs_ids_sliced .size (0 ) * inputs_ids_sliced .size (1 ),
454
+ - 1 ,
455
+ ).to (self .device )
456
+ else :
457
+ logits_token = inputs_ids [:, start_idx :, 0 ].to (self .device )
458
+
459
+ logits /= temperature
454
460
455
- logits /= temperature
461
+ for logitsProcessors in logits_processors :
462
+ logits = logitsProcessors (logits_token , logits )
456
463
457
- for logitsProcessors in logits_processors :
458
- logits = logitsProcessors (logits_token , logits )
464
+ for logitsWarpers in logits_warpers :
465
+ logits = logitsWarpers (logits_token , logits )
459
466
460
- for logitsWarpers in logits_warpers :
461
- logits = logitsWarpers (logits_token , logits )
467
+ del logits_token
462
468
463
- del logits_token
469
+ if i < min_new_token :
470
+ logits [:, eos_token ] = - torch .inf
464
471
465
- if i < min_new_token :
466
- logits [:, eos_token ] = - torch .inf
472
+ scores = F .softmax (logits , dim = - 1 )
467
473
468
- scores = F . softmax ( logits , dim = - 1 )
474
+ del logits
469
475
470
- del logits
476
+ idx_next = torch .multinomial (scores , num_samples = 1 ).to (
477
+ finish .device
478
+ )
471
479
472
- idx_next = torch .multinomial (scores , num_samples = 1 ).to (
473
- finish .device
480
+ if not infer_text :
481
+ # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
482
+ idx_next = idx_next .view (- 1 , self .num_vq )
483
+ finish_or = idx_next .eq (eos_token ).any (1 )
484
+ finish .logical_or_ (finish_or )
485
+ del finish_or
486
+ inputs_ids_tmp = torch .cat (
487
+ [inputs_ids , idx_next .unsqueeze_ (1 )], 1
488
+ )
489
+ else :
490
+ finish_or = idx_next .eq (eos_token ).any (1 )
491
+ finish .logical_or_ (finish_or )
492
+ del finish_or
493
+ inputs_ids_tmp = torch .cat (
494
+ [
495
+ inputs_ids ,
496
+ idx_next .unsqueeze_ (- 1 ).expand (- 1 , - 1 , self .num_vq ),
497
+ ],
498
+ 1 ,
474
499
)
475
500
476
- if not infer_text :
477
- # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
478
- idx_next = idx_next .view (- 1 , self .num_vq )
479
- finish_or = idx_next .eq (eos_token ).any (1 )
480
- finish .logical_or_ (finish_or )
481
- del finish_or
482
- inputs_ids_tmp = torch .cat (
483
- [inputs_ids , idx_next .unsqueeze_ (1 )], 1
484
- )
485
- else :
486
- finish_or = idx_next .eq (eos_token ).any (1 )
487
- finish .logical_or_ (finish_or )
488
- del finish_or
489
- inputs_ids_tmp = torch .cat (
490
- [
491
- inputs_ids ,
492
- idx_next .unsqueeze_ (- 1 ).expand (- 1 , - 1 , self .num_vq ),
493
- ],
494
- 1 ,
501
+ del inputs_ids
502
+ inputs_ids = inputs_ids_tmp
503
+ del inputs_ids_tmp , idx_next
504
+
505
+ if stream :
506
+ minus_prev_end_index = end_idx .neg ()
507
+ end_idx .add_ ((finish .logical_not ().to (end_idx .device )).int ())
508
+ if stream :
509
+ if (
510
+ end_idx .all ()
511
+ and end_idx .fmod (24 ).eq (0 ).any ()
512
+ and minus_prev_end_index .add_ (end_idx ).any ()
513
+ ):
514
+ self .logger .debug ("yield stream result, end: %d" , end_idx )
515
+ yield self ._prepare_generation_outputs (
516
+ inputs_ids ,
517
+ start_idx ,
518
+ end_idx ,
519
+ attentions ,
520
+ hiddens ,
521
+ infer_text ,
495
522
)
523
+ del minus_prev_end_index
524
+
525
+ if finish .all () or context .get ():
526
+ break
496
527
497
- del inputs_ids
498
- inputs_ids = inputs_ids_tmp
499
- del inputs_ids_tmp , idx_next
500
-
501
- if stream :
502
- minus_prev_end_index = end_idx .neg ()
503
- end_idx .add_ ((finish .logical_not ().to (end_idx .device )).int ())
504
- if stream :
505
- if (
506
- end_idx .all ()
507
- and end_idx .fmod (24 ).eq (0 ).any ()
508
- and minus_prev_end_index .add_ (end_idx ).any ()
509
- ):
510
- self .logger .debug ("yield stream result, end: %d" , end_idx )
511
- yield self ._prepare_generation_outputs (
512
- inputs_ids ,
513
- start_idx ,
514
- end_idx ,
515
- attentions ,
516
- hiddens ,
517
- infer_text ,
518
- )
519
- del minus_prev_end_index
520
-
521
- if finish .all () or context .get ():
522
- break
523
-
524
- pbar .update (1 )
528
+ if pbar is not None : pbar .update (1 )
529
+
530
+ if pbar is not None : pbar .close ()
525
531
526
532
if not finish .all ():
527
533
if context .get ():
0 commit comments