fix packages

This commit is contained in:
Emmanuel Schmidbauer 2025-03-25 14:03:29 -04:00
parent 4c086f954b
commit 2fe6a73ada
31 changed files with 171 additions and 148 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
venv/
__pycache__

View File

@ -12,6 +12,7 @@ from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d
from indextts.BigVGAN.nnet.linear import Linear from indextts.BigVGAN.nnet.linear import Linear
from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d
def length_to_mask(length, max_len=None, dtype=None, device=None): def length_to_mask(length, max_len=None, dtype=None, device=None):
"""Creates a binary mask for each sequence. """Creates a binary mask for each sequence.

View File

View File

@ -2,7 +2,7 @@
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
import torch import torch
from torch import nn, sin, pow from torch import nn, pow, sin
from torch.nn import Parameter from torch.nn import Parameter
@ -22,6 +22,7 @@ class Snake(nn.Module):
>>> x = torch.randn(256) >>> x = torch.randn(256)
>>> x = a1(x) >>> x = a1(x)
''' '''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
''' '''
Initialization. Initialization.
@ -36,9 +37,9 @@ class Snake(nn.Module):
# initialize alpha # initialize alpha
self.alpha_logscale = alpha_logscale self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha) self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha) self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable self.alpha.requires_grad = alpha_trainable
@ -51,7 +52,7 @@ class Snake(nn.Module):
Applies the function to the input elementwise. Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa) Snake = x + 1/a * sin^2 (xa)
''' '''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale: if self.alpha_logscale:
alpha = torch.exp(alpha) alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
@ -76,6 +77,7 @@ class SnakeBeta(nn.Module):
>>> x = torch.randn(256) >>> x = torch.randn(256)
>>> x = a1(x) >>> x = a1(x)
''' '''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
''' '''
Initialization. Initialization.
@ -92,10 +94,10 @@ class SnakeBeta(nn.Module):
# initialize alpha # initialize alpha
self.alpha_logscale = alpha_logscale self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha) self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha) self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha) self.beta = Parameter(torch.ones(in_features) * alpha)
@ -110,11 +112,11 @@ class SnakeBeta(nn.Module):
Applies the function to the input elementwise. Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa) SnakeBeta = x + 1/b * sin^2 (xa)
''' '''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1) beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale: if self.alpha_logscale:
alpha = torch.exp(alpha) alpha = torch.exp(alpha)
beta = torch.exp(beta) beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x return x

View File

@ -3,10 +3,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
from alias_free_activation.cuda import load from alias_free_activation.cuda import load
from alias_free_activation.torch.resample import DownSample1d, UpSample1d
anti_alias_activation_cuda = load.load() anti_alias_activation_cuda = load.load()

View File

@ -1,6 +1,6 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
from .act import *
from .filter import * from .filter import *
from .resample import * from .resample import *
from .act import *

View File

@ -2,7 +2,8 @@
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
import torch.nn as nn import torch.nn as nn
from .resample import UpSample1d, DownSample1d
from .resample import DownSample1d, UpSample1d
class Activation1d(nn.Module): class Activation1d(nn.Module):

View File

@ -1,10 +1,11 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math
if "sinc" in dir(torch): if "sinc" in dir(torch):
sinc = torch.sinc sinc = torch.sinc

View File

@ -3,8 +3,8 @@
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d from .filter import LowPassFilter1d, kaiser_sinc_filter1d
class UpSample1d(nn.Module): class UpSample1d(nn.Module):

View File

@ -1,6 +1,6 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
from .act import *
from .filter import * from .filter import *
from .resample import * from .resample import *
from .act import *

View File

@ -2,7 +2,8 @@
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
import torch.nn as nn import torch.nn as nn
from .resample import UpSample1d, DownSample1d
from .resample import DownSample1d, UpSample1d
class Activation1d(nn.Module): class Activation1d(nn.Module):

View File

@ -1,10 +1,11 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math
if 'sinc' in dir(torch): if 'sinc' in dir(torch):
sinc = torch.sinc sinc = torch.sinc

View File

@ -3,8 +3,8 @@
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d from .filter import LowPassFilter1d, kaiser_sinc_filter1d
class UpSample1d(nn.Module): class UpSample1d(nn.Module):

View File

@ -4,23 +4,23 @@
# Adapted from https://github.com/jik876/hifi-gan under the MIT license. # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
import os
import json import json
import os
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Dict from typing import Dict, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm from torch.nn.utils import remove_weight_norm, weight_norm
import indextts.BigVGAN.activations as activations import indextts.BigVGAN.activations as activations
from indextts.BigVGAN.utils import init_weights, get_padding from indextts.BigVGAN.alias_free_activation.torch.act import \
from indextts.BigVGAN.alias_free_activation.torch.act import Activation1d as TorchActivation1d Activation1d as TorchActivation1d
from indextts.BigVGAN.env import AttrDict
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
from indextts.BigVGAN.env import AttrDict
from indextts.BigVGAN.utils import get_padding, init_weights
def load_hparams_from_json(path) -> AttrDict: def load_hparams_from_json(path) -> AttrDict:
@ -51,7 +51,7 @@ class AMPBlock1(torch.nn.Module):
activation: str = None, activation: str = None,
): ):
super().__init__() super().__init__()
self.h = h self.h = h
self.convs1 = nn.ModuleList( self.convs1 = nn.ModuleList(
@ -94,9 +94,8 @@ class AMPBlock1(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False): if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import ( from alias_free_activation.cuda.activation1d import \
Activation1d as CudaActivation1d, Activation1d as CudaActivation1d
)
Activation1d = CudaActivation1d Activation1d = CudaActivation1d
else: else:
@ -170,7 +169,7 @@ class AMPBlock2(torch.nn.Module):
activation: str = None, activation: str = None,
): ):
super().__init__() super().__init__()
self.h = h self.h = h
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
@ -194,9 +193,8 @@ class AMPBlock2(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False): if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import ( from alias_free_activation.cuda.activation1d import \
Activation1d as CudaActivation1d, Activation1d as CudaActivation1d
)
Activation1d = CudaActivation1d Activation1d = CudaActivation1d
else: else:
@ -241,6 +239,7 @@ class AMPBlock2(torch.nn.Module):
for l in self.convs: for l in self.convs:
remove_weight_norm(l) remove_weight_norm(l)
''' '''
PyTorchModelHubMixin, PyTorchModelHubMixin,
library_name="bigvgan", library_name="bigvgan",
@ -251,6 +250,7 @@ class AMPBlock2(torch.nn.Module):
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"], tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
''' '''
class BigVGAN( class BigVGAN(
torch.nn.Module, torch.nn.Module,
): ):
@ -274,9 +274,8 @@ class BigVGAN(
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False): if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import ( from alias_free_activation.cuda.activation1d import \
Activation1d as CudaActivation1d, Activation1d as CudaActivation1d
)
Activation1d = CudaActivation1d Activation1d = CudaActivation1d
else: else:

View File

@ -1,17 +1,16 @@
# Copyright (c) 2022 NVIDIA CORPORATION. # Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license. # Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license. # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
from torch.nn import Conv1d, ConvTranspose1d, Conv2d from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
import indextts.BigVGAN.activations as activations import indextts.BigVGAN.activations as activations
from indextts.BigVGAN.utils import init_weights, get_padding
from indextts.BigVGAN.alias_free_torch import * from indextts.BigVGAN.alias_free_torch import *
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
from indextts.BigVGAN.utils import get_padding, init_weights
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
@ -41,19 +40,19 @@ class AMPBlock1(torch.nn.Module):
]) ])
self.convs2.apply(init_weights) self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([ self.activations = nn.ModuleList([
Activation1d( Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers) for _ in range(self.num_layers)
]) ])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([ self.activations = nn.ModuleList([
Activation1d( Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers) for _ in range(self.num_layers)
]) ])
else: else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
@ -89,25 +88,25 @@ class AMPBlock2(torch.nn.Module):
]) ])
self.convs.apply(init_weights) self.convs.apply(init_weights)
self.num_layers = len(self.convs) # total number of conv layers self.num_layers = len(self.convs) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([ self.activations = nn.ModuleList([
Activation1d( Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers) for _ in range(self.num_layers)
]) ])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([ self.activations = nn.ModuleList([
Activation1d( Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers) for _ in range(self.num_layers)
]) ])
else: else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
def forward(self, x): def forward(self, x):
for c, a in zip (self.convs, self.activations): for c, a in zip(self.convs, self.activations):
xt = a(x) xt = a(x)
xt = c(xt) xt = c(xt)
x = xt + x x = xt + x
@ -127,7 +126,7 @@ class BigVGAN(torch.nn.Module):
self.num_kernels = len(h.resblock_kernel_sizes) self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates) self.num_upsamples = len(h.upsample_rates)
self.feat_upsample = h.feat_upsample self.feat_upsample = h.feat_upsample
self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer
@ -154,10 +153,10 @@ class BigVGAN(torch.nn.Module):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# post conv # post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale) activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post) self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post) self.activation_post = Activation1d(activation=activation_post)
else: else:
@ -180,7 +179,6 @@ class BigVGAN(torch.nn.Module):
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, x, mel_ref, lens=None): def forward(self, x, mel_ref, lens=None):
speaker_embedding = self.speaker_encoder(mel_ref, lens) speaker_embedding = self.speaker_encoder(mel_ref, lens)
n_batch = x.size(0) n_batch = x.size(0)
@ -190,7 +188,7 @@ class BigVGAN(torch.nn.Module):
contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp()) contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp())
speaker_embedding = speaker_embedding[:n_batch, :, :] speaker_embedding = speaker_embedding[:n_batch, :, :]
speaker_embedding = speaker_embedding.transpose(1,2) speaker_embedding = speaker_embedding.transpose(1, 2)
# upsample feat # upsample feat
if self.feat_upsample: if self.feat_upsample:
@ -265,20 +263,20 @@ class DiscriminatorP(torch.nn.Module):
self.d_mult = h.discriminator_channel_mult self.d_mult = h.discriminator_channel_mult
norm_f = weight_norm if use_spectral_norm == False else spectral_norm norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([ self.convs = nn.ModuleList([
norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))), norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
]) ])
self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0))) self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x): def forward(self, x):
fmap = [] fmap = []
# 1d to 2d # 1d to 2d
b, c, t = x.shape b, c, t = x.shape
if t % self.period != 0: # pad first if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period) n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect") x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad t = t + n_pad
@ -338,11 +336,11 @@ class DiscriminatorR(nn.Module):
self.d_mult = cfg.mrd_channel_mult self.d_mult = cfg.mrd_channel_mult
self.convs = nn.ModuleList([ self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))), norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))), norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))),
]) ])
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))) self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
@ -367,7 +365,7 @@ class DiscriminatorR(nn.Module):
x = x.squeeze(1) x = x.squeeze(1)
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True) x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
x = torch.view_as_real(x) # [B, F, TT, 2] x = torch.view_as_real(x) # [B, F, TT, 2]
mag = torch.norm(x, p=2, dim =-1) #[B, F, TT] mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
return mag return mag
@ -376,9 +374,9 @@ class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False): def __init__(self, cfg, debug=False):
super().__init__() super().__init__()
self.resolutions = cfg.resolutions self.resolutions = cfg.resolutions
assert len(self.resolutions) == 3,\ assert len(self.resolutions) == 3, \
"MRD requires list of list with len=3, each element having a list with len=3. got {}".\ "MRD requires list of list with len=3, each element having a list with len=3. got {}".\
format(self.resolutions) format(self.resolutions)
self.discriminators = nn.ModuleList( self.discriminators = nn.ModuleList(
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions] [DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
) )
@ -406,7 +404,7 @@ def feature_loss(fmap_r, fmap_g):
for rl, gl in zip(dr, dg): for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl)) loss += torch.mean(torch.abs(rl - gl))
return loss*2 return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs): def discriminator_loss(disc_real_outputs, disc_generated_outputs):
@ -414,7 +412,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
r_losses = [] r_losses = []
g_losses = [] g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1-dr)**2) r_loss = torch.mean((1 - dr)**2)
g_loss = torch.mean(dg**2) g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss) loss += (r_loss + g_loss)
r_losses.append(r_loss.item()) r_losses.append(r_loss.item())
@ -427,9 +425,8 @@ def generator_loss(disc_outputs):
loss = 0 loss = 0
gen_losses = [] gen_losses = []
for dg in disc_outputs: for dg in disc_outputs:
l = torch.mean((1-dg)**2) l = torch.mean((1 - dg)**2)
gen_losses.append(l) gen_losses.append(l)
loss += l loss += l
return loss, gen_losses return loss, gen_losses

View File

@ -19,6 +19,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
class SincConv(nn.Module): class SincConv(nn.Module):
"""This function implements SincConv (SincNet). """This function implements SincConv (SincNet).

View File

View File

@ -3,13 +3,14 @@
import glob import glob
import os import os
import matplotlib import matplotlib
import matplotlib.pylab as plt
import torch import torch
from scipy.io.wavfile import write
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm
matplotlib.use("Agg") matplotlib.use("Agg")
import matplotlib.pylab as plt
from scipy.io.wavfile import write
MAX_WAV_VALUE = 32768.0 MAX_WAV_VALUE = 32768.0

0
indextts/__init__.py Normal file
View File

View File

@ -21,6 +21,7 @@ from typing import Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
class PositionalEncoding(torch.nn.Module): class PositionalEncoding(torch.nn.Module):
"""Positional encoding. """Positional encoding.

View File

@ -3,11 +3,18 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from gpt.conformer.subsampling import Conv2dSubsampling4, Conv2dSubsampling6, \
Conv2dSubsampling8, LinearNoSubsampling, Conv2dSubsampling2 from indextts.gpt.conformer.attention import (MultiHeadedAttention,
from gpt.conformer.embedding import PositionalEncoding, RelPositionalEncoding, NoPositionalEncoding RelPositionMultiHeadedAttention)
from gpt.conformer.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention from indextts.gpt.conformer.embedding import (NoPositionalEncoding,
from utils.utils import make_pad_mask PositionalEncoding,
RelPositionalEncoding)
from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2,
Conv2dSubsampling4,
Conv2dSubsampling6,
Conv2dSubsampling8,
LinearNoSubsampling)
from indextts.utils.utils import make_pad_mask
class PositionwiseFeedForward(torch.nn.Module): class PositionwiseFeedForward(torch.nn.Module):
@ -22,6 +29,7 @@ class PositionwiseFeedForward(torch.nn.Module):
dropout_rate (float): Dropout rate. dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function activation (torch.nn.Module): Activation function
""" """
def __init__(self, def __init__(self,
idim: int, idim: int,
hidden_units: int, hidden_units: int,
@ -47,6 +55,7 @@ class PositionwiseFeedForward(torch.nn.Module):
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.""" """ConvolutionModule in Conformer model."""
def __init__(self, def __init__(self,
channels: int, channels: int,
kernel_size: int = 15, kernel_size: int = 15,
@ -181,6 +190,7 @@ class ConformerEncoderLayer(nn.Module):
True: x -> x + linear(concat(x, att(x))) True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x) False: x -> x + att(x)
""" """
def __init__( def __init__(
self, self,
size: int, size: int,
@ -428,6 +438,7 @@ class BaseEncoder(torch.nn.Module):
class ConformerEncoder(BaseEncoder): class ConformerEncoder(BaseEncoder):
"""Conformer encoder module.""" """Conformer encoder module."""
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
@ -507,4 +518,3 @@ class ConformerEncoder(BaseEncoder):
concat_after, concat_after,
) for _ in range(num_blocks) ) for _ in range(num_blocks)
]) ])

View File

@ -5,11 +5,13 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map from transformers.utils.model_parallel_utils import (assert_device_map,
from gpt.perceiver import PerceiverResampler get_device_map)
from gpt.conformer_encoder import ConformerEncoder
from indextts.gpt.conformer_encoder import ConformerEncoder
from indextts.gpt.perceiver import PerceiverResampler
from indextts.utils.arch_util import AttentionBlock from indextts.utils.arch_util import AttentionBlock
from utils.typical_sampling import TypicalLogitsWarper from indextts.utils.typical_sampling import TypicalLogitsWarper
def null_position_embeddings(range, dim): def null_position_embeddings(range, dim):
@ -20,14 +22,15 @@ class ResBlock(nn.Module):
""" """
Basic residual convolutional block that uses GroupNorm. Basic residual convolutional block that uses GroupNorm.
""" """
def __init__(self, chan): def __init__(self, chan):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=3, padding=1), nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan), nn.GroupNorm(chan // 8, chan),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=3, padding=1), nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan) nn.GroupNorm(chan // 8, chan)
) )
def forward(self, x): def forward(self, x):
@ -229,7 +232,7 @@ class ConditioningEncoder(nn.Module):
return h.mean(dim=2) return h.mean(dim=2)
else: else:
return h return h
#return h[:, :, 0] # return h[:, :, 0]
class LearnedPositionEmbeddings(nn.Module): class LearnedPositionEmbeddings(nn.Module):
@ -253,8 +256,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
""" """
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(vocab_size=256, # Unused. gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len+max_text_seq_len, n_positions=max_mel_seq_len + max_text_seq_len,
n_ctx=max_mel_seq_len+max_text_seq_len, n_ctx=max_mel_seq_len + max_text_seq_len,
n_embd=model_dim, n_embd=model_dim,
n_layer=layers, n_layer=layers,
n_head=heads, n_head=heads,
@ -266,7 +269,7 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused. # Built-in token embeddings are unused.
del gpt.wte del gpt.wte
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\ return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
None, None None, None
@ -274,14 +277,14 @@ class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1), self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1), nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//16, channels//2), nn.GroupNorm(channels // 16, channels // 2),
nn.ReLU(), nn.ReLU(),
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//8, channels), nn.GroupNorm(channels // 8, channels),
nn.ReLU(), nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
) )
@ -493,7 +496,7 @@ class UnifiedVoice(nn.Module):
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2), speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
cond_mel_lengths) # (b, s, d), (b, 1, s) cond_mel_lengths) # (b, s, d), (b, 1, s)
if self.condition_type == "conformer_perceiver": if self.condition_type == "conformer_perceiver":
#conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1) # conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
conds_mask = self.cond_mask_pad(mask.squeeze(1)) conds_mask = self.cond_mask_pad(mask.squeeze(1))
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d) conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
elif self.condition_type == "gst": elif self.condition_type == "gst":
@ -536,7 +539,7 @@ class UnifiedVoice(nn.Module):
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths) speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
# Types are expressed by expanding the text embedding space. # Types are expressed by expanding the text embedding space.
if types is not None: if types is not None:
text_inputs = text_inputs * (1+types).unsqueeze(-1) text_inputs = text_inputs * (1 + types).unsqueeze(-1)
if clip_inputs: if clip_inputs:
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
@ -546,10 +549,10 @@ class UnifiedVoice(nn.Module):
max_mel_len = wav_lengths.max() // self.mel_length_compression max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = mel_codes[:, :max_mel_len] mel_codes = mel_codes[:, :max_mel_len]
if raw_mels is not None: if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4] raw_mels = raw_mels[:, :, :max_mel_len * 4]
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>). # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
#mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc') # mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1 mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths) mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
@ -569,7 +572,7 @@ class UnifiedVoice(nn.Module):
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
if text_first: if text_first:
#print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}") # print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent: if return_latent:
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
@ -598,7 +601,7 @@ class UnifiedVoice(nn.Module):
self.inference_model.store_mel_emb(emb) self.inference_model.store_mel_emb(emb)
# +1 for the start_audio_token # +1 for the start_audio_token
fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long, fake_inputs = torch.full((emb.shape[0], emb.shape[1] + 1,), fill_value=1, dtype=torch.long,
device=text_inputs.device) device=text_inputs.device)
fake_inputs[:, -1] = self.start_mel_token fake_inputs[:, -1] = self.start_mel_token
@ -619,7 +622,3 @@ class UnifiedVoice(nn.Module):
max_length=max_length, logits_processor=logits_processor, max_length=max_length, logits_processor=logits_processor,
num_return_sequences=num_return_sequences, **hf_generate_kwargs) num_return_sequences=num_return_sequences, **hf_generate_kwargs)
return gen[:, trunc_index:] return gen[:, trunc_index:]

View File

@ -1,16 +1,18 @@
import os import os
import re import re
import sys import sys
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from omegaconf import OmegaConf from omegaconf import OmegaConf
import sentencepiece as spm
from utils.utils import tokenize_by_CJK_char
from utils.feature_extractors import MelSpectrogramFeatures
from indextts.vqvae.xtts_dvae import DiscreteVAE
from indextts.utils.checkpoint import load_checkpoint
from indextts.gpt.model import UnifiedVoice
from indextts.BigVGAN.models import BigVGAN as Generator from indextts.BigVGAN.models import BigVGAN as Generator
from indextts.gpt.model import UnifiedVoice
from indextts.utils.checkpoint import load_checkpoint
from indextts.utils.feature_extractors import MelSpectrogramFeatures
from indextts.utils.utils import tokenize_by_CJK_char
from indextts.vqvae.xtts_dvae import DiscreteVAE
class IndexTTS: class IndexTTS:
@ -107,18 +109,18 @@ class IndexTTS:
print(text_len) print(text_len)
with torch.no_grad(): with torch.no_grad():
codes = self.gpt.inference_speech(auto_conditioning, text_tokens, codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
device=text_tokens.device), device=text_tokens.device),
# text_lengths=text_len, # text_lengths=text_len,
do_sample=True, do_sample=True,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
temperature=temperature, temperature=temperature,
num_return_sequences=autoregressive_batch_size, num_return_sequences=autoregressive_batch_size,
length_penalty=length_penalty, length_penalty=length_penalty,
num_beams=num_beams, num_beams=num_beams,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens) max_generate_length=max_mel_tokens)
print(codes) print(codes)
print(f"codes shape: {codes.shape}") print(f"codes shape: {codes.shape}")
codes = codes[:, :-2] codes = codes[:, :-2]
@ -126,10 +128,10 @@ class IndexTTS:
# latent, text_lens_out, code_lens_out = \ # latent, text_lens_out, code_lens_out = \
latent = \ latent = \
self.gpt(auto_conditioning, text_tokens, self.gpt(auto_conditioning, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device), torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False) return_latent=True, clip_inputs=False)
latent = latent.transpose(1, 2) latent = latent.transpose(1, 2)
''' '''
latent_list = [] latent_list = []
@ -155,4 +157,4 @@ class IndexTTS:
if __name__ == "__main__": if __name__ == "__main__":
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints") tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
tts.infer(audio_prompt='test_data/input.wav', text='大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了',output_path="gen.wav") tts.infer(audio_prompt='test_data/input.wav', text='大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了', output_path="gen.wav")

View File

View File

@ -1,6 +1,8 @@
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import math
from indextts.utils.xtransformers import RelativePositionBias from indextts.utils.xtransformers import RelativePositionBias

View File

@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime
import logging import logging
import os import os
import re import re
import yaml
import torch
from collections import OrderedDict from collections import OrderedDict
import datetime import torch
import yaml
def load_checkpoint(model: torch.nn.Module, model_pth: str) -> dict: def load_checkpoint(model: torch.nn.Module, model_pth: str) -> dict:

View File

@ -1,7 +1,8 @@
import torch import torch
import torchaudio import torchaudio
from torch import nn from torch import nn
from utils import safe_log
from indextts.utils.utils import safe_log
class FeatureExtractor(nn.Module): class FeatureExtractor(nn.Module):
@ -23,7 +24,7 @@ class FeatureExtractor(nn.Module):
class MelSpectrogramFeatures(FeatureExtractor): class MelSpectrogramFeatures(FeatureExtractor):
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, win_length=None, def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, win_length=None,
n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"): n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"):
super().__init__() super().__init__()
if padding not in ["center", "same"]: if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.") raise ValueError("Padding must be 'center' or 'same'.")
@ -47,4 +48,4 @@ class MelSpectrogramFeatures(FeatureExtractor):
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
mel = self.mel_spec(audio) mel = self.mel_spec(audio)
mel = safe_log(mel) mel = safe_log(mel)
return mel return mel

View File

@ -1,6 +1,7 @@
import os import os
import re
import random import random
import re
import torch import torch
import torchaudio import torchaudio

View File

@ -6,7 +6,7 @@ from inspect import isfunction
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import nn, einsum from torch import einsum, nn
DEFAULT_DIM_HEAD = 64 DEFAULT_DIM_HEAD = 64

View File

@ -14,8 +14,8 @@ matplotlib==3.8.2
opencv-python==4.9.0.80 opencv-python==4.9.0.80
vocos==0.1.0 vocos==0.1.0
accelerate==0.25.0 accelerate==0.25.0
omegaconf==2.0.6
tensorboard==2.9.1 tensorboard==2.9.1
omegaconf
sentencepiece sentencepiece
pypinyin pypinyin
librosa librosa

View File

@ -1,16 +1,18 @@
import os import os
import shutil import shutil
import sys
import threading import threading
import time import time
import sys
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir) sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, "indextts")) sys.path.append(os.path.join(current_dir, "indextts"))
import gradio as gr import gradio as gr
from utils.webui_utils import next_page, prev_page
from indextts.infer import IndexTTS from indextts.infer import IndexTTS
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
from utils.webui_utils import next_page, prev_page
i18n = I18nAuto(language="zh_CN") i18n = I18nAuto(language="zh_CN")
MODE = 'local' MODE = 'local'