+简单前端

This commit is contained in:
wangyining02 2025-03-26 19:14:47 +08:00
parent bd7530a7d9
commit 46630ca45b
3 changed files with 109 additions and 8 deletions

View File

@ -14,7 +14,7 @@ from indextts.utils.feature_extractors import MelSpectrogramFeatures
from indextts.utils.common import tokenize_by_CJK_char
from indextts.vqvae.xtts_dvae import DiscreteVAE
from indextts.utils.front import TextNormalizer
class IndexTTS:
def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints'):
self.cfg = OmegaConf.load(cfg_path)
@ -42,16 +42,20 @@ class IndexTTS:
self.bigvgan = self.bigvgan.to(self.device)
self.bigvgan.eval()
print(">> bigvgan weights restored from:", self.bigvgan_path)
self.normalizer = TextNormalizer()
self.normalizer.load()
print(">> TextNormalizer loaded")
def preprocess_text(self, text):
chinese_punctuation = ",。!?;:“”‘’()【】《》"
english_punctuation = ",.!?;:\"\"''()[]<>"
# 创建一个映射字典
punctuation_map = str.maketrans(chinese_punctuation, english_punctuation)
# chinese_punctuation = ",。!?;:“”‘’()【】《》"
# english_punctuation = ",.!?;:\"\"''()[]<>"
#
# # 创建一个映射字典
# punctuation_map = str.maketrans(chinese_punctuation, english_punctuation)
# 使用translate方法替换标点符号
return text.translate(punctuation_map)
# return text.translate(punctuation_map)
return self.normalizer.infer(text)
def infer(self, audio_prompt, text, output_path):
text = self.preprocess_text(text)

96
indextts/utils/front.py Normal file
View File

@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
import traceback
import os
import sys
import re
import re
class TextNormalizer:
def __init__(self):
# self.normalizer = Normalizer(cache_dir="textprocessing/tn")
self.zh_normalizer = None
self.en_normalizer = None
self.char_rep_map = {
"": ",",
"": ",",
";": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
"": ",",
"...": "",
"……": "",
"$": ".",
"": "'",
"": "'",
'"': "'",
"": "'",
"": "'",
"": "'",
"": "'",
"(": "'",
")": "'",
"": "'",
"": "'",
"": "'",
"": "'",
"[": "'",
"]": "'",
"": "-",
"": "-",
"~": "-",
"": "'",
"": "'",
":": ",",
}
def match_email(self, email):
# 正则表达式匹配邮箱格式:数字英文@数字英文.英文
pattern = r'^[a-zA-Z0-9]+@[a-zA-Z0-9]+\.[a-zA-Z]+$'
return re.match(pattern, email) is not None
def use_chinese(self, s):
has_chinese = bool(re.search(r'[\u4e00-\u9fff]', s))
has_digit = bool(re.search(r'\d', s))
has_alpha = bool(re.search(r'[a-zA-Z]', s))
is_email = self.match_email(s)
if has_chinese or not has_alpha or is_email:
return True
else:
return False
def load(self):
# print(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
# sys.path.append(model_dir)
from tn.chinese.normalizer import Normalizer as NormalizerZh
from tn.english.normalizer import Normalizer as NormalizerEn
self.zh_normalizer = NormalizerZh(remove_interjections=False, remove_erhua=False)
self.en_normalizer = NormalizerEn()
def infer(self, text):
pattern = re.compile("|".join(re.escape(p) for p in self.char_rep_map.keys()))
replaced_text = pattern.sub(lambda x: self.char_rep_map[x.group()], text)
if not self.zh_normalizer or not self.en_normalizer:
print("Error, text normalizer is not initialized !!!")
return ""
try:
normalizer = self.zh_normalizer if self.use_chinese(text) else self.en_normalizer
result = normalizer.normalize(text)
except Exception:
result = ""
print(traceback.format_exc())
return result
if __name__ == '__main__':
# 测试程序
text_normalizer = TextNormalizer()
print(text_normalizer.infer("2.5平方电线"))

View File

@ -20,4 +20,5 @@ sentencepiece
pypinyin
librosa
gradio
tqdm
tqdm
WeTextProcessing