From 94d1353e4e46c0b39ae1ea568d2e99f130d64a24 Mon Sep 17 00:00:00 2001 From: Yrom Date: Tue, 15 Apr 2025 12:03:23 +0800 Subject: [PATCH] enable custom cuda kernel for BigVGAN --- .../alias_free_activation/cuda/.gitignore | 1 + .../cuda/activation1d.py | 4 +- .../cuda/anti_alias_activation_cuda.cu | 54 ++++--- indextts/BigVGAN/models.py | 31 +++- indextts/infer.py | 150 +++++++++++------- 5 files changed, 149 insertions(+), 91 deletions(-) create mode 100644 indextts/BigVGAN/alias_free_activation/cuda/.gitignore diff --git a/indextts/BigVGAN/alias_free_activation/cuda/.gitignore b/indextts/BigVGAN/alias_free_activation/cuda/.gitignore new file mode 100644 index 0000000..42afabf --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/cuda/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py b/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py index e0c4ff7..d05f179 100644 --- a/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +++ b/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn # load fused CUDA kernel: this enables importing anti_alias_activation_cuda -from alias_free_activation.cuda import load -from alias_free_activation.torch.resample import DownSample1d, UpSample1d +from indextts.BigVGAN.alias_free_activation.cuda import load +from indextts.BigVGAN.alias_free_activation.torch.resample import DownSample1d, UpSample1d anti_alias_activation_cuda = load.load() diff --git a/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu b/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu index 8c44233..a36d917 100644 --- a/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +++ b/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu @@ -44,10 +44,10 @@ namespace __global__ void anti_alias_activation_forward( output_t *dst, const input_t *src, - const input_t *up_ftr, - const input_t *down_ftr, - const input_t *alpha, - const input_t *beta, + const acc_t *up_ftr, + const acc_t *down_ftr, + const acc_t *alpha, + const acc_t *beta, int batch_size, int channels, int seq_len) @@ -84,9 +84,10 @@ namespace // Alpha and beta values for snake activatons. Applies exp by default alpha = alpha + blockIdx.y; - input_t alpha_val = expf(alpha[0]); beta = beta + blockIdx.y; - input_t beta_val = expf(beta[0]); + + acc_t alpha_val = expf(alpha[0]); + acc_t beta_val = expf(beta[0]); #pragma unroll for (int it = 0; it < FILTER_SIZE; it += 1) @@ -118,7 +119,7 @@ namespace #pragma unroll for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1) { - input_t acc = 0.0; + acc_t acc = 0.0; int element_index = intermediate_seq_offset + it; // index for intermediate #pragma unroll for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) @@ -136,7 +137,8 @@ namespace #pragma unroll for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1) { - intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val); + acc_t a = sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val); + intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * a * a; } // Apply replication padding before downsampling conv from intermediates @@ -155,7 +157,7 @@ namespace #pragma unroll for (int it = 0; it < BUFFER_SIZE; it += 1) { - input_t acc = 0.0; + acc_t acc = 0.0; #pragma unroll for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) { @@ -182,10 +184,10 @@ namespace void dispatch_anti_alias_activation_forward( output_t *dst, const input_t *src, - const input_t *up_ftr, - const input_t *down_ftr, - const input_t *alpha, - const input_t *beta, + const acc_t *up_ftr, + const acc_t *down_ftr, + const acc_t *alpha, + const acc_t *beta, int batch_size, int channels, int seq_len) @@ -222,23 +224,31 @@ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor cons torch::Tensor anti_alias_activation_results = torch::empty({batches, channels, seq_len}, act_options); + using float32 = float; + // The dtype of input is float16, bfloat16, or float32 + // The dtype of up_filter, down_filter, alpha, and beta is float32 + // printf("input scalar type: %d\n", input.scalar_type()); + // printf("up_filter scalar type: %d\n", up_filter.scalar_type()); + // printf("down_filter scalar type: %d\n", down_filter.scalar_type()); + // printf("alpha scalar type: %d\n", alpha.scalar_type()); + // printf("beta scalar type: %d\n", beta.scalar_type()); void *input_ptr = static_cast(input.data_ptr()); - void *up_filter_ptr = static_cast(up_filter.data_ptr()); - void *down_filter_ptr = static_cast(down_filter.data_ptr()); - void *alpha_ptr = static_cast(alpha.data_ptr()); - void *beta_ptr = static_cast(beta.data_ptr()); + float32 *up_filter_ptr = static_cast(up_filter.data_ptr()); + float32 *down_filter_ptr = static_cast(down_filter.data_ptr()); + float32 *alpha_ptr = static_cast(alpha.data_ptr()); + float32 *beta_ptr = static_cast(beta.data_ptr()); void *anti_alias_activation_results_ptr = static_cast(anti_alias_activation_results.data_ptr()); DISPATCH_FLOAT_HALF_AND_BFLOAT( input.scalar_type(), "dispatch anti alias activation_forward", - dispatch_anti_alias_activation_forward( + dispatch_anti_alias_activation_forward( reinterpret_cast(anti_alias_activation_results_ptr), reinterpret_cast(input_ptr), - reinterpret_cast(up_filter_ptr), - reinterpret_cast(down_filter_ptr), - reinterpret_cast(alpha_ptr), - reinterpret_cast(beta_ptr), + reinterpret_cast(up_filter_ptr), + reinterpret_cast(down_filter_ptr), + reinterpret_cast(alpha_ptr), + reinterpret_cast(beta_ptr), batches, channels, seq_len);); diff --git a/indextts/BigVGAN/models.py b/indextts/BigVGAN/models.py index 602f103..771b89b 100644 --- a/indextts/BigVGAN/models.py +++ b/indextts/BigVGAN/models.py @@ -3,12 +3,14 @@ # Adapted from https://github.com/jik876/hifi-gan under the MIT license. # LICENSE is in incl_licenses directory. - +import torch +import torch.nn as nn +import torch.nn.functional as F from torch.nn import Conv1d, Conv2d, ConvTranspose1d from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm import indextts.BigVGAN.activations as activations -from indextts.BigVGAN.alias_free_torch import * + from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN from indextts.BigVGAN.utils import get_padding, init_weights @@ -41,7 +43,10 @@ class AMPBlock1(torch.nn.Module): self.convs2.apply(init_weights) self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers - + if self.h.get("use_cuda_kernel", False): + from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d + else: + from indextts.BigVGAN.alias_free_torch import Activation1d if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( @@ -89,6 +94,10 @@ class AMPBlock2(torch.nn.Module): self.convs.apply(init_weights) self.num_layers = len(self.convs) # total number of conv layers + if self.h.get("use_cuda_kernel", False): + from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d + else: + from indextts.BigVGAN.alias_free_torch import Activation1d if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing self.activations = nn.ModuleList([ @@ -120,9 +129,15 @@ class AMPBlock2(torch.nn.Module): class BigVGAN(torch.nn.Module): # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. - def __init__(self, h): + def __init__(self, h, use_cuda_kernel=False): + """ + Args: + h (dict) + use_cuda_kernel (bool): whether to use custom cuda kernel for anti-aliased activation + """ super(BigVGAN, self).__init__() self.h = h + self.h["use_cuda_kernel"] = use_cuda_kernel self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) @@ -134,7 +149,7 @@ class BigVGAN(torch.nn.Module): self.conv_pre = weight_norm(Conv1d(h.gpt_dim, h.upsample_initial_channel, 7, 1, padding=3)) # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default - resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2 + resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2 # transposed conv-based upsamplers. does not apply anti-aliasing self.ups = nn.ModuleList() @@ -150,7 +165,11 @@ class BigVGAN(torch.nn.Module): for i in range(len(self.ups)): ch = h.upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): - self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) + self.resblocks.append(resblock(self.h, ch, k, d, activation=h.activation)) + if use_cuda_kernel: + from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d + else: + from indextts.BigVGAN.alias_free_torch import Activation1d # post conv if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing diff --git a/indextts/infer.py b/indextts/infer.py index 2d82767..aab51d6 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -14,47 +14,57 @@ from indextts.gpt.model import UnifiedVoice from indextts.utils.checkpoint import load_checkpoint from indextts.utils.feature_extractors import MelSpectrogramFeatures from indextts.utils.common import tokenize_by_CJK_char -from indextts.vqvae.xtts_dvae import DiscreteVAE from indextts.utils.front import TextNormalizer class IndexTTS: - def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints', is_fp16=True, device=None): + def __init__( + self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, device=None, use_cuda_kernel=None, + ): """ Args: cfg_path (str): path to the config file. model_dir (str): path to the model directory. is_fp16 (bool): whether to use fp16. device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS. + use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device. """ if device is not None: self.device = device - self.is_fp16 = False if device == 'cpu' else is_fp16 + self.is_fp16 = False if device == "cpu" else is_fp16 + self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda") elif torch.cuda.is_available(): - self.device = 'cuda:0' + self.device = "cuda:0" self.is_fp16 = is_fp16 + self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel elif torch.mps.is_available(): - self.device = 'mps' + self.device = "mps" self.is_fp16 = is_fp16 + self.use_cuda_kernel = False else: - self.device = 'cpu' + self.device = "cpu" self.is_fp16 = False + self.use_cuda_kernel = False print(">> Be patient, it may take a while to run in CPU mode.") self.cfg = OmegaConf.load(cfg_path) self.model_dir = model_dir self.dtype = torch.float16 if self.is_fp16 else None self.stop_mel_token = self.cfg.gpt.stop_mel_token - - self.dvae = DiscreteVAE(**self.cfg.vqvae) - self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint) - load_checkpoint(self.dvae, self.dvae_path) - self.dvae = self.dvae.to(self.device) - if self.is_fp16: - self.dvae.eval().half() - else: - self.dvae.eval() - print(">> vqvae weights restored from:", self.dvae_path) + + # Comment-off to load the VQ-VAE model for debugging tokenizer + # https://github.com/index-tts/index-tts/issues/34 + # + # from indextts.vqvae.xtts_dvae import DiscreteVAE + # self.dvae = DiscreteVAE(**self.cfg.vqvae) + # self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint) + # load_checkpoint(self.dvae, self.dvae_path) + # self.dvae = self.dvae.to(self.device) + # if self.is_fp16: + # self.dvae.eval().half() + # else: + # self.dvae.eval() + # print(">> vqvae weights restored from:", self.dvae_path) self.gpt = UnifiedVoice(**self.cfg.gpt) self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint) load_checkpoint(self.gpt, self.gpt_path) @@ -75,12 +85,23 @@ class IndexTTS: self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True) else: self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False) - - self.bigvgan = Generator(self.cfg.bigvgan) + + if self.use_cuda_kernel: + # preload the CUDA kernel for BigVGAN + try: + from indextts.BigVGAN.alias_free_activation.cuda import load + anti_alias_activation_cuda = load.load() + print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda) + except: + print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.") + self.use_cuda_kernel = False + self.bigvgan = Generator(self.cfg.bigvgan, use_cuda_kernel=self.use_cuda_kernel) self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint) - vocoder_dict = torch.load(self.bigvgan_path, map_location='cpu') - self.bigvgan.load_state_dict(vocoder_dict['generator']) + vocoder_dict = torch.load(self.bigvgan_path, map_location="cpu") + self.bigvgan.load_state_dict(vocoder_dict["generator"]) self.bigvgan = self.bigvgan.to(self.device) + # remove weight norm on eval mode + self.bigvgan.remove_weight_norm() self.bigvgan.eval() print(">> bigvgan weights restored from:", self.bigvgan_path) self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset['bpe_model']) @@ -157,11 +178,13 @@ class IndexTTS: sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in {"'", ".", ","} ] - def infer(self, audio_prompt, text, output_path,verbose=False): - print(f"origin text:{text}") - text = self.preprocess_text(text) - print(f"normalized text:{text}") - + def infer(self, audio_prompt, text, output_path, verbose=False): + print(">> start inference...") + if verbose: + print(f"origin text:{text}") + start_time = time.perf_counter() + normalized_text = self.preprocess_text(text) + print(f"normalized text:{normalized_text}") audio, sr = torchaudio.load(audio_prompt) audio = torch.mean(audio, dim=0, keepdim=True) @@ -169,12 +192,16 @@ class IndexTTS: audio = audio[0].unsqueeze(0) audio = torchaudio.transforms.Resample(sr, 24000)(audio) cond_mel = MelSpectrogramFeatures()(audio).to(self.device) - print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype) + cond_mel_frame = cond_mel.shape[-1] + if verbose: + print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype) auto_conditioning = cond_mel - sentences = self.split_sentences(text) - print("sentences:", sentences) + sentences = self.split_sentences(normalized_text) + if verbose: + print("sentences:", sentences) + top_p = .8 top_k = 30 temperature = 1.0 @@ -187,27 +214,24 @@ class IndexTTS: # lang = "EN" # lang = "ZH" wavs = [] - print(">> start inference...") - - start_time = time.time() for sent in sentences: - print(sent) # sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper() cleand_text = tokenize_by_CJK_char(sent) # cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ." - print("cleand_text:", cleand_text) + if verbose: + print("cleand_text:", cleand_text) text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0) # text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. # text_tokens = F.pad(text_tokens, (1, 0), value=0) # text_tokens = F.pad(text_tokens, (0, 1), value=1) - - print(text_tokens) - print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") - # debug tokenizer - text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist()) - print(text_token_syms) + if verbose: + print(text_tokens) + print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") + # debug tokenizer + text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist()) + print(text_token_syms) # text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device) # print(text_len) @@ -229,15 +253,18 @@ class IndexTTS: max_generate_length=max_mel_tokens) #codes = codes[:, :-2] code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype) - print(codes, type(codes)) - print(f"codes shape: {codes.shape}, codes type: {codes.dtype}") - print(f"code len: {code_lens}") + if verbose: + print(codes, type(codes)) + print(f"codes shape: {codes.shape}, codes type: {codes.dtype}") + print(f"code len: {code_lens}") + # remove ultra-long silence if exits # temporarily fix the long silence bug. codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30) - print(codes, type(codes)) - print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}") - print(f"code len: {code_lens}") + if verbose: + print(codes, type(codes)) + print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}") + print(f"code len: {code_lens}") # latent, text_lens_out, code_lens_out = \ with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): @@ -247,31 +274,32 @@ class IndexTTS: code_lens*self.gpt.mel_length_compression, cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), return_latent=True, clip_inputs=False) - latent = latent.transpose(1, 2) - wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2)) - wav = wav.squeeze(1).cpu() + wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2)) + wav = wav.squeeze(1) - wav = torch.clip(32767 * wav, -32767.0, 32767.0) - print(f"wav shape: {wav.shape}") + wav = torch.clamp(32767 * wav, -32767.0, 32767.0) + print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max()) # wavs.append(wav[:, :-512]) wavs.append(wav) + end_time = time.perf_counter() - end_time = time.time() - elapsed_time = end_time - start_time - minutes, seconds = divmod(int(elapsed_time), 60) - milliseconds = int((elapsed_time - int(elapsed_time)) * 1000) - print(f">> inference done. time: {minutes:02d}:{seconds:02d}.{milliseconds:03d}") - print(">> saving wav file") wav = torch.cat(wavs, dim=1) - torchaudio.save(output_path, wav.type(torch.int16), sampling_rate) + wav_length = wav.shape[-1] / sampling_rate + print(f">> Reference audio length: {cond_mel_frame / sampling_rate:.2f} seconds") + print(f">> Total inference time: {end_time - start_time:.2f} seconds") + print(f">> Generated audio length: {wav_length:.2f} seconds") + print(f">> RTF: {(end_time - start_time) / wav_length:.4f}") + + torchaudio.save(output_path, wav.cpu().type(torch.int16), sampling_rate) print(">> wav file saved to:", output_path) if __name__ == "__main__": - prompt_wav="test_data/input.wav" - #text="晕 XUAN4 是 一 种 GAN3 觉" - #text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!' - text="There is a vehicle arriving in dock number 7?" + prompt_wav = "test_data/input.wav" + prompt_wav = "testwav/spk_1744181067_1.wav" + # text="晕 XUAN4 是 一 种 GAN3 觉" + # text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!' + text = "There is a vehicle arriving in dock number 7?" tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True) tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)