Merge pull request #557 from yrom/feat/glossary-support

feature: 用户可自定义专业术语的读法,如 M.2 → "M dot two" (EN) / "M 二" (ZH)
This commit is contained in:
Vanka0051 2025-12-02 16:18:21 +08:00 committed by GitHub
commit 1698b32033
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 329 additions and 18 deletions

View File

@ -167,12 +167,18 @@ class IndexTTS2:
print(">> bigvgan weights restored from:", bigvgan_name)
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
self.normalizer = TextNormalizer()
self.normalizer = TextNormalizer(enable_glossary=True)
self.normalizer.load()
print(">> TextNormalizer loaded")
self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer)
print(">> bpe model loaded from:", self.bpe_path)
# 加载术语词汇表(如果存在)
self.glossary_path = os.path.join(self.model_dir, "glossary.yaml")
if os.path.exists(self.glossary_path):
self.normalizer.load_glossary_from_yaml(self.glossary_path)
print(">> Glossary loaded from:", self.glossary_path)
emo_matrix = torch.load(os.path.join(self.model_dir, self.cfg.emo_matrix))
self.emo_matrix = emo_matrix.to(self.device)
self.emo_num = list(self.cfg.emo_num)

View File

@ -62,22 +62,23 @@ def de_tokenized_by_CJK_char(line: str, do_lower_case=False) -> str:
output = "see you!"
"""
# replace english words in the line with placeholders
english_word_pattern = re.compile(r"([A-Z]+(?:[\s-][A-Z-]+)*)", re.IGNORECASE)
english_word_pattern = re.compile(r"([A-Z]+(?:[\s'-][A-Z-]+)*)", re.IGNORECASE)
english_sents = english_word_pattern.findall(line)
for i, sent in enumerate(english_sents):
line = line.replace(sent, f"<sent_{i}>")
words = line.split()
# restore english sentences
sent_placeholder_pattern = re.compile(r"^.*?(<sent_(\d+)>)")
sent_placeholder_pattern = re.compile(r"(<sent_(\d+)>)")
for i in range(len(words)):
m = sent_placeholder_pattern.match(words[i])
if m:
all_matches = sent_placeholder_pattern.findall(words[i])
if len(all_matches) > 1:
# restore the english word
placeholder_index = int(m.group(2))
words[i] = words[i].replace(m.group(1), english_sents[placeholder_index])
if do_lower_case:
words[i] = words[i].lower()
for h,j in all_matches:
placeholder_index = int(j)
words[i] = words[i].replace(h, english_sents[placeholder_index])
if do_lower_case:
words[i] = words[i].lower()
return "".join(words)

View File

@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from functools import lru_cache
import os
import traceback
import re
@ -9,7 +10,7 @@ from sentencepiece import SentencePieceProcessor
class TextNormalizer:
def __init__(self):
def __init__(self, enable_glossary=False):
self.zh_normalizer = None
self.en_normalizer = None
self.char_rep_map = {
@ -53,6 +54,24 @@ class TextNormalizer:
"$": ".",
**self.char_rep_map,
}
self.enable_glossary = enable_glossary
# 术语词汇表:用户可自定义专业术语的读法
# 格式: {"原始术语": {"en": "英文读法", "zh": "中文读法"}}
# "M.2": {"en": "M dot two", "zh": "M 二"},
# "PCIe 5.0": {"en": "PCIE five", "zh": "PCIE 五点零"},
# "PCIe 4.0": {"en": "PCIE four", "zh": "PCIE 四点零"},
# "AHCI": "A H C I",
# "TTS": "T T S",
# "Inc.": {"en": "Ink"},
# ".json": {"en": " dot Jay-Son", "zh": "点 Jay-Son"},
# "C++": {"en": "C plus plus", "zh": "C 加加"},
# "C#": "C sharp"
# self.term_glossary = {
# "C++": {"en": "C plus plus", "zh": "C 加加"},
# "C#": "C sharp",
# "CMake": "C Make",
# }
self.term_glossary = dict()
def match_email(self, email):
# 正则表达式匹配邮箱格式:数字英文@数字英文.英文
@ -71,6 +90,14 @@ class TextNormalizer:
例如克里斯托弗·诺兰约瑟夫·高登-莱维特
"""
TECH_TERM_PATTERN = r"[A-Za-z][A-Za-z0-9]*(?:-[A-Za-z0-9]+)+"
"""
匹配技术术语格式字母开头+(字母或数字)*+(-字母或数字)+
例如GPT-5-nano, F5-TTS, Fish-Speech, GPT-5, CosyVoice-2
必须以字母开头避免匹配纯数字如电话号码 135-4567-8900
用于保护连字符结构防止中文normalizer将连字符解析为减号"负五减"
"""
# 匹配常见英语缩写 's仅用于替换为 is不匹配所有 's
ENGLISH_CONTRACTION_PATTERN = r"(what|where|who|which|how|t?here|it|s?he|that|this)'s"
@ -116,7 +143,12 @@ class TextNormalizer:
return ""
if self.use_chinese(text):
text = re.sub(TextNormalizer.ENGLISH_CONTRACTION_PATTERN, r"\1 is", text, flags=re.IGNORECASE)
replaced_text, pinyin_list = self.save_pinyin_tones(text.rstrip())
# 应用术语词汇表(优先级最高,在所有保护之前)
if self.enable_glossary:
text = self.apply_glossary_terms(text, lang="zh")
# 保护技术术语(如 GPT-5-nano避免被中文normalizer错误处理
replaced_text, tech_list = self.save_tech_terms(text.rstrip())
replaced_text, pinyin_list = self.save_pinyin_tones(replaced_text)
replaced_text, original_name_list = self.save_names(replaced_text)
try:
@ -128,12 +160,21 @@ class TextNormalizer:
result = self.restore_names(result, original_name_list)
# 恢复拼音声调
result = self.restore_pinyin_tones(result, pinyin_list)
# 恢复技术术语
result = self.restore_tech_terms(result, tech_list)
pattern = re.compile("|".join(re.escape(p) for p in self.zh_char_rep_map.keys()))
result = pattern.sub(lambda x: self.zh_char_rep_map[x.group()], result)
else:
try:
text = re.sub(TextNormalizer.ENGLISH_CONTRACTION_PATTERN, r"\1 is", text, flags=re.IGNORECASE)
result = self.en_normalizer.normalize(text)
# 应用术语词汇表(优先级最高,在所有保护之前)
if self.enable_glossary:
text = self.apply_glossary_terms(text, lang="en")
# 保护技术术语(如 GPT-5-Nano避免被英文normalizer错误处理
replaced_text, tech_list = self.save_tech_terms(text)
result = self.en_normalizer.normalize(replaced_text)
# 恢复技术术语
result = self.restore_tech_terms(result, tech_list)
except Exception:
result = text
print(traceback.format_exc())
@ -188,6 +229,133 @@ class TextNormalizer:
transformed_text = transformed_text.replace(f"<n_{number}>", name)
return transformed_text
def save_tech_terms(self, original_text):
"""
保护技术术语中的连字符防止被中文normalizer解析为减号
策略将术语中的连字符替换为特殊占位符<H>数字仍可被正常处理
例如GPT-5-nano -> GPT<H>5<H>nano然后 5 被转换为
最终恢复为GPT--nano
"""
tech_pattern = re.compile(TextNormalizer.TECH_TERM_PATTERN)
original_tech_list = tech_pattern.findall(original_text)
if len(original_tech_list) == 0:
return (original_text, None)
# 去重并按长度降序排列(避免短匹配先替换导致问题)
original_tech_list = sorted(set(original_tech_list), key=len, reverse=True)
transformed_text = original_text
# 将术语中的连字符替换为占位符 <H>
for term in original_tech_list:
# 将 GPT-5-nano 替换为 GPT<H>5<H>nano
protected_term = term.replace("-", "<H>")
transformed_text = transformed_text.replace(term, protected_term)
return transformed_text, original_tech_list
def restore_tech_terms(self, normalized_text, original_tech_list):
"""
恢复技术术语中的连字符
将占位符 <H> 恢复为连字符 -
同时清理 normalizer 可能在占位符周围添加的多余空格
"""
if not original_tech_list or len(original_tech_list) == 0:
return normalized_text
# 清理 <H> 周围可能的空格,然后恢复为连字符
# 处理模式: " <H> " -> "-", " <H>" -> "-", "<H> " -> "-", "<H>" -> "-"
transformed_text = re.sub(r'\s*<H>\s*', '-', normalized_text)
return transformed_text
def apply_glossary_terms(self, text, lang="zh"):
"""
应用术语词汇表将专业术语替换为对应语言的读法
Args:
text: 待处理文本
lang: 语言类型 "zh" "en"
Returns:
处理后的文本
Example:
"M.2 NVMe SSD" -> (zh) "M 二 NVMe SSD"
"M.2 NVMe SSD" -> (en) "M dot two NVMe SSD"
"""
if not self.term_glossary:
return text
# 按术语长度降序排列,避免短术语先匹配导致长术语无法匹配
# 例如:"PCIe 5.0" 应该在 "PCIe" 之前匹配
sorted_terms = sorted(self.term_glossary.keys(), key=len, reverse=True)
@lru_cache(maxsize=42)
def get_term_pattern(term: str):
return re.compile(re.escape(term), re.IGNORECASE)
transformed_text = text
for term in sorted_terms:
term_value = self.term_glossary[term]
if isinstance(term_value, dict):
replacement = term_value.get(lang, term_value.get(lang, term))
else:
replacement = term_value
# 使用正则进行大小写不敏感的替换
pattern = get_term_pattern(term)
transformed_text = pattern.sub(replacement, transformed_text)
return transformed_text
def load_glossary(self, glossary_dict):
"""
加载外部术语词汇表
Args:
glossary_dict: 术语词典格式为 {"术语": {"en": "英文读法", "zh": "中文读法"}}
Example:
normalizer.load_glossary({
"M.2": {"en": "M dot two", "zh": "M 二"},
"PCIe": {"en": "PCIE", "zh": "PCIE"}
})
"""
if glossary_dict and isinstance(glossary_dict, dict):
self.term_glossary.update(glossary_dict)
def load_glossary_from_yaml(self, glossary_path):
"""
YAML 文件加载术语词汇表
Args:
glossary_path: YAML 文件路径
Example:
normalizer.load_glossary_from_yaml("checkpoints/glossary.yaml")
YAML 文件格式:
M.2:
en: M dot two
zh: M
NVMe: N-V-M-E # 中英文相同读法
"""
if glossary_path and os.path.exists(glossary_path):
import yaml
with open(glossary_path, 'r', encoding='utf-8') as f:
external_glossary = yaml.safe_load(f)
if external_glossary and isinstance(external_glossary, dict):
self.term_glossary = external_glossary
return True
return False
def save_glossary_to_yaml(self, glossary_path):
"""
保存术语词汇表到 YAML 文件
Args:
glossary_path: YAML 文件路径
"""
import yaml
with open(glossary_path, 'w', encoding='utf-8') as f:
yaml.dump(self.term_glossary, f, allow_unicode=True, default_flow_style=False)
def save_pinyin_tones(self, original_text):
"""
替换拼音声调为占位符 <pinyin_a>, <pinyin_b>, ...
@ -439,7 +607,7 @@ class TextTokenizer:
if __name__ == "__main__":
# 测试程序
text_normalizer = TextNormalizer()
text_normalizer = TextNormalizer(enable_glossary=True)
cases = [
"IndexTTS 正式发布1.0版本了效果666",
@ -474,12 +642,18 @@ if __name__ == "__main__":
"babala2是什么", # babala二是什么?
"用beta1测试", # 用beta一测试
"have you ever been to beta2?", # have you ever been to beta two?
"such as XTTS, CosyVoice2, Fish-Speech, and F5-TTS", # such as xtts,cosyvoice two,fish-speech,and f five-tts
"where's the money?", # where is the money?
"who's there?", # who is there?
"which's the best?", # which is the best?
"how's it going?", # how is it going?
"今天是个好日子 it's a good day", # 今天是个好日子 it is a good day
# 术语
"such as XTTS, CosyVoice2, Fish-Speech, and F5-TTS", # such as xtts,cosyvoice two,fish-speech,and f five-tts
"GPT-5-Nano is the smallest and fastest variant in the GPT-5 model family.", # GPT-five-Nano is the smallest and fastest variant in the GPT-five model family
"GPT-5-Nano 是 GPT-5 模型家族中最小且速度最快的变体", # GPT-五-Nano 是 GPT-五 系统中最小且速度最快的变体
"2025/09/08 IndexTTS-2 全球发布", # 二零二五年九月八日 IndexTTS-二全球发布
"Here are some highly-rated M.2 NVMe SSDs: Samsung 9100 PRO PCIe 5.0 SSD M.2, $139.99", # Here are some highly-rated M dot two NVMe SSD's, Samsung nine thousand one hundred PRO PCIE five SSD M dot two . one hundred and thirty nine dollars and ninety nine cents
"we dive deep into the showdown between DisplayPort 1.4 and HDMI 2.1 to determine which is the best choice for gaming enthusiasts",
# 人名
"约瑟夫·高登-莱维特Joseph Gordon-Levitt is an American actor",
"蒂莫西·唐纳德·库克英文名Timothy Donald Cook通称蒂姆·库克Tim Cook美国商业经理、工业工程师和工业开发商现任苹果公司首席执行官。",

View File

@ -46,5 +46,20 @@
"与音色参考音频相同": "Same as the voice reference",
"情感随机采样": "Randomize emotion sampling",
"显示实验功能": "Show experimental features",
"提示:此功能为实验版,结果尚不稳定,我们正在持续优化中。": "Note: This feature is currently experimental and may not produce satisfactory results. We're dedicated to improving its performance in a future release."
"提示:此功能为实验版,结果尚不稳定,我们正在持续优化中。": "Note: This feature is currently experimental and may not produce satisfactory results. We're dedicated to improving its performance in a future release.",
"自定义个别专业术语的读音": "Customize the pronunciation of each term or how it is \"say as\".",
"请至少输入一种读法": "Please enter at least one pronunciation",
"已添加": "Added",
"开启术语词汇读音": "Enable custom term pronunciations",
"添加术语": "Add term",
"暂无术语": "No custom terms added yet",
"请输入术语": "Please input the term",
"术语": "Term",
"已删除": "Deleted",
"英文读法": "English pronunciation",
"自定义术语词汇读音": "Customize term pronunciations",
"中文读法": "Chinese pronunciation",
"词汇表已更新": "Glossary updated successfully",
"保存词汇表时出错": "Error saving glossary",
"加载词汇表时出错": "Error loading glossary"
}

119
webui.py
View File

@ -109,6 +109,21 @@ def get_example_cases(include_experimental = False):
# exclude emotion control mode 3 (emotion from text description)
return [x for x in example_cases if x[1] != EMO_CHOICES_ALL[3]]
def format_glossary_markdown():
"""将词汇表转换为Markdown表格格式"""
if not tts.normalizer.term_glossary:
return i18n("暂无术语")
lines = [f"| {i18n('术语')} | {i18n('中文读法')} | {i18n('英文读法')} |"]
lines.append("|---|---|---|")
for term, reading in tts.normalizer.term_glossary.items():
zh = reading.get("zh", "") if isinstance(reading, dict) else reading
en = reading.get("en", "") if isinstance(reading, dict) else reading
lines.append(f"| {term} | {zh} | {en} |")
return "\n".join(lines)
def gen_single(emo_control_method,prompt, text,
emo_ref_path, emo_weight,
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
@ -195,8 +210,9 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
gen_button = gr.Button(i18n("生成语音"), key="gen_button",interactive=True)
output_audio = gr.Audio(label=i18n("生成结果"), visible=True,key="output_audio")
experimental_checkbox = gr.Checkbox(label=i18n("显示实验功能"), value=False)
with gr.Row():
experimental_checkbox = gr.Checkbox(label=i18n("显示实验功能"), value=False)
glossary_checkbox = gr.Checkbox(label=i18n("开启术语词汇读音"), value=tts.normalizer.enable_glossary)
with gr.Accordion(i18n("功能设置")):
# 情感控制选项部分
with gr.Row():
@ -246,6 +262,29 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
with gr.Row(visible=False) as emo_weight_group:
emo_weight = gr.Slider(label=i18n("情感权重"), minimum=0.0, maximum=1.0, value=0.65, step=0.01)
# 术语词汇表管理
with gr.Accordion(i18n("自定义术语词汇读音"), open=False, visible=tts.normalizer.enable_glossary) as glossary_accordion:
gr.Markdown(i18n("自定义个别专业术语的读音"))
with gr.Row():
with gr.Column(scale=1):
glossary_term = gr.Textbox(
label=i18n("术语"),
placeholder="IndexTTS2",
)
glossary_reading_zh = gr.Textbox(
label=i18n("中文读法"),
placeholder="Index T-T-S 二",
)
glossary_reading_en = gr.Textbox(
label=i18n("英文读法"),
placeholder="Index T-T-S two",
)
btn_add_term = gr.Button(i18n("添加术语"), scale=1)
with gr.Column(scale=2):
glossary_table = gr.Markdown(
value=format_glossary_markdown()
)
with gr.Accordion(i18n("高级生成参数设置"), open=False, visible=True) as advanced_settings_group:
with gr.Row():
with gr.Column(scale=1):
@ -353,6 +392,48 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
segments_preview: gr.update(value=df),
}
# 术语词汇表事件处理函数
def on_add_glossary_term(term, reading_zh, reading_en):
"""添加术语到词汇表并自动保存"""
term = term.rstrip()
reading_zh = reading_zh.rstrip()
reading_en = reading_en.rstrip()
if not term:
gr.Warning(i18n("请输入术语"))
return gr.update()
if not reading_zh and not reading_en:
gr.Warning(i18n("请至少输入一种读法"))
return gr.update()
# 构建读法数据
if reading_zh and reading_en:
reading = {"zh": reading_zh, "en": reading_en}
elif reading_zh:
reading = {"zh": reading_zh}
elif reading_en:
reading = {"en": reading_en}
else:
reading = reading_zh or reading_en
# 添加到词汇表
tts.normalizer.term_glossary[term] = reading
# 自动保存到文件
try:
tts.normalizer.save_glossary_to_yaml(tts.glossary_path)
gr.Info(i18n("词汇表已更新"), duration=1)
except Exception as e:
gr.Error(i18n("保存词汇表时出错"))
print(f"Error details: {e}")
return gr.update()
# 更新Markdown表格
return gr.update(value=format_glossary_markdown())
def on_method_change(emo_control_method):
if emo_control_method == 1: # emotion reference audio
return (gr.update(visible=True),
@ -410,6 +491,17 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
outputs=[emo_control_method, example_table]
)
def on_glossary_checkbox_change(is_enabled):
"""控制术语词汇表的可见性"""
tts.normalizer.enable_glossary = is_enabled
return gr.update(visible=is_enabled)
glossary_checkbox.change(
on_glossary_checkbox_change,
inputs=[glossary_checkbox],
outputs=[glossary_accordion]
)
input_text_single.change(
on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_segment],
@ -426,6 +518,29 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
inputs=[],
outputs=[gen_button])
def on_demo_load():
"""页面加载时重新加载glossary数据"""
try:
tts.normalizer.load_glossary_from_yaml(tts.glossary_path)
except Exception as e:
gr.Error(i18n("加载词汇表时出错"))
print(f"Failed to reload glossary on page load: {e}")
return gr.update(value=format_glossary_markdown())
# 术语词汇表事件绑定
btn_add_term.click(
on_add_glossary_term,
inputs=[glossary_term, glossary_reading_zh, glossary_reading_en],
outputs=[glossary_table]
)
# 页面加载时重新加载glossary
demo.load(
on_demo_load,
inputs=[],
outputs=[glossary_table]
)
gen_button.click(gen_single,
inputs=[emo_control_method,prompt_audio, input_text_single, emo_upload, emo_weight,
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,