Feature/kemurin (#99)
* deepspeed无法使用时回退到通常路径 * ninja支持中文路径编译补丁:BigVGAN fused cuda kernel * 缓存参考音频的Mel * ninja支持中文路径编译方案2:BigVGAN fused cuda kernel * 增加批次推理:长句实现至少 2~10 倍以上的速度提升~ * fix上层目录为空时报错 --------- Co-authored-by: kemuriririn <10inspiral@gmail.com> Co-authored-by: sunnyboxs <sjt2000@qq.com>
This commit is contained in:
parent
91b7fa6148
commit
6783f22fe4
@ -3,11 +3,13 @@ import re
|
||||
import time
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from indextts.BigVGAN.models import BigVGAN as Generator
|
||||
from indextts.gpt.model import UnifiedVoice
|
||||
@ -113,6 +115,8 @@ class IndexTTS:
|
||||
# 缓存参考音频mel:
|
||||
self.cache_audio_prompt = None
|
||||
self.cache_cond_mel = None
|
||||
# 进度引用显示(可选)
|
||||
self.gr_progress = None
|
||||
|
||||
def preprocess_text(self, text):
|
||||
# chinese_punctuation = ",。!?;:“”‘’()【】《》"
|
||||
@ -180,9 +184,230 @@ class IndexTTS:
|
||||
return [
|
||||
sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in {"'", ".", ","}
|
||||
]
|
||||
|
||||
def pad_tokens_cat(self, tokens):
|
||||
if len(tokens) <= 1:return tokens[-1]
|
||||
max_len = max(t.size(1) for t in tokens)
|
||||
outputs = []
|
||||
for tensor in tokens:
|
||||
pad_len = max_len - tensor.size(1)
|
||||
if pad_len > 0:
|
||||
padded = torch.nn.functional.pad(tensor,
|
||||
(0, pad_len),
|
||||
value=self.cfg.gpt.stop_text_token
|
||||
)
|
||||
outputs.append(padded)
|
||||
else:
|
||||
outputs.append(tensor)
|
||||
tokens = torch.cat(outputs, dim=0)
|
||||
return tokens
|
||||
|
||||
def torch_empty_cache(self):
|
||||
try:
|
||||
if "cuda" in str(self.device):
|
||||
torch.cuda.empty_cache()
|
||||
elif "mps" in str(self.device):
|
||||
torch.mps.empty_cache()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def _set_gr_progress(self, value, desc):
|
||||
if self.gr_progress is not None:self.gr_progress(value, desc=desc)
|
||||
|
||||
|
||||
|
||||
# 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16)
|
||||
def infer_fast(self, audio_prompt, text, output_path, verbose=False):
|
||||
print(">> start fast inference...")
|
||||
self._set_gr_progress(0, "start fast inference...")
|
||||
if verbose:
|
||||
print(f"origin text:{text}")
|
||||
start_time = time.perf_counter()
|
||||
normalized_text = self.preprocess_text(text)
|
||||
print(f"normalized text:{normalized_text}")
|
||||
|
||||
|
||||
# 如果参考音频改变了,才需要重新生成 cond_mel, 提升速度
|
||||
if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt:
|
||||
audio, sr = torchaudio.load(audio_prompt)
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio[0].unsqueeze(0)
|
||||
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
|
||||
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
|
||||
cond_mel_frame = cond_mel.shape[-1]
|
||||
if verbose:
|
||||
print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
|
||||
|
||||
self.cache_audio_prompt = audio_prompt
|
||||
self.cache_cond_mel = cond_mel
|
||||
else:
|
||||
cond_mel = self.cache_cond_mel
|
||||
cond_mel_frame = cond_mel.shape[-1]
|
||||
pass
|
||||
|
||||
auto_conditioning = cond_mel
|
||||
cond_mel_lengths = torch.tensor([cond_mel_frame],device=self.device)
|
||||
|
||||
# text_tokens
|
||||
sentences = self.split_sentences(normalized_text)
|
||||
if verbose:
|
||||
print("sentences:", sentences)
|
||||
|
||||
top_p = .8
|
||||
top_k = 30
|
||||
temperature = 1.0
|
||||
autoregressive_batch_size = 1
|
||||
length_penalty = 0.0
|
||||
num_beams = 3
|
||||
repetition_penalty = 10.0
|
||||
max_mel_tokens = 600
|
||||
sampling_rate = 24000
|
||||
# lang = "EN"
|
||||
# lang = "ZH"
|
||||
wavs = []
|
||||
gpt_gen_time = 0
|
||||
gpt_forward_time = 0
|
||||
bigvgan_time = 0
|
||||
|
||||
all_text_tokens = []
|
||||
self._set_gr_progress(0.1, "text processing...")
|
||||
for sent in sentences:
|
||||
# sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper()
|
||||
cleand_text = tokenize_by_CJK_char(sent)
|
||||
# cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ."
|
||||
if verbose:
|
||||
print("cleand_text:", cleand_text)
|
||||
|
||||
text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0)
|
||||
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||
# text_tokens = F.pad(text_tokens, (1, 0), value=0)
|
||||
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
|
||||
if verbose:
|
||||
print(text_tokens)
|
||||
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
||||
# debug tokenizer
|
||||
text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist())
|
||||
print(text_token_syms)
|
||||
|
||||
all_text_tokens.append(text_tokens)
|
||||
|
||||
batch_num = len(all_text_tokens)
|
||||
batch_text_tokens = self.pad_tokens_cat(all_text_tokens)
|
||||
batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0)
|
||||
batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0)
|
||||
|
||||
# gpt speech
|
||||
self._set_gr_progress(0.2, "gpt inference speech...")
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
batch_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens,
|
||||
cond_mel_lengths=batch_cond_mel_lengths,
|
||||
# text_lengths=text_len,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
gpt_gen_time += time.perf_counter() - m_start_time
|
||||
|
||||
# clear cache
|
||||
batch_auto_conditioning = None
|
||||
batch_cond_mel_lengths = None
|
||||
batch_text_tokens = None
|
||||
self.torch_empty_cache()
|
||||
|
||||
# gpt latent
|
||||
self._set_gr_progress(0.5, "gpt inference latents...")
|
||||
all_latents = []
|
||||
for i in range(batch_codes.shape[0]):
|
||||
codes = batch_codes[i] # [x]
|
||||
codes = codes[codes != self.cfg.gpt.stop_mel_token]
|
||||
codes, _ = torch.unique_consecutive(codes, return_inverse=True)
|
||||
codes = codes.unsqueeze(0) # [x] -> [1, x]
|
||||
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
|
||||
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
|
||||
text_tokens = all_text_tokens[i]
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
latent = \
|
||||
self.gpt(auto_conditioning, text_tokens,
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||
code_lens*self.gpt.mel_length_compression,
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
gpt_forward_time += time.perf_counter() - m_start_time
|
||||
all_latents.append(latent)
|
||||
|
||||
# bigvgan chunk
|
||||
chunk_size = 2
|
||||
chunk_latents = [all_latents[i:i + chunk_size] for i in range(0, len(all_latents), chunk_size)]
|
||||
chunk_length = len(chunk_latents)
|
||||
latent_length = len(all_latents)
|
||||
all_latents = None
|
||||
|
||||
# bigvgan chunk decode
|
||||
self._set_gr_progress(0.7, "bigvgan decode...")
|
||||
tqdm_progress = tqdm(total=latent_length, desc="bigvgan")
|
||||
for items in chunk_latents:
|
||||
tqdm_progress.update(len(items))
|
||||
latent = torch.cat(items, dim=1)
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
m_start_time = time.perf_counter()
|
||||
wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
|
||||
bigvgan_time += time.perf_counter() - m_start_time
|
||||
wav = wav.squeeze(1)
|
||||
pass
|
||||
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
|
||||
wavs.append(wav)
|
||||
|
||||
# clear cache
|
||||
tqdm_progress.close() # 确保进度条被关闭
|
||||
chunk_latents.clear()
|
||||
end_time = time.perf_counter()
|
||||
self.torch_empty_cache()
|
||||
|
||||
# wav audio output
|
||||
self._set_gr_progress(0.9, "save audio...")
|
||||
wav = torch.cat(wavs, dim=1)
|
||||
wav_length = wav.shape[-1] / sampling_rate
|
||||
print(f">> Reference audio length: {cond_mel_frame*256 / sampling_rate:.2f} seconds")
|
||||
print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
|
||||
print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds")
|
||||
print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
|
||||
print(f">> Total fast inference time: {end_time - start_time:.2f} seconds")
|
||||
print(f">> Generated audio length: {wav_length:.2f} seconds")
|
||||
print(f">> [fast] bigvgan chunk_length: {chunk_length}")
|
||||
print(f">> [fast] batch_num: {batch_num}")
|
||||
print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}")
|
||||
|
||||
# save audio
|
||||
wav = wav.cpu() # to cpu
|
||||
if output_path:
|
||||
# 直接保存音频到指定路径中
|
||||
os.makedirs(os.path.dirname(output_path),exist_ok=True)
|
||||
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
|
||||
print(">> wav file saved to:", output_path)
|
||||
return output_path
|
||||
else:
|
||||
# 返回以符合Gradio的格式要求
|
||||
wav_data = wav.type(torch.int16)
|
||||
wav_data = wav_data.numpy().T
|
||||
return (sampling_rate, wav_data)
|
||||
|
||||
|
||||
|
||||
# 原始推理模式
|
||||
def infer(self, audio_prompt, text, output_path, verbose=False):
|
||||
print(">> start inference...")
|
||||
self._set_gr_progress(0, "start inference...")
|
||||
if verbose:
|
||||
print(f"origin text:{text}")
|
||||
start_time = time.perf_counter()
|
||||
@ -317,16 +542,33 @@ class IndexTTS:
|
||||
print(f">> Generated audio length: {wav_length:.2f} seconds")
|
||||
print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
|
||||
|
||||
torchaudio.save(output_path, wav.cpu().type(torch.int16), sampling_rate)
|
||||
print(">> wav file saved to:", output_path)
|
||||
# torchaudio.save(output_path, wav.cpu().type(torch.int16), sampling_rate)
|
||||
# print(">> wav file saved to:", output_path)
|
||||
|
||||
# save audio
|
||||
wav = wav.cpu() # to cpu
|
||||
if output_path:
|
||||
# 直接保存音频到指定路径中
|
||||
if os.path.isfile(output_path):
|
||||
os.remove(output_path)
|
||||
print(">> remove old wav file:", output_path)
|
||||
if os.path.dirname(output_path) != "":
|
||||
os.makedirs(os.path.dirname(output_path),exist_ok=True)
|
||||
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
|
||||
print(">> wav file saved to:", output_path)
|
||||
return output_path
|
||||
else:
|
||||
# 返回以符合Gradio的格式要求
|
||||
wav_data = wav.type(torch.int16)
|
||||
wav_data = wav_data.numpy().T
|
||||
return (sampling_rate, wav_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_wav = "testwav/input.wav"
|
||||
prompt_wav = "testwav/spk_1744181067_1.wav"
|
||||
text="晕 XUAN4 是 一 种 GAN3 觉"
|
||||
text = "There is a vehicle arriving in dock number 7?"
|
||||
text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!'
|
||||
prompt_wav="test_data/input.wav"
|
||||
#text="晕 XUAN4 是 一 种 GAN3 觉"
|
||||
#text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!'
|
||||
text="There is a vehicle arriving in dock number 7?"
|
||||
|
||||
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, use_cuda_kernel=False)
|
||||
tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
|
||||
|
||||
34
webui.py
34
webui.py
@ -4,12 +4,16 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(current_dir)
|
||||
sys.path.append(os.path.join(current_dir, "indextts"))
|
||||
|
||||
import gradio as gr
|
||||
from utils.webui_utils import next_page, prev_page
|
||||
from indextts.utils.webui_utils import next_page, prev_page
|
||||
|
||||
from indextts.infer import IndexTTS
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
@ -21,15 +25,18 @@ tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml")
|
||||
os.makedirs("outputs/tasks",exist_ok=True)
|
||||
os.makedirs("prompts",exist_ok=True)
|
||||
|
||||
def infer(voice, text,output_path=None):
|
||||
|
||||
def gen_single(prompt, text, infer_mode, progress=gr.Progress()):
|
||||
output_path = None
|
||||
if not output_path:
|
||||
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
|
||||
tts.infer(voice, text, output_path)
|
||||
return output_path
|
||||
|
||||
def gen_single(prompt, text):
|
||||
output_path = infer(prompt, text)
|
||||
return gr.update(value=output_path,visible=True)
|
||||
# set gradio progress
|
||||
tts.gr_progress = progress
|
||||
if infer_mode == "普通推理":
|
||||
output = tts.infer(prompt, text, output_path) # 普通推理
|
||||
else:
|
||||
output = tts.infer_fast(prompt, text, output_path) # 批次推理
|
||||
return gr.update(value=output,visible=True)
|
||||
|
||||
def update_prompt_audio():
|
||||
update_button = gr.update(interactive=True)
|
||||
@ -40,6 +47,7 @@ with gr.Blocks() as demo:
|
||||
mutex = threading.Lock()
|
||||
gr.HTML('''
|
||||
<h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
|
||||
<h2><center>(一款工业级可控且高效的零样本文本转语音系统)</h2>
|
||||
|
||||
<p align="center">
|
||||
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
|
||||
@ -53,16 +61,18 @@ with gr.Blocks() as demo:
|
||||
default = ''
|
||||
if prompt_list:
|
||||
default = prompt_list[0]
|
||||
input_text_single = gr.Textbox(label="请输入目标文本",key="input_text_single")
|
||||
gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
|
||||
output_audio = gr.Audio(label="生成结果", visible=False,key="output_audio")
|
||||
with gr.Column():
|
||||
input_text_single = gr.TextArea(label="请输入目标文本",key="input_text_single")
|
||||
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="选择推理模式(批次推理:更适合长句,性能翻倍)",value="普通推理")
|
||||
gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
|
||||
output_audio = gr.Audio(label="生成结果", visible=True,key="output_audio")
|
||||
|
||||
prompt_audio.upload(update_prompt_audio,
|
||||
inputs=[],
|
||||
outputs=[gen_button])
|
||||
|
||||
gen_button.click(gen_single,
|
||||
inputs=[prompt_audio, input_text_single],
|
||||
inputs=[prompt_audio, input_text_single, infer_mode],
|
||||
outputs=[output_audio])
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user