diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f9f2d87 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +venv/ +__pycache__ +*.DS_Store +.idea/ \ No newline at end of file diff --git a/indextts/BigVGAN/ECAPA_TDNN.py b/indextts/BigVGAN/ECAPA_TDNN.py index caf9e48..beea8ca 100644 --- a/indextts/BigVGAN/ECAPA_TDNN.py +++ b/indextts/BigVGAN/ECAPA_TDNN.py @@ -12,6 +12,7 @@ from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d from indextts.BigVGAN.nnet.linear import Linear from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d + def length_to_mask(length, max_len=None, dtype=None, device=None): """Creates a binary mask for each sequence. diff --git a/indextts/BigVGAN/__init__.py b/indextts/BigVGAN/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/indextts/BigVGAN/activations.py b/indextts/BigVGAN/activations.py index 61f2808..1962c53 100644 --- a/indextts/BigVGAN/activations.py +++ b/indextts/BigVGAN/activations.py @@ -2,7 +2,7 @@ # LICENSE is in incl_licenses directory. import torch -from torch import nn, sin, pow +from torch import nn, pow, sin from torch.nn import Parameter @@ -22,6 +22,7 @@ class Snake(nn.Module): >>> x = torch.randn(256) >>> x = a1(x) ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. @@ -36,9 +37,9 @@ class Snake(nn.Module): # initialize alpha self.alpha_logscale = alpha_logscale - if self.alpha_logscale: # log scale alphas initialized to zeros + if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) - else: # linear scale alphas initialized to ones + else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable @@ -51,7 +52,7 @@ class Snake(nn.Module): Applies the function to the input elementwise. Snake ∶= x + 1/a * sin^2 (xa) ''' - alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] if self.alpha_logscale: alpha = torch.exp(alpha) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) @@ -76,6 +77,7 @@ class SnakeBeta(nn.Module): >>> x = torch.randn(256) >>> x = a1(x) ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. @@ -92,10 +94,10 @@ class SnakeBeta(nn.Module): # initialize alpha self.alpha_logscale = alpha_logscale - if self.alpha_logscale: # log scale alphas initialized to zeros + if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha) - else: # linear scale alphas initialized to ones + else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.beta = Parameter(torch.ones(in_features) * alpha) @@ -110,11 +112,11 @@ class SnakeBeta(nn.Module): Applies the function to the input elementwise. SnakeBeta ∶= x + 1/b * sin^2 (xa) ''' - alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) - return x \ No newline at end of file + return x diff --git a/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py b/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py index fbc0fd8..e0c4ff7 100644 --- a/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +++ b/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py @@ -3,10 +3,9 @@ import torch import torch.nn as nn -from alias_free_activation.torch.resample import UpSample1d, DownSample1d - # load fused CUDA kernel: this enables importing anti_alias_activation_cuda from alias_free_activation.cuda import load +from alias_free_activation.torch.resample import DownSample1d, UpSample1d anti_alias_activation_cuda = load.load() diff --git a/indextts/BigVGAN/alias_free_activation/torch/__init__.py b/indextts/BigVGAN/alias_free_activation/torch/__init__.py index 8f756ed..117e5ac 100644 --- a/indextts/BigVGAN/alias_free_activation/torch/__init__.py +++ b/indextts/BigVGAN/alias_free_activation/torch/__init__.py @@ -1,6 +1,6 @@ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. +from .act import * from .filter import * from .resample import * -from .act import * diff --git a/indextts/BigVGAN/alias_free_activation/torch/act.py b/indextts/BigVGAN/alias_free_activation/torch/act.py index a6693aa..d46808d 100644 --- a/indextts/BigVGAN/alias_free_activation/torch/act.py +++ b/indextts/BigVGAN/alias_free_activation/torch/act.py @@ -2,7 +2,8 @@ # LICENSE is in incl_licenses directory. import torch.nn as nn -from .resample import UpSample1d, DownSample1d + +from .resample import DownSample1d, UpSample1d class Activation1d(nn.Module): diff --git a/indextts/BigVGAN/alias_free_activation/torch/filter.py b/indextts/BigVGAN/alias_free_activation/torch/filter.py index 0fa35b0..172cfc6 100644 --- a/indextts/BigVGAN/alias_free_activation/torch/filter.py +++ b/indextts/BigVGAN/alias_free_activation/torch/filter.py @@ -1,10 +1,11 @@ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. +import math + import torch import torch.nn as nn import torch.nn.functional as F -import math if "sinc" in dir(torch): sinc = torch.sinc diff --git a/indextts/BigVGAN/alias_free_activation/torch/resample.py b/indextts/BigVGAN/alias_free_activation/torch/resample.py index a35380f..46c4770 100644 --- a/indextts/BigVGAN/alias_free_activation/torch/resample.py +++ b/indextts/BigVGAN/alias_free_activation/torch/resample.py @@ -3,8 +3,8 @@ import torch.nn as nn from torch.nn import functional as F -from .filter import LowPassFilter1d -from .filter import kaiser_sinc_filter1d + +from .filter import LowPassFilter1d, kaiser_sinc_filter1d class UpSample1d(nn.Module): diff --git a/indextts/BigVGAN/alias_free_torch/__init__.py b/indextts/BigVGAN/alias_free_torch/__init__.py index a2318b6..117e5ac 100644 --- a/indextts/BigVGAN/alias_free_torch/__init__.py +++ b/indextts/BigVGAN/alias_free_torch/__init__.py @@ -1,6 +1,6 @@ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. +from .act import * from .filter import * from .resample import * -from .act import * \ No newline at end of file diff --git a/indextts/BigVGAN/alias_free_torch/act.py b/indextts/BigVGAN/alias_free_torch/act.py index 028debd..e6798bf 100644 --- a/indextts/BigVGAN/alias_free_torch/act.py +++ b/indextts/BigVGAN/alias_free_torch/act.py @@ -2,7 +2,8 @@ # LICENSE is in incl_licenses directory. import torch.nn as nn -from .resample import UpSample1d, DownSample1d + +from .resample import DownSample1d, UpSample1d class Activation1d(nn.Module): diff --git a/indextts/BigVGAN/alias_free_torch/filter.py b/indextts/BigVGAN/alias_free_torch/filter.py index 7ad6ea8..2a90bda 100644 --- a/indextts/BigVGAN/alias_free_torch/filter.py +++ b/indextts/BigVGAN/alias_free_torch/filter.py @@ -1,10 +1,11 @@ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. +import math + import torch import torch.nn as nn import torch.nn.functional as F -import math if 'sinc' in dir(torch): sinc = torch.sinc diff --git a/indextts/BigVGAN/alias_free_torch/resample.py b/indextts/BigVGAN/alias_free_torch/resample.py index 750e6c3..1cf4d54 100644 --- a/indextts/BigVGAN/alias_free_torch/resample.py +++ b/indextts/BigVGAN/alias_free_torch/resample.py @@ -3,8 +3,8 @@ import torch.nn as nn from torch.nn import functional as F -from .filter import LowPassFilter1d -from .filter import kaiser_sinc_filter1d + +from .filter import LowPassFilter1d, kaiser_sinc_filter1d class UpSample1d(nn.Module): diff --git a/indextts/BigVGAN/bigvgan.py b/indextts/BigVGAN/bigvgan.py index 7c9a6f5..58e4ba1 100644 --- a/indextts/BigVGAN/bigvgan.py +++ b/indextts/BigVGAN/bigvgan.py @@ -4,23 +4,23 @@ # Adapted from https://github.com/jik876/hifi-gan under the MIT license. # LICENSE is in incl_licenses directory. -import os import json +import os from pathlib import Path -from typing import Optional, Union, Dict +from typing import Dict, Optional, Union import torch import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download from torch.nn import Conv1d, ConvTranspose1d -from torch.nn.utils import weight_norm, remove_weight_norm +from torch.nn.utils import remove_weight_norm, weight_norm import indextts.BigVGAN.activations as activations -from indextts.BigVGAN.utils import init_weights, get_padding -from indextts.BigVGAN.alias_free_activation.torch.act import Activation1d as TorchActivation1d -from indextts.BigVGAN.env import AttrDict - -from huggingface_hub import PyTorchModelHubMixin, hf_hub_download +from indextts.BigVGAN.alias_free_activation.torch.act import \ + Activation1d as TorchActivation1d from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN +from indextts.BigVGAN.env import AttrDict +from indextts.BigVGAN.utils import get_padding, init_weights def load_hparams_from_json(path) -> AttrDict: @@ -51,7 +51,7 @@ class AMPBlock1(torch.nn.Module): activation: str = None, ): super().__init__() - + self.h = h self.convs1 = nn.ModuleList( @@ -94,9 +94,8 @@ class AMPBlock1(torch.nn.Module): # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( - Activation1d as CudaActivation1d, - ) + from alias_free_activation.cuda.activation1d import \ + Activation1d as CudaActivation1d Activation1d = CudaActivation1d else: @@ -170,7 +169,7 @@ class AMPBlock2(torch.nn.Module): activation: str = None, ): super().__init__() - + self.h = h self.convs = nn.ModuleList( @@ -194,9 +193,8 @@ class AMPBlock2(torch.nn.Module): # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( - Activation1d as CudaActivation1d, - ) + from alias_free_activation.cuda.activation1d import \ + Activation1d as CudaActivation1d Activation1d = CudaActivation1d else: @@ -241,6 +239,7 @@ class AMPBlock2(torch.nn.Module): for l in self.convs: remove_weight_norm(l) + ''' PyTorchModelHubMixin, library_name="bigvgan", @@ -251,6 +250,7 @@ class AMPBlock2(torch.nn.Module): tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"], ''' + class BigVGAN( torch.nn.Module, ): @@ -274,9 +274,8 @@ class BigVGAN( # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( - Activation1d as CudaActivation1d, - ) + from alias_free_activation.cuda.activation1d import \ + Activation1d as CudaActivation1d Activation1d = CudaActivation1d else: diff --git a/indextts/BigVGAN/models.py b/indextts/BigVGAN/models.py index b822027..602f103 100644 --- a/indextts/BigVGAN/models.py +++ b/indextts/BigVGAN/models.py @@ -1,17 +1,16 @@ -# Copyright (c) 2022 NVIDIA CORPORATION. +# Copyright (c) 2022 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/jik876/hifi-gan under the MIT license. # LICENSE is in incl_licenses directory. -from torch.nn import Conv1d, ConvTranspose1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from torch.nn import Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm import indextts.BigVGAN.activations as activations -from indextts.BigVGAN.utils import init_weights, get_padding from indextts.BigVGAN.alias_free_torch import * - from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN +from indextts.BigVGAN.utils import get_padding, init_weights LRELU_SLOPE = 0.1 @@ -41,19 +40,19 @@ class AMPBlock1(torch.nn.Module): ]) self.convs2.apply(init_weights) - self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers - if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) for _ in range(self.num_layers) ]) - elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) - for _ in range(self.num_layers) + for _ in range(self.num_layers) ]) else: raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") @@ -89,25 +88,25 @@ class AMPBlock2(torch.nn.Module): ]) self.convs.apply(init_weights) - self.num_layers = len(self.convs) # total number of conv layers + self.num_layers = len(self.convs) # total number of conv layers - if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) for _ in range(self.num_layers) ]) - elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) - for _ in range(self.num_layers) + for _ in range(self.num_layers) ]) else: raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") def forward(self, x): - for c, a in zip (self.convs, self.activations): + for c, a in zip(self.convs, self.activations): xt = a(x) xt = c(xt) x = xt + x @@ -127,7 +126,7 @@ class BigVGAN(torch.nn.Module): self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) - + self.feat_upsample = h.feat_upsample self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer @@ -154,10 +153,10 @@ class BigVGAN(torch.nn.Module): self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) # post conv - if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale) self.activation_post = Activation1d(activation=activation_post) - elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing + elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) self.activation_post = Activation1d(activation=activation_post) else: @@ -180,7 +179,6 @@ class BigVGAN(torch.nn.Module): # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - def forward(self, x, mel_ref, lens=None): speaker_embedding = self.speaker_encoder(mel_ref, lens) n_batch = x.size(0) @@ -190,7 +188,7 @@ class BigVGAN(torch.nn.Module): contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp()) speaker_embedding = speaker_embedding[:n_batch, :, :] - speaker_embedding = speaker_embedding.transpose(1,2) + speaker_embedding = speaker_embedding.transpose(1, 2) # upsample feat if self.feat_upsample: @@ -265,20 +263,20 @@ class DiscriminatorP(torch.nn.Module): self.d_mult = h.discriminator_channel_mult norm_f = weight_norm if use_spectral_norm == False else spectral_norm self.convs = nn.ModuleList([ - norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))), + norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))), ]) - self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0))) + self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): fmap = [] # 1d to 2d b, c, t = x.shape - if t % self.period != 0: # pad first + if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad @@ -338,11 +336,11 @@ class DiscriminatorR(nn.Module): self.d_mult = cfg.mrd_channel_mult self.convs = nn.ModuleList([ - norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))), - norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), - norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), - norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), - norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))), + norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))), ]) self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))) @@ -367,7 +365,7 @@ class DiscriminatorR(nn.Module): x = x.squeeze(1) x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True) x = torch.view_as_real(x) # [B, F, TT, 2] - mag = torch.norm(x, p=2, dim =-1) #[B, F, TT] + mag = torch.norm(x, p=2, dim=-1) # [B, F, TT] return mag @@ -376,9 +374,9 @@ class MultiResolutionDiscriminator(nn.Module): def __init__(self, cfg, debug=False): super().__init__() self.resolutions = cfg.resolutions - assert len(self.resolutions) == 3,\ + assert len(self.resolutions) == 3, \ "MRD requires list of list with len=3, each element having a list with len=3. got {}".\ - format(self.resolutions) + format(self.resolutions) self.discriminators = nn.ModuleList( [DiscriminatorR(cfg, resolution) for resolution in self.resolutions] ) @@ -406,7 +404,7 @@ def feature_loss(fmap_r, fmap_g): for rl, gl in zip(dr, dg): loss += torch.mean(torch.abs(rl - gl)) - return loss*2 + return loss * 2 def discriminator_loss(disc_real_outputs, disc_generated_outputs): @@ -414,7 +412,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs): r_losses = [] g_losses = [] for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - r_loss = torch.mean((1-dr)**2) + r_loss = torch.mean((1 - dr)**2) g_loss = torch.mean(dg**2) loss += (r_loss + g_loss) r_losses.append(r_loss.item()) @@ -427,9 +425,8 @@ def generator_loss(disc_outputs): loss = 0 gen_losses = [] for dg in disc_outputs: - l = torch.mean((1-dg)**2) + l = torch.mean((1 - dg)**2) gen_losses.append(l) loss += l return loss, gen_losses - diff --git a/indextts/BigVGAN/nnet/CNN.py b/indextts/BigVGAN/nnet/CNN.py index 454568c..fa79dc1 100644 --- a/indextts/BigVGAN/nnet/CNN.py +++ b/indextts/BigVGAN/nnet/CNN.py @@ -19,6 +19,7 @@ import torch.nn as nn import torch.nn.functional as F import torchaudio + class SincConv(nn.Module): """This function implements SincConv (SincNet). diff --git a/indextts/BigVGAN/nnet/__init__.py b/indextts/BigVGAN/nnet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/indextts/BigVGAN/utils.py b/indextts/BigVGAN/utils.py index 0d1aa97..e968fd4 100644 --- a/indextts/BigVGAN/utils.py +++ b/indextts/BigVGAN/utils.py @@ -3,13 +3,14 @@ import glob import os + import matplotlib +import matplotlib.pylab as plt import torch +from scipy.io.wavfile import write from torch.nn.utils import weight_norm matplotlib.use("Agg") -import matplotlib.pylab as plt -from scipy.io.wavfile import write MAX_WAV_VALUE = 32768.0 diff --git a/indextts/__init__.py b/indextts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/indextts/gpt/conformer/embedding.py b/indextts/gpt/conformer/embedding.py index 5311d43..97565e9 100644 --- a/indextts/gpt/conformer/embedding.py +++ b/indextts/gpt/conformer/embedding.py @@ -21,6 +21,7 @@ from typing import Tuple, Union import torch import torch.nn.functional as F + class PositionalEncoding(torch.nn.Module): """Positional encoding. diff --git a/indextts/gpt/conformer_encoder.py b/indextts/gpt/conformer_encoder.py index 05f896f..a6967e6 100644 --- a/indextts/gpt/conformer_encoder.py +++ b/indextts/gpt/conformer_encoder.py @@ -3,11 +3,18 @@ from typing import Optional, Tuple import torch import torch.nn as nn -from gpt.conformer.subsampling import Conv2dSubsampling4, Conv2dSubsampling6, \ - Conv2dSubsampling8, LinearNoSubsampling, Conv2dSubsampling2 -from gpt.conformer.embedding import PositionalEncoding, RelPositionalEncoding, NoPositionalEncoding -from gpt.conformer.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention -from utils.common import make_pad_mask + +from indextts.gpt.conformer.attention import (MultiHeadedAttention, + RelPositionMultiHeadedAttention) +from indextts.gpt.conformer.embedding import (NoPositionalEncoding, + PositionalEncoding, + RelPositionalEncoding) +from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2, + Conv2dSubsampling4, + Conv2dSubsampling6, + Conv2dSubsampling8, + LinearNoSubsampling) +from indextts.utils.common import make_pad_mask class PositionwiseFeedForward(torch.nn.Module): @@ -22,6 +29,7 @@ class PositionwiseFeedForward(torch.nn.Module): dropout_rate (float): Dropout rate. activation (torch.nn.Module): Activation function """ + def __init__(self, idim: int, hidden_units: int, @@ -47,6 +55,7 @@ class PositionwiseFeedForward(torch.nn.Module): class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model.""" + def __init__(self, channels: int, kernel_size: int = 15, @@ -181,6 +190,7 @@ class ConformerEncoderLayer(nn.Module): True: x -> x + linear(concat(x, att(x))) False: x -> x + att(x) """ + def __init__( self, size: int, @@ -428,6 +438,7 @@ class BaseEncoder(torch.nn.Module): class ConformerEncoder(BaseEncoder): """Conformer encoder module.""" + def __init__( self, input_size: int, @@ -507,4 +518,3 @@ class ConformerEncoder(BaseEncoder): concat_after, ) for _ in range(num_blocks) ]) - diff --git a/indextts/gpt/model.py b/indextts/gpt/model.py index 007c07e..becf037 100644 --- a/indextts/gpt/model.py +++ b/indextts/gpt/model.py @@ -5,11 +5,13 @@ import torch.nn as nn import torch.nn.functional as F from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from transformers.utils.model_parallel_utils import get_device_map, assert_device_map -from gpt.perceiver import PerceiverResampler -from gpt.conformer_encoder import ConformerEncoder +from transformers.utils.model_parallel_utils import (assert_device_map, + get_device_map) + +from indextts.gpt.conformer_encoder import ConformerEncoder +from indextts.gpt.perceiver import PerceiverResampler from indextts.utils.arch_util import AttentionBlock -from utils.typical_sampling import TypicalLogitsWarper +from indextts.utils.typical_sampling import TypicalLogitsWarper def null_position_embeddings(range, dim): @@ -20,14 +22,15 @@ class ResBlock(nn.Module): """ Basic residual convolutional block that uses GroupNorm. """ + def __init__(self, chan): super().__init__() self.net = nn.Sequential( nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.GroupNorm(chan//8, chan), + nn.GroupNorm(chan // 8, chan), nn.ReLU(), nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.GroupNorm(chan//8, chan) + nn.GroupNorm(chan // 8, chan) ) def forward(self, x): @@ -229,7 +232,7 @@ class ConditioningEncoder(nn.Module): return h.mean(dim=2) else: return h - #return h[:, :, 0] + # return h[:, :, 0] class LearnedPositionEmbeddings(nn.Module): @@ -253,8 +256,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text """ from transformers import GPT2Config, GPT2Model gpt_config = GPT2Config(vocab_size=256, # Unused. - n_positions=max_mel_seq_len+max_text_seq_len, - n_ctx=max_mel_seq_len+max_text_seq_len, + n_positions=max_mel_seq_len + max_text_seq_len, + n_ctx=max_mel_seq_len + max_text_seq_len, n_embd=model_dim, n_layer=layers, n_head=heads, @@ -266,7 +269,7 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) # Built-in token embeddings are unused. del gpt.wte - return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\ + return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \ None, None @@ -274,14 +277,14 @@ class MelEncoder(nn.Module): def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): super().__init__() self.channels = channels - self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1), - nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(channels//16, channels//2), + self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), + nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 16, channels // 2), nn.ReLU(), - nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(channels//8, channels), + nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 8, channels), nn.ReLU(), nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), ) @@ -493,7 +496,7 @@ class UnifiedVoice(nn.Module): speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2), cond_mel_lengths) # (b, s, d), (b, 1, s) if self.condition_type == "conformer_perceiver": - #conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1) + # conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1) conds_mask = self.cond_mask_pad(mask.squeeze(1)) conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d) elif self.condition_type == "gst": @@ -536,7 +539,7 @@ class UnifiedVoice(nn.Module): speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths) # Types are expressed by expanding the text embedding space. if types is not None: - text_inputs = text_inputs * (1+types).unsqueeze(-1) + text_inputs = text_inputs * (1 + types).unsqueeze(-1) if clip_inputs: # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by @@ -546,10 +549,10 @@ class UnifiedVoice(nn.Module): max_mel_len = wav_lengths.max() // self.mel_length_compression mel_codes = mel_codes[:, :max_mel_len] if raw_mels is not None: - raw_mels = raw_mels[:, :, :max_mel_len*4] + raw_mels = raw_mels[:, :, :max_mel_len * 4] # Set padding areas within MEL (currently it is coded with the MEL code for ). - #mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc') + # mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc') mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1 mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths) @@ -569,7 +572,7 @@ class UnifiedVoice(nn.Module): mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) if text_first: - #print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}") + # print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}") text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. @@ -598,7 +601,7 @@ class UnifiedVoice(nn.Module): self.inference_model.store_mel_emb(emb) # +1 for the start_audio_token - fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long, + fake_inputs = torch.full((emb.shape[0], emb.shape[1] + 1,), fill_value=1, dtype=torch.long, device=text_inputs.device) fake_inputs[:, -1] = self.start_mel_token @@ -619,7 +622,3 @@ class UnifiedVoice(nn.Module): max_length=max_length, logits_processor=logits_processor, num_return_sequences=num_return_sequences, **hf_generate_kwargs) return gen[:, trunc_index:] - - - - diff --git a/indextts/infer.py b/indextts/infer.py index 3a29b0c..eed25d4 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -1,16 +1,18 @@ import os import re import sys + +import sentencepiece as spm import torch import torchaudio from omegaconf import OmegaConf -import sentencepiece as spm -from utils.common import tokenize_by_CJK_char -from utils.feature_extractors import MelSpectrogramFeatures -from indextts.vqvae.xtts_dvae import DiscreteVAE -from indextts.utils.checkpoint import load_checkpoint -from indextts.gpt.model import UnifiedVoice + from indextts.BigVGAN.models import BigVGAN as Generator +from indextts.gpt.model import UnifiedVoice +from indextts.utils.checkpoint import load_checkpoint +from indextts.utils.feature_extractors import MelSpectrogramFeatures +from indextts.utils.common import tokenize_by_CJK_char +from indextts.vqvae.xtts_dvae import DiscreteVAE class IndexTTS: @@ -107,18 +109,18 @@ class IndexTTS: print(text_len) with torch.no_grad(): codes = self.gpt.inference_speech(auto_conditioning, text_tokens, - cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], - device=text_tokens.device), - # text_lengths=text_len, - do_sample=True, - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_return_sequences=autoregressive_batch_size, - length_penalty=length_penalty, - num_beams=num_beams, - repetition_penalty=repetition_penalty, - max_generate_length=max_mel_tokens) + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], + device=text_tokens.device), + # text_lengths=text_len, + do_sample=True, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=autoregressive_batch_size, + length_penalty=length_penalty, + num_beams=num_beams, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens) print(codes) print(f"codes shape: {codes.shape}") codes = codes[:, :-2] @@ -126,10 +128,10 @@ class IndexTTS: # latent, text_lens_out, code_lens_out = \ latent = \ self.gpt(auto_conditioning, text_tokens, - torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, - torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device), - cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), - return_latent=True, clip_inputs=False) + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, + torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device), + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), + return_latent=True, clip_inputs=False) latent = latent.transpose(1, 2) ''' latent_list = [] @@ -155,4 +157,4 @@ class IndexTTS: if __name__ == "__main__": tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints") - tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!',output_path="gen.wav") + tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!', output_path="gen.wav") diff --git a/indextts/utils/__init__.py b/indextts/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/indextts/utils/arch_util.py b/indextts/utils/arch_util.py index c8ca1cb..01eac10 100644 --- a/indextts/utils/arch_util.py +++ b/indextts/utils/arch_util.py @@ -1,6 +1,8 @@ +import math + import torch import torch.nn as nn -import math + from indextts.utils.xtransformers import RelativePositionBias diff --git a/indextts/utils/checkpoint.py b/indextts/utils/checkpoint.py index 35a91e6..b8e34d4 100644 --- a/indextts/utils/checkpoint.py +++ b/indextts/utils/checkpoint.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import logging import os import re - -import yaml -import torch from collections import OrderedDict -import datetime +import torch +import yaml def load_checkpoint(model: torch.nn.Module, model_pth: str) -> dict: diff --git a/indextts/utils/common.py b/indextts/utils/common.py index 82cf878..b22d064 100644 --- a/indextts/utils/common.py +++ b/indextts/utils/common.py @@ -1,6 +1,7 @@ import os -import re import random +import re + import torch import torchaudio diff --git a/indextts/utils/feature_extractors.py b/indextts/utils/feature_extractors.py index 112af0d..c3af3e0 100644 --- a/indextts/utils/feature_extractors.py +++ b/indextts/utils/feature_extractors.py @@ -1,7 +1,7 @@ import torch import torchaudio from torch import nn -from utils.common import safe_log +from indextts.utils.common import safe_log class FeatureExtractor(nn.Module): @@ -23,7 +23,7 @@ class FeatureExtractor(nn.Module): class MelSpectrogramFeatures(FeatureExtractor): def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, win_length=None, - n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"): + n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"): super().__init__() if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") @@ -47,4 +47,4 @@ class MelSpectrogramFeatures(FeatureExtractor): audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") mel = self.mel_spec(audio) mel = safe_log(mel) - return mel \ No newline at end of file + return mel diff --git a/indextts/utils/xtransformers.py b/indextts/utils/xtransformers.py index 704298b..5470476 100644 --- a/indextts/utils/xtransformers.py +++ b/indextts/utils/xtransformers.py @@ -6,7 +6,7 @@ from inspect import isfunction import torch import torch.nn.functional as F from einops import rearrange, repeat -from torch import nn, einsum +from torch import einsum, nn DEFAULT_DIM_HEAD = 64 diff --git a/requirements.txt b/requirements.txt index a91570f..846e188 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,8 +14,8 @@ matplotlib==3.8.2 opencv-python==4.9.0.80 vocos==0.1.0 accelerate==0.25.0 -omegaconf tensorboard==2.9.1 +omegaconf sentencepiece pypinyin librosa diff --git a/webui.py b/webui.py index 4584a75..1d3c7dd 100644 --- a/webui.py +++ b/webui.py @@ -1,13 +1,16 @@ import os import shutil +import sys import threading import time -import sys + current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_dir) sys.path.append(os.path.join(current_dir, "indextts")) import gradio as gr +from utils.webui_utils import next_page, prev_page + from indextts.infer import IndexTTS from tools.i18n.i18n import I18nAuto