批次推理:修复(漏句/丢句/音频空白) (#100)

* 批次推理:重要修复(漏句/丢句/音频空白)

* 批次推理:新增数据分桶机制,增强稳定性~
This commit is contained in:
sunnyboxs 2025-04-18 17:57:07 +08:00 committed by GitHub
parent 919062dfb0
commit 71c5295198
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -185,6 +185,20 @@ class IndexTTS:
sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in {"'", ".", ","}
]
def bucket_sentences(self, sentences, enable):
"""
Sentence data bucketing
"""
max_len = max(len(s) for s in sentences)
half = max_len // 2
outputs = [[],[]]
for idx, sent in enumerate(sentences):
if enable == False or len(sent) <= half:
outputs[0].append({"idx":idx,"sent":sent})
else:
outputs[1].append({"idx":idx,"sent":sent})
return [item for item in outputs if item]
def pad_tokens_cat(self, tokens):
if len(tokens) <= 1:return tokens[-1]
max_len = max(t.size(1) for t in tokens)
@ -192,13 +206,17 @@ class IndexTTS:
for tensor in tokens:
pad_len = max_len - tensor.size(1)
if pad_len > 0:
padded = torch.nn.functional.pad(tensor,
(0, pad_len),
n = min(8, pad_len)
tensor = torch.nn.functional.pad(tensor,
(0, n),
value=self.cfg.gpt.stop_text_token
)
outputs.append(padded)
else:
outputs.append(tensor)
tensor = torch.nn.functional.pad(tensor,
(0, pad_len - n),
value=self.cfg.gpt.start_text_token
)
tensor = tensor[:,:max_len]
outputs.append(tensor)
tokens = torch.cat(outputs, dim=0)
return tokens
@ -270,83 +288,98 @@ class IndexTTS:
gpt_forward_time = 0
bigvgan_time = 0
# text processing
all_text_tokens = []
self._set_gr_progress(0.1, "text processing...")
for sent in sentences:
# sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper()
cleand_text = tokenize_by_CJK_char(sent)
# cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ."
if verbose:
print("cleand_text:", cleand_text)
text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0)
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
# text_tokens = F.pad(text_tokens, (1, 0), value=0)
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
if verbose:
print(text_tokens)
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
# debug tokenizer
text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist())
print(text_token_syms)
all_text_tokens.append(text_tokens)
bucket_enable = True # 预分桶开关,优先保证质量=True。优先保证速度=False。
all_sentences = self.bucket_sentences(sentences, enable=bucket_enable)
for sentences in all_sentences:
temp_tokens = []
all_text_tokens.append(temp_tokens)
for item in sentences:
sent = item["sent"]
# sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper()
cleand_text = tokenize_by_CJK_char(sent)
# cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ."
if verbose:
print("cleand_text:", cleand_text)
text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0)
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
# text_tokens = F.pad(text_tokens, (1, 0), value=0)
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
if verbose:
print(text_tokens)
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
# debug tokenizer
text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist())
print(text_token_syms)
temp_tokens.append(text_tokens)
batch_num = len(all_text_tokens)
batch_text_tokens = self.pad_tokens_cat(all_text_tokens)
batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0)
batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0)
# gpt speech
self._set_gr_progress(0.2, "gpt inference speech...")
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
batch_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens,
cond_mel_lengths=batch_cond_mel_lengths,
# text_lengths=text_len,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=autoregressive_batch_size,
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
gpt_gen_time += time.perf_counter() - m_start_time
# clear cache
batch_auto_conditioning = None
batch_cond_mel_lengths = None
batch_text_tokens = None
self.torch_empty_cache()
# gpt latent
self._set_gr_progress(0.5, "gpt inference latents...")
all_latents = []
for i in range(batch_codes.shape[0]):
codes = batch_codes[i] # [x]
codes = codes[codes != self.cfg.gpt.stop_mel_token]
codes, _ = torch.unique_consecutive(codes, return_inverse=True)
codes = codes.unsqueeze(0) # [x] -> [1, x]
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
text_tokens = all_text_tokens[i]
# Sequential processing of bucketing data
all_batch_num = 0
all_batch_codes = []
for item_tokens in all_text_tokens:
batch_num = len(item_tokens)
batch_text_tokens = self.pad_tokens_cat(item_tokens)
batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0)
batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0)
all_batch_num += batch_num
# gpt speech
self._set_gr_progress(0.2, "gpt inference speech...")
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
latent = \
self.gpt(auto_conditioning, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
code_lens*self.gpt.mel_length_compression,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
gpt_forward_time += time.perf_counter() - m_start_time
all_latents.append(latent)
temp_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens,
cond_mel_lengths=batch_cond_mel_lengths,
# text_lengths=text_len,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=autoregressive_batch_size,
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
all_batch_codes.append(temp_codes)
gpt_gen_time += time.perf_counter() - m_start_time
# gpt latent
self._set_gr_progress(0.5, "gpt inference latents...")
all_idxs = []
all_latents = []
for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences):
for i in range(batch_codes.shape[0]):
codes = batch_codes[i] # [x]
codes = codes[codes != self.cfg.gpt.stop_mel_token]
codes, _ = torch.unique_consecutive(codes, return_inverse=True)
codes = codes.unsqueeze(0) # [x] -> [1, x]
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
text_tokens = batch_tokens[i]
all_idxs.append(batch_sentences[i]["idx"])
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
latent = \
self.gpt(auto_conditioning, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
code_lens*self.gpt.mel_length_compression,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
gpt_forward_time += time.perf_counter() - m_start_time
all_latents.append(latent)
# bigvgan chunk
chunk_size = 2
all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
chunk_latents = [all_latents[i:i + chunk_size] for i in range(0, len(all_latents), chunk_size)]
chunk_length = len(chunk_latents)
latent_length = len(all_latents)
@ -385,7 +418,7 @@ class IndexTTS:
print(f">> Total fast inference time: {end_time - start_time:.2f} seconds")
print(f">> Generated audio length: {wav_length:.2f} seconds")
print(f">> [fast] bigvgan chunk_length: {chunk_length}")
print(f">> [fast] batch_num: {batch_num}")
print(f">> [fast] batch_num: {all_batch_num} bucket_enable: {bucket_enable}")
print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}")
# save audio