enable custom cuda kernel for BigVGAN

This commit is contained in:
Yrom 2025-04-15 12:03:23 +08:00
parent 21a3212a34
commit 94d1353e4e
No known key found for this signature in database
5 changed files with 149 additions and 91 deletions

View File

@ -0,0 +1 @@
/build

View File

@ -4,8 +4,8 @@
import torch
import torch.nn as nn
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
from alias_free_activation.cuda import load
from alias_free_activation.torch.resample import DownSample1d, UpSample1d
from indextts.BigVGAN.alias_free_activation.cuda import load
from indextts.BigVGAN.alias_free_activation.torch.resample import DownSample1d, UpSample1d
anti_alias_activation_cuda = load.load()

View File

@ -44,10 +44,10 @@ namespace
__global__ void anti_alias_activation_forward(
output_t *dst,
const input_t *src,
const input_t *up_ftr,
const input_t *down_ftr,
const input_t *alpha,
const input_t *beta,
const acc_t *up_ftr,
const acc_t *down_ftr,
const acc_t *alpha,
const acc_t *beta,
int batch_size,
int channels,
int seq_len)
@ -84,9 +84,10 @@ namespace
// Alpha and beta values for snake activatons. Applies exp by default
alpha = alpha + blockIdx.y;
input_t alpha_val = expf(alpha[0]);
beta = beta + blockIdx.y;
input_t beta_val = expf(beta[0]);
acc_t alpha_val = expf(alpha[0]);
acc_t beta_val = expf(beta[0]);
#pragma unroll
for (int it = 0; it < FILTER_SIZE; it += 1)
@ -118,7 +119,7 @@ namespace
#pragma unroll
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
{
input_t acc = 0.0;
acc_t acc = 0.0;
int element_index = intermediate_seq_offset + it; // index for intermediate
#pragma unroll
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
@ -136,7 +137,8 @@ namespace
#pragma unroll
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
{
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
acc_t a = sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * a * a;
}
// Apply replication padding before downsampling conv from intermediates
@ -155,7 +157,7 @@ namespace
#pragma unroll
for (int it = 0; it < BUFFER_SIZE; it += 1)
{
input_t acc = 0.0;
acc_t acc = 0.0;
#pragma unroll
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
{
@ -182,10 +184,10 @@ namespace
void dispatch_anti_alias_activation_forward(
output_t *dst,
const input_t *src,
const input_t *up_ftr,
const input_t *down_ftr,
const input_t *alpha,
const input_t *beta,
const acc_t *up_ftr,
const acc_t *down_ftr,
const acc_t *alpha,
const acc_t *beta,
int batch_size,
int channels,
int seq_len)
@ -222,23 +224,31 @@ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor cons
torch::Tensor anti_alias_activation_results =
torch::empty({batches, channels, seq_len}, act_options);
using float32 = float;
// The dtype of input is float16, bfloat16, or float32
// The dtype of up_filter, down_filter, alpha, and beta is float32
// printf("input scalar type: %d\n", input.scalar_type());
// printf("up_filter scalar type: %d\n", up_filter.scalar_type());
// printf("down_filter scalar type: %d\n", down_filter.scalar_type());
// printf("alpha scalar type: %d\n", alpha.scalar_type());
// printf("beta scalar type: %d\n", beta.scalar_type());
void *input_ptr = static_cast<void *>(input.data_ptr());
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
void *beta_ptr = static_cast<void *>(beta.data_ptr());
float32 *up_filter_ptr = static_cast<float32 *>(up_filter.data_ptr());
float32 *down_filter_ptr = static_cast<float32 *>(down_filter.data_ptr());
float32 *alpha_ptr = static_cast<float32 *>(alpha.data_ptr());
float32 *beta_ptr = static_cast<float32 *>(beta.data_ptr());
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch anti alias activation_forward",
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float32>(
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
reinterpret_cast<const scalar_t *>(input_ptr),
reinterpret_cast<const scalar_t *>(up_filter_ptr),
reinterpret_cast<const scalar_t *>(down_filter_ptr),
reinterpret_cast<const scalar_t *>(alpha_ptr),
reinterpret_cast<const scalar_t *>(beta_ptr),
reinterpret_cast<const float32 *>(up_filter_ptr),
reinterpret_cast<const float32 *>(down_filter_ptr),
reinterpret_cast<const float32 *>(alpha_ptr),
reinterpret_cast<const float32 *>(beta_ptr),
batches,
channels,
seq_len););

View File

@ -3,12 +3,14 @@
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
import indextts.BigVGAN.activations as activations
from indextts.BigVGAN.alias_free_torch import *
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
from indextts.BigVGAN.utils import get_padding, init_weights
@ -41,7 +43,10 @@ class AMPBlock1(torch.nn.Module):
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if self.h.get("use_cuda_kernel", False):
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
else:
from indextts.BigVGAN.alias_free_torch import Activation1d
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
@ -89,6 +94,10 @@ class AMPBlock2(torch.nn.Module):
self.convs.apply(init_weights)
self.num_layers = len(self.convs) # total number of conv layers
if self.h.get("use_cuda_kernel", False):
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
else:
from indextts.BigVGAN.alias_free_torch import Activation1d
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
@ -120,9 +129,15 @@ class AMPBlock2(torch.nn.Module):
class BigVGAN(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, h):
def __init__(self, h, use_cuda_kernel=False):
"""
Args:
h (dict)
use_cuda_kernel (bool): whether to use custom cuda kernel for anti-aliased activation
"""
super(BigVGAN, self).__init__()
self.h = h
self.h["use_cuda_kernel"] = use_cuda_kernel
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
@ -134,7 +149,7 @@ class BigVGAN(torch.nn.Module):
self.conv_pre = weight_norm(Conv1d(h.gpt_dim, h.upsample_initial_channel, 7, 1, padding=3))
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
@ -150,7 +165,11 @@ class BigVGAN(torch.nn.Module):
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
self.resblocks.append(resblock(self.h, ch, k, d, activation=h.activation))
if use_cuda_kernel:
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
else:
from indextts.BigVGAN.alias_free_torch import Activation1d
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing

View File

@ -14,47 +14,57 @@ from indextts.gpt.model import UnifiedVoice
from indextts.utils.checkpoint import load_checkpoint
from indextts.utils.feature_extractors import MelSpectrogramFeatures
from indextts.utils.common import tokenize_by_CJK_char
from indextts.vqvae.xtts_dvae import DiscreteVAE
from indextts.utils.front import TextNormalizer
class IndexTTS:
def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints', is_fp16=True, device=None):
def __init__(
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, device=None, use_cuda_kernel=None,
):
"""
Args:
cfg_path (str): path to the config file.
model_dir (str): path to the model directory.
is_fp16 (bool): whether to use fp16.
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
"""
if device is not None:
self.device = device
self.is_fp16 = False if device == 'cpu' else is_fp16
self.is_fp16 = False if device == "cpu" else is_fp16
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
elif torch.cuda.is_available():
self.device = 'cuda:0'
self.device = "cuda:0"
self.is_fp16 = is_fp16
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
elif torch.mps.is_available():
self.device = 'mps'
self.device = "mps"
self.is_fp16 = is_fp16
self.use_cuda_kernel = False
else:
self.device = 'cpu'
self.device = "cpu"
self.is_fp16 = False
self.use_cuda_kernel = False
print(">> Be patient, it may take a while to run in CPU mode.")
self.cfg = OmegaConf.load(cfg_path)
self.model_dir = model_dir
self.dtype = torch.float16 if self.is_fp16 else None
self.stop_mel_token = self.cfg.gpt.stop_mel_token
self.dvae = DiscreteVAE(**self.cfg.vqvae)
self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
load_checkpoint(self.dvae, self.dvae_path)
self.dvae = self.dvae.to(self.device)
if self.is_fp16:
self.dvae.eval().half()
else:
self.dvae.eval()
print(">> vqvae weights restored from:", self.dvae_path)
# Comment-off to load the VQ-VAE model for debugging tokenizer
# https://github.com/index-tts/index-tts/issues/34
#
# from indextts.vqvae.xtts_dvae import DiscreteVAE
# self.dvae = DiscreteVAE(**self.cfg.vqvae)
# self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
# load_checkpoint(self.dvae, self.dvae_path)
# self.dvae = self.dvae.to(self.device)
# if self.is_fp16:
# self.dvae.eval().half()
# else:
# self.dvae.eval()
# print(">> vqvae weights restored from:", self.dvae_path)
self.gpt = UnifiedVoice(**self.cfg.gpt)
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
load_checkpoint(self.gpt, self.gpt_path)
@ -75,12 +85,23 @@ class IndexTTS:
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
else:
self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False)
self.bigvgan = Generator(self.cfg.bigvgan)
if self.use_cuda_kernel:
# preload the CUDA kernel for BigVGAN
try:
from indextts.BigVGAN.alias_free_activation.cuda import load
anti_alias_activation_cuda = load.load()
print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda)
except:
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
self.use_cuda_kernel = False
self.bigvgan = Generator(self.cfg.bigvgan, use_cuda_kernel=self.use_cuda_kernel)
self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint)
vocoder_dict = torch.load(self.bigvgan_path, map_location='cpu')
self.bigvgan.load_state_dict(vocoder_dict['generator'])
vocoder_dict = torch.load(self.bigvgan_path, map_location="cpu")
self.bigvgan.load_state_dict(vocoder_dict["generator"])
self.bigvgan = self.bigvgan.to(self.device)
# remove weight norm on eval mode
self.bigvgan.remove_weight_norm()
self.bigvgan.eval()
print(">> bigvgan weights restored from:", self.bigvgan_path)
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset['bpe_model'])
@ -157,11 +178,13 @@ class IndexTTS:
sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in {"'", ".", ","}
]
def infer(self, audio_prompt, text, output_path,verbose=False):
print(f"origin text:{text}")
text = self.preprocess_text(text)
print(f"normalized text:{text}")
def infer(self, audio_prompt, text, output_path, verbose=False):
print(">> start inference...")
if verbose:
print(f"origin text:{text}")
start_time = time.perf_counter()
normalized_text = self.preprocess_text(text)
print(f"normalized text:{normalized_text}")
audio, sr = torchaudio.load(audio_prompt)
audio = torch.mean(audio, dim=0, keepdim=True)
@ -169,12 +192,16 @@ class IndexTTS:
audio = audio[0].unsqueeze(0)
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
cond_mel_frame = cond_mel.shape[-1]
if verbose:
print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
auto_conditioning = cond_mel
sentences = self.split_sentences(text)
print("sentences:", sentences)
sentences = self.split_sentences(normalized_text)
if verbose:
print("sentences:", sentences)
top_p = .8
top_k = 30
temperature = 1.0
@ -187,27 +214,24 @@ class IndexTTS:
# lang = "EN"
# lang = "ZH"
wavs = []
print(">> start inference...")
start_time = time.time()
for sent in sentences:
print(sent)
# sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper()
cleand_text = tokenize_by_CJK_char(sent)
# cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ."
print("cleand_text:", cleand_text)
if verbose:
print("cleand_text:", cleand_text)
text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0)
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
# text_tokens = F.pad(text_tokens, (1, 0), value=0)
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
print(text_tokens)
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
# debug tokenizer
text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist())
print(text_token_syms)
if verbose:
print(text_tokens)
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
# debug tokenizer
text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist())
print(text_token_syms)
# text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
# print(text_len)
@ -229,15 +253,18 @@ class IndexTTS:
max_generate_length=max_mel_tokens)
#codes = codes[:, :-2]
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
print(codes, type(codes))
print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
print(f"code len: {code_lens}")
if verbose:
print(codes, type(codes))
print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
print(f"code len: {code_lens}")
# remove ultra-long silence if exits
# temporarily fix the long silence bug.
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
print(codes, type(codes))
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
print(f"code len: {code_lens}")
if verbose:
print(codes, type(codes))
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
print(f"code len: {code_lens}")
# latent, text_lens_out, code_lens_out = \
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
@ -247,31 +274,32 @@ class IndexTTS:
code_lens*self.gpt.mel_length_compression,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
latent = latent.transpose(1, 2)
wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2))
wav = wav.squeeze(1).cpu()
wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
wav = wav.squeeze(1)
wav = torch.clip(32767 * wav, -32767.0, 32767.0)
print(f"wav shape: {wav.shape}")
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
# wavs.append(wav[:, :-512])
wavs.append(wav)
end_time = time.perf_counter()
end_time = time.time()
elapsed_time = end_time - start_time
minutes, seconds = divmod(int(elapsed_time), 60)
milliseconds = int((elapsed_time - int(elapsed_time)) * 1000)
print(f">> inference done. time: {minutes:02d}:{seconds:02d}.{milliseconds:03d}")
print(">> saving wav file")
wav = torch.cat(wavs, dim=1)
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
wav_length = wav.shape[-1] / sampling_rate
print(f">> Reference audio length: {cond_mel_frame / sampling_rate:.2f} seconds")
print(f">> Total inference time: {end_time - start_time:.2f} seconds")
print(f">> Generated audio length: {wav_length:.2f} seconds")
print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
torchaudio.save(output_path, wav.cpu().type(torch.int16), sampling_rate)
print(">> wav file saved to:", output_path)
if __name__ == "__main__":
prompt_wav="test_data/input.wav"
#text="晕 XUAN4 是 一 种 GAN3 觉"
#text='大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了'
text="There is a vehicle arriving in dock number 7?"
prompt_wav = "test_data/input.wav"
prompt_wav = "testwav/spk_1744181067_1.wav"
# text="晕 XUAN4 是 一 种 GAN3 觉"
# text='大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了'
text = "There is a vehicle arriving in dock number 7?"
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True)
tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)