diff --git a/indextts/infer.py b/indextts/infer.py index 1019242..3452868 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -75,6 +75,8 @@ class IndexTTS: self.bigvgan.eval() print(">> bigvgan weights restored from:", self.bigvgan_path) self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset['bpe_model']) + self.tokenizer = spm.SentencePieceProcessor(model_file=self.bpe_path) + print(">> bpe model loaded from:", self.bpe_path) self.normalizer = TextNormalizer() self.normalizer.load() print(">> TextNormalizer loaded") @@ -134,6 +136,18 @@ class IndexTTS: code_lens = torch.LongTensor(code_lens).to(device, dtype=dtype) return codes, code_lens + def split_sentences(self, text): + """ + Split the text into sentences based on punctuation marks. + """ + # 匹配标点符号(包括中英文标点) + pattern = r'(?<=[.!?;。!?;])\s*' + sentences = re.split(pattern, text) + # 过滤掉空字符串和仅包含标点符号的字符串 + return [ + sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in {"'", ".", ","} + ] + def infer(self, audio_prompt, text, output_path): print(f"origin text:{text}") text = self.preprocess_text(text) @@ -150,14 +164,8 @@ class IndexTTS: auto_conditioning = cond_mel - tokenizer = spm.SentencePieceProcessor() - tokenizer.load(self.bpe_path) - - punctuation = ["!", "?", ".", ";", "!", "?", "。", ";"] - pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) - sentences = [i for i in re.split(pattern, text) if i.strip() != ""] + sentences = self.split_sentences(text) print("sentences:", sentences) - top_p = .8 top_k = 30 temperature = 1.0 @@ -167,8 +175,8 @@ class IndexTTS: repetition_penalty = 10.0 max_mel_tokens = 600 sampling_rate = 24000 - lang = "EN" - lang = "ZH" + # lang = "EN" + # lang = "ZH" wavs = [] print(">> start inference...") @@ -181,19 +189,19 @@ class IndexTTS: # cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ." print("cleand_text:", cleand_text) - text_tokens = torch.IntTensor(tokenizer.encode(cleand_text)).unsqueeze(0).to(self.device) - + text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0) # text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. # text_tokens = F.pad(text_tokens, (1, 0), value=0) # text_tokens = F.pad(text_tokens, (0, 1), value=1) - # text_tokens = text_tokens.to(self.device) + print(text_tokens) print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") - text_token_syms = [tokenizer.IdToPiece(idx) for idx in text_tokens[0].tolist()] + # debug tokenizer + text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist()) print(text_token_syms) - text_len = [text_tokens.size(1)] - text_len = torch.IntTensor(text_len).to(self.device) - print(text_len) + + # text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device) + # print(text_len) with torch.no_grad(): with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): @@ -234,8 +242,7 @@ class IndexTTS: wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2)) wav = wav.squeeze(1).cpu() - wav = 32767 * wav - torch.clip(wav, -32767.0, 32767.0) + wav = torch.clip(32767 * wav, -32767.0, 32767.0) print(f"wav shape: {wav.shape}") # wavs.append(wav[:, :-512]) wavs.append(wav) @@ -244,7 +251,7 @@ class IndexTTS: elapsed_time = end_time - start_time minutes, seconds = divmod(int(elapsed_time), 60) milliseconds = int((elapsed_time - int(elapsed_time)) * 1000) - print(f">> inference done. time: {minutes}:{seconds}.{milliseconds}") + print(f">> inference done. time: {minutes:02d}:{seconds:02d}.{milliseconds:03d}") print(">> saving wav file") wav = torch.cat(wavs, dim=1) torchaudio.save(output_path, wav.type(torch.int16), sampling_rate) @@ -259,4 +266,4 @@ if __name__ == "__main__": text="There is a vehicle arriving in dock number 7?" tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True) - tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav") + tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True) diff --git a/indextts/utils/front.py b/indextts/utils/front.py index a43d4cc..93df6f0 100644 --- a/indextts/utils/front.py +++ b/indextts/utils/front.py @@ -4,7 +4,6 @@ import re class TextNormalizer: def __init__(self): - # self.normalizer = Normalizer(cache_dir="textprocessing/tn") self.zh_normalizer = None self.en_normalizer = None self.char_rep_map = { @@ -15,8 +14,8 @@ class TextNormalizer: "。": ".", "!": "!", "?": "?", - "\n": ".", - "·": ",", + "\n": " ", + "·": "-", "、": ",", "...": "…", "……": "…", @@ -48,16 +47,20 @@ class TextNormalizer: # 正则表达式匹配邮箱格式:数字英文@数字英文.英文 pattern = r'^[a-zA-Z0-9]+@[a-zA-Z0-9]+\.[a-zA-Z]+$' return re.match(pattern, email) is not None - + """ + 匹配拼音声调格式:pinyin+数字,声调1-5,5表示轻声 + 例如:xuan4, jve2, ying1, zhong4, shang5 + """ + PINYIN_TONE_PATTERN = r"([bmnpqdfghjklzcsxwy]?h?[aeiouüv]{1,2}[ng]*|ng)([1-5])" def use_chinese(self, s): has_chinese = bool(re.search(r'[\u4e00-\u9fff]', s)) - has_digit = bool(re.search(r'\d', s)) has_alpha = bool(re.search(r'[a-zA-Z]', s)) is_email = self.match_email(s) if has_chinese or not has_alpha or is_email: return True - else: - return False + + has_pinyin = bool(re.search(self.PINYIN_TONE_PATTERN, s, re.IGNORECASE)) + return has_pinyin def load(self): # print(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) @@ -73,94 +76,102 @@ class TextNormalizer: self.zh_normalizer = NormalizerZh(remove_interjections=False, remove_erhua=False,overwrite_cache=False) self.en_normalizer = NormalizerEn(overwrite_cache=False) - def infer(self, text): - pattern = re.compile("|".join(re.escape(p) for p in self.char_rep_map.keys())) - replaced_text = pattern.sub(lambda x: self.char_rep_map[x.group()], text) + def infer(self, text: str): if not self.zh_normalizer or not self.en_normalizer: print("Error, text normalizer is not initialized !!!") return "" + replaced_text, pinyin_list = self.save_pinyin_tones(text.rstrip()) + try: normalizer = self.zh_normalizer if self.use_chinese(replaced_text) else self.en_normalizer result = normalizer.normalize(replaced_text) except Exception: result = "" print(traceback.format_exc()) - result = self.restore_pinyin_tone_numbers(replaced_text, result) + result = self.restore_pinyin_tones(result, pinyin_list) + pattern = re.compile("|".join(re.escape(p) for p in self.char_rep_map.keys())) + result = pattern.sub(lambda x: self.char_rep_map[x.group()], result) return result - def pinyin_match(self, pinyin): - pattern = r"(qun)(\d)" - repl = r"qvn\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(quan)(\d)" - repl = r"qvan\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(que)(\d)" - repl = r"qve\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(qu)(\d)" - repl = r"qv\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(ju)(\d)" - repl = r"jv\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(jue)(\d)" - repl = r"jve\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(xun)(\d)" - repl = r"xvn\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(xue)(\d)" - repl = r"xve\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(xu)(\d)" - repl = r"xv\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(juan)(\d)" - repl = r"jvan\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(jun)(\d)" - repl = r"jvn\g<2>" - pinyin = re.sub(pattern, repl, pinyin) - - pattern = r"(xuan)(\d)" - repl = r"xvan\g<2>" + def correct_pinyin(self, pinyin): + """ + 将 jqx 的韵母为 u/ü 的拼音转换为 v + 如:ju -> jv , que -> qve, xün -> xvn + """ + if pinyin[0] not in "jqx": + return pinyin + # 匹配 jqx 的韵母为 u/ü 的拼音 + pattern = r"([jqx])[uü](n|e|an)*(\d)" + repl = r"\g<1>v\g<2>\g<3>" pinyin = re.sub(pattern, repl, pinyin) return pinyin - def restore_pinyin_tone_numbers(self,original_text, processed_text): - # 第一步:恢复拼音后的音调数字(1-4) - # 建立中文数字到阿拉伯数字的映射 - chinese_to_num = {'一': '1', '二': '2', '三': '3', '四': '4'} + def save_pinyin_tones(self, original_text): + """ + 替换拼音声调为占位符 , , ... + 例如:xuan4 -> + """ + # 声母韵母+声调数字 + origin_pinyin_pattern = re.compile(self.PINYIN_TONE_PATTERN, re.IGNORECASE) + original_pinyin_list = re.findall(origin_pinyin_pattern, original_text) + if len(original_pinyin_list) == 0: + return (original_text, None) + original_pinyin_list = list(set(''.join(p) for p in original_pinyin_list)) + transformed_text = original_text + # 替换为占位符 , , ... + for i, pinyin in enumerate(original_pinyin_list): + number = chr(ord("a") + i) + transformed_text = transformed_text.replace(pinyin, f"") + + # print("original_text: ", original_text) + # print("transformed_text: ", transformed_text) + return transformed_text, original_pinyin_list - # 使用正则表达式找到拼音+中文数字的组合(如 "xuan四") - def replace_tone(match): - pinyin = match.group(1) # 拼音部分 - chinese_num = match.group(2) # 中文数字部分 - # 将中文数字转换为阿拉伯数字 - num = chinese_to_num.get(chinese_num, chinese_num) - return f"{pinyin}{num}" - - # 匹配拼音后跟中文数字(一、二、三、四)的情况 - pattern = r'([a-zA-Z]+)([一二三四])' - restored_text = re.sub(pattern, replace_tone, processed_text) - restored_text = restored_text.lower() - restored_text = self.pinyin_match(restored_text) - - return restored_text + def restore_pinyin_tones(self, normalized_text, original_pinyin_list): + """ + 恢复拼音中的音调数字(1-5)为原来的拼音 + 例如: -> original_pinyin_list[0] + """ + if not original_pinyin_list or len(original_pinyin_list) == 0: + return normalized_text + transformed_text = normalized_text + # 替换为占位符 , , ... + for i, pinyin in enumerate(original_pinyin_list): + number = chr(ord("a") + i) + pinyin = self.correct_pinyin(pinyin) + transformed_text = transformed_text.replace(f"", pinyin) + # print("normalized_text: ", normalized_text) + # print("transformed_text: ", transformed_text) + return transformed_text if __name__ == '__main__': # 测试程序 text_normalizer = TextNormalizer() - print(text_normalizer.infer("2.5平方电线")) + text_normalizer.load() + cases = [ + "我爱你!", + "I love you!", + "我爱你的英语是”I love you“", + "2.5平方电线", + "共465篇,约315万字", + "2002年的第一场雪,下在了2003年", + "速度是10km/h", + "现在是北京时间2025年01月11日 20:00", + "他这条裤子是2012年买的,花了200块钱", + "电话:135-4567-8900", + "1键3连", + "他这条视频点赞3000+,评论1000+,收藏500+", + "这是1024元的手机,你要吗?", + "受不liao3你了", + "”衣裳“不读衣chang2,而是读衣shang5", + "最zhong4要的是:不要chong2蹈覆辙", + "IndexTTS 正式发布1.0版本了,效果666", + "See you at 8:00 AM", + "8:00 AM 开会", + "苹果于2030/1/2发布新 iPhone 2X 系列手机,最低售价仅 ¥12999", + ] + for case in cases: + print(f"原始文本: {case}") + print(f"处理后文本: {text_normalizer.infer(case)}") + print("-" * 50)