Skip to content

Commit

Permalink
optimize(utils): move custom processors into model (#419)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 24, 2024
1 parent e0a9e7e commit b62e0dc
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 49 deletions.
2 changes: 1 addition & 1 deletion ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _load(
coef: Optional[str] = None
):
if device is None:
device = select_device(4096)
device = select_device()
self.logger.log(logging.INFO, f'use {device}')
self.device = device

Expand Down
2 changes: 1 addition & 1 deletion ChatTTS/infer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper

from ..utils.infer import CustomRepetitionPenaltyLogitsProcessorRepeat
from ..model.processors import CustomRepetitionPenaltyLogitsProcessorRepeat
from ..utils.io import del_all
from ..model.gpt import GPT

Expand Down
2 changes: 1 addition & 1 deletion ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast

from ..utils.infer import CustomRepetitionPenaltyLogitsProcessorRepeat
from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat
from ..utils.io import del_all


Expand Down
45 changes: 45 additions & 0 deletions ChatTTS/model/processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.nn.functional as F


class CustomRepetitionPenaltyLogitsProcessorRepeat():

def __init__(self, penalty: float, max_input_ids, past_window):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

input_ids = input_ids[:, -self.past_window:]
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
freq[self.max_input_ids:] = 0
alpha = self.penalty**freq
scores = scores.contiguous()
scores = torch.where(scores < 0, scores*alpha, scores/alpha)

return scores

class CustomRepetitionPenaltyLogitsProcessor():

def __init__(self, penalty: float, max_input_ids, past_window):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

input_ids = input_ids[:, -self.past_window:]
score = torch.gather(scores, 1, input_ids)
_score = score.detach().clone()
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
scores.scatter_(1, input_ids, score)

return scores
2 changes: 1 addition & 1 deletion ChatTTS/utils/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .log import logger

def select_device(min_memory=2048):
def select_device(min_memory=2047):
if torch.cuda.is_available():
available_gpus = []
for i in range(torch.cuda.device_count()):
Expand Down
46 changes: 1 addition & 45 deletions ChatTTS/utils/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,7 @@

from numba import jit
import numpy as np
import torch
import torch.nn.functional as F


class CustomRepetitionPenaltyLogitsProcessorRepeat():

def __init__(self, penalty: float, max_input_ids, past_window):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

input_ids = input_ids[:, -self.past_window:]
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
freq[self.max_input_ids:] = 0
alpha = self.penalty**freq
scores = scores.contiguous()
scores = torch.where(scores < 0, scores*alpha, scores/alpha)

return scores

class CustomRepetitionPenaltyLogitsProcessor():

def __init__(self, penalty: float, max_input_ids, past_window):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

input_ids = input_ids[:, -self.past_window:]
score = torch.gather(scores, 1, input_ids)
_score = score.detach().clone()
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
scores.scatter_(1, input_ids, score)

return scores


@jit
def _find_index(table: np.ndarray, val: np.uint16):
Expand Down

0 comments on commit b62e0dc

Please sign in to comment.