Skip to content

Commit

Permalink
feat: add normalizer interface & move instances out (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 24, 2024
1 parent b62e0dc commit c8cb6bd
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 247 deletions.
69 changes: 12 additions & 57 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import logging
import tempfile
from functools import partial
from typing import Literal, Optional, List, Callable

import numpy as np
Expand All @@ -13,22 +12,24 @@
from .model.dvae import DVAE
from .model.gpt import GPT
from .utils.gpu import select_device
from .utils.infer import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
from .utils.io import get_latest_modified_file, del_all
from .infer.api import refine_text, infer_code
from .utils.dl import check_all_assets, download_all_assets
from .utils.log import logger as utils_logger

from .norm import Normalizer


class Chat:
def __init__(self, logger=logging.getLogger(__name__)):
self.logger = logger
utils_logger.set_logger(logger)

self.pretrain_models = {}
self.normalizer = {}
self.homophones_replacer = self.homophones_replacer = HomophonesReplacer(os.path.join(os.path.dirname(__file__), 'res', 'homophones_map.json'))

self.normalizer = Normalizer(
os.path.join(os.path.dirname(__file__), 'res', 'homophones_map.json'),
logger,
)

def has_loaded(self, use_decoder = False):
not_finish = False
Expand Down Expand Up @@ -188,6 +189,8 @@ def _load(
def unload(self):
logger = self.logger
del_all(self.pretrain_models)
self.normalizer.destroy()
del self.normalizer
del_list = ["vocos", "_vocos_decode", 'gpt', 'decoder', 'dvae']
for module in del_list:
if hasattr(self, module):
Expand All @@ -212,23 +215,10 @@ def _infer(

if not isinstance(text, list):
text = [text]
if do_text_normalization:
for i, t in enumerate(text):
_lang = detect_language(t) if lang is None else lang
if self._init_normalizer(_lang):
text[i] = self.normalizer[_lang](t)
if _lang == 'zh':
text[i] = apply_half2full_map(text[i])
for i, t in enumerate(text):
invalid_characters = count_invalid_characters(t)
if len(invalid_characters):
self.logger.warn(f'Invalid characters found! : {invalid_characters}')
text[i] = apply_character_map(t)
if do_homophone_replacement:
text[i], replaced_words = self.homophones_replacer.replace(text[i])
if replaced_words:
repl_res = ', '.join([f'{_[0]}->{_[1]}' for _ in replaced_words])
self.logger.log(logging.INFO, f'Homophones replace: {repl_res}')

text = [self.normalizer(
t, do_text_normalization, do_homophone_replacement, lang,
) for t in text]

if not skip_refine_text:
refined = refine_text(
Expand Down Expand Up @@ -314,38 +304,3 @@ def decode_to_wavs(self, result: GPT.GenerationOutputs, start_seeks: List[int],
result.destroy()
del_all(x)
return wavs

def _init_normalizer(self, lang) -> bool:

if lang in self.normalizer:
return True

if lang == 'zh':
try:
from tn.chinese.normalizer import Normalizer
self.normalizer[lang] = Normalizer().normalize
return True
except:
self.logger.log(
logging.WARNING,
'Package WeTextProcessing not found!',
)
self.logger.log(
logging.WARNING,
'Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing',
)
else:
try:
from nemo_text_processing.text_normalization.normalize import Normalizer
self.normalizer[lang] = partial(Normalizer(input_case='cased', lang=lang).normalize, verbose=False, punct_post_process=True)
return True
except:
self.logger.log(
logging.WARNING,
'Package nemo_text_processing not found!',
)
self.logger.log(
logging.WARNING,
'Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing',
)
return False
199 changes: 199 additions & 0 deletions ChatTTS/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import json
import logging
import re
from typing import Dict, Tuple, List, Literal, Callable, Optional
import sys

from numba import jit
import numpy as np

from .utils.io import del_all


@jit
def _find_index(table: np.ndarray, val: np.uint16):
for i in range(table.size):
if table[i] == val:
return i
return -1

@jit
def _fast_replace(table: np.ndarray, text: bytes) -> Tuple[np.ndarray, List[Tuple[str, str]]]:
result = np.frombuffer(text, dtype=np.uint16).copy()
replaced_words = []
for i in range(result.size):
ch = result[i]
p = _find_index(table[0], ch)
if p >= 0:
repl_char = table[1][p]
result[i] = repl_char
replaced_words.append((chr(ch), chr(repl_char)))
return result, replaced_words

class Normalizer:
def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)):
self.logger = logger
self.normalizers: Dict[str, Callable[[str], str]] = {}
self.homophones_map = self._load_homophones_map(map_file_path)
"""
homophones_map
Replace the mispronounced characters with correctly pronounced ones.
Creation process of homophones_map.json:
1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text.
2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words.
3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS.
4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping.
Thanks to:
[Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html)
[python-pinyin](https://github.com/mozillazg/python-pinyin)
"""
self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be"
self.accept_pattern = re.compile(r'[^\u4e00-\u9fffA-Za-z,。、,\. ]')
self.sub_pattern = re.compile(r'\[uv_break\]|\[laugh\]|\[lbreak\]')
self.chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]')
self.english_word_pattern = re.compile(r'\b[A-Za-z]+\b')
self.character_simplifier = str.maketrans({
':': ',',
';': ',',
'!': '。',
'(': ',',
')': ',',
'【': ',',
'】': ',',
'『': ',',
'』': ',',
'「': ',',
'」': ',',
'《': ',',
'》': ',',
'-': ',',
'‘': '',
'“': '',
'’': '',
'”': '',
':': ',',
';': ',',
'!': '.',
'(': ',',
')': ',',
'[': ',',
']': ',',
'>': ',',
'<': ',',
'-': ',',
})
self.halfwidth_2_fullwidth = str.maketrans({
'!': '!',
'"': '“',
"'": '‘',
'#': '#',
'$': '$',
'%': '%',
'&': '&',
'(': '(',
')': ')',
',': ',',
'-': '-',
'*': '*',
'+': '+',
'.': '。',
'/': '/',
':': ':',
';': ';',
'<': '<',
'=': '=',
'>': '>',
'?': '?',
'@': '@',
# '[': '[',
'\\': '\',
# ']': ']',
'^': '^',
# '_': '_',
'`': '`',
'{': '{',
'|': '|',
'}': '}',
'~': '~'
})

def __call__(
self,
text: str,
do_text_normalization=True,
do_homophone_replacement=True,
lang: Optional[Literal["zh", "en"]] = None,
) -> str:
if do_text_normalization:
_lang = self._detect_language(text) if lang is None else lang
if _lang in self.normalizers:
text = self.normalizers[_lang](text)
if _lang == 'zh':
text = self._apply_half2full_map(text)
invalid_characters = self._count_invalid_characters(text)
if len(invalid_characters):
self.logger.warn(f'found invalid characters: {invalid_characters}')
text = self._apply_character_map(text)
if do_homophone_replacement:
arr, replaced_words = _fast_replace(
self.homophones_map,
text.encode(self.coding),
)
if replaced_words:
text = arr.tobytes().decode(self.coding)
repl_res = ', '.join([f'{_[0]}->{_[1]}' for _ in replaced_words])
self.logger.info(f'replace homophones: {repl_res}')
return text


def register(self, name: str, normalizer: Callable[[str], str]) -> bool:
if name in self.normalizers:
self.logger.warn(f"name {name} has been registered")
return False
if not isinstance(normalizer, Callable[[str], str]):
self.logger.warn("normalizer must have caller type (str) -> str")
return False
self.normalizers[name] = normalizer
return True

def unregister(self, name: str):
if name in self.normalizers:
del self.normalizers[name]

def destroy(self):
del_all(self.normalizers)
del self.homophones_map

def _load_homophones_map(self, map_file_path: str) -> np.ndarray:
with open(map_file_path, 'r', encoding='utf-8') as f:
homophones_map: Dict[str, str] = json.load(f)
map = np.empty((2, len(homophones_map)), dtype=np.uint32)
for i, k in enumerate(homophones_map.keys()):
map[:, i] = (ord(k), ord(homophones_map[k]))
del homophones_map
return map

def _count_invalid_characters(self, s: str):
s = self.sub_pattern.sub('', s)
non_alphabetic_chinese_chars = self.accept_pattern.findall(s)
return set(non_alphabetic_chinese_chars)

def _apply_half2full_map(self, text: str) -> str:
return text.translate(self.halfwidth_2_fullwidth)

def _apply_character_map(self, text: str) -> str:
return text.translate(self.character_simplifier)

def _detect_language(self, sentence: str) -> Literal["zh", "en"]:
chinese_chars = self.chinese_char_pattern.findall(sentence)
english_words = self.english_word_pattern.findall(sentence)

if len(chinese_chars) > len(english_words):
return "zh"
else:
return "en"
Loading

0 comments on commit c8cb6bd

Please sign in to comment.