diff --git a/ChatTTS/utils/infer_utils.py b/ChatTTS/utils/infer_utils.py index 2c083fbf3..b7d70bf52 100644 --- a/ChatTTS/utils/infer_utils.py +++ b/ChatTTS/utils/infer_utils.py @@ -22,6 +22,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to freq = F.one_hot(input_ids, scores.size(1)).sum(1) freq[self.max_input_ids:] = 0 alpha = self.penalty**freq + scores = scores.contiguous() scores = torch.where(scores < 0, scores*alpha, scores/alpha) return scores