feat(indextts): add glossary support for custom term pronunciations
This commit is contained in:
parent
82a5b9004a
commit
6deed97efe
@ -167,12 +167,18 @@ class IndexTTS2:
|
||||
print(">> bigvgan weights restored from:", bigvgan_name)
|
||||
|
||||
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
|
||||
self.normalizer = TextNormalizer()
|
||||
self.normalizer = TextNormalizer(enable_glossary=True)
|
||||
self.normalizer.load()
|
||||
print(">> TextNormalizer loaded")
|
||||
self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer)
|
||||
print(">> bpe model loaded from:", self.bpe_path)
|
||||
|
||||
# 加载术语词汇表(如果存在)
|
||||
self.glossary_path = os.path.join(self.model_dir, "glossary.yaml")
|
||||
if os.path.exists(self.glossary_path):
|
||||
self.normalizer.load_glossary_from_yaml(self.glossary_path)
|
||||
print(">> Glossary loaded from:", self.glossary_path)
|
||||
|
||||
emo_matrix = torch.load(os.path.join(self.model_dir, self.cfg.emo_matrix))
|
||||
self.emo_matrix = emo_matrix.to(self.device)
|
||||
self.emo_num = list(self.cfg.emo_num)
|
||||
|
||||
@ -62,7 +62,7 @@ def de_tokenized_by_CJK_char(line: str, do_lower_case=False) -> str:
|
||||
output = "see you!"
|
||||
"""
|
||||
# replace english words in the line with placeholders
|
||||
english_word_pattern = re.compile(r"([A-Z]+(?:[\s-][A-Z-]+)*)", re.IGNORECASE)
|
||||
english_word_pattern = re.compile(r"([A-Z]+(?:[\s'-][A-Z-]+)*)", re.IGNORECASE)
|
||||
english_sents = english_word_pattern.findall(line)
|
||||
for i, sent in enumerate(english_sents):
|
||||
line = line.replace(sent, f"<sent_{i}>")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from functools import lru_cache
|
||||
import os
|
||||
import traceback
|
||||
import re
|
||||
@ -9,7 +10,7 @@ from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
class TextNormalizer:
|
||||
def __init__(self):
|
||||
def __init__(self, enable_glossary=False):
|
||||
self.zh_normalizer = None
|
||||
self.en_normalizer = None
|
||||
self.char_rep_map = {
|
||||
@ -53,6 +54,18 @@ class TextNormalizer:
|
||||
"$": ".",
|
||||
**self.char_rep_map,
|
||||
}
|
||||
self.enable_glossary = enable_glossary
|
||||
# 术语词汇表:用户可自定义专业术语的读法
|
||||
# 格式: {"原始术语": {"en": "英文读法", "zh": "中文读法"}}
|
||||
self.term_glossary = {
|
||||
# "M.2": {"en": "M dot two", "zh": "M 二"},
|
||||
# "PCIe 5.0": {"en": "PCIE five", "zh": "PCIE 五点零"},
|
||||
# "PCIe 4.0": {"en": "PCIE four", "zh": "PCIE 四点零"},
|
||||
# "AHCI": "A H C I",
|
||||
# "TTS": "T T S",
|
||||
# "Inc.": {"en": "Ink"},
|
||||
# ".md": {"en": "dot M D", "zh": "点 M D"},
|
||||
}
|
||||
|
||||
def match_email(self, email):
|
||||
# 正则表达式匹配邮箱格式:数字英文@数字英文.英文
|
||||
@ -124,6 +137,9 @@ class TextNormalizer:
|
||||
return ""
|
||||
if self.use_chinese(text):
|
||||
text = re.sub(TextNormalizer.ENGLISH_CONTRACTION_PATTERN, r"\1 is", text, flags=re.IGNORECASE)
|
||||
# 应用术语词汇表(优先级最高,在所有保护之前)
|
||||
if self.enable_glossary:
|
||||
text = self.apply_glossary_terms(text, lang="zh")
|
||||
# 保护技术术语(如 GPT-5-nano)避免被中文normalizer错误处理
|
||||
replaced_text, tech_list = self.save_tech_terms(text.rstrip())
|
||||
replaced_text, pinyin_list = self.save_pinyin_tones(replaced_text)
|
||||
@ -145,6 +161,9 @@ class TextNormalizer:
|
||||
else:
|
||||
try:
|
||||
text = re.sub(TextNormalizer.ENGLISH_CONTRACTION_PATTERN, r"\1 is", text, flags=re.IGNORECASE)
|
||||
# 应用术语词汇表(优先级最高,在所有保护之前)
|
||||
if self.enable_glossary:
|
||||
text = self.apply_glossary_terms(text, lang="en")
|
||||
# 保护技术术语(如 GPT-5-Nano)避免被英文normalizer错误处理
|
||||
replaced_text, tech_list = self.save_tech_terms(text)
|
||||
result = self.en_normalizer.normalize(replaced_text)
|
||||
@ -242,6 +261,95 @@ class TextNormalizer:
|
||||
transformed_text = re.sub(r'\s*<H>\s*', '-', normalized_text)
|
||||
return transformed_text
|
||||
|
||||
def apply_glossary_terms(self, text, lang="zh"):
|
||||
"""
|
||||
应用术语词汇表,将专业术语替换为对应语言的读法
|
||||
|
||||
Args:
|
||||
text: 待处理文本
|
||||
lang: 语言类型 "zh" 或 "en"
|
||||
|
||||
Returns:
|
||||
处理后的文本
|
||||
|
||||
Example:
|
||||
"M.2 NVMe SSD" -> (zh) "M 二 NVMe SSD"
|
||||
"M.2 NVMe SSD" -> (en) "M dot two NVMe SSD"
|
||||
"""
|
||||
if not self.term_glossary:
|
||||
return text
|
||||
|
||||
# 按术语长度降序排列,避免短术语先匹配导致长术语无法匹配
|
||||
# 例如:"PCIe 5.0" 应该在 "PCIe" 之前匹配
|
||||
sorted_terms = sorted(self.term_glossary.keys(), key=len, reverse=True)
|
||||
@lru_cache(maxsize=42)
|
||||
def get_term_pattern(term: str):
|
||||
return re.compile(re.escape(term), re.IGNORECASE)
|
||||
transformed_text = text
|
||||
for term in sorted_terms:
|
||||
term_value = self.term_glossary[term]
|
||||
if isinstance(term_value, dict):
|
||||
replacement = term_value.get(lang, term_value.get(lang, term))
|
||||
else:
|
||||
replacement = term_value
|
||||
# 使用正则进行大小写不敏感的替换
|
||||
pattern = get_term_pattern(term)
|
||||
transformed_text = pattern.sub(replacement, transformed_text)
|
||||
|
||||
return transformed_text
|
||||
|
||||
def load_glossary(self, glossary_dict):
|
||||
"""
|
||||
加载外部术语词汇表
|
||||
|
||||
Args:
|
||||
glossary_dict: 术语词典,格式为 {"术语": {"en": "英文读法", "zh": "中文读法"}}
|
||||
|
||||
Example:
|
||||
normalizer.load_glossary({
|
||||
"M.2": {"en": "M dot two", "zh": "M 二"},
|
||||
"PCIe": {"en": "PCIE", "zh": "PCIE"}
|
||||
})
|
||||
"""
|
||||
if glossary_dict and isinstance(glossary_dict, dict):
|
||||
self.term_glossary.update(glossary_dict)
|
||||
|
||||
def load_glossary_from_yaml(self, glossary_path):
|
||||
"""
|
||||
从 YAML 文件加载术语词汇表
|
||||
|
||||
Args:
|
||||
glossary_path: YAML 文件路径
|
||||
|
||||
Example:
|
||||
normalizer.load_glossary_from_yaml("checkpoints/glossary.yaml")
|
||||
|
||||
YAML 文件格式:
|
||||
M.2:
|
||||
en: M dot two
|
||||
zh: M 二
|
||||
NVMe: N-V-M-E # 中英文相同读法
|
||||
"""
|
||||
if glossary_path and os.path.exists(glossary_path):
|
||||
import yaml
|
||||
with open(glossary_path, 'r', encoding='utf-8') as f:
|
||||
external_glossary = yaml.safe_load(f)
|
||||
if external_glossary and isinstance(external_glossary, dict):
|
||||
self.term_glossary.update(external_glossary)
|
||||
return True
|
||||
return False
|
||||
|
||||
def save_glossary_to_yaml(self, glossary_path):
|
||||
"""
|
||||
保存术语词汇表到 YAML 文件
|
||||
|
||||
Args:
|
||||
glossary_path: YAML 文件路径
|
||||
"""
|
||||
import yaml
|
||||
with open(glossary_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(self.term_glossary, f, allow_unicode=True, default_flow_style=False)
|
||||
|
||||
def save_pinyin_tones(self, original_text):
|
||||
"""
|
||||
替换拼音声调为占位符 <pinyin_a>, <pinyin_b>, ...
|
||||
@ -493,7 +601,7 @@ class TextTokenizer:
|
||||
if __name__ == "__main__":
|
||||
# 测试程序
|
||||
|
||||
text_normalizer = TextNormalizer()
|
||||
text_normalizer = TextNormalizer(enable_glossary=True)
|
||||
|
||||
cases = [
|
||||
"IndexTTS 正式发布1.0版本了,效果666",
|
||||
@ -535,9 +643,11 @@ if __name__ == "__main__":
|
||||
"今天是个好日子 it's a good day", # 今天是个好日子 it is a good day
|
||||
# 术语
|
||||
"such as XTTS, CosyVoice2, Fish-Speech, and F5-TTS", # such as xtts,cosyvoice two,fish-speech,and f five-tts
|
||||
"GPT-5-Nano is the smallest and fastest variant in the GPT-5 model family.",
|
||||
"GPT-5-Nano 是 GPT-5 模型家族中最小且速度最快的变体",
|
||||
"2025/09/08 IndexTTS-2 全球发布"
|
||||
"GPT-5-Nano is the smallest and fastest variant in the GPT-5 model family.", # GPT-five-Nano is the smallest and fastest variant in the GPT-five model family
|
||||
"GPT-5-Nano 是 GPT-5 模型家族中最小且速度最快的变体", # GPT-五-Nano 是 GPT-五 系统中最小且速度最快的变体
|
||||
"2025/09/08 IndexTTS-2 全球发布", # 二零二五年九月八日 IndexTTS-二全球发布
|
||||
"Here are some highly-rated M.2 NVMe SSDs: Samsung 9100 PRO PCIe 5.0 SSD M.2, $139.99", # Here are some highly-rated M dot two NVMe SSD's, Samsung nine thousand one hundred PRO PCIE five SSD M dot two . one hundred and thirty nine dollars and ninety nine cents
|
||||
"we dive deep into the showdown between DisplayPort 1.4 and HDMI 2.1 to determine which is the best choice for gaming enthusiasts",
|
||||
# 人名
|
||||
"约瑟夫·高登-莱维特(Joseph Gordon-Levitt is an American actor)",
|
||||
"蒂莫西·唐纳德·库克(英文名:Timothy Donald Cook),通称蒂姆·库克(Tim Cook),美国商业经理、工业工程师和工业开发商,现任苹果公司首席执行官。",
|
||||
|
||||
@ -46,5 +46,16 @@
|
||||
"与音色参考音频相同": "Same as the voice reference",
|
||||
"情感随机采样": "Randomize emotion sampling",
|
||||
"显示实验功能": "Show experimental features",
|
||||
"提示:此功能为实验版,结果尚不稳定,我们正在持续优化中。": "Note: This feature is currently experimental and may not produce satisfactory results. We're dedicated to improving its performance in a future release."
|
||||
"提示:此功能为实验版,结果尚不稳定,我们正在持续优化中。": "Note: This feature is currently experimental and may not produce satisfactory results. We're dedicated to improving its performance in a future release.",
|
||||
"自定义个别专业术语的读音": "Customize the pronunciation of each term or how it is \"say as\".",
|
||||
"请至少输入一种读法": "Please enter at least one pronunciation",
|
||||
"已添加": "Added",
|
||||
"开启术语词汇读音": "Enable custom term pronunciations",
|
||||
"添加术语": "Add term",
|
||||
"暂无术语": "No custom terms added yet",
|
||||
"请输入术语": "Please input the term",
|
||||
"术语": "Term",
|
||||
"已删除": "Deleted",
|
||||
"英文读法": "English pronunciation",
|
||||
"自定义术语词汇读音": "Customize term pronunciations"
|
||||
}
|
||||
125
webui.py
125
webui.py
@ -109,6 +109,21 @@ def get_example_cases(include_experimental = False):
|
||||
# exclude emotion control mode 3 (emotion from text description)
|
||||
return [x for x in example_cases if x[1] != EMO_CHOICES_ALL[3]]
|
||||
|
||||
def format_glossary_markdown():
|
||||
"""将词汇表转换为Markdown表格格式"""
|
||||
if not tts.normalizer.term_glossary:
|
||||
return i18n("暂无术语")
|
||||
|
||||
lines = [f"| {i18n('术语')} | {i18n('中文读法')} | {i18n('英文读法')} |"]
|
||||
lines.append("|---|---|---|")
|
||||
|
||||
for term, reading in tts.normalizer.term_glossary.items():
|
||||
zh = reading.get("zh", "") if isinstance(reading, dict) else reading
|
||||
en = reading.get("en", "") if isinstance(reading, dict) else reading
|
||||
lines.append(f"| {term} | {zh} | {en} |")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def gen_single(emo_control_method,prompt, text,
|
||||
emo_ref_path, emo_weight,
|
||||
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
|
||||
@ -195,8 +210,9 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
||||
gen_button = gr.Button(i18n("生成语音"), key="gen_button",interactive=True)
|
||||
output_audio = gr.Audio(label=i18n("生成结果"), visible=True,key="output_audio")
|
||||
|
||||
with gr.Row():
|
||||
experimental_checkbox = gr.Checkbox(label=i18n("显示实验功能"), value=False)
|
||||
|
||||
glossary_checkbox = gr.Checkbox(label=i18n("开启术语词汇读音"), value=tts.normalizer.enable_glossary)
|
||||
with gr.Accordion(i18n("功能设置")):
|
||||
# 情感控制选项部分
|
||||
with gr.Row():
|
||||
@ -246,6 +262,29 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
||||
with gr.Row(visible=False) as emo_weight_group:
|
||||
emo_weight = gr.Slider(label=i18n("情感权重"), minimum=0.0, maximum=1.0, value=0.65, step=0.01)
|
||||
|
||||
# 术语词汇表管理
|
||||
with gr.Accordion(i18n("自定义术语词汇读音"), open=False, visible=tts.normalizer.enable_glossary) as glossary_accordion:
|
||||
gr.Markdown(i18n("自定义个别专业术语的读音"))
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
glossary_term = gr.Textbox(
|
||||
label=i18n("术语"),
|
||||
placeholder="IndexTTS2",
|
||||
)
|
||||
glossary_reading_zh = gr.Textbox(
|
||||
label=i18n("中文读法"),
|
||||
placeholder="Index T-T-S 二",
|
||||
)
|
||||
glossary_reading_en = gr.Textbox(
|
||||
label=i18n("英文读法"),
|
||||
placeholder="Index T-T-S two",
|
||||
)
|
||||
btn_add_term = gr.Button(i18n("添加术语"), scale=1)
|
||||
with gr.Column(scale=2):
|
||||
glossary_table = gr.Markdown(
|
||||
value=format_glossary_markdown()
|
||||
)
|
||||
|
||||
with gr.Accordion(i18n("高级生成参数设置"), open=False, visible=True) as advanced_settings_group:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
@ -353,6 +392,71 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
||||
segments_preview: gr.update(value=df),
|
||||
}
|
||||
|
||||
# 术语词汇表事件处理函数
|
||||
def on_add_glossary_term(term, reading_zh, reading_en):
|
||||
"""添加术语到词汇表并自动保存"""
|
||||
if not term:
|
||||
return (
|
||||
gr.update(value=i18n("请输入术语")),
|
||||
gr.update()
|
||||
)
|
||||
if not reading_zh and not reading_en:
|
||||
return (
|
||||
gr.update(value=i18n("请至少输入一种读法")),
|
||||
gr.update()
|
||||
)
|
||||
|
||||
# 构建读法数据
|
||||
if reading_zh and reading_en:
|
||||
reading = {"zh": reading_zh, "en": reading_en}
|
||||
elif reading_zh:
|
||||
reading = {"zh": reading_zh}
|
||||
elif reading_en:
|
||||
reading = {"en": reading_en}
|
||||
else:
|
||||
reading = reading_zh or reading_en
|
||||
|
||||
# 添加到词汇表
|
||||
tts.normalizer.term_glossary[term] = reading
|
||||
|
||||
# 自动保存到文件
|
||||
tts.normalizer.save_glossary_to_yaml(tts.glossary_path)
|
||||
|
||||
# 更新Markdown表格
|
||||
return (
|
||||
gr.update(value=f"{i18n('已添加')}: {term}"),
|
||||
gr.update(value=format_glossary_markdown())
|
||||
)
|
||||
|
||||
def on_delete_glossary_term(table_data):
|
||||
"""删除选中的术语(通过清空输入框中的术语来触发)"""
|
||||
# 注意:Gradio Dataframe 的选择功能有限,这里我们通过术语输入框来指定要删除的术语
|
||||
pass # 实际删除功能需要结合具体的选择机制
|
||||
|
||||
def on_delete_by_term(term):
|
||||
"""通过术语名称删除"""
|
||||
if not term:
|
||||
return (
|
||||
gr.update(value=i18n("请输入要删除的术语")),
|
||||
gr.update()
|
||||
)
|
||||
|
||||
if term in tts.normalizer.term_glossary:
|
||||
del tts.normalizer.term_glossary[term]
|
||||
# 自动保存
|
||||
tts.normalizer.save_glossary_to_yaml(tts.glossary_path)
|
||||
|
||||
# 更新Markdown表格
|
||||
return (
|
||||
gr.update(value=f"{i18n('已删除')}: {term}"),
|
||||
gr.update(value=format_glossary_markdown())
|
||||
)
|
||||
else:
|
||||
return (
|
||||
gr.update(value=f"{i18n('术语不存在')}: {term}"),
|
||||
gr.update()
|
||||
)
|
||||
|
||||
def on_method_change(emo_control_method):
|
||||
if emo_control_method == 1: # emotion reference audio
|
||||
return (gr.update(visible=True),
|
||||
@ -410,6 +514,17 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
||||
outputs=[emo_control_method, example_table]
|
||||
)
|
||||
|
||||
def on_glossary_checkbox_change(is_enabled):
|
||||
"""控制术语词汇表的可见性"""
|
||||
tts.normalizer.enable_glossary = is_enabled
|
||||
return gr.update(visible=is_enabled)
|
||||
|
||||
glossary_checkbox.change(
|
||||
on_glossary_checkbox_change,
|
||||
inputs=[glossary_checkbox],
|
||||
outputs=[glossary_accordion]
|
||||
)
|
||||
|
||||
input_text_single.change(
|
||||
on_input_text_change,
|
||||
inputs=[input_text_single, max_text_tokens_per_segment],
|
||||
@ -426,6 +541,14 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
||||
inputs=[],
|
||||
outputs=[gen_button])
|
||||
|
||||
# 术语词汇表事件绑定
|
||||
btn_add_term.click(
|
||||
on_add_glossary_term,
|
||||
inputs=[glossary_term, glossary_reading_zh, glossary_reading_en],
|
||||
outputs=[glossary_table]
|
||||
)
|
||||
|
||||
|
||||
gen_button.click(gen_single,
|
||||
inputs=[emo_control_method,prompt_audio, input_text_single, emo_upload, emo_weight,
|
||||
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user