@@ -174,17 +174,15 @@ class RefineTextParams:
174
174
min_new_token : int = 0
175
175
show_tqdm : bool = True
176
176
ensure_non_empty : bool = True
177
- manual_seed : Optional [int ] = 0
177
+ manual_seed : Optional [int ] = None
178
178
179
179
@dataclass (repr = False , eq = False )
180
180
class InferCodeParams (RefineTextParams ):
181
181
prompt : str = "[speed_5]"
182
182
spk_emb : Optional [str ] = None
183
183
spk_smp : Optional [str ] = None
184
184
txt_smp : Optional [str ] = None
185
- top_P : float = 1
186
- top_K : int = 1
187
- temperature : float = 0.01
185
+ temperature : float = 0.3
188
186
repetition_penalty : float = 1.05
189
187
max_new_token : int = 2048
190
188
stream_batch : int = 24
@@ -196,13 +194,13 @@ def infer(
196
194
text ,
197
195
stream = False ,
198
196
lang = None ,
199
- skip_refine_text = True ,
197
+ skip_refine_text = False ,
200
198
refine_text_only = False ,
201
199
use_decoder = True ,
202
200
do_text_normalization = True ,
203
201
do_homophone_replacement = True ,
204
- params_refine_text = None ,
205
- params_infer_code = None ,
202
+ params_refine_text = RefineTextParams () ,
203
+ params_infer_code = InferCodeParams () ,
206
204
stream_batch_size = 16 ,
207
205
):
208
206
self .context .set (False )
@@ -273,7 +271,7 @@ def _load(
273
271
vq_config = asdict (self .config .dvae .vq ),
274
272
dim = self .config .dvae .decoder .idim ,
275
273
coef = coef ,
276
- device = device ,
274
+ device = self . device ,
277
275
)
278
276
.to (device )
279
277
.eval ()
@@ -290,8 +288,8 @@ def _load(
290
288
self .config .embed .num_text_tokens ,
291
289
self .config .embed .num_vq ,
292
290
)
293
- embed .from_pretrained (embed_path , device = device )
294
- self .embed = embed .to (device )
291
+ embed .from_pretrained (embed_path , device = self . device )
292
+ self .embed = embed .to (self . device )
295
293
self .logger .log (logging .INFO , "embed loaded." )
296
294
297
295
gpt = GPT (
@@ -343,15 +341,15 @@ def _load(
343
341
async def _infer (
344
342
self ,
345
343
text ,
346
- stream = True ,
344
+ stream = False ,
347
345
lang = None ,
348
- skip_refine_text = True ,
346
+ skip_refine_text = False ,
349
347
refine_text_only = False ,
350
348
use_decoder = True ,
351
349
do_text_normalization = True ,
352
350
do_homophone_replacement = True ,
353
- params_refine_text = None ,
354
- params_infer_code = None ,
351
+ params_refine_text = RefineTextParams () ,
352
+ params_infer_code = InferCodeParams () ,
355
353
stream_batch_size = 16 ,
356
354
):
357
355
@@ -399,13 +397,11 @@ async def _infer(
399
397
result .hiddens if use_decoder else result .ids ,
400
398
use_decoder ,
401
399
)
402
-
403
400
if result .finished :
404
401
yield wavs [:, length :]
405
402
else :
406
403
# Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
407
404
import librosa
408
-
409
405
silence_intervals = librosa .effects .split (wavs [0 ][length :], top_db = 10 )
410
406
silence_left = 0
411
407
if len (silence_intervals ) == 0 :
@@ -504,8 +500,8 @@ async def _infer_code(
504
500
repetition_penalty = params .repetition_penalty ,
505
501
)
506
502
507
- speaker_embedding_param = gpt (input_ids , text_mask )
508
-
503
+ speaker_embedding_param = self . embed (input_ids , text_mask )
504
+ del text_mask
509
505
if params .spk_emb is not None :
510
506
self .speaker .apply (
511
507
speaker_embedding_param ,
@@ -536,7 +532,7 @@ async def _infer_code(
536
532
async for i in results_generator :
537
533
token_ids = []
538
534
hidden_states = []
539
- if len (i .outputs [0 ].token_ids ) % stream_batch_size == 0 or i .finished :
535
+ if ( stream and len (i .outputs [0 ].token_ids ) % stream_batch_size == 0 ) or i .finished :
540
536
token_ids .append (torch .tensor (i .outputs [0 ].token_ids ))
541
537
hidden_states .append (
542
538
i .outputs [0 ].hidden_states .to (torch .float32 ).to (self .device )
@@ -547,6 +543,40 @@ async def _infer_code(
547
543
hiddens = hidden_states ,
548
544
attentions = [],
549
545
)
546
+ else :
547
+ results_generator = gpt .generate (
548
+ speaker_embedding_param ,
549
+ input_ids ,
550
+ temperature = torch .tensor (temperature , device = device ),
551
+ eos_token = num_code ,
552
+ attention_mask = attention_mask ,
553
+ max_new_token = params .max_new_token ,
554
+ min_new_token = params .min_new_token ,
555
+ logits_processors = (* logits_processors , * logits_warpers ),
556
+ infer_text = False ,
557
+ return_hidden = return_hidden ,
558
+ stream = stream ,
559
+ show_tqdm = params .show_tqdm ,
560
+ ensure_non_empty = params .ensure_non_empty ,
561
+ stream_batch = params .stream_batch ,
562
+ manual_seed = params .manual_seed ,
563
+ context = self .context ,
564
+ )
565
+ del speaker_embedding_param , input_ids
566
+ async for i in results_generator :
567
+ token_ids = []
568
+ hidden_states = []
569
+ if (stream and len (i .ids [0 ]) % stream_batch_size == 0 ) or i .finished :
570
+ token_ids .append (i .ids [0 ])
571
+ hidden_states .append (
572
+ i .hiddens [0 ].to (torch .float32 ).to (self .device )
573
+ )
574
+ yield GPT .GenerationOutputs (
575
+ ids = token_ids ,
576
+ finished = i .finished ,
577
+ hiddens = hidden_states ,
578
+ attentions = [],
579
+ )
550
580
551
581
@torch .no_grad ()
552
582
def _refine_text (
0 commit comments