Skip to content

Commit

Permalink
doc: sync to latest grammar (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 24, 2024
1 parent 8235a46 commit f412184
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 86 deletions.
6 changes: 3 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def has_loaded(self, use_decoder = False):

for module in check_list:
if not hasattr(self, module) and module not in self.pretrain_models:
self.logger.warn(f'{module} not initialized.')
self.logger.warning(f'{module} not initialized.')
not_finish = True

if not not_finish:
Expand Down Expand Up @@ -75,7 +75,7 @@ def download_models(
except:
download_path = None
if download_path is None or force_redownload:
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
self.logger.log(logging.INFO, f'download from HF: https://huggingface.co/2Noise/ChatTTS')
try:
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
except:
Expand Down Expand Up @@ -232,7 +232,7 @@ def _load(
try:
gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
except RuntimeError as e:
self.logger.warning(f'Compile failed,{e}. fallback to normal mode.')
self.logger.warning(f'compile failed: {e}. fallback to normal mode.')
self.gpt = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
Expand Down
2 changes: 1 addition & 1 deletion ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def generate(
pbar.update(1)

if not finish.all():
self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
self.logger.warning(f'incomplete result. hit max_new_token: {max_new_token}')

del finish

Expand Down
6 changes: 3 additions & 3 deletions ChatTTS/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __call__(
text = self._apply_half2full_map(text)
invalid_characters = self._count_invalid_characters(text)
if len(invalid_characters):
self.logger.warn(f'found invalid characters: {invalid_characters}')
self.logger.warning(f'found invalid characters: {invalid_characters}')
text = self._apply_character_map(text)
if do_homophone_replacement:
arr, replaced_words = _fast_replace(
Expand All @@ -153,10 +153,10 @@ def __call__(

def register(self, name: str, normalizer: Callable[[str], str]) -> bool:
if name in self.normalizers:
self.logger.warn(f"name {name} has been registered")
self.logger.warning(f"name {name} has been registered")
return False
if not isinstance(normalizer, Callable[[str], str]):
self.logger.warn("normalizer must have caller type (str) -> str")
self.logger.warning("normalizer must have caller type (str) -> str")
return False
self.normalizers[name] = normalizer
return True
Expand Down
4 changes: 2 additions & 2 deletions ChatTTS/utils/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def select_device(min_memory=2047):
device = torch.device('cpu')
elif torch.backends.mps.is_available():
# For Apple M1/M2 chips with Metal Performance Shaders
logger.get_logger().info('Apple GPU found, using MPS.')
logger.get_logger().info('apple GPU found, using MPS.')
device = torch.device('mps')
else:
logger.get_logger().warning('No GPU found, use CPU instead')
logger.get_logger().warning('no GPU found, use CPU instead')
device = torch.device('cpu')

return device
62 changes: 46 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ chat.load_models(compile=False) # Set to True for better performance

texts = ["PUT YOUR TEXT HERE",]

wavs = chat.infer(texts, )
wavs = chat.infer(texts)

torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
```
Expand All @@ -125,23 +125,27 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)

rand_spk = chat.sample_random_speaker()

params_infer_code = {
'spk_emb': rand_spk, # add sampled speaker
'temperature': .3, # using custom temperature
'top_P': 0.7, # top P decode
'top_K': 20, # top K decode
}
params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb = rand_spk, # add sampled speaker
temperature = .3, # using custom temperature
top_P = 0.7, # top P decode
top_K = 20, # top K decode
)

###################################
# For sentence level manual control.

# use oral_(0-9), laugh_(0-2), break_(0-7)
# to generate special token in text to synthesize.
params_refine_text = {
'prompt': '[oral_2][laugh_0][break_6]'
}
params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt='[oral_2][laugh_0][break_6]',
)

wavs = chat.infer(texts, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
wavs = chat.infer(
texts,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code,
)

###################################
# For word level manual control.
Expand All @@ -163,16 +167,42 @@ capabilities with precise control over prosodic elements [laugh]like like
[uv_break] use the project responsibly at your own risk.[uv_break]
""".replace('\n', '') # English is still experimental.

params_refine_text = {
'prompt': '[oral_2][laugh_0][break_4]'
}
# audio_array_cn = chat.infer(inputs_cn, params_refine_text=params_refine_text)
params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt='[oral_2][laugh_0][break_4]',
)

audio_array_en = chat.infer(inputs_en, params_refine_text=params_refine_text)
torchaudio.save("output3.wav", torch.from_numpy(audio_array_en[0]), 24000)
```

<table>
<tr>
<td align="center">

**male speaker**

</td>
<td align="center">

**female speaker**

</td>
</tr>
<tr>
<td align="center">

[male speaker](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)

</td>
<td align="center">

[female speaker](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)

</td>
</tr>
</table>


</details>

## FAQ
Expand Down Expand Up @@ -206,4 +236,4 @@ In the current released model, the only token-level control units are `[laugh]`,

![counter](https://counter.seku.su/cmoe?name=chattts&theme=mbs)

</div>
</div>
Loading

0 comments on commit f412184

Please sign in to comment.