From 2523001bb4ff87c8a1547049e9662dd76ebed5ea Mon Sep 17 00:00:00 2001 From: boostpapa Date: Tue, 8 Apr 2025 11:23:11 +0800 Subject: [PATCH 1/4] support ultra-long silence filtering --- indextts/infer.py | 51 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/indextts/infer.py b/indextts/infer.py index f1f419e..ff707f5 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -42,6 +42,7 @@ class IndexTTS: self.bigvgan = self.bigvgan.to(self.device) self.bigvgan.eval() print(">> bigvgan weights restored from:", self.bigvgan_path) + self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset['bpe_model']) self.normalizer = TextNormalizer() self.normalizer.load() print(">> TextNormalizer loaded") @@ -57,6 +58,40 @@ class IndexTTS: # return text.translate(punctuation_map) return self.normalizer.infer(text) + def remove_long_silence(self, codes): + code_lens = [] + for i in range(0, codes.shape[0]): + code = codes[i] + if self.cfg.gpt.stop_mel_token not in code: + code_lens.append(len(code)) + len_ = len(code) + else: + # len_ = code.cpu().tolist().index(8193)+1 + len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1 + len_ = len_ - 2 + + count = torch.sum(code == 52).item() + if count > 50: + code = code.cpu().tolist() + ncode = [] + n = 0 + for k in range(0, len_): + if code[k] != 52: + ncode.append(code[k]) + n = 0 + elif code[k] == 52 and n < 30: + ncode.append(code[k]) + n += 1 + # if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52): + # n += 1 + len_ = len(ncode) + ncode = torch.LongTensor(ncode) + codes[i] = self.stop_mel_token + codes[i, 0:len_] = ncode + code_lens.append(len_) + code_lens = torch.LongTensor(code_lens).cuda() + return codes, code_lens + def infer(self, audio_prompt, text, output_path): print(f"origin text:{text}") text = self.preprocess_text(text) @@ -74,7 +109,7 @@ class IndexTTS: auto_conditioning = cond_mel tokenizer = spm.SentencePieceProcessor() - tokenizer.load(os.path.join(self.model_dir,self.cfg.dataset['bpe_model'])) + tokenizer.load(self.bpe_path) punctuation = ["!", "?", ".", ";", "!", "?", "。", ";"] pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) @@ -128,15 +163,19 @@ class IndexTTS: num_beams=num_beams, repetition_penalty=repetition_penalty, max_generate_length=max_mel_tokens) + codes = codes[:, :-2] + # code_lens = torch.tensor([codes.shape[-1]]) print(codes) print(f"codes shape: {codes.shape}") - codes = codes[:, :-2] + # remove ultra-long silence if exits + codes, code_lens = self.remove_long_silence(codes) + print(f"codes shape: {codes.shape}") # latent, text_lens_out, code_lens_out = \ latent = \ self.gpt(auto_conditioning, text_tokens, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, - torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device), + code_lens*self.gpt.mel_length_compression, cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), return_latent=True, clip_inputs=False) latent = latent.transpose(1, 2) @@ -162,6 +201,10 @@ class IndexTTS: torchaudio.save(output_path, wav.type(torch.int16), 24000) +prompt_wav="/juicefs/users/wd007/work2024/tts/indextts/testwav/spk_1743041132.wav" +text="晕 XUAN4 是 一 种 GAN3 觉" +text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!' if __name__ == "__main__": tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints") - tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!', output_path="gen.wav") + #tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!', output_path="gen.wav") + tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav") From ae395dc4166aa8874fba037a4dc71e4d8bbfb38e Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Apr 2025 11:54:31 +0800 Subject: [PATCH 2/4] cleanup code --- indextts/infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/indextts/infer.py b/indextts/infer.py index ff707f5..f67625e 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -201,10 +201,10 @@ class IndexTTS: torchaudio.save(output_path, wav.type(torch.int16), 24000) -prompt_wav="/juicefs/users/wd007/work2024/tts/indextts/testwav/spk_1743041132.wav" +prompt_wav="test_data/input.wav" text="晕 XUAN4 是 一 种 GAN3 觉" text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!' + if __name__ == "__main__": tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints") - #tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!', output_path="gen.wav") tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav") From 18e20ccbb4e33ba137385a8fe32f9a793a44a5aa Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Apr 2025 10:35:47 +0800 Subject: [PATCH 3/4] enable front-end caching to speed up startup. --- indextts/utils/front.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/indextts/utils/front.py b/indextts/utils/front.py index cf5a7e9..a43d4cc 100644 --- a/indextts/utils/front.py +++ b/indextts/utils/front.py @@ -70,8 +70,8 @@ class TextNormalizer: else: from tn.chinese.normalizer import Normalizer as NormalizerZh from tn.english.normalizer import Normalizer as NormalizerEn - self.zh_normalizer = NormalizerZh(remove_interjections=False, remove_erhua=False,overwrite_cache=True) - self.en_normalizer = NormalizerEn(overwrite_cache=True) + self.zh_normalizer = NormalizerZh(remove_interjections=False, remove_erhua=False,overwrite_cache=False) + self.en_normalizer = NormalizerEn(overwrite_cache=False) def infer(self, text): pattern = re.compile("|".join(re.escape(p) for p in self.char_rep_map.keys())) @@ -163,4 +163,4 @@ class TextNormalizer: if __name__ == '__main__': # 测试程序 text_normalizer = TextNormalizer() - print(text_normalizer.infer("2.5平方电线")) \ No newline at end of file + print(text_normalizer.infer("2.5平方电线")) From 19be5dba2d16e910f182bc04f150c98b6f81be4d Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Apr 2025 10:38:51 +0800 Subject: [PATCH 4/4] fix bug. --- indextts/infer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/indextts/infer.py b/indextts/infer.py index f67625e..cf15856 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -20,6 +20,7 @@ class IndexTTS: self.cfg = OmegaConf.load(cfg_path) self.device = 'cuda:0' self.model_dir = model_dir + self.stop_mel_token = self.cfg.gpt.stop_mel_token self.dvae = DiscreteVAE(**self.cfg.vqvae) self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint) load_checkpoint(self.dvae, self.dvae_path)