diff --git a/indextts/infer.py b/indextts/infer.py index 84bd56b..0485061 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -185,6 +185,20 @@ class IndexTTS: sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in {"'", ".", ","} ] + def bucket_sentences(self, sentences, enable): + """ + Sentence data bucketing + """ + max_len = max(len(s) for s in sentences) + half = max_len // 2 + outputs = [[],[]] + for idx, sent in enumerate(sentences): + if enable == False or len(sent) <= half: + outputs[0].append({"idx":idx,"sent":sent}) + else: + outputs[1].append({"idx":idx,"sent":sent}) + return [item for item in outputs if item] + def pad_tokens_cat(self, tokens): if len(tokens) <= 1:return tokens[-1] max_len = max(t.size(1) for t in tokens) @@ -192,13 +206,17 @@ class IndexTTS: for tensor in tokens: pad_len = max_len - tensor.size(1) if pad_len > 0: - padded = torch.nn.functional.pad(tensor, - (0, pad_len), + n = min(8, pad_len) + tensor = torch.nn.functional.pad(tensor, + (0, n), value=self.cfg.gpt.stop_text_token ) - outputs.append(padded) - else: - outputs.append(tensor) + tensor = torch.nn.functional.pad(tensor, + (0, pad_len - n), + value=self.cfg.gpt.start_text_token + ) + tensor = tensor[:,:max_len] + outputs.append(tensor) tokens = torch.cat(outputs, dim=0) return tokens @@ -270,83 +288,98 @@ class IndexTTS: gpt_forward_time = 0 bigvgan_time = 0 + + # text processing all_text_tokens = [] self._set_gr_progress(0.1, "text processing...") - for sent in sentences: - # sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper() - cleand_text = tokenize_by_CJK_char(sent) - # cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ." - if verbose: - print("cleand_text:", cleand_text) - - text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0) - # text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. - # text_tokens = F.pad(text_tokens, (1, 0), value=0) - # text_tokens = F.pad(text_tokens, (0, 1), value=1) - if verbose: - print(text_tokens) - print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") - # debug tokenizer - text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist()) - print(text_token_syms) - - all_text_tokens.append(text_tokens) + bucket_enable = True # 预分桶开关,优先保证质量=True。优先保证速度=False。 + all_sentences = self.bucket_sentences(sentences, enable=bucket_enable) + for sentences in all_sentences: + temp_tokens = [] + all_text_tokens.append(temp_tokens) + for item in sentences: + sent = item["sent"] + # sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper() + cleand_text = tokenize_by_CJK_char(sent) + # cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ." + if verbose: + print("cleand_text:", cleand_text) + + text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0) + # text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + # text_tokens = F.pad(text_tokens, (1, 0), value=0) + # text_tokens = F.pad(text_tokens, (0, 1), value=1) + if verbose: + print(text_tokens) + print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") + # debug tokenizer + text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist()) + print(text_token_syms) + + temp_tokens.append(text_tokens) - batch_num = len(all_text_tokens) - batch_text_tokens = self.pad_tokens_cat(all_text_tokens) - batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0) - batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0) - - # gpt speech - self._set_gr_progress(0.2, "gpt inference speech...") - m_start_time = time.perf_counter() - with torch.no_grad(): - with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): - batch_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens, - cond_mel_lengths=batch_cond_mel_lengths, - # text_lengths=text_len, - do_sample=True, - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_return_sequences=autoregressive_batch_size, - length_penalty=length_penalty, - num_beams=num_beams, - repetition_penalty=repetition_penalty, - max_generate_length=max_mel_tokens) - gpt_gen_time += time.perf_counter() - m_start_time - - # clear cache - batch_auto_conditioning = None - batch_cond_mel_lengths = None - batch_text_tokens = None - self.torch_empty_cache() - - # gpt latent - self._set_gr_progress(0.5, "gpt inference latents...") - all_latents = [] - for i in range(batch_codes.shape[0]): - codes = batch_codes[i] # [x] - codes = codes[codes != self.cfg.gpt.stop_mel_token] - codes, _ = torch.unique_consecutive(codes, return_inverse=True) - codes = codes.unsqueeze(0) # [x] -> [1, x] - code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype) - codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30) - text_tokens = all_text_tokens[i] + + # Sequential processing of bucketing data + all_batch_num = 0 + all_batch_codes = [] + for item_tokens in all_text_tokens: + batch_num = len(item_tokens) + batch_text_tokens = self.pad_tokens_cat(item_tokens) + batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0) + batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0) + all_batch_num += batch_num + + # gpt speech + self._set_gr_progress(0.2, "gpt inference speech...") m_start_time = time.perf_counter() with torch.no_grad(): with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): - latent = \ - self.gpt(auto_conditioning, text_tokens, - torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, - code_lens*self.gpt.mel_length_compression, - cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), - return_latent=True, clip_inputs=False) - gpt_forward_time += time.perf_counter() - m_start_time - all_latents.append(latent) + temp_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens, + cond_mel_lengths=batch_cond_mel_lengths, + # text_lengths=text_len, + do_sample=True, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=autoregressive_batch_size, + length_penalty=length_penalty, + num_beams=num_beams, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens) + all_batch_codes.append(temp_codes) + gpt_gen_time += time.perf_counter() - m_start_time + + + # gpt latent + self._set_gr_progress(0.5, "gpt inference latents...") + all_idxs = [] + all_latents = [] + for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences): + for i in range(batch_codes.shape[0]): + codes = batch_codes[i] # [x] + codes = codes[codes != self.cfg.gpt.stop_mel_token] + codes, _ = torch.unique_consecutive(codes, return_inverse=True) + codes = codes.unsqueeze(0) # [x] -> [1, x] + code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype) + codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30) + text_tokens = batch_tokens[i] + all_idxs.append(batch_sentences[i]["idx"]) + m_start_time = time.perf_counter() + with torch.no_grad(): + with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): + latent = \ + self.gpt(auto_conditioning, text_tokens, + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, + code_lens*self.gpt.mel_length_compression, + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), + return_latent=True, clip_inputs=False) + gpt_forward_time += time.perf_counter() - m_start_time + all_latents.append(latent) + # bigvgan chunk chunk_size = 2 + all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))] chunk_latents = [all_latents[i:i + chunk_size] for i in range(0, len(all_latents), chunk_size)] chunk_length = len(chunk_latents) latent_length = len(all_latents) @@ -385,7 +418,7 @@ class IndexTTS: print(f">> Total fast inference time: {end_time - start_time:.2f} seconds") print(f">> Generated audio length: {wav_length:.2f} seconds") print(f">> [fast] bigvgan chunk_length: {chunk_length}") - print(f">> [fast] batch_num: {batch_num}") + print(f">> [fast] batch_num: {all_batch_num} bucket_enable: {bucket_enable}") print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}") # save audio