fix packages
This commit is contained in:
parent
4c086f954b
commit
2fe6a73ada
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
venv/
|
||||
__pycache__
|
||||
@ -12,6 +12,7 @@ from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d
|
||||
from indextts.BigVGAN.nnet.linear import Linear
|
||||
from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d
|
||||
|
||||
|
||||
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
||||
"""Creates a binary mask for each sequence.
|
||||
|
||||
|
||||
0
indextts/BigVGAN/__init__.py
Normal file
0
indextts/BigVGAN/__init__.py
Normal file
@ -2,7 +2,7 @@
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch import nn, pow, sin
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ class Snake(nn.Module):
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
@ -36,9 +37,9 @@ class Snake(nn.Module):
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
@ -51,7 +52,7 @@ class Snake(nn.Module):
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
'''
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
@ -76,6 +77,7 @@ class SnakeBeta(nn.Module):
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
@ -92,10 +94,10 @@ class SnakeBeta(nn.Module):
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
@ -110,11 +112,11 @@ class SnakeBeta(nn.Module):
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
'''
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
return x
|
||||
|
||||
@ -3,10 +3,9 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
||||
|
||||
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
||||
from alias_free_activation.cuda import load
|
||||
from alias_free_activation.torch.resample import DownSample1d, UpSample1d
|
||||
|
||||
anti_alias_activation_cuda = load.load()
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from .act import *
|
||||
from .filter import *
|
||||
from .resample import *
|
||||
from .act import *
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch.nn as nn
|
||||
from .resample import UpSample1d, DownSample1d
|
||||
|
||||
from .resample import DownSample1d, UpSample1d
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
if "sinc" in dir(torch):
|
||||
sinc = torch.sinc
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from .filter import LowPassFilter1d
|
||||
from .filter import kaiser_sinc_filter1d
|
||||
|
||||
from .filter import LowPassFilter1d, kaiser_sinc_filter1d
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from .act import *
|
||||
from .filter import *
|
||||
from .resample import *
|
||||
from .act import *
|
||||
@ -2,7 +2,8 @@
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch.nn as nn
|
||||
from .resample import UpSample1d, DownSample1d
|
||||
|
||||
from .resample import DownSample1d, UpSample1d
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
if 'sinc' in dir(torch):
|
||||
sinc = torch.sinc
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from .filter import LowPassFilter1d
|
||||
from .filter import kaiser_sinc_filter1d
|
||||
|
||||
from .filter import LowPassFilter1d, kaiser_sinc_filter1d
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
|
||||
@ -4,23 +4,23 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Dict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
import indextts.BigVGAN.activations as activations
|
||||
from indextts.BigVGAN.utils import init_weights, get_padding
|
||||
from indextts.BigVGAN.alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
||||
from indextts.BigVGAN.env import AttrDict
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
from indextts.BigVGAN.alias_free_activation.torch.act import \
|
||||
Activation1d as TorchActivation1d
|
||||
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
||||
from indextts.BigVGAN.env import AttrDict
|
||||
from indextts.BigVGAN.utils import get_padding, init_weights
|
||||
|
||||
|
||||
def load_hparams_from_json(path) -> AttrDict:
|
||||
@ -51,7 +51,7 @@ class AMPBlock1(torch.nn.Module):
|
||||
activation: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
@ -94,9 +94,8 @@ class AMPBlock1(torch.nn.Module):
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
from alias_free_activation.cuda.activation1d import \
|
||||
Activation1d as CudaActivation1d
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
@ -170,7 +169,7 @@ class AMPBlock2(torch.nn.Module):
|
||||
activation: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList(
|
||||
@ -194,9 +193,8 @@ class AMPBlock2(torch.nn.Module):
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
from alias_free_activation.cuda.activation1d import \
|
||||
Activation1d as CudaActivation1d
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
@ -241,6 +239,7 @@ class AMPBlock2(torch.nn.Module):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
'''
|
||||
PyTorchModelHubMixin,
|
||||
library_name="bigvgan",
|
||||
@ -251,6 +250,7 @@ class AMPBlock2(torch.nn.Module):
|
||||
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
||||
'''
|
||||
|
||||
|
||||
class BigVGAN(
|
||||
torch.nn.Module,
|
||||
):
|
||||
@ -274,9 +274,8 @@ class BigVGAN(
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
from alias_free_activation.cuda.activation1d import \
|
||||
Activation1d as CudaActivation1d
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
|
||||
@ -1,17 +1,16 @@
|
||||
# Copyright (c) 2022 NVIDIA CORPORATION.
|
||||
# Copyright (c) 2022 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
import indextts.BigVGAN.activations as activations
|
||||
from indextts.BigVGAN.utils import init_weights, get_padding
|
||||
from indextts.BigVGAN.alias_free_torch import *
|
||||
|
||||
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
||||
from indextts.BigVGAN.utils import get_padding, init_weights
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
@ -41,19 +40,19 @@ class AMPBlock1(torch.nn.Module):
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
|
||||
@ -89,25 +88,25 @@ class AMPBlock2(torch.nn.Module):
|
||||
])
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs) # total number of conv layers
|
||||
self.num_layers = len(self.convs) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
|
||||
|
||||
def forward(self, x):
|
||||
for c, a in zip (self.convs, self.activations):
|
||||
for c, a in zip(self.convs, self.activations):
|
||||
xt = a(x)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
@ -127,7 +126,7 @@ class BigVGAN(torch.nn.Module):
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
|
||||
self.feat_upsample = h.feat_upsample
|
||||
self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer
|
||||
|
||||
@ -154,10 +153,10 @@ class BigVGAN(torch.nn.Module):
|
||||
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# post conv
|
||||
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
else:
|
||||
@ -180,7 +179,6 @@ class BigVGAN(torch.nn.Module):
|
||||
|
||||
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
|
||||
def forward(self, x, mel_ref, lens=None):
|
||||
speaker_embedding = self.speaker_encoder(mel_ref, lens)
|
||||
n_batch = x.size(0)
|
||||
@ -190,7 +188,7 @@ class BigVGAN(torch.nn.Module):
|
||||
contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp())
|
||||
|
||||
speaker_embedding = speaker_embedding[:n_batch, :, :]
|
||||
speaker_embedding = speaker_embedding.transpose(1,2)
|
||||
speaker_embedding = speaker_embedding.transpose(1, 2)
|
||||
|
||||
# upsample feat
|
||||
if self.feat_upsample:
|
||||
@ -265,20 +263,20 @@ class DiscriminatorP(torch.nn.Module):
|
||||
self.d_mult = h.discriminator_channel_mult
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
|
||||
norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
|
||||
])
|
||||
self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
@ -338,11 +336,11 @@ class DiscriminatorR(nn.Module):
|
||||
self.d_mult = cfg.mrd_channel_mult
|
||||
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))),
|
||||
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))),
|
||||
])
|
||||
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
@ -367,7 +365,7 @@ class DiscriminatorR(nn.Module):
|
||||
x = x.squeeze(1)
|
||||
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
|
||||
x = torch.view_as_real(x) # [B, F, TT, 2]
|
||||
mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
|
||||
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
|
||||
|
||||
return mag
|
||||
|
||||
@ -376,9 +374,9 @@ class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(self, cfg, debug=False):
|
||||
super().__init__()
|
||||
self.resolutions = cfg.resolutions
|
||||
assert len(self.resolutions) == 3,\
|
||||
assert len(self.resolutions) == 3, \
|
||||
"MRD requires list of list with len=3, each element having a list with len=3. got {}".\
|
||||
format(self.resolutions)
|
||||
format(self.resolutions)
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
|
||||
)
|
||||
@ -406,7 +404,7 @@ def feature_loss(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss*2
|
||||
return loss * 2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
@ -414,7 +412,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean((1-dr)**2)
|
||||
r_loss = torch.mean((1 - dr)**2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += (r_loss + g_loss)
|
||||
r_losses.append(r_loss.item())
|
||||
@ -427,9 +425,8 @@ def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean((1-dg)**2)
|
||||
l = torch.mean((1 - dg)**2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
|
||||
class SincConv(nn.Module):
|
||||
"""This function implements SincConv (SincNet).
|
||||
|
||||
|
||||
0
indextts/BigVGAN/nnet/__init__.py
Normal file
0
indextts/BigVGAN/nnet/__init__.py
Normal file
@ -3,13 +3,14 @@
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pylab as plt
|
||||
import torch
|
||||
from scipy.io.wavfile import write
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
0
indextts/__init__.py
Normal file
0
indextts/__init__.py
Normal file
@ -21,6 +21,7 @@ from typing import Tuple, Union
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
"""Positional encoding.
|
||||
|
||||
|
||||
@ -3,11 +3,18 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gpt.conformer.subsampling import Conv2dSubsampling4, Conv2dSubsampling6, \
|
||||
Conv2dSubsampling8, LinearNoSubsampling, Conv2dSubsampling2
|
||||
from gpt.conformer.embedding import PositionalEncoding, RelPositionalEncoding, NoPositionalEncoding
|
||||
from gpt.conformer.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
|
||||
from utils.utils import make_pad_mask
|
||||
|
||||
from indextts.gpt.conformer.attention import (MultiHeadedAttention,
|
||||
RelPositionMultiHeadedAttention)
|
||||
from indextts.gpt.conformer.embedding import (NoPositionalEncoding,
|
||||
PositionalEncoding,
|
||||
RelPositionalEncoding)
|
||||
from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2,
|
||||
Conv2dSubsampling4,
|
||||
Conv2dSubsampling6,
|
||||
Conv2dSubsampling8,
|
||||
LinearNoSubsampling)
|
||||
from indextts.utils.utils import make_pad_mask
|
||||
|
||||
|
||||
class PositionwiseFeedForward(torch.nn.Module):
|
||||
@ -22,6 +29,7 @@ class PositionwiseFeedForward(torch.nn.Module):
|
||||
dropout_rate (float): Dropout rate.
|
||||
activation (torch.nn.Module): Activation function
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
idim: int,
|
||||
hidden_units: int,
|
||||
@ -47,6 +55,7 @@ class PositionwiseFeedForward(torch.nn.Module):
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model."""
|
||||
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
kernel_size: int = 15,
|
||||
@ -181,6 +190,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
True: x -> x + linear(concat(x, att(x)))
|
||||
False: x -> x + att(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
@ -428,6 +438,7 @@ class BaseEncoder(torch.nn.Module):
|
||||
|
||||
class ConformerEncoder(BaseEncoder):
|
||||
"""Conformer encoder module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
@ -507,4 +518,3 @@ class ConformerEncoder(BaseEncoder):
|
||||
concat_after,
|
||||
) for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
|
||||
@ -5,11 +5,13 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||
from gpt.perceiver import PerceiverResampler
|
||||
from gpt.conformer_encoder import ConformerEncoder
|
||||
from transformers.utils.model_parallel_utils import (assert_device_map,
|
||||
get_device_map)
|
||||
|
||||
from indextts.gpt.conformer_encoder import ConformerEncoder
|
||||
from indextts.gpt.perceiver import PerceiverResampler
|
||||
from indextts.utils.arch_util import AttentionBlock
|
||||
from utils.typical_sampling import TypicalLogitsWarper
|
||||
from indextts.utils.typical_sampling import TypicalLogitsWarper
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
@ -20,14 +22,15 @@ class ResBlock(nn.Module):
|
||||
"""
|
||||
Basic residual convolutional block that uses GroupNorm.
|
||||
"""
|
||||
|
||||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||
nn.GroupNorm(chan//8, chan),
|
||||
nn.GroupNorm(chan // 8, chan),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||
nn.GroupNorm(chan//8, chan)
|
||||
nn.GroupNorm(chan // 8, chan)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -229,7 +232,7 @@ class ConditioningEncoder(nn.Module):
|
||||
return h.mean(dim=2)
|
||||
else:
|
||||
return h
|
||||
#return h[:, :, 0]
|
||||
# return h[:, :, 0]
|
||||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
@ -253,8 +256,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
||||
"""
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||
n_positions=max_mel_seq_len + max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
@ -266,7 +269,7 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
# Built-in token embeddings are unused.
|
||||
del gpt.wte
|
||||
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
|
||||
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
|
||||
None, None
|
||||
|
||||
|
||||
@ -274,14 +277,14 @@ class MelEncoder(nn.Module):
|
||||
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
||||
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels//16, channels//2),
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
|
||||
nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels // 16, channels // 2),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels//8, channels),
|
||||
nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels // 8, channels),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||
)
|
||||
@ -493,7 +496,7 @@ class UnifiedVoice(nn.Module):
|
||||
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
|
||||
cond_mel_lengths) # (b, s, d), (b, 1, s)
|
||||
if self.condition_type == "conformer_perceiver":
|
||||
#conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
|
||||
# conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
|
||||
conds_mask = self.cond_mask_pad(mask.squeeze(1))
|
||||
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
|
||||
elif self.condition_type == "gst":
|
||||
@ -536,7 +539,7 @@ class UnifiedVoice(nn.Module):
|
||||
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
|
||||
# Types are expressed by expanding the text embedding space.
|
||||
if types is not None:
|
||||
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
||||
text_inputs = text_inputs * (1 + types).unsqueeze(-1)
|
||||
|
||||
if clip_inputs:
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
@ -546,10 +549,10 @@ class UnifiedVoice(nn.Module):
|
||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||
mel_codes = mel_codes[:, :max_mel_len]
|
||||
if raw_mels is not None:
|
||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||
raw_mels = raw_mels[:, :, :max_mel_len * 4]
|
||||
|
||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
#mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
|
||||
# mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
|
||||
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
|
||||
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
|
||||
|
||||
@ -569,7 +572,7 @@ class UnifiedVoice(nn.Module):
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||
|
||||
if text_first:
|
||||
#print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
|
||||
# print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
|
||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
@ -598,7 +601,7 @@ class UnifiedVoice(nn.Module):
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
|
||||
# +1 for the start_audio_token
|
||||
fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long,
|
||||
fake_inputs = torch.full((emb.shape[0], emb.shape[1] + 1,), fill_value=1, dtype=torch.long,
|
||||
device=text_inputs.device)
|
||||
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
@ -619,7 +622,3 @@ class UnifiedVoice(nn.Module):
|
||||
max_length=max_length, logits_processor=logits_processor,
|
||||
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
||||
return gen[:, trunc_index:]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,16 +1,18 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from omegaconf import OmegaConf
|
||||
import sentencepiece as spm
|
||||
from utils.utils import tokenize_by_CJK_char
|
||||
from utils.feature_extractors import MelSpectrogramFeatures
|
||||
from indextts.vqvae.xtts_dvae import DiscreteVAE
|
||||
from indextts.utils.checkpoint import load_checkpoint
|
||||
from indextts.gpt.model import UnifiedVoice
|
||||
|
||||
from indextts.BigVGAN.models import BigVGAN as Generator
|
||||
from indextts.gpt.model import UnifiedVoice
|
||||
from indextts.utils.checkpoint import load_checkpoint
|
||||
from indextts.utils.feature_extractors import MelSpectrogramFeatures
|
||||
from indextts.utils.utils import tokenize_by_CJK_char
|
||||
from indextts.vqvae.xtts_dvae import DiscreteVAE
|
||||
|
||||
|
||||
class IndexTTS:
|
||||
@ -107,18 +109,18 @@ class IndexTTS:
|
||||
print(text_len)
|
||||
with torch.no_grad():
|
||||
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
||||
device=text_tokens.device),
|
||||
# text_lengths=text_len,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
||||
device=text_tokens.device),
|
||||
# text_lengths=text_len,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
print(codes)
|
||||
print(f"codes shape: {codes.shape}")
|
||||
codes = codes[:, :-2]
|
||||
@ -126,10 +128,10 @@ class IndexTTS:
|
||||
# latent, text_lens_out, code_lens_out = \
|
||||
latent = \
|
||||
self.gpt(auto_conditioning, text_tokens,
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
latent = latent.transpose(1, 2)
|
||||
'''
|
||||
latent_list = []
|
||||
@ -155,4 +157,4 @@ class IndexTTS:
|
||||
|
||||
if __name__ == "__main__":
|
||||
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
|
||||
tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!',output_path="gen.wav")
|
||||
tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!', output_path="gen.wav")
|
||||
|
||||
0
indextts/utils/__init__.py
Normal file
0
indextts/utils/__init__.py
Normal file
@ -1,6 +1,8 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from indextts.utils.xtransformers import RelativePositionBias
|
||||
|
||||
|
||||
|
||||
@ -12,15 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
import datetime
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
|
||||
def load_checkpoint(model: torch.nn.Module, model_pth: str) -> dict:
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
from utils import safe_log
|
||||
|
||||
from indextts.utils.utils import safe_log
|
||||
|
||||
|
||||
class FeatureExtractor(nn.Module):
|
||||
@ -23,7 +24,7 @@ class FeatureExtractor(nn.Module):
|
||||
|
||||
class MelSpectrogramFeatures(FeatureExtractor):
|
||||
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, win_length=None,
|
||||
n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"):
|
||||
n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"):
|
||||
super().__init__()
|
||||
if padding not in ["center", "same"]:
|
||||
raise ValueError("Padding must be 'center' or 'same'.")
|
||||
@ -47,4 +48,4 @@ class MelSpectrogramFeatures(FeatureExtractor):
|
||||
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
|
||||
mel = self.mel_spec(audio)
|
||||
mel = safe_log(mel)
|
||||
return mel
|
||||
return mel
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from inspect import isfunction
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn, einsum
|
||||
from torch import einsum, nn
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
|
||||
@ -14,8 +14,8 @@ matplotlib==3.8.2
|
||||
opencv-python==4.9.0.80
|
||||
vocos==0.1.0
|
||||
accelerate==0.25.0
|
||||
omegaconf==2.0.6
|
||||
tensorboard==2.9.1
|
||||
omegaconf
|
||||
sentencepiece
|
||||
pypinyin
|
||||
librosa
|
||||
|
||||
6
webui.py
6
webui.py
@ -1,16 +1,18 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import sys
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(current_dir)
|
||||
sys.path.append(os.path.join(current_dir, "indextts"))
|
||||
|
||||
import gradio as gr
|
||||
from utils.webui_utils import next_page, prev_page
|
||||
|
||||
from indextts.infer import IndexTTS
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from utils.webui_utils import next_page, prev_page
|
||||
|
||||
i18n = I18nAuto(language="zh_CN")
|
||||
MODE = 'local'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user