批次推理:修复(漏句/丢句/音频空白) (#100)
* 批次推理:重要修复(漏句/丢句/音频空白) * 批次推理:新增数据分桶机制,增强稳定性~
This commit is contained in:
parent
919062dfb0
commit
71c5295198
@ -185,6 +185,20 @@ class IndexTTS:
|
||||
sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in {"'", ".", ","}
|
||||
]
|
||||
|
||||
def bucket_sentences(self, sentences, enable):
|
||||
"""
|
||||
Sentence data bucketing
|
||||
"""
|
||||
max_len = max(len(s) for s in sentences)
|
||||
half = max_len // 2
|
||||
outputs = [[],[]]
|
||||
for idx, sent in enumerate(sentences):
|
||||
if enable == False or len(sent) <= half:
|
||||
outputs[0].append({"idx":idx,"sent":sent})
|
||||
else:
|
||||
outputs[1].append({"idx":idx,"sent":sent})
|
||||
return [item for item in outputs if item]
|
||||
|
||||
def pad_tokens_cat(self, tokens):
|
||||
if len(tokens) <= 1:return tokens[-1]
|
||||
max_len = max(t.size(1) for t in tokens)
|
||||
@ -192,13 +206,17 @@ class IndexTTS:
|
||||
for tensor in tokens:
|
||||
pad_len = max_len - tensor.size(1)
|
||||
if pad_len > 0:
|
||||
padded = torch.nn.functional.pad(tensor,
|
||||
(0, pad_len),
|
||||
n = min(8, pad_len)
|
||||
tensor = torch.nn.functional.pad(tensor,
|
||||
(0, n),
|
||||
value=self.cfg.gpt.stop_text_token
|
||||
)
|
||||
outputs.append(padded)
|
||||
else:
|
||||
outputs.append(tensor)
|
||||
tensor = torch.nn.functional.pad(tensor,
|
||||
(0, pad_len - n),
|
||||
value=self.cfg.gpt.start_text_token
|
||||
)
|
||||
tensor = tensor[:,:max_len]
|
||||
outputs.append(tensor)
|
||||
tokens = torch.cat(outputs, dim=0)
|
||||
return tokens
|
||||
|
||||
@ -270,83 +288,98 @@ class IndexTTS:
|
||||
gpt_forward_time = 0
|
||||
bigvgan_time = 0
|
||||
|
||||
|
||||
# text processing
|
||||
all_text_tokens = []
|
||||
self._set_gr_progress(0.1, "text processing...")
|
||||
for sent in sentences:
|
||||
# sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper()
|
||||
cleand_text = tokenize_by_CJK_char(sent)
|
||||
# cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ."
|
||||
if verbose:
|
||||
print("cleand_text:", cleand_text)
|
||||
|
||||
text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0)
|
||||
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||
# text_tokens = F.pad(text_tokens, (1, 0), value=0)
|
||||
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
|
||||
if verbose:
|
||||
print(text_tokens)
|
||||
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
||||
# debug tokenizer
|
||||
text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist())
|
||||
print(text_token_syms)
|
||||
|
||||
all_text_tokens.append(text_tokens)
|
||||
bucket_enable = True # 预分桶开关,优先保证质量=True。优先保证速度=False。
|
||||
all_sentences = self.bucket_sentences(sentences, enable=bucket_enable)
|
||||
for sentences in all_sentences:
|
||||
temp_tokens = []
|
||||
all_text_tokens.append(temp_tokens)
|
||||
for item in sentences:
|
||||
sent = item["sent"]
|
||||
# sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper()
|
||||
cleand_text = tokenize_by_CJK_char(sent)
|
||||
# cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ."
|
||||
if verbose:
|
||||
print("cleand_text:", cleand_text)
|
||||
|
||||
text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0)
|
||||
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||
# text_tokens = F.pad(text_tokens, (1, 0), value=0)
|
||||
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
|
||||
if verbose:
|
||||
print(text_tokens)
|
||||
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
||||
# debug tokenizer
|
||||
text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist())
|
||||
print(text_token_syms)
|
||||
|
||||
temp_tokens.append(text_tokens)
|
||||
|
||||
batch_num = len(all_text_tokens)
|
||||
batch_text_tokens = self.pad_tokens_cat(all_text_tokens)
|
||||
batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0)
|
||||
batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0)
|
||||
|
||||
# gpt speech
|
||||
self._set_gr_progress(0.2, "gpt inference speech...")
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
batch_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens,
|
||||
cond_mel_lengths=batch_cond_mel_lengths,
|
||||
# text_lengths=text_len,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
gpt_gen_time += time.perf_counter() - m_start_time
|
||||
|
||||
# clear cache
|
||||
batch_auto_conditioning = None
|
||||
batch_cond_mel_lengths = None
|
||||
batch_text_tokens = None
|
||||
self.torch_empty_cache()
|
||||
|
||||
# gpt latent
|
||||
self._set_gr_progress(0.5, "gpt inference latents...")
|
||||
all_latents = []
|
||||
for i in range(batch_codes.shape[0]):
|
||||
codes = batch_codes[i] # [x]
|
||||
codes = codes[codes != self.cfg.gpt.stop_mel_token]
|
||||
codes, _ = torch.unique_consecutive(codes, return_inverse=True)
|
||||
codes = codes.unsqueeze(0) # [x] -> [1, x]
|
||||
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
|
||||
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
|
||||
text_tokens = all_text_tokens[i]
|
||||
|
||||
# Sequential processing of bucketing data
|
||||
all_batch_num = 0
|
||||
all_batch_codes = []
|
||||
for item_tokens in all_text_tokens:
|
||||
batch_num = len(item_tokens)
|
||||
batch_text_tokens = self.pad_tokens_cat(item_tokens)
|
||||
batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0)
|
||||
batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0)
|
||||
all_batch_num += batch_num
|
||||
|
||||
# gpt speech
|
||||
self._set_gr_progress(0.2, "gpt inference speech...")
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
latent = \
|
||||
self.gpt(auto_conditioning, text_tokens,
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||
code_lens*self.gpt.mel_length_compression,
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
gpt_forward_time += time.perf_counter() - m_start_time
|
||||
all_latents.append(latent)
|
||||
temp_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens,
|
||||
cond_mel_lengths=batch_cond_mel_lengths,
|
||||
# text_lengths=text_len,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
all_batch_codes.append(temp_codes)
|
||||
gpt_gen_time += time.perf_counter() - m_start_time
|
||||
|
||||
|
||||
# gpt latent
|
||||
self._set_gr_progress(0.5, "gpt inference latents...")
|
||||
all_idxs = []
|
||||
all_latents = []
|
||||
for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences):
|
||||
for i in range(batch_codes.shape[0]):
|
||||
codes = batch_codes[i] # [x]
|
||||
codes = codes[codes != self.cfg.gpt.stop_mel_token]
|
||||
codes, _ = torch.unique_consecutive(codes, return_inverse=True)
|
||||
codes = codes.unsqueeze(0) # [x] -> [1, x]
|
||||
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
|
||||
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
|
||||
text_tokens = batch_tokens[i]
|
||||
all_idxs.append(batch_sentences[i]["idx"])
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
latent = \
|
||||
self.gpt(auto_conditioning, text_tokens,
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||
code_lens*self.gpt.mel_length_compression,
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
gpt_forward_time += time.perf_counter() - m_start_time
|
||||
all_latents.append(latent)
|
||||
|
||||
|
||||
# bigvgan chunk
|
||||
chunk_size = 2
|
||||
all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
|
||||
chunk_latents = [all_latents[i:i + chunk_size] for i in range(0, len(all_latents), chunk_size)]
|
||||
chunk_length = len(chunk_latents)
|
||||
latent_length = len(all_latents)
|
||||
@ -385,7 +418,7 @@ class IndexTTS:
|
||||
print(f">> Total fast inference time: {end_time - start_time:.2f} seconds")
|
||||
print(f">> Generated audio length: {wav_length:.2f} seconds")
|
||||
print(f">> [fast] bigvgan chunk_length: {chunk_length}")
|
||||
print(f">> [fast] batch_num: {batch_num}")
|
||||
print(f">> [fast] batch_num: {all_batch_num} bucket_enable: {bucket_enable}")
|
||||
print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}")
|
||||
|
||||
# save audio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user