fix long silence bug.

This commit is contained in:
root 2025-04-09 19:45:18 +08:00
parent 19be5dba2d
commit 47ec591d40

View File

@ -5,6 +5,7 @@ import sys
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from omegaconf import OmegaConf
from indextts.BigVGAN.models import BigVGAN as Generator
@ -59,8 +60,10 @@ class IndexTTS:
# return text.translate(punctuation_map)
return self.normalizer.infer(text)
def remove_long_silence(self, codes):
def remove_long_silence(self, codes, silent_token=52, max_consecutive=30):
code_lens = []
codes_list = []
isfix = False
for i in range(0, codes.shape[0]):
code = codes[i]
if self.cfg.gpt.stop_mel_token not in code:
@ -71,26 +74,32 @@ class IndexTTS:
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1
len_ = len_ - 2
count = torch.sum(code == 52).item()
if count > 50:
count = torch.sum(code == silent_token).item()
if count > max_consecutive:
code = code.cpu().tolist()
ncode = []
n = 0
for k in range(0, len_):
if code[k] != 52:
if code[k] != silent_token:
ncode.append(code[k])
n = 0
elif code[k] == 52 and n < 30:
elif code[k] == silent_token and n < 10:
ncode.append(code[k])
n += 1
# if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52):
# n += 1
len_ = len(ncode)
ncode = torch.LongTensor(ncode)
codes[i] = self.stop_mel_token
codes[i, 0:len_] = ncode
codes_list.append(ncode.cuda())
isfix = True
#codes[i] = self.stop_mel_token
#codes[i, 0:len_] = ncode
else:
codes_list.append(codes[i])
code_lens.append(len_)
code_lens = torch.LongTensor(code_lens).cuda()
if isfix:
codes = pad_sequence(codes_list, batch_first=True)
return codes, code_lens
def infer(self, audio_prompt, text, output_path):
@ -164,13 +173,17 @@ class IndexTTS:
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
codes = codes[:, :-2]
# code_lens = torch.tensor([codes.shape[-1]])
#codes = codes[:, :-2]
code_lens = torch.tensor([codes.shape[-1]])
print(codes)
print(f"codes shape: {codes.shape}")
print(f"code len: {code_lens}")
# remove ultra-long silence if exits
codes, code_lens = self.remove_long_silence(codes)
# temporarily fix the long silence bug.
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
print(codes)
print(f"codes shape: {codes.shape}")
print(f"code len: {code_lens}")
# latent, text_lens_out, code_lens_out = \
latent = \
@ -205,6 +218,7 @@ class IndexTTS:
prompt_wav="test_data/input.wav"
text="晕 XUAN4 是 一 种 GAN3 觉"
text='大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了'
text="There is a vehicle arriving in dock number 7?"
if __name__ == "__main__":
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")