DeepSpeed acceleration and FP16 inference support, but bigvgan disable
This commit is contained in:
parent
397fef2f14
commit
e92bf90235
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -16,24 +17,38 @@ from indextts.vqvae.xtts_dvae import DiscreteVAE
|
||||
|
||||
from indextts.utils.front import TextNormalizer
|
||||
class IndexTTS:
|
||||
def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints'):
|
||||
def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints', is_fp16=True):
|
||||
self.cfg = OmegaConf.load(cfg_path)
|
||||
self.device = 'cuda:0'
|
||||
self.model_dir = model_dir
|
||||
self.is_fp16 = is_fp16
|
||||
if self.is_fp16:
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = None
|
||||
self.dvae = DiscreteVAE(**self.cfg.vqvae)
|
||||
self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
|
||||
load_checkpoint(self.dvae, self.dvae_path)
|
||||
self.dvae = self.dvae.to(self.device)
|
||||
self.dvae.eval()
|
||||
if self.is_fp16:
|
||||
self.dvae.eval().half()
|
||||
else:
|
||||
self.dvae.eval()
|
||||
print(">> vqvae weights restored from:", self.dvae_path)
|
||||
|
||||
self.gpt = UnifiedVoice(**self.cfg.gpt)
|
||||
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
|
||||
load_checkpoint(self.gpt, self.gpt_path)
|
||||
self.gpt = self.gpt.to(self.device)
|
||||
self.gpt.eval()
|
||||
if self.is_fp16:
|
||||
self.gpt.eval().half()
|
||||
else:
|
||||
self.gpt.eval()
|
||||
print(">> GPT weights restored from:", self.gpt_path)
|
||||
self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False)
|
||||
if self.is_fp16:
|
||||
self.gpt.post_init_gpt2_config(use_deepspeed=True, kv_cache=True, half=True)
|
||||
else:
|
||||
self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False)
|
||||
|
||||
self.bigvgan = Generator(self.cfg.bigvgan)
|
||||
self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint)
|
||||
@ -108,49 +123,83 @@ class IndexTTS:
|
||||
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
|
||||
text_tokens = text_tokens.to(self.device)
|
||||
print(text_tokens)
|
||||
print(f"text_tokens shape: {text_tokens.shape}")
|
||||
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
||||
text_token_syms = [tokenizer.IdToPiece(idx) for idx in text_tokens[0].tolist()]
|
||||
print(text_token_syms)
|
||||
text_len = [text_tokens.size(1)]
|
||||
text_len = torch.IntTensor(text_len).to(self.device)
|
||||
print(text_len)
|
||||
with torch.no_grad():
|
||||
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
||||
device=text_tokens.device),
|
||||
# text_lengths=text_len,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
print(codes)
|
||||
print(f"codes shape: {codes.shape}")
|
||||
if self.is_fp16:
|
||||
with torch.cuda.amp.autocast(enabled=self.dtype is not None, dtype=self.dtype):
|
||||
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
||||
device=text_tokens.device),
|
||||
# text_lengths=text_len,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
else:
|
||||
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
||||
device=text_tokens.device),
|
||||
# text_lengths=text_len,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
print(codes, type(codes))
|
||||
print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
|
||||
codes = codes[:, :-2]
|
||||
|
||||
# latent, text_lens_out, code_lens_out = \
|
||||
latent = \
|
||||
self.gpt(auto_conditioning, text_tokens,
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
latent = latent.transpose(1, 2)
|
||||
'''
|
||||
latent_list = []
|
||||
for lat, t_len in zip(latent, text_lens_out):
|
||||
lat = lat[:, t_len:]
|
||||
latent_list.append(lat)
|
||||
latent = torch.stack(latent_list)
|
||||
print(f"latent shape: {latent.shape}")
|
||||
'''
|
||||
if self.is_fp16:
|
||||
with torch.cuda.amp.autocast(enabled=self.dtype is not None, dtype=self.dtype):
|
||||
latent = \
|
||||
self.gpt(auto_conditioning, text_tokens,
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
latent = latent.transpose(1, 2)
|
||||
print(f'latent shape: {latent.shape}, latent type: {latent.dtype}')
|
||||
print(f'auto_conditioning shape: {auto_conditioning.shape}, auto_conditioning type: {auto_conditioning.dtype}')
|
||||
fp16_auto_conditioning = auto_conditioning.half()
|
||||
wav, _ = self.bigvgan(latent.transpose(1, 2), fp16_auto_conditioning.transpose(1, 2))
|
||||
wav = wav.squeeze(1).cpu()
|
||||
|
||||
wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2))
|
||||
wav = wav.squeeze(1).cpu()
|
||||
else:
|
||||
latent = \
|
||||
self.gpt(auto_conditioning, text_tokens,
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
print(f'latent shape: {latent.shape}, latent type: {latent.dtype}')
|
||||
print(f'auto_conditioning: {auto_conditioning.shape}, auto_conditioning: {auto_conditioning.dtype}')
|
||||
latent = latent.transpose(1, 2)
|
||||
'''
|
||||
latent_list = []
|
||||
for lat, t_len in zip(latent, text_lens_out):
|
||||
lat = lat[:, t_len:]
|
||||
latent_list.append(lat)
|
||||
latent = torch.stack(latent_list)
|
||||
print(f"latent shape: {latent.shape}")
|
||||
'''
|
||||
|
||||
wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2))
|
||||
wav = wav.squeeze(1).cpu()
|
||||
|
||||
wav = 32767 * wav
|
||||
torch.clip(wav, -32767.0, 32767.0)
|
||||
@ -163,5 +212,6 @@ class IndexTTS:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
|
||||
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True)
|
||||
tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!', output_path="gen.wav")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user