Feature/kemurin (#99)

* deepspeed无法使用时回退到通常路径

* ninja支持中文路径编译补丁:BigVGAN fused cuda kernel

* 缓存参考音频的Mel

* ninja支持中文路径编译方案2:BigVGAN fused cuda kernel

* 增加批次推理:长句实现至少 2~10 倍以上的速度提升~

* fix上层目录为空时报错

---------

Co-authored-by: kemuriririn <10inspiral@gmail.com>
Co-authored-by: sunnyboxs <sjt2000@qq.com>
This commit is contained in:
kemuriririn 2025-04-17 15:12:45 +08:00 committed by GitHub
parent 91b7fa6148
commit 6783f22fe4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 271 additions and 19 deletions

View File

@ -3,11 +3,13 @@ import re
import time
from subprocess import CalledProcessError
import numpy as np
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from omegaconf import OmegaConf
from tqdm import tqdm
from indextts.BigVGAN.models import BigVGAN as Generator
from indextts.gpt.model import UnifiedVoice
@ -113,6 +115,8 @@ class IndexTTS:
# 缓存参考音频mel
self.cache_audio_prompt = None
self.cache_cond_mel = None
# 进度引用显示(可选)
self.gr_progress = None
def preprocess_text(self, text):
# chinese_punctuation = ",。!?;:“”‘’()【】《》"
@ -180,9 +184,230 @@ class IndexTTS:
return [
sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in {"'", ".", ","}
]
def pad_tokens_cat(self, tokens):
if len(tokens) <= 1:return tokens[-1]
max_len = max(t.size(1) for t in tokens)
outputs = []
for tensor in tokens:
pad_len = max_len - tensor.size(1)
if pad_len > 0:
padded = torch.nn.functional.pad(tensor,
(0, pad_len),
value=self.cfg.gpt.stop_text_token
)
outputs.append(padded)
else:
outputs.append(tensor)
tokens = torch.cat(outputs, dim=0)
return tokens
def torch_empty_cache(self):
try:
if "cuda" in str(self.device):
torch.cuda.empty_cache()
elif "mps" in str(self.device):
torch.mps.empty_cache()
except Exception as e:
pass
def _set_gr_progress(self, value, desc):
if self.gr_progress is not None:self.gr_progress(value, desc=desc)
# 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ First modified by sunnyboxs 2025-04-16
def infer_fast(self, audio_prompt, text, output_path, verbose=False):
print(">> start fast inference...")
self._set_gr_progress(0, "start fast inference...")
if verbose:
print(f"origin text:{text}")
start_time = time.perf_counter()
normalized_text = self.preprocess_text(text)
print(f"normalized text:{normalized_text}")
# 如果参考音频改变了,才需要重新生成 cond_mel, 提升速度
if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt:
audio, sr = torchaudio.load(audio_prompt)
audio = torch.mean(audio, dim=0, keepdim=True)
if audio.shape[0] > 1:
audio = audio[0].unsqueeze(0)
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
cond_mel_frame = cond_mel.shape[-1]
if verbose:
print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
self.cache_audio_prompt = audio_prompt
self.cache_cond_mel = cond_mel
else:
cond_mel = self.cache_cond_mel
cond_mel_frame = cond_mel.shape[-1]
pass
auto_conditioning = cond_mel
cond_mel_lengths = torch.tensor([cond_mel_frame],device=self.device)
# text_tokens
sentences = self.split_sentences(normalized_text)
if verbose:
print("sentences:", sentences)
top_p = .8
top_k = 30
temperature = 1.0
autoregressive_batch_size = 1
length_penalty = 0.0
num_beams = 3
repetition_penalty = 10.0
max_mel_tokens = 600
sampling_rate = 24000
# lang = "EN"
# lang = "ZH"
wavs = []
gpt_gen_time = 0
gpt_forward_time = 0
bigvgan_time = 0
all_text_tokens = []
self._set_gr_progress(0.1, "text processing...")
for sent in sentences:
# sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper()
cleand_text = tokenize_by_CJK_char(sent)
# cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ."
if verbose:
print("cleand_text:", cleand_text)
text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0)
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
# text_tokens = F.pad(text_tokens, (1, 0), value=0)
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
if verbose:
print(text_tokens)
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
# debug tokenizer
text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist())
print(text_token_syms)
all_text_tokens.append(text_tokens)
batch_num = len(all_text_tokens)
batch_text_tokens = self.pad_tokens_cat(all_text_tokens)
batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0)
batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0)
# gpt speech
self._set_gr_progress(0.2, "gpt inference speech...")
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
batch_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens,
cond_mel_lengths=batch_cond_mel_lengths,
# text_lengths=text_len,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=autoregressive_batch_size,
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
gpt_gen_time += time.perf_counter() - m_start_time
# clear cache
batch_auto_conditioning = None
batch_cond_mel_lengths = None
batch_text_tokens = None
self.torch_empty_cache()
# gpt latent
self._set_gr_progress(0.5, "gpt inference latents...")
all_latents = []
for i in range(batch_codes.shape[0]):
codes = batch_codes[i] # [x]
codes = codes[codes != self.cfg.gpt.stop_mel_token]
codes, _ = torch.unique_consecutive(codes, return_inverse=True)
codes = codes.unsqueeze(0) # [x] -> [1, x]
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
text_tokens = all_text_tokens[i]
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
latent = \
self.gpt(auto_conditioning, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
code_lens*self.gpt.mel_length_compression,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
gpt_forward_time += time.perf_counter() - m_start_time
all_latents.append(latent)
# bigvgan chunk
chunk_size = 2
chunk_latents = [all_latents[i:i + chunk_size] for i in range(0, len(all_latents), chunk_size)]
chunk_length = len(chunk_latents)
latent_length = len(all_latents)
all_latents = None
# bigvgan chunk decode
self._set_gr_progress(0.7, "bigvgan decode...")
tqdm_progress = tqdm(total=latent_length, desc="bigvgan")
for items in chunk_latents:
tqdm_progress.update(len(items))
latent = torch.cat(items, dim=1)
with torch.no_grad():
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
m_start_time = time.perf_counter()
wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
bigvgan_time += time.perf_counter() - m_start_time
wav = wav.squeeze(1)
pass
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
wavs.append(wav)
# clear cache
tqdm_progress.close() # 确保进度条被关闭
chunk_latents.clear()
end_time = time.perf_counter()
self.torch_empty_cache()
# wav audio output
self._set_gr_progress(0.9, "save audio...")
wav = torch.cat(wavs, dim=1)
wav_length = wav.shape[-1] / sampling_rate
print(f">> Reference audio length: {cond_mel_frame*256 / sampling_rate:.2f} seconds")
print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds")
print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
print(f">> Total fast inference time: {end_time - start_time:.2f} seconds")
print(f">> Generated audio length: {wav_length:.2f} seconds")
print(f">> [fast] bigvgan chunk_length: {chunk_length}")
print(f">> [fast] batch_num: {batch_num}")
print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}")
# save audio
wav = wav.cpu() # to cpu
if output_path:
# 直接保存音频到指定路径中
os.makedirs(os.path.dirname(output_path),exist_ok=True)
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
print(">> wav file saved to:", output_path)
return output_path
else:
# 返回以符合Gradio的格式要求
wav_data = wav.type(torch.int16)
wav_data = wav_data.numpy().T
return (sampling_rate, wav_data)
# 原始推理模式
def infer(self, audio_prompt, text, output_path, verbose=False):
print(">> start inference...")
self._set_gr_progress(0, "start inference...")
if verbose:
print(f"origin text:{text}")
start_time = time.perf_counter()
@ -317,16 +542,33 @@ class IndexTTS:
print(f">> Generated audio length: {wav_length:.2f} seconds")
print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
torchaudio.save(output_path, wav.cpu().type(torch.int16), sampling_rate)
print(">> wav file saved to:", output_path)
# torchaudio.save(output_path, wav.cpu().type(torch.int16), sampling_rate)
# print(">> wav file saved to:", output_path)
# save audio
wav = wav.cpu() # to cpu
if output_path:
# 直接保存音频到指定路径中
if os.path.isfile(output_path):
os.remove(output_path)
print(">> remove old wav file:", output_path)
if os.path.dirname(output_path) != "":
os.makedirs(os.path.dirname(output_path),exist_ok=True)
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
print(">> wav file saved to:", output_path)
return output_path
else:
# 返回以符合Gradio的格式要求
wav_data = wav.type(torch.int16)
wav_data = wav_data.numpy().T
return (sampling_rate, wav_data)
if __name__ == "__main__":
prompt_wav = "testwav/input.wav"
prompt_wav = "testwav/spk_1744181067_1.wav"
text="晕 XUAN4 是 一 种 GAN3 觉"
text = "There is a vehicle arriving in dock number 7?"
text='大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了'
prompt_wav="test_data/input.wav"
#text="晕 XUAN4 是 一 种 GAN3 觉"
#text='大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了'
text="There is a vehicle arriving in dock number 7?"
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, use_cuda_kernel=False)
tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)

View File

@ -4,12 +4,16 @@ import sys
import threading
import time
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, "indextts"))
import gradio as gr
from utils.webui_utils import next_page, prev_page
from indextts.utils.webui_utils import next_page, prev_page
from indextts.infer import IndexTTS
from tools.i18n.i18n import I18nAuto
@ -21,15 +25,18 @@ tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml")
os.makedirs("outputs/tasks",exist_ok=True)
os.makedirs("prompts",exist_ok=True)
def infer(voice, text,output_path=None):
def gen_single(prompt, text, infer_mode, progress=gr.Progress()):
output_path = None
if not output_path:
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
tts.infer(voice, text, output_path)
return output_path
def gen_single(prompt, text):
output_path = infer(prompt, text)
return gr.update(value=output_path,visible=True)
# set gradio progress
tts.gr_progress = progress
if infer_mode == "普通推理":
output = tts.infer(prompt, text, output_path) # 普通推理
else:
output = tts.infer_fast(prompt, text, output_path) # 批次推理
return gr.update(value=output,visible=True)
def update_prompt_audio():
update_button = gr.update(interactive=True)
@ -40,6 +47,7 @@ with gr.Blocks() as demo:
mutex = threading.Lock()
gr.HTML('''
<h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
<h2><center>(一款工业级可控且高效的零样本文本转语音系统)</h2>
<p align="center">
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
@ -53,16 +61,18 @@ with gr.Blocks() as demo:
default = ''
if prompt_list:
default = prompt_list[0]
input_text_single = gr.Textbox(label="请输入目标文本",key="input_text_single")
gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
output_audio = gr.Audio(label="生成结果", visible=False,key="output_audio")
with gr.Column():
input_text_single = gr.TextArea(label="请输入目标文本",key="input_text_single")
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="选择推理模式(批次推理:更适合长句,性能翻倍)",value="普通推理")
gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
output_audio = gr.Audio(label="生成结果", visible=True,key="output_audio")
prompt_audio.upload(update_prompt_audio,
inputs=[],
outputs=[gen_button])
gen_button.click(gen_single,
inputs=[prompt_audio, input_text_single],
inputs=[prompt_audio, input_text_single, infer_mode],
outputs=[output_audio])