enable custom cuda kernel for BigVGAN
This commit is contained in:
parent
21a3212a34
commit
94d1353e4e
1
indextts/BigVGAN/alias_free_activation/cuda/.gitignore
vendored
Normal file
1
indextts/BigVGAN/alias_free_activation/cuda/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/build
|
||||
@ -4,8 +4,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
||||
from alias_free_activation.cuda import load
|
||||
from alias_free_activation.torch.resample import DownSample1d, UpSample1d
|
||||
from indextts.BigVGAN.alias_free_activation.cuda import load
|
||||
from indextts.BigVGAN.alias_free_activation.torch.resample import DownSample1d, UpSample1d
|
||||
|
||||
anti_alias_activation_cuda = load.load()
|
||||
|
||||
|
||||
@ -44,10 +44,10 @@ namespace
|
||||
__global__ void anti_alias_activation_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const input_t *up_ftr,
|
||||
const input_t *down_ftr,
|
||||
const input_t *alpha,
|
||||
const input_t *beta,
|
||||
const acc_t *up_ftr,
|
||||
const acc_t *down_ftr,
|
||||
const acc_t *alpha,
|
||||
const acc_t *beta,
|
||||
int batch_size,
|
||||
int channels,
|
||||
int seq_len)
|
||||
@ -84,9 +84,10 @@ namespace
|
||||
|
||||
// Alpha and beta values for snake activatons. Applies exp by default
|
||||
alpha = alpha + blockIdx.y;
|
||||
input_t alpha_val = expf(alpha[0]);
|
||||
beta = beta + blockIdx.y;
|
||||
input_t beta_val = expf(beta[0]);
|
||||
|
||||
acc_t alpha_val = expf(alpha[0]);
|
||||
acc_t beta_val = expf(beta[0]);
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < FILTER_SIZE; it += 1)
|
||||
@ -118,7 +119,7 @@ namespace
|
||||
#pragma unroll
|
||||
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
|
||||
{
|
||||
input_t acc = 0.0;
|
||||
acc_t acc = 0.0;
|
||||
int element_index = intermediate_seq_offset + it; // index for intermediate
|
||||
#pragma unroll
|
||||
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
||||
@ -136,7 +137,8 @@ namespace
|
||||
#pragma unroll
|
||||
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
|
||||
{
|
||||
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
||||
acc_t a = sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
||||
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * a * a;
|
||||
}
|
||||
|
||||
// Apply replication padding before downsampling conv from intermediates
|
||||
@ -155,7 +157,7 @@ namespace
|
||||
#pragma unroll
|
||||
for (int it = 0; it < BUFFER_SIZE; it += 1)
|
||||
{
|
||||
input_t acc = 0.0;
|
||||
acc_t acc = 0.0;
|
||||
#pragma unroll
|
||||
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
||||
{
|
||||
@ -182,10 +184,10 @@ namespace
|
||||
void dispatch_anti_alias_activation_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const input_t *up_ftr,
|
||||
const input_t *down_ftr,
|
||||
const input_t *alpha,
|
||||
const input_t *beta,
|
||||
const acc_t *up_ftr,
|
||||
const acc_t *down_ftr,
|
||||
const acc_t *alpha,
|
||||
const acc_t *beta,
|
||||
int batch_size,
|
||||
int channels,
|
||||
int seq_len)
|
||||
@ -222,23 +224,31 @@ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor cons
|
||||
torch::Tensor anti_alias_activation_results =
|
||||
torch::empty({batches, channels, seq_len}, act_options);
|
||||
|
||||
using float32 = float;
|
||||
// The dtype of input is float16, bfloat16, or float32
|
||||
// The dtype of up_filter, down_filter, alpha, and beta is float32
|
||||
// printf("input scalar type: %d\n", input.scalar_type());
|
||||
// printf("up_filter scalar type: %d\n", up_filter.scalar_type());
|
||||
// printf("down_filter scalar type: %d\n", down_filter.scalar_type());
|
||||
// printf("alpha scalar type: %d\n", alpha.scalar_type());
|
||||
// printf("beta scalar type: %d\n", beta.scalar_type());
|
||||
void *input_ptr = static_cast<void *>(input.data_ptr());
|
||||
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
|
||||
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
|
||||
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
|
||||
void *beta_ptr = static_cast<void *>(beta.data_ptr());
|
||||
float32 *up_filter_ptr = static_cast<float32 *>(up_filter.data_ptr());
|
||||
float32 *down_filter_ptr = static_cast<float32 *>(down_filter.data_ptr());
|
||||
float32 *alpha_ptr = static_cast<float32 *>(alpha.data_ptr());
|
||||
float32 *beta_ptr = static_cast<float32 *>(beta.data_ptr());
|
||||
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
|
||||
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"dispatch anti alias activation_forward",
|
||||
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
|
||||
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float32>(
|
||||
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
|
||||
reinterpret_cast<const scalar_t *>(input_ptr),
|
||||
reinterpret_cast<const scalar_t *>(up_filter_ptr),
|
||||
reinterpret_cast<const scalar_t *>(down_filter_ptr),
|
||||
reinterpret_cast<const scalar_t *>(alpha_ptr),
|
||||
reinterpret_cast<const scalar_t *>(beta_ptr),
|
||||
reinterpret_cast<const float32 *>(up_filter_ptr),
|
||||
reinterpret_cast<const float32 *>(down_filter_ptr),
|
||||
reinterpret_cast<const float32 *>(alpha_ptr),
|
||||
reinterpret_cast<const float32 *>(beta_ptr),
|
||||
batches,
|
||||
channels,
|
||||
seq_len););
|
||||
|
||||
@ -3,12 +3,14 @@
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
import indextts.BigVGAN.activations as activations
|
||||
from indextts.BigVGAN.alias_free_torch import *
|
||||
|
||||
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
||||
from indextts.BigVGAN.utils import get_padding, init_weights
|
||||
|
||||
@ -41,7 +43,10 @@ class AMPBlock1(torch.nn.Module):
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
|
||||
else:
|
||||
from indextts.BigVGAN.alias_free_torch import Activation1d
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
@ -89,6 +94,10 @@ class AMPBlock2(torch.nn.Module):
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs) # total number of conv layers
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
|
||||
else:
|
||||
from indextts.BigVGAN.alias_free_torch import Activation1d
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
@ -120,9 +129,15 @@ class AMPBlock2(torch.nn.Module):
|
||||
|
||||
class BigVGAN(torch.nn.Module):
|
||||
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
||||
def __init__(self, h):
|
||||
def __init__(self, h, use_cuda_kernel=False):
|
||||
"""
|
||||
Args:
|
||||
h (dict)
|
||||
use_cuda_kernel (bool): whether to use custom cuda kernel for anti-aliased activation
|
||||
"""
|
||||
super(BigVGAN, self).__init__()
|
||||
self.h = h
|
||||
self.h["use_cuda_kernel"] = use_cuda_kernel
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
@ -134,7 +149,7 @@ class BigVGAN(torch.nn.Module):
|
||||
self.conv_pre = weight_norm(Conv1d(h.gpt_dim, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
|
||||
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
|
||||
resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2
|
||||
|
||||
# transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
@ -150,7 +165,11 @@ class BigVGAN(torch.nn.Module):
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||
self.resblocks.append(resblock(self.h, ch, k, d, activation=h.activation))
|
||||
if use_cuda_kernel:
|
||||
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
|
||||
else:
|
||||
from indextts.BigVGAN.alias_free_torch import Activation1d
|
||||
|
||||
# post conv
|
||||
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||
|
||||
@ -14,47 +14,57 @@ from indextts.gpt.model import UnifiedVoice
|
||||
from indextts.utils.checkpoint import load_checkpoint
|
||||
from indextts.utils.feature_extractors import MelSpectrogramFeatures
|
||||
from indextts.utils.common import tokenize_by_CJK_char
|
||||
from indextts.vqvae.xtts_dvae import DiscreteVAE
|
||||
|
||||
from indextts.utils.front import TextNormalizer
|
||||
|
||||
class IndexTTS:
|
||||
def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints', is_fp16=True, device=None):
|
||||
def __init__(
|
||||
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, device=None, use_cuda_kernel=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cfg_path (str): path to the config file.
|
||||
model_dir (str): path to the model directory.
|
||||
is_fp16 (bool): whether to use fp16.
|
||||
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
|
||||
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
|
||||
"""
|
||||
if device is not None:
|
||||
self.device = device
|
||||
self.is_fp16 = False if device == 'cpu' else is_fp16
|
||||
self.is_fp16 = False if device == "cpu" else is_fp16
|
||||
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
|
||||
elif torch.cuda.is_available():
|
||||
self.device = 'cuda:0'
|
||||
self.device = "cuda:0"
|
||||
self.is_fp16 = is_fp16
|
||||
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
|
||||
elif torch.mps.is_available():
|
||||
self.device = 'mps'
|
||||
self.device = "mps"
|
||||
self.is_fp16 = is_fp16
|
||||
self.use_cuda_kernel = False
|
||||
else:
|
||||
self.device = 'cpu'
|
||||
self.device = "cpu"
|
||||
self.is_fp16 = False
|
||||
self.use_cuda_kernel = False
|
||||
print(">> Be patient, it may take a while to run in CPU mode.")
|
||||
|
||||
self.cfg = OmegaConf.load(cfg_path)
|
||||
self.model_dir = model_dir
|
||||
self.dtype = torch.float16 if self.is_fp16 else None
|
||||
self.stop_mel_token = self.cfg.gpt.stop_mel_token
|
||||
|
||||
self.dvae = DiscreteVAE(**self.cfg.vqvae)
|
||||
self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
|
||||
load_checkpoint(self.dvae, self.dvae_path)
|
||||
self.dvae = self.dvae.to(self.device)
|
||||
if self.is_fp16:
|
||||
self.dvae.eval().half()
|
||||
else:
|
||||
self.dvae.eval()
|
||||
print(">> vqvae weights restored from:", self.dvae_path)
|
||||
|
||||
# Comment-off to load the VQ-VAE model for debugging tokenizer
|
||||
# https://github.com/index-tts/index-tts/issues/34
|
||||
#
|
||||
# from indextts.vqvae.xtts_dvae import DiscreteVAE
|
||||
# self.dvae = DiscreteVAE(**self.cfg.vqvae)
|
||||
# self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
|
||||
# load_checkpoint(self.dvae, self.dvae_path)
|
||||
# self.dvae = self.dvae.to(self.device)
|
||||
# if self.is_fp16:
|
||||
# self.dvae.eval().half()
|
||||
# else:
|
||||
# self.dvae.eval()
|
||||
# print(">> vqvae weights restored from:", self.dvae_path)
|
||||
self.gpt = UnifiedVoice(**self.cfg.gpt)
|
||||
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
|
||||
load_checkpoint(self.gpt, self.gpt_path)
|
||||
@ -75,12 +85,23 @@ class IndexTTS:
|
||||
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
|
||||
else:
|
||||
self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False)
|
||||
|
||||
self.bigvgan = Generator(self.cfg.bigvgan)
|
||||
|
||||
if self.use_cuda_kernel:
|
||||
# preload the CUDA kernel for BigVGAN
|
||||
try:
|
||||
from indextts.BigVGAN.alias_free_activation.cuda import load
|
||||
anti_alias_activation_cuda = load.load()
|
||||
print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda)
|
||||
except:
|
||||
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
|
||||
self.use_cuda_kernel = False
|
||||
self.bigvgan = Generator(self.cfg.bigvgan, use_cuda_kernel=self.use_cuda_kernel)
|
||||
self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint)
|
||||
vocoder_dict = torch.load(self.bigvgan_path, map_location='cpu')
|
||||
self.bigvgan.load_state_dict(vocoder_dict['generator'])
|
||||
vocoder_dict = torch.load(self.bigvgan_path, map_location="cpu")
|
||||
self.bigvgan.load_state_dict(vocoder_dict["generator"])
|
||||
self.bigvgan = self.bigvgan.to(self.device)
|
||||
# remove weight norm on eval mode
|
||||
self.bigvgan.remove_weight_norm()
|
||||
self.bigvgan.eval()
|
||||
print(">> bigvgan weights restored from:", self.bigvgan_path)
|
||||
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset['bpe_model'])
|
||||
@ -157,11 +178,13 @@ class IndexTTS:
|
||||
sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in {"'", ".", ","}
|
||||
]
|
||||
|
||||
def infer(self, audio_prompt, text, output_path,verbose=False):
|
||||
print(f"origin text:{text}")
|
||||
text = self.preprocess_text(text)
|
||||
print(f"normalized text:{text}")
|
||||
|
||||
def infer(self, audio_prompt, text, output_path, verbose=False):
|
||||
print(">> start inference...")
|
||||
if verbose:
|
||||
print(f"origin text:{text}")
|
||||
start_time = time.perf_counter()
|
||||
normalized_text = self.preprocess_text(text)
|
||||
print(f"normalized text:{normalized_text}")
|
||||
|
||||
audio, sr = torchaudio.load(audio_prompt)
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
@ -169,12 +192,16 @@ class IndexTTS:
|
||||
audio = audio[0].unsqueeze(0)
|
||||
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
|
||||
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
|
||||
print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
|
||||
cond_mel_frame = cond_mel.shape[-1]
|
||||
if verbose:
|
||||
print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
|
||||
|
||||
auto_conditioning = cond_mel
|
||||
|
||||
sentences = self.split_sentences(text)
|
||||
print("sentences:", sentences)
|
||||
sentences = self.split_sentences(normalized_text)
|
||||
if verbose:
|
||||
print("sentences:", sentences)
|
||||
|
||||
top_p = .8
|
||||
top_k = 30
|
||||
temperature = 1.0
|
||||
@ -187,27 +214,24 @@ class IndexTTS:
|
||||
# lang = "EN"
|
||||
# lang = "ZH"
|
||||
wavs = []
|
||||
print(">> start inference...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for sent in sentences:
|
||||
print(sent)
|
||||
# sent = " ".join([char for char in sent.upper()]) if lang == "ZH" else sent.upper()
|
||||
cleand_text = tokenize_by_CJK_char(sent)
|
||||
# cleand_text = "他 那 像 HONG3 小 孩 似 的 话 , 引 得 人 们 HONG1 堂 大 笑 , 大 家 听 了 一 HONG3 而 散 ."
|
||||
print("cleand_text:", cleand_text)
|
||||
if verbose:
|
||||
print("cleand_text:", cleand_text)
|
||||
|
||||
text_tokens = torch.tensor(self.tokenizer.EncodeAsIds(cleand_text),dtype=torch.int32, device=self.device).unsqueeze(0)
|
||||
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||
# text_tokens = F.pad(text_tokens, (1, 0), value=0)
|
||||
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
|
||||
|
||||
print(text_tokens)
|
||||
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
||||
# debug tokenizer
|
||||
text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist())
|
||||
print(text_token_syms)
|
||||
if verbose:
|
||||
print(text_tokens)
|
||||
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
||||
# debug tokenizer
|
||||
text_token_syms = self.tokenizer.IdToPiece(text_tokens[0].tolist())
|
||||
print(text_token_syms)
|
||||
|
||||
# text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
|
||||
# print(text_len)
|
||||
@ -229,15 +253,18 @@ class IndexTTS:
|
||||
max_generate_length=max_mel_tokens)
|
||||
#codes = codes[:, :-2]
|
||||
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
|
||||
print(codes, type(codes))
|
||||
print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
|
||||
print(f"code len: {code_lens}")
|
||||
if verbose:
|
||||
print(codes, type(codes))
|
||||
print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
|
||||
print(f"code len: {code_lens}")
|
||||
|
||||
# remove ultra-long silence if exits
|
||||
# temporarily fix the long silence bug.
|
||||
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
|
||||
print(codes, type(codes))
|
||||
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
|
||||
print(f"code len: {code_lens}")
|
||||
if verbose:
|
||||
print(codes, type(codes))
|
||||
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
|
||||
print(f"code len: {code_lens}")
|
||||
|
||||
# latent, text_lens_out, code_lens_out = \
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
@ -247,31 +274,32 @@ class IndexTTS:
|
||||
code_lens*self.gpt.mel_length_compression,
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
latent = latent.transpose(1, 2)
|
||||
wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2))
|
||||
wav = wav.squeeze(1).cpu()
|
||||
wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
|
||||
wav = wav.squeeze(1)
|
||||
|
||||
wav = torch.clip(32767 * wav, -32767.0, 32767.0)
|
||||
print(f"wav shape: {wav.shape}")
|
||||
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
|
||||
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
|
||||
# wavs.append(wav[:, :-512])
|
||||
wavs.append(wav)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
minutes, seconds = divmod(int(elapsed_time), 60)
|
||||
milliseconds = int((elapsed_time - int(elapsed_time)) * 1000)
|
||||
print(f">> inference done. time: {minutes:02d}:{seconds:02d}.{milliseconds:03d}")
|
||||
print(">> saving wav file")
|
||||
wav = torch.cat(wavs, dim=1)
|
||||
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
|
||||
wav_length = wav.shape[-1] / sampling_rate
|
||||
print(f">> Reference audio length: {cond_mel_frame / sampling_rate:.2f} seconds")
|
||||
print(f">> Total inference time: {end_time - start_time:.2f} seconds")
|
||||
print(f">> Generated audio length: {wav_length:.2f} seconds")
|
||||
print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
|
||||
|
||||
torchaudio.save(output_path, wav.cpu().type(torch.int16), sampling_rate)
|
||||
print(">> wav file saved to:", output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_wav="test_data/input.wav"
|
||||
#text="晕 XUAN4 是 一 种 GAN3 觉"
|
||||
#text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!'
|
||||
text="There is a vehicle arriving in dock number 7?"
|
||||
prompt_wav = "test_data/input.wav"
|
||||
prompt_wav = "testwav/spk_1744181067_1.wav"
|
||||
# text="晕 XUAN4 是 一 种 GAN3 觉"
|
||||
# text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!'
|
||||
text = "There is a vehicle arriving in dock number 7?"
|
||||
|
||||
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True)
|
||||
tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user