diff --git a/indextts/infer.py b/indextts/infer.py index d6a6cd2..84bd56b 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -3,11 +3,13 @@ import re import time from subprocess import CalledProcessError +import numpy as np import sentencepiece as spm import torch import torchaudio from torch.nn.utils.rnn import pad_sequence from omegaconf import OmegaConf +from tqdm import tqdm from indextts.BigVGAN.models import BigVGAN as Generator from indextts.gpt.model import UnifiedVoice @@ -113,6 +115,8 @@ class IndexTTS: # 缓存参考音频mel: self.cache_audio_prompt = None self.cache_cond_mel = None + # 进度引用显示(可选) + self.gr_progress = None def preprocess_text(self, text): # chinese_punctuation = ",。!?;:“”‘’()【】《》" @@ -180,9 +184,230 @@ class IndexTTS: return [ sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in {"'", ".", ","} ] + + def pad_tokens_cat(self, tokens): + if len(tokens) <= 1:return tokens[-1] + max_len = max(t.size(1) for t in tokens) + outputs = [] + for tensor in tokens: + pad_len = max_len - tensor.size(1) + if pad_len > 0: + padded = torch.nn.functional.pad(tensor, + (0, pad_len), + value=self.cfg.gpt.stop_text_token + ) + outputs.append(padded) + else: + outputs.append(tensor) + tokens = torch.cat(outputs, dim=0) + return tokens + + def torch_empty_cache(self): + try: + if "cuda" in str(self.device): + torch.cuda.empty_cache() + elif "mps" in str(self.device): + torch.mps.empty_cache() + except Exception as e: + pass + + def _set_gr_progress(self, value, desc): + if self.gr_progress is not None:self.gr_progress(value, desc=desc) + + + + # 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16) + def infer_fast(self, audio_prompt, text, output_path, verbose=False): + print(">> start fast inference...") + self._set_gr_progress(0, "start fast inference...") + if verbose: + print(f"origin text:{text}") + start_time = time.perf_counter() + normalized_text = self.preprocess_text(text) + print(f"normalized text:{normalized_text}") + + # 如果参考音频改变了,才需要重新生成 cond_mel, 提升速度 + if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt: + audio, sr = torchaudio.load(audio_prompt) + audio = torch.mean(audio, dim=0, keepdim=True) + if audio.shape[0] > 1: + audio = audio[0].unsqueeze(0) + audio = torchaudio.transforms.Resample(sr, 24000)(audio) + cond_mel = MelSpectrogramFeatures()(audio).to(self.device) + cond_mel_frame = cond_mel.shape[-1] + if verbose: + print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype) + + self.cache_audio_prompt = audio_prompt + self.cache_cond_mel = cond_mel + else: + cond_mel = self.cache_cond_mel + cond_mel_frame = cond_mel.shape[-1] + pass + + auto_conditioning = cond_mel + cond_mel_lengths = torch.tensor([cond_mel_frame],device=self.device) + + # text_tokens + sentences = self.split_sentences(normalized_text) + if verbose: + print("sentences:", sentences) + + top_p = .8 + top_k = 30 + temperature = 1.0 + autoregressive_batch_size = 1 + length_penalty = 0.0 + num_beams = 3 + repetition_penalty = 10.0 + max_mel_tokens = 600 + sampling_rate = 24000 + # lang = "EN" + # lang = "ZH" + wavs = [] + gpt_gen_time = 0 + gpt_forward_time = 0 + bigvgan_time = 0 + + all_text_tokens = [] + self._set_gr_progress(0.1, "text processing...") + for sent in sentences: + # sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper() + cleand_text = tokenize_by_CJK_char(sent) + # cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ." + if verbose: + print("cleand_text:", cleand_text) + + text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0) + # text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + # text_tokens = F.pad(text_tokens, (1, 0), value=0) + # text_tokens = F.pad(text_tokens, (0, 1), value=1) + if verbose: + print(text_tokens) + print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") + # debug tokenizer + text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist()) + print(text_token_syms) + + all_text_tokens.append(text_tokens) + + batch_num = len(all_text_tokens) + batch_text_tokens = self.pad_tokens_cat(all_text_tokens) + batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0) + batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0) + + # gpt speech + self._set_gr_progress(0.2, "gpt inference speech...") + m_start_time = time.perf_counter() + with torch.no_grad(): + with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): + batch_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens, + cond_mel_lengths=batch_cond_mel_lengths, + # text_lengths=text_len, + do_sample=True, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=autoregressive_batch_size, + length_penalty=length_penalty, + num_beams=num_beams, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens) + gpt_gen_time += time.perf_counter() - m_start_time + + # clear cache + batch_auto_conditioning = None + batch_cond_mel_lengths = None + batch_text_tokens = None + self.torch_empty_cache() + + # gpt latent + self._set_gr_progress(0.5, "gpt inference latents...") + all_latents = [] + for i in range(batch_codes.shape[0]): + codes = batch_codes[i] # [x] + codes = codes[codes != self.cfg.gpt.stop_mel_token] + codes, _ = torch.unique_consecutive(codes, return_inverse=True) + codes = codes.unsqueeze(0) # [x] -> [1, x] + code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype) + codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30) + text_tokens = all_text_tokens[i] + m_start_time = time.perf_counter() + with torch.no_grad(): + with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): + latent = \ + self.gpt(auto_conditioning, text_tokens, + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, + code_lens*self.gpt.mel_length_compression, + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), + return_latent=True, clip_inputs=False) + gpt_forward_time += time.perf_counter() - m_start_time + all_latents.append(latent) + + # bigvgan chunk + chunk_size = 2 + chunk_latents = [all_latents[i:i + chunk_size] for i in range(0, len(all_latents), chunk_size)] + chunk_length = len(chunk_latents) + latent_length = len(all_latents) + all_latents = None + + # bigvgan chunk decode + self._set_gr_progress(0.7, "bigvgan decode...") + tqdm_progress = tqdm(total=latent_length, desc="bigvgan") + for items in chunk_latents: + tqdm_progress.update(len(items)) + latent = torch.cat(items, dim=1) + with torch.no_grad(): + with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): + m_start_time = time.perf_counter() + wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2)) + bigvgan_time += time.perf_counter() - m_start_time + wav = wav.squeeze(1) + pass + wav = torch.clamp(32767 * wav, -32767.0, 32767.0) + wavs.append(wav) + + # clear cache + tqdm_progress.close() # 确保进度条被关闭 + chunk_latents.clear() + end_time = time.perf_counter() + self.torch_empty_cache() + + # wav audio output + self._set_gr_progress(0.9, "save audio...") + wav = torch.cat(wavs, dim=1) + wav_length = wav.shape[-1] / sampling_rate + print(f">> Reference audio length: {cond_mel_frame*256 / sampling_rate:.2f} seconds") + print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds") + print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds") + print(f">> bigvgan_time: {bigvgan_time:.2f} seconds") + print(f">> Total fast inference time: {end_time - start_time:.2f} seconds") + print(f">> Generated audio length: {wav_length:.2f} seconds") + print(f">> [fast] bigvgan chunk_length: {chunk_length}") + print(f">> [fast] batch_num: {batch_num}") + print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}") + + # save audio + wav = wav.cpu() # to cpu + if output_path: + # 直接保存音频到指定路径中 + os.makedirs(os.path.dirname(output_path),exist_ok=True) + torchaudio.save(output_path, wav.type(torch.int16), sampling_rate) + print(">> wav file saved to:", output_path) + return output_path + else: + # 返回以符合Gradio的格式要求 + wav_data = wav.type(torch.int16) + wav_data = wav_data.numpy().T + return (sampling_rate, wav_data) + + + + # 原始推理模式 def infer(self, audio_prompt, text, output_path, verbose=False): print(">> start inference...") + self._set_gr_progress(0, "start inference...") if verbose: print(f"origin text:{text}") start_time = time.perf_counter() @@ -317,16 +542,33 @@ class IndexTTS: print(f">> Generated audio length: {wav_length:.2f} seconds") print(f">> RTF: {(end_time - start_time) / wav_length:.4f}") - torchaudio.save(output_path, wav.cpu().type(torch.int16), sampling_rate) - print(">> wav file saved to:", output_path) + # torchaudio.save(output_path, wav.cpu().type(torch.int16), sampling_rate) + # print(">> wav file saved to:", output_path) + + # save audio + wav = wav.cpu() # to cpu + if output_path: + # 直接保存音频到指定路径中 + if os.path.isfile(output_path): + os.remove(output_path) + print(">> remove old wav file:", output_path) + if os.path.dirname(output_path) != "": + os.makedirs(os.path.dirname(output_path),exist_ok=True) + torchaudio.save(output_path, wav.type(torch.int16), sampling_rate) + print(">> wav file saved to:", output_path) + return output_path + else: + # 返回以符合Gradio的格式要求 + wav_data = wav.type(torch.int16) + wav_data = wav_data.numpy().T + return (sampling_rate, wav_data) if __name__ == "__main__": - prompt_wav = "testwav/input.wav" - prompt_wav = "testwav/spk_1744181067_1.wav" - text="晕 XUAN4 是 一 种 GAN3 觉" - text = "There is a vehicle arriving in dock number 7?" - text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!' + prompt_wav="test_data/input.wav" + #text="晕 XUAN4 是 一 种 GAN3 觉" + #text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!' + text="There is a vehicle arriving in dock number 7?" tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, use_cuda_kernel=False) tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True) diff --git a/webui.py b/webui.py index 5a4527f..29d44e8 100644 --- a/webui.py +++ b/webui.py @@ -4,12 +4,16 @@ import sys import threading import time +import warnings +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_dir) sys.path.append(os.path.join(current_dir, "indextts")) import gradio as gr -from utils.webui_utils import next_page, prev_page +from indextts.utils.webui_utils import next_page, prev_page from indextts.infer import IndexTTS from tools.i18n.i18n import I18nAuto @@ -21,15 +25,18 @@ tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml") os.makedirs("outputs/tasks",exist_ok=True) os.makedirs("prompts",exist_ok=True) -def infer(voice, text,output_path=None): + +def gen_single(prompt, text, infer_mode, progress=gr.Progress()): + output_path = None if not output_path: output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav") - tts.infer(voice, text, output_path) - return output_path - -def gen_single(prompt, text): - output_path = infer(prompt, text) - return gr.update(value=output_path,visible=True) + # set gradio progress + tts.gr_progress = progress + if infer_mode == "普通推理": + output = tts.infer(prompt, text, output_path) # 普通推理 + else: + output = tts.infer_fast(prompt, text, output_path) # 批次推理 + return gr.update(value=output,visible=True) def update_prompt_audio(): update_button = gr.update(interactive=True) @@ -40,6 +47,7 @@ with gr.Blocks() as demo: mutex = threading.Lock() gr.HTML('''
@@ -53,16 +61,18 @@ with gr.Blocks() as demo:
default = ''
if prompt_list:
default = prompt_list[0]
- input_text_single = gr.Textbox(label="请输入目标文本",key="input_text_single")
- gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
- output_audio = gr.Audio(label="生成结果", visible=False,key="output_audio")
+ with gr.Column():
+ input_text_single = gr.TextArea(label="请输入目标文本",key="input_text_single")
+ infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="选择推理模式(批次推理:更适合长句,性能翻倍)",value="普通推理")
+ gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
+ output_audio = gr.Audio(label="生成结果", visible=True,key="output_audio")
prompt_audio.upload(update_prompt_audio,
inputs=[],
outputs=[gen_button])
gen_button.click(gen_single,
- inputs=[prompt_audio, input_text_single],
+ inputs=[prompt_audio, input_text_single, infer_mode],
outputs=[output_audio])