DeepSpeed acceleration and FP16 inference support, but bigvgan disable

This commit is contained in:
shujingchen 2025-04-03 16:30:39 +08:00
parent 397fef2f14
commit e92bf90235

View File

@ -1,6 +1,7 @@
import os
import re
import sys
import time
import sentencepiece as spm
import torch
@ -16,24 +17,38 @@ from indextts.vqvae.xtts_dvae import DiscreteVAE
from indextts.utils.front import TextNormalizer
class IndexTTS:
def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints'):
def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints', is_fp16=True):
self.cfg = OmegaConf.load(cfg_path)
self.device = 'cuda:0'
self.model_dir = model_dir
self.is_fp16 = is_fp16
if self.is_fp16:
self.dtype = torch.float16
else:
self.dtype = None
self.dvae = DiscreteVAE(**self.cfg.vqvae)
self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
load_checkpoint(self.dvae, self.dvae_path)
self.dvae = self.dvae.to(self.device)
self.dvae.eval()
if self.is_fp16:
self.dvae.eval().half()
else:
self.dvae.eval()
print(">> vqvae weights restored from:", self.dvae_path)
self.gpt = UnifiedVoice(**self.cfg.gpt)
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
load_checkpoint(self.gpt, self.gpt_path)
self.gpt = self.gpt.to(self.device)
self.gpt.eval()
if self.is_fp16:
self.gpt.eval().half()
else:
self.gpt.eval()
print(">> GPT weights restored from:", self.gpt_path)
self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False)
if self.is_fp16:
self.gpt.post_init_gpt2_config(use_deepspeed=True, kv_cache=True, half=True)
else:
self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False)
self.bigvgan = Generator(self.cfg.bigvgan)
self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint)
@ -108,49 +123,83 @@ class IndexTTS:
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
text_tokens = text_tokens.to(self.device)
print(text_tokens)
print(f"text_tokens shape: {text_tokens.shape}")
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
text_token_syms = [tokenizer.IdToPiece(idx) for idx in text_tokens[0].tolist()]
print(text_token_syms)
text_len = [text_tokens.size(1)]
text_len = torch.IntTensor(text_len).to(self.device)
print(text_len)
with torch.no_grad():
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
device=text_tokens.device),
# text_lengths=text_len,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=autoregressive_batch_size,
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
print(codes)
print(f"codes shape: {codes.shape}")
if self.is_fp16:
with torch.cuda.amp.autocast(enabled=self.dtype is not None, dtype=self.dtype):
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
device=text_tokens.device),
# text_lengths=text_len,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=autoregressive_batch_size,
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
else:
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
device=text_tokens.device),
# text_lengths=text_len,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=autoregressive_batch_size,
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
print(codes, type(codes))
print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
codes = codes[:, :-2]
# latent, text_lens_out, code_lens_out = \
latent = \
self.gpt(auto_conditioning, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
latent = latent.transpose(1, 2)
'''
latent_list = []
for lat, t_len in zip(latent, text_lens_out):
lat = lat[:, t_len:]
latent_list.append(lat)
latent = torch.stack(latent_list)
print(f"latent shape: {latent.shape}")
'''
if self.is_fp16:
with torch.cuda.amp.autocast(enabled=self.dtype is not None, dtype=self.dtype):
latent = \
self.gpt(auto_conditioning, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
latent = latent.transpose(1, 2)
print(f'latent shape: {latent.shape}, latent type: {latent.dtype}')
print(f'auto_conditioning shape: {auto_conditioning.shape}, auto_conditioning type: {auto_conditioning.dtype}')
fp16_auto_conditioning = auto_conditioning.half()
wav, _ = self.bigvgan(latent.transpose(1, 2), fp16_auto_conditioning.transpose(1, 2))
wav = wav.squeeze(1).cpu()
wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2))
wav = wav.squeeze(1).cpu()
else:
latent = \
self.gpt(auto_conditioning, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
print(f'latent shape: {latent.shape}, latent type: {latent.dtype}')
print(f'auto_conditioning: {auto_conditioning.shape}, auto_conditioning: {auto_conditioning.dtype}')
latent = latent.transpose(1, 2)
'''
latent_list = []
for lat, t_len in zip(latent, text_lens_out):
lat = lat[:, t_len:]
latent_list.append(lat)
latent = torch.stack(latent_list)
print(f"latent shape: {latent.shape}")
'''
wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2))
wav = wav.squeeze(1).cpu()
wav = 32767 * wav
torch.clip(wav, -32767.0, 32767.0)
@ -163,5 +212,6 @@ class IndexTTS:
if __name__ == "__main__":
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True)
tts.infer(audio_prompt='test_data/input.wav', text='大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了', output_path="gen.wav")