From 47ec591d40a3991c93be03221cb56ce5aaabae1e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Apr 2025 19:45:18 +0800 Subject: [PATCH] fix long silence bug. --- indextts/infer.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/indextts/infer.py b/indextts/infer.py index cf15856..84728e3 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -5,6 +5,7 @@ import sys import sentencepiece as spm import torch import torchaudio +from torch.nn.utils.rnn import pad_sequence from omegaconf import OmegaConf from indextts.BigVGAN.models import BigVGAN as Generator @@ -59,8 +60,10 @@ class IndexTTS: # return text.translate(punctuation_map) return self.normalizer.infer(text) - def remove_long_silence(self, codes): + def remove_long_silence(self, codes, silent_token=52, max_consecutive=30): code_lens = [] + codes_list = [] + isfix = False for i in range(0, codes.shape[0]): code = codes[i] if self.cfg.gpt.stop_mel_token not in code: @@ -71,26 +74,32 @@ class IndexTTS: len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1 len_ = len_ - 2 - count = torch.sum(code == 52).item() - if count > 50: + count = torch.sum(code == silent_token).item() + if count > max_consecutive: code = code.cpu().tolist() ncode = [] n = 0 for k in range(0, len_): - if code[k] != 52: + if code[k] != silent_token: ncode.append(code[k]) n = 0 - elif code[k] == 52 and n < 30: + elif code[k] == silent_token and n < 10: ncode.append(code[k]) n += 1 # if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52): # n += 1 len_ = len(ncode) ncode = torch.LongTensor(ncode) - codes[i] = self.stop_mel_token - codes[i, 0:len_] = ncode + codes_list.append(ncode.cuda()) + isfix = True + #codes[i] = self.stop_mel_token + #codes[i, 0:len_] = ncode + else: + codes_list.append(codes[i]) code_lens.append(len_) code_lens = torch.LongTensor(code_lens).cuda() + if isfix: + codes = pad_sequence(codes_list, batch_first=True) return codes, code_lens def infer(self, audio_prompt, text, output_path): @@ -164,13 +173,17 @@ class IndexTTS: num_beams=num_beams, repetition_penalty=repetition_penalty, max_generate_length=max_mel_tokens) - codes = codes[:, :-2] - # code_lens = torch.tensor([codes.shape[-1]]) + #codes = codes[:, :-2] + code_lens = torch.tensor([codes.shape[-1]]) print(codes) print(f"codes shape: {codes.shape}") + print(f"code len: {code_lens}") # remove ultra-long silence if exits - codes, code_lens = self.remove_long_silence(codes) + # temporarily fix the long silence bug. + codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30) + print(codes) print(f"codes shape: {codes.shape}") + print(f"code len: {code_lens}") # latent, text_lens_out, code_lens_out = \ latent = \ @@ -205,6 +218,7 @@ class IndexTTS: prompt_wav="test_data/input.wav" text="晕 XUAN4 是 一 种 GAN3 觉" text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!' +text="There is a vehicle arriving in dock number 7?" if __name__ == "__main__": tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")