fix long silence bug.
This commit is contained in:
parent
19be5dba2d
commit
47ec591d40
@ -5,6 +5,7 @@ import sys
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from indextts.BigVGAN.models import BigVGAN as Generator
|
||||
@ -59,8 +60,10 @@ class IndexTTS:
|
||||
# return text.translate(punctuation_map)
|
||||
return self.normalizer.infer(text)
|
||||
|
||||
def remove_long_silence(self, codes):
|
||||
def remove_long_silence(self, codes, silent_token=52, max_consecutive=30):
|
||||
code_lens = []
|
||||
codes_list = []
|
||||
isfix = False
|
||||
for i in range(0, codes.shape[0]):
|
||||
code = codes[i]
|
||||
if self.cfg.gpt.stop_mel_token not in code:
|
||||
@ -71,26 +74,32 @@ class IndexTTS:
|
||||
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1
|
||||
len_ = len_ - 2
|
||||
|
||||
count = torch.sum(code == 52).item()
|
||||
if count > 50:
|
||||
count = torch.sum(code == silent_token).item()
|
||||
if count > max_consecutive:
|
||||
code = code.cpu().tolist()
|
||||
ncode = []
|
||||
n = 0
|
||||
for k in range(0, len_):
|
||||
if code[k] != 52:
|
||||
if code[k] != silent_token:
|
||||
ncode.append(code[k])
|
||||
n = 0
|
||||
elif code[k] == 52 and n < 30:
|
||||
elif code[k] == silent_token and n < 10:
|
||||
ncode.append(code[k])
|
||||
n += 1
|
||||
# if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52):
|
||||
# n += 1
|
||||
len_ = len(ncode)
|
||||
ncode = torch.LongTensor(ncode)
|
||||
codes[i] = self.stop_mel_token
|
||||
codes[i, 0:len_] = ncode
|
||||
codes_list.append(ncode.cuda())
|
||||
isfix = True
|
||||
#codes[i] = self.stop_mel_token
|
||||
#codes[i, 0:len_] = ncode
|
||||
else:
|
||||
codes_list.append(codes[i])
|
||||
code_lens.append(len_)
|
||||
code_lens = torch.LongTensor(code_lens).cuda()
|
||||
if isfix:
|
||||
codes = pad_sequence(codes_list, batch_first=True)
|
||||
return codes, code_lens
|
||||
|
||||
def infer(self, audio_prompt, text, output_path):
|
||||
@ -164,13 +173,17 @@ class IndexTTS:
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
codes = codes[:, :-2]
|
||||
# code_lens = torch.tensor([codes.shape[-1]])
|
||||
#codes = codes[:, :-2]
|
||||
code_lens = torch.tensor([codes.shape[-1]])
|
||||
print(codes)
|
||||
print(f"codes shape: {codes.shape}")
|
||||
print(f"code len: {code_lens}")
|
||||
# remove ultra-long silence if exits
|
||||
codes, code_lens = self.remove_long_silence(codes)
|
||||
# temporarily fix the long silence bug.
|
||||
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
|
||||
print(codes)
|
||||
print(f"codes shape: {codes.shape}")
|
||||
print(f"code len: {code_lens}")
|
||||
|
||||
# latent, text_lens_out, code_lens_out = \
|
||||
latent = \
|
||||
@ -205,6 +218,7 @@ class IndexTTS:
|
||||
prompt_wav="test_data/input.wav"
|
||||
text="晕 XUAN4 是 一 种 GAN3 觉"
|
||||
text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!'
|
||||
text="There is a vehicle arriving in dock number 7?"
|
||||
|
||||
if __name__ == "__main__":
|
||||
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user