+简单前端
This commit is contained in:
parent
bd7530a7d9
commit
46630ca45b
@ -14,7 +14,7 @@ from indextts.utils.feature_extractors import MelSpectrogramFeatures
|
||||
from indextts.utils.common import tokenize_by_CJK_char
|
||||
from indextts.vqvae.xtts_dvae import DiscreteVAE
|
||||
|
||||
|
||||
from indextts.utils.front import TextNormalizer
|
||||
class IndexTTS:
|
||||
def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints'):
|
||||
self.cfg = OmegaConf.load(cfg_path)
|
||||
@ -42,16 +42,20 @@ class IndexTTS:
|
||||
self.bigvgan = self.bigvgan.to(self.device)
|
||||
self.bigvgan.eval()
|
||||
print(">> bigvgan weights restored from:", self.bigvgan_path)
|
||||
self.normalizer = TextNormalizer()
|
||||
self.normalizer.load()
|
||||
print(">> TextNormalizer loaded")
|
||||
|
||||
def preprocess_text(self, text):
|
||||
chinese_punctuation = ",。!?;:“”‘’()【】《》"
|
||||
english_punctuation = ",.!?;:\"\"''()[]<>"
|
||||
|
||||
# 创建一个映射字典
|
||||
punctuation_map = str.maketrans(chinese_punctuation, english_punctuation)
|
||||
# chinese_punctuation = ",。!?;:“”‘’()【】《》"
|
||||
# english_punctuation = ",.!?;:\"\"''()[]<>"
|
||||
#
|
||||
# # 创建一个映射字典
|
||||
# punctuation_map = str.maketrans(chinese_punctuation, english_punctuation)
|
||||
|
||||
# 使用translate方法替换标点符号
|
||||
return text.translate(punctuation_map)
|
||||
# return text.translate(punctuation_map)
|
||||
return self.normalizer.infer(text)
|
||||
|
||||
def infer(self, audio_prompt, text, output_path):
|
||||
text = self.preprocess_text(text)
|
||||
|
||||
96
indextts/utils/front.py
Normal file
96
indextts/utils/front.py
Normal file
@ -0,0 +1,96 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import traceback
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import re
|
||||
|
||||
|
||||
|
||||
|
||||
class TextNormalizer:
|
||||
def __init__(self):
|
||||
# self.normalizer = Normalizer(cache_dir="textprocessing/tn")
|
||||
self.zh_normalizer = None
|
||||
self.en_normalizer = None
|
||||
self.char_rep_map = {
|
||||
":": ",",
|
||||
";": ",",
|
||||
";": ",",
|
||||
",": ",",
|
||||
"。": ".",
|
||||
"!": "!",
|
||||
"?": "?",
|
||||
"\n": ".",
|
||||
"·": ",",
|
||||
"、": ",",
|
||||
"...": "…",
|
||||
"……": "…",
|
||||
"$": ".",
|
||||
"“": "'",
|
||||
"”": "'",
|
||||
'"': "'",
|
||||
"‘": "'",
|
||||
"’": "'",
|
||||
"(": "'",
|
||||
")": "'",
|
||||
"(": "'",
|
||||
")": "'",
|
||||
"《": "'",
|
||||
"》": "'",
|
||||
"【": "'",
|
||||
"】": "'",
|
||||
"[": "'",
|
||||
"]": "'",
|
||||
"—": "-",
|
||||
"~": "-",
|
||||
"~": "-",
|
||||
"「": "'",
|
||||
"」": "'",
|
||||
":": ",",
|
||||
}
|
||||
|
||||
def match_email(self, email):
|
||||
# 正则表达式匹配邮箱格式:数字英文@数字英文.英文
|
||||
pattern = r'^[a-zA-Z0-9]+@[a-zA-Z0-9]+\.[a-zA-Z]+$'
|
||||
return re.match(pattern, email) is not None
|
||||
|
||||
def use_chinese(self, s):
|
||||
has_chinese = bool(re.search(r'[\u4e00-\u9fff]', s))
|
||||
has_digit = bool(re.search(r'\d', s))
|
||||
has_alpha = bool(re.search(r'[a-zA-Z]', s))
|
||||
is_email = self.match_email(s)
|
||||
if has_chinese or not has_alpha or is_email:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def load(self):
|
||||
# print(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
|
||||
# sys.path.append(model_dir)
|
||||
|
||||
from tn.chinese.normalizer import Normalizer as NormalizerZh
|
||||
from tn.english.normalizer import Normalizer as NormalizerEn
|
||||
|
||||
self.zh_normalizer = NormalizerZh(remove_interjections=False, remove_erhua=False)
|
||||
self.en_normalizer = NormalizerEn()
|
||||
|
||||
def infer(self, text):
|
||||
pattern = re.compile("|".join(re.escape(p) for p in self.char_rep_map.keys()))
|
||||
replaced_text = pattern.sub(lambda x: self.char_rep_map[x.group()], text)
|
||||
if not self.zh_normalizer or not self.en_normalizer:
|
||||
print("Error, text normalizer is not initialized !!!")
|
||||
return ""
|
||||
try:
|
||||
normalizer = self.zh_normalizer if self.use_chinese(text) else self.en_normalizer
|
||||
result = normalizer.normalize(text)
|
||||
except Exception:
|
||||
result = ""
|
||||
print(traceback.format_exc())
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 测试程序
|
||||
text_normalizer = TextNormalizer()
|
||||
print(text_normalizer.infer("2.5平方电线"))
|
||||
@ -20,4 +20,5 @@ sentencepiece
|
||||
pypinyin
|
||||
librosa
|
||||
gradio
|
||||
tqdm
|
||||
tqdm
|
||||
WeTextProcessing
|
||||
Loading…
x
Reference in New Issue
Block a user