diff --git a/indextts/infer.py b/indextts/infer.py index f1f419e..673433d 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -1,6 +1,7 @@ import os import re import sys +import time import sentencepiece as spm import torch @@ -16,24 +17,38 @@ from indextts.vqvae.xtts_dvae import DiscreteVAE from indextts.utils.front import TextNormalizer class IndexTTS: - def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints'): + def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints', is_fp16=True): self.cfg = OmegaConf.load(cfg_path) self.device = 'cuda:0' self.model_dir = model_dir + self.is_fp16 = is_fp16 + if self.is_fp16: + self.dtype = torch.float16 + else: + self.dtype = None self.dvae = DiscreteVAE(**self.cfg.vqvae) self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint) load_checkpoint(self.dvae, self.dvae_path) self.dvae = self.dvae.to(self.device) - self.dvae.eval() + if self.is_fp16: + self.dvae.eval().half() + else: + self.dvae.eval() print(">> vqvae weights restored from:", self.dvae_path) self.gpt = UnifiedVoice(**self.cfg.gpt) self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint) load_checkpoint(self.gpt, self.gpt_path) self.gpt = self.gpt.to(self.device) - self.gpt.eval() + if self.is_fp16: + self.gpt.eval().half() + else: + self.gpt.eval() print(">> GPT weights restored from:", self.gpt_path) - self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False) + if self.is_fp16: + self.gpt.post_init_gpt2_config(use_deepspeed=True, kv_cache=True, half=True) + else: + self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False) self.bigvgan = Generator(self.cfg.bigvgan) self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint) @@ -108,49 +123,83 @@ class IndexTTS: # text_tokens = F.pad(text_tokens, (0, 1), value=1) text_tokens = text_tokens.to(self.device) print(text_tokens) - print(f"text_tokens shape: {text_tokens.shape}") + print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") text_token_syms = [tokenizer.IdToPiece(idx) for idx in text_tokens[0].tolist()] print(text_token_syms) text_len = [text_tokens.size(1)] text_len = torch.IntTensor(text_len).to(self.device) print(text_len) with torch.no_grad(): - codes = self.gpt.inference_speech(auto_conditioning, text_tokens, - cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], - device=text_tokens.device), - # text_lengths=text_len, - do_sample=True, - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_return_sequences=autoregressive_batch_size, - length_penalty=length_penalty, - num_beams=num_beams, - repetition_penalty=repetition_penalty, - max_generate_length=max_mel_tokens) - print(codes) - print(f"codes shape: {codes.shape}") + if self.is_fp16: + with torch.cuda.amp.autocast(enabled=self.dtype is not None, dtype=self.dtype): + codes = self.gpt.inference_speech(auto_conditioning, text_tokens, + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], + device=text_tokens.device), + # text_lengths=text_len, + do_sample=True, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=autoregressive_batch_size, + length_penalty=length_penalty, + num_beams=num_beams, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens) + else: + codes = self.gpt.inference_speech(auto_conditioning, text_tokens, + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], + device=text_tokens.device), + # text_lengths=text_len, + do_sample=True, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=autoregressive_batch_size, + length_penalty=length_penalty, + num_beams=num_beams, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens) + print(codes, type(codes)) + print(f"codes shape: {codes.shape}, codes type: {codes.dtype}") codes = codes[:, :-2] # latent, text_lens_out, code_lens_out = \ - latent = \ - self.gpt(auto_conditioning, text_tokens, - torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, - torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device), - cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), - return_latent=True, clip_inputs=False) - latent = latent.transpose(1, 2) - ''' - latent_list = [] - for lat, t_len in zip(latent, text_lens_out): - lat = lat[:, t_len:] - latent_list.append(lat) - latent = torch.stack(latent_list) - print(f"latent shape: {latent.shape}") - ''' + if self.is_fp16: + with torch.cuda.amp.autocast(enabled=self.dtype is not None, dtype=self.dtype): + latent = \ + self.gpt(auto_conditioning, text_tokens, + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, + torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device), + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), + return_latent=True, clip_inputs=False) + latent = latent.transpose(1, 2) + print(f'latent shape: {latent.shape}, latent type: {latent.dtype}') + print(f'auto_conditioning shape: {auto_conditioning.shape}, auto_conditioning type: {auto_conditioning.dtype}') + fp16_auto_conditioning = auto_conditioning.half() + wav, _ = self.bigvgan(latent.transpose(1, 2), fp16_auto_conditioning.transpose(1, 2)) + wav = wav.squeeze(1).cpu() - wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2)) - wav = wav.squeeze(1).cpu() + else: + latent = \ + self.gpt(auto_conditioning, text_tokens, + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, + torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device), + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), + return_latent=True, clip_inputs=False) + print(f'latent shape: {latent.shape}, latent type: {latent.dtype}') + print(f'auto_conditioning: {auto_conditioning.shape}, auto_conditioning: {auto_conditioning.dtype}') + latent = latent.transpose(1, 2) + ''' + latent_list = [] + for lat, t_len in zip(latent, text_lens_out): + lat = lat[:, t_len:] + latent_list.append(lat) + latent = torch.stack(latent_list) + print(f"latent shape: {latent.shape}") + ''' + + wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2)) + wav = wav.squeeze(1).cpu() wav = 32767 * wav torch.clip(wav, -32767.0, 32767.0) @@ -163,5 +212,6 @@ class IndexTTS: if __name__ == "__main__": - tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints") + tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True) tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!', output_path="gen.wav") +