fix packages
This commit is contained in:
parent
4c086f954b
commit
2fe6a73ada
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
venv/
|
||||
__pycache__
|
||||
@ -12,6 +12,7 @@ from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d
|
||||
from indextts.BigVGAN.nnet.linear import Linear
|
||||
from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d
|
||||
|
||||
|
||||
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
||||
"""Creates a binary mask for each sequence.
|
||||
|
||||
|
||||
0
indextts/BigVGAN/__init__.py
Normal file
0
indextts/BigVGAN/__init__.py
Normal file
@ -2,7 +2,7 @@
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch import nn, pow, sin
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ class Snake(nn.Module):
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
@ -76,6 +77,7 @@ class SnakeBeta(nn.Module):
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
|
||||
@ -3,10 +3,9 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
||||
|
||||
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
||||
from alias_free_activation.cuda import load
|
||||
from alias_free_activation.torch.resample import DownSample1d, UpSample1d
|
||||
|
||||
anti_alias_activation_cuda = load.load()
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from .act import *
|
||||
from .filter import *
|
||||
from .resample import *
|
||||
from .act import *
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch.nn as nn
|
||||
from .resample import UpSample1d, DownSample1d
|
||||
|
||||
from .resample import DownSample1d, UpSample1d
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
if "sinc" in dir(torch):
|
||||
sinc = torch.sinc
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from .filter import LowPassFilter1d
|
||||
from .filter import kaiser_sinc_filter1d
|
||||
|
||||
from .filter import LowPassFilter1d, kaiser_sinc_filter1d
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from .act import *
|
||||
from .filter import *
|
||||
from .resample import *
|
||||
from .act import *
|
||||
@ -2,7 +2,8 @@
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch.nn as nn
|
||||
from .resample import UpSample1d, DownSample1d
|
||||
|
||||
from .resample import DownSample1d, UpSample1d
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
if 'sinc' in dir(torch):
|
||||
sinc = torch.sinc
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from .filter import LowPassFilter1d
|
||||
from .filter import kaiser_sinc_filter1d
|
||||
|
||||
from .filter import LowPassFilter1d, kaiser_sinc_filter1d
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
|
||||
@ -4,23 +4,23 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Dict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
import indextts.BigVGAN.activations as activations
|
||||
from indextts.BigVGAN.utils import init_weights, get_padding
|
||||
from indextts.BigVGAN.alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
||||
from indextts.BigVGAN.env import AttrDict
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
from indextts.BigVGAN.alias_free_activation.torch.act import \
|
||||
Activation1d as TorchActivation1d
|
||||
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
||||
from indextts.BigVGAN.env import AttrDict
|
||||
from indextts.BigVGAN.utils import get_padding, init_weights
|
||||
|
||||
|
||||
def load_hparams_from_json(path) -> AttrDict:
|
||||
@ -94,9 +94,8 @@ class AMPBlock1(torch.nn.Module):
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
from alias_free_activation.cuda.activation1d import \
|
||||
Activation1d as CudaActivation1d
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
@ -194,9 +193,8 @@ class AMPBlock2(torch.nn.Module):
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
from alias_free_activation.cuda.activation1d import \
|
||||
Activation1d as CudaActivation1d
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
@ -241,6 +239,7 @@ class AMPBlock2(torch.nn.Module):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
'''
|
||||
PyTorchModelHubMixin,
|
||||
library_name="bigvgan",
|
||||
@ -251,6 +250,7 @@ class AMPBlock2(torch.nn.Module):
|
||||
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
||||
'''
|
||||
|
||||
|
||||
class BigVGAN(
|
||||
torch.nn.Module,
|
||||
):
|
||||
@ -274,9 +274,8 @@ class BigVGAN(
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
from alias_free_activation.cuda.activation1d import \
|
||||
Activation1d as CudaActivation1d
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
|
||||
@ -4,14 +4,13 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
import indextts.BigVGAN.activations as activations
|
||||
from indextts.BigVGAN.utils import init_weights, get_padding
|
||||
from indextts.BigVGAN.alias_free_torch import *
|
||||
|
||||
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
||||
from indextts.BigVGAN.utils import get_padding, init_weights
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
@ -107,7 +106,7 @@ class AMPBlock2(torch.nn.Module):
|
||||
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
|
||||
|
||||
def forward(self, x):
|
||||
for c, a in zip (self.convs, self.activations):
|
||||
for c, a in zip(self.convs, self.activations):
|
||||
xt = a(x)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
@ -180,7 +179,6 @@ class BigVGAN(torch.nn.Module):
|
||||
|
||||
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
|
||||
def forward(self, x, mel_ref, lens=None):
|
||||
speaker_embedding = self.speaker_encoder(mel_ref, lens)
|
||||
n_batch = x.size(0)
|
||||
@ -190,7 +188,7 @@ class BigVGAN(torch.nn.Module):
|
||||
contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp())
|
||||
|
||||
speaker_embedding = speaker_embedding[:n_batch, :, :]
|
||||
speaker_embedding = speaker_embedding.transpose(1,2)
|
||||
speaker_embedding = speaker_embedding.transpose(1, 2)
|
||||
|
||||
# upsample feat
|
||||
if self.feat_upsample:
|
||||
@ -265,13 +263,13 @@ class DiscriminatorP(torch.nn.Module):
|
||||
self.d_mult = h.discriminator_channel_mult
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
|
||||
norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
|
||||
])
|
||||
self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
@ -338,11 +336,11 @@ class DiscriminatorR(nn.Module):
|
||||
self.d_mult = cfg.mrd_channel_mult
|
||||
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))),
|
||||
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))),
|
||||
])
|
||||
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
@ -367,7 +365,7 @@ class DiscriminatorR(nn.Module):
|
||||
x = x.squeeze(1)
|
||||
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
|
||||
x = torch.view_as_real(x) # [B, F, TT, 2]
|
||||
mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
|
||||
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
|
||||
|
||||
return mag
|
||||
|
||||
@ -376,7 +374,7 @@ class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(self, cfg, debug=False):
|
||||
super().__init__()
|
||||
self.resolutions = cfg.resolutions
|
||||
assert len(self.resolutions) == 3,\
|
||||
assert len(self.resolutions) == 3, \
|
||||
"MRD requires list of list with len=3, each element having a list with len=3. got {}".\
|
||||
format(self.resolutions)
|
||||
self.discriminators = nn.ModuleList(
|
||||
@ -406,7 +404,7 @@ def feature_loss(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss*2
|
||||
return loss * 2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
@ -414,7 +412,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean((1-dr)**2)
|
||||
r_loss = torch.mean((1 - dr)**2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += (r_loss + g_loss)
|
||||
r_losses.append(r_loss.item())
|
||||
@ -427,9 +425,8 @@ def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean((1-dg)**2)
|
||||
l = torch.mean((1 - dg)**2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
|
||||
class SincConv(nn.Module):
|
||||
"""This function implements SincConv (SincNet).
|
||||
|
||||
|
||||
0
indextts/BigVGAN/nnet/__init__.py
Normal file
0
indextts/BigVGAN/nnet/__init__.py
Normal file
@ -3,13 +3,14 @@
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pylab as plt
|
||||
import torch
|
||||
from scipy.io.wavfile import write
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
0
indextts/__init__.py
Normal file
0
indextts/__init__.py
Normal file
@ -21,6 +21,7 @@ from typing import Tuple, Union
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
"""Positional encoding.
|
||||
|
||||
|
||||
@ -3,11 +3,18 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gpt.conformer.subsampling import Conv2dSubsampling4, Conv2dSubsampling6, \
|
||||
Conv2dSubsampling8, LinearNoSubsampling, Conv2dSubsampling2
|
||||
from gpt.conformer.embedding import PositionalEncoding, RelPositionalEncoding, NoPositionalEncoding
|
||||
from gpt.conformer.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
|
||||
from utils.utils import make_pad_mask
|
||||
|
||||
from indextts.gpt.conformer.attention import (MultiHeadedAttention,
|
||||
RelPositionMultiHeadedAttention)
|
||||
from indextts.gpt.conformer.embedding import (NoPositionalEncoding,
|
||||
PositionalEncoding,
|
||||
RelPositionalEncoding)
|
||||
from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2,
|
||||
Conv2dSubsampling4,
|
||||
Conv2dSubsampling6,
|
||||
Conv2dSubsampling8,
|
||||
LinearNoSubsampling)
|
||||
from indextts.utils.utils import make_pad_mask
|
||||
|
||||
|
||||
class PositionwiseFeedForward(torch.nn.Module):
|
||||
@ -22,6 +29,7 @@ class PositionwiseFeedForward(torch.nn.Module):
|
||||
dropout_rate (float): Dropout rate.
|
||||
activation (torch.nn.Module): Activation function
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
idim: int,
|
||||
hidden_units: int,
|
||||
@ -47,6 +55,7 @@ class PositionwiseFeedForward(torch.nn.Module):
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model."""
|
||||
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
kernel_size: int = 15,
|
||||
@ -181,6 +190,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
True: x -> x + linear(concat(x, att(x)))
|
||||
False: x -> x + att(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
@ -428,6 +438,7 @@ class BaseEncoder(torch.nn.Module):
|
||||
|
||||
class ConformerEncoder(BaseEncoder):
|
||||
"""Conformer encoder module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
@ -507,4 +518,3 @@ class ConformerEncoder(BaseEncoder):
|
||||
concat_after,
|
||||
) for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
|
||||
@ -5,11 +5,13 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||
from gpt.perceiver import PerceiverResampler
|
||||
from gpt.conformer_encoder import ConformerEncoder
|
||||
from transformers.utils.model_parallel_utils import (assert_device_map,
|
||||
get_device_map)
|
||||
|
||||
from indextts.gpt.conformer_encoder import ConformerEncoder
|
||||
from indextts.gpt.perceiver import PerceiverResampler
|
||||
from indextts.utils.arch_util import AttentionBlock
|
||||
from utils.typical_sampling import TypicalLogitsWarper
|
||||
from indextts.utils.typical_sampling import TypicalLogitsWarper
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
@ -20,14 +22,15 @@ class ResBlock(nn.Module):
|
||||
"""
|
||||
Basic residual convolutional block that uses GroupNorm.
|
||||
"""
|
||||
|
||||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||
nn.GroupNorm(chan//8, chan),
|
||||
nn.GroupNorm(chan // 8, chan),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||
nn.GroupNorm(chan//8, chan)
|
||||
nn.GroupNorm(chan // 8, chan)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -229,7 +232,7 @@ class ConditioningEncoder(nn.Module):
|
||||
return h.mean(dim=2)
|
||||
else:
|
||||
return h
|
||||
#return h[:, :, 0]
|
||||
# return h[:, :, 0]
|
||||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
@ -253,8 +256,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
||||
"""
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||
n_positions=max_mel_seq_len + max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
@ -266,7 +269,7 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
# Built-in token embeddings are unused.
|
||||
del gpt.wte
|
||||
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
|
||||
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
|
||||
None, None
|
||||
|
||||
|
||||
@ -274,14 +277,14 @@ class MelEncoder(nn.Module):
|
||||
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
||||
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels//16, channels//2),
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
|
||||
nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels // 16, channels // 2),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels//8, channels),
|
||||
nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels // 8, channels),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||
)
|
||||
@ -493,7 +496,7 @@ class UnifiedVoice(nn.Module):
|
||||
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
|
||||
cond_mel_lengths) # (b, s, d), (b, 1, s)
|
||||
if self.condition_type == "conformer_perceiver":
|
||||
#conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
|
||||
# conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
|
||||
conds_mask = self.cond_mask_pad(mask.squeeze(1))
|
||||
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
|
||||
elif self.condition_type == "gst":
|
||||
@ -536,7 +539,7 @@ class UnifiedVoice(nn.Module):
|
||||
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
|
||||
# Types are expressed by expanding the text embedding space.
|
||||
if types is not None:
|
||||
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
||||
text_inputs = text_inputs * (1 + types).unsqueeze(-1)
|
||||
|
||||
if clip_inputs:
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
@ -546,10 +549,10 @@ class UnifiedVoice(nn.Module):
|
||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||
mel_codes = mel_codes[:, :max_mel_len]
|
||||
if raw_mels is not None:
|
||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||
raw_mels = raw_mels[:, :, :max_mel_len * 4]
|
||||
|
||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
#mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
|
||||
# mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
|
||||
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
|
||||
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
|
||||
|
||||
@ -569,7 +572,7 @@ class UnifiedVoice(nn.Module):
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||
|
||||
if text_first:
|
||||
#print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
|
||||
# print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
|
||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
@ -598,7 +601,7 @@ class UnifiedVoice(nn.Module):
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
|
||||
# +1 for the start_audio_token
|
||||
fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long,
|
||||
fake_inputs = torch.full((emb.shape[0], emb.shape[1] + 1,), fill_value=1, dtype=torch.long,
|
||||
device=text_inputs.device)
|
||||
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
@ -619,7 +622,3 @@ class UnifiedVoice(nn.Module):
|
||||
max_length=max_length, logits_processor=logits_processor,
|
||||
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
||||
return gen[:, trunc_index:]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,16 +1,18 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from omegaconf import OmegaConf
|
||||
import sentencepiece as spm
|
||||
from utils.utils import tokenize_by_CJK_char
|
||||
from utils.feature_extractors import MelSpectrogramFeatures
|
||||
from indextts.vqvae.xtts_dvae import DiscreteVAE
|
||||
from indextts.utils.checkpoint import load_checkpoint
|
||||
from indextts.gpt.model import UnifiedVoice
|
||||
|
||||
from indextts.BigVGAN.models import BigVGAN as Generator
|
||||
from indextts.gpt.model import UnifiedVoice
|
||||
from indextts.utils.checkpoint import load_checkpoint
|
||||
from indextts.utils.feature_extractors import MelSpectrogramFeatures
|
||||
from indextts.utils.utils import tokenize_by_CJK_char
|
||||
from indextts.vqvae.xtts_dvae import DiscreteVAE
|
||||
|
||||
|
||||
class IndexTTS:
|
||||
@ -155,4 +157,4 @@ class IndexTTS:
|
||||
|
||||
if __name__ == "__main__":
|
||||
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
|
||||
tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!',output_path="gen.wav")
|
||||
tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!', output_path="gen.wav")
|
||||
|
||||
0
indextts/utils/__init__.py
Normal file
0
indextts/utils/__init__.py
Normal file
@ -1,6 +1,8 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from indextts.utils.xtransformers import RelativePositionBias
|
||||
|
||||
|
||||
|
||||
@ -12,15 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
import datetime
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
|
||||
def load_checkpoint(model: torch.nn.Module, model_pth: str) -> dict:
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
from utils import safe_log
|
||||
|
||||
from indextts.utils.utils import safe_log
|
||||
|
||||
|
||||
class FeatureExtractor(nn.Module):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from inspect import isfunction
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn, einsum
|
||||
from torch import einsum, nn
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
|
||||
@ -14,8 +14,8 @@ matplotlib==3.8.2
|
||||
opencv-python==4.9.0.80
|
||||
vocos==0.1.0
|
||||
accelerate==0.25.0
|
||||
omegaconf==2.0.6
|
||||
tensorboard==2.9.1
|
||||
omegaconf
|
||||
sentencepiece
|
||||
pypinyin
|
||||
librosa
|
||||
|
||||
6
webui.py
6
webui.py
@ -1,16 +1,18 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import sys
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(current_dir)
|
||||
sys.path.append(os.path.join(current_dir, "indextts"))
|
||||
|
||||
import gradio as gr
|
||||
from utils.webui_utils import next_page, prev_page
|
||||
|
||||
from indextts.infer import IndexTTS
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from utils.webui_utils import next_page, prev_page
|
||||
|
||||
i18n = I18nAuto(language="zh_CN")
|
||||
MODE = 'local'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user