diff --git a/indextts/infer.py b/indextts/infer.py index eed25d4..bc62784 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -14,7 +14,7 @@ from indextts.utils.feature_extractors import MelSpectrogramFeatures from indextts.utils.common import tokenize_by_CJK_char from indextts.vqvae.xtts_dvae import DiscreteVAE - +from indextts.utils.front import TextNormalizer class IndexTTS: def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints'): self.cfg = OmegaConf.load(cfg_path) @@ -42,16 +42,20 @@ class IndexTTS: self.bigvgan = self.bigvgan.to(self.device) self.bigvgan.eval() print(">> bigvgan weights restored from:", self.bigvgan_path) + self.normalizer = TextNormalizer() + self.normalizer.load() + print(">> TextNormalizer loaded") def preprocess_text(self, text): - chinese_punctuation = ",。!?;:“”‘’()【】《》" - english_punctuation = ",.!?;:\"\"''()[]<>" - - # 创建一个映射字典 - punctuation_map = str.maketrans(chinese_punctuation, english_punctuation) + # chinese_punctuation = ",。!?;:“”‘’()【】《》" + # english_punctuation = ",.!?;:\"\"''()[]<>" + # + # # 创建一个映射字典 + # punctuation_map = str.maketrans(chinese_punctuation, english_punctuation) # 使用translate方法替换标点符号 - return text.translate(punctuation_map) + # return text.translate(punctuation_map) + return self.normalizer.infer(text) def infer(self, audio_prompt, text, output_path): text = self.preprocess_text(text) diff --git a/indextts/utils/front.py b/indextts/utils/front.py new file mode 100644 index 0000000..24ddf03 --- /dev/null +++ b/indextts/utils/front.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +import traceback +import os +import sys +import re +import re + + + + +class TextNormalizer: + def __init__(self): + # self.normalizer = Normalizer(cache_dir="textprocessing/tn") + self.zh_normalizer = None + self.en_normalizer = None + self.char_rep_map = { + ":": ",", + ";": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": "…", + "……": "…", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + "—": "-", + "~": "-", + "~": "-", + "「": "'", + "」": "'", + ":": ",", + } + + def match_email(self, email): + # 正则表达式匹配邮箱格式:数字英文@数字英文.英文 + pattern = r'^[a-zA-Z0-9]+@[a-zA-Z0-9]+\.[a-zA-Z]+$' + return re.match(pattern, email) is not None + + def use_chinese(self, s): + has_chinese = bool(re.search(r'[\u4e00-\u9fff]', s)) + has_digit = bool(re.search(r'\d', s)) + has_alpha = bool(re.search(r'[a-zA-Z]', s)) + is_email = self.match_email(s) + if has_chinese or not has_alpha or is_email: + return True + else: + return False + + def load(self): + # print(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + # sys.path.append(model_dir) + + from tn.chinese.normalizer import Normalizer as NormalizerZh + from tn.english.normalizer import Normalizer as NormalizerEn + + self.zh_normalizer = NormalizerZh(remove_interjections=False, remove_erhua=False) + self.en_normalizer = NormalizerEn() + + def infer(self, text): + pattern = re.compile("|".join(re.escape(p) for p in self.char_rep_map.keys())) + replaced_text = pattern.sub(lambda x: self.char_rep_map[x.group()], text) + if not self.zh_normalizer or not self.en_normalizer: + print("Error, text normalizer is not initialized !!!") + return "" + try: + normalizer = self.zh_normalizer if self.use_chinese(text) else self.en_normalizer + result = normalizer.normalize(text) + except Exception: + result = "" + print(traceback.format_exc()) + return result + + +if __name__ == '__main__': + # 测试程序 + text_normalizer = TextNormalizer() + print(text_normalizer.infer("2.5平方电线")) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 846e188..b813fe9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,5 @@ sentencepiece pypinyin librosa gradio -tqdm \ No newline at end of file +tqdm +WeTextProcessing \ No newline at end of file