fix packages

This commit is contained in:
Emmanuel Schmidbauer 2025-03-25 14:03:29 -04:00
parent 4c086f954b
commit 2fe6a73ada
31 changed files with 171 additions and 148 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
venv/
__pycache__

View File

@ -12,6 +12,7 @@ from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d
from indextts.BigVGAN.nnet.linear import Linear
from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d
def length_to_mask(length, max_len=None, dtype=None, device=None):
"""Creates a binary mask for each sequence.

View File

View File

@ -2,7 +2,7 @@
# LICENSE is in incl_licenses directory.
import torch
from torch import nn, sin, pow
from torch import nn, pow, sin
from torch.nn import Parameter
@ -22,6 +22,7 @@ class Snake(nn.Module):
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
@ -36,9 +37,9 @@ class Snake(nn.Module):
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
@ -51,7 +52,7 @@ class Snake(nn.Module):
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
@ -76,6 +77,7 @@ class SnakeBeta(nn.Module):
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
@ -92,10 +94,10 @@ class SnakeBeta(nn.Module):
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha)
@ -110,7 +112,7 @@ class SnakeBeta(nn.Module):
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)

View File

@ -3,10 +3,9 @@
import torch
import torch.nn as nn
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
from alias_free_activation.cuda import load
from alias_free_activation.torch.resample import DownSample1d, UpSample1d
anti_alias_activation_cuda = load.load()

View File

@ -1,6 +1,6 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
from .act import *
from .filter import *
from .resample import *
from .act import *

View File

@ -2,7 +2,8 @@
# LICENSE is in incl_licenses directory.
import torch.nn as nn
from .resample import UpSample1d, DownSample1d
from .resample import DownSample1d, UpSample1d
class Activation1d(nn.Module):

View File

@ -1,10 +1,11 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
if "sinc" in dir(torch):
sinc = torch.sinc

View File

@ -3,8 +3,8 @@
import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d
from .filter import LowPassFilter1d, kaiser_sinc_filter1d
class UpSample1d(nn.Module):

View File

@ -1,6 +1,6 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
from .act import *
from .filter import *
from .resample import *
from .act import *

View File

@ -2,7 +2,8 @@
# LICENSE is in incl_licenses directory.
import torch.nn as nn
from .resample import UpSample1d, DownSample1d
from .resample import DownSample1d, UpSample1d
class Activation1d(nn.Module):

View File

@ -1,10 +1,11 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
if 'sinc' in dir(torch):
sinc = torch.sinc

View File

@ -3,8 +3,8 @@
import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d
from .filter import LowPassFilter1d, kaiser_sinc_filter1d
class UpSample1d(nn.Module):

View File

@ -4,23 +4,23 @@
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import os
import json
import os
from pathlib import Path
from typing import Optional, Union, Dict
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
from torch.nn.utils import remove_weight_norm, weight_norm
import indextts.BigVGAN.activations as activations
from indextts.BigVGAN.utils import init_weights, get_padding
from indextts.BigVGAN.alias_free_activation.torch.act import Activation1d as TorchActivation1d
from indextts.BigVGAN.env import AttrDict
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from indextts.BigVGAN.alias_free_activation.torch.act import \
Activation1d as TorchActivation1d
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
from indextts.BigVGAN.env import AttrDict
from indextts.BigVGAN.utils import get_padding, init_weights
def load_hparams_from_json(path) -> AttrDict:
@ -94,9 +94,8 @@ class AMPBlock1(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
from alias_free_activation.cuda.activation1d import \
Activation1d as CudaActivation1d
Activation1d = CudaActivation1d
else:
@ -194,9 +193,8 @@ class AMPBlock2(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
from alias_free_activation.cuda.activation1d import \
Activation1d as CudaActivation1d
Activation1d = CudaActivation1d
else:
@ -241,6 +239,7 @@ class AMPBlock2(torch.nn.Module):
for l in self.convs:
remove_weight_norm(l)
'''
PyTorchModelHubMixin,
library_name="bigvgan",
@ -251,6 +250,7 @@ class AMPBlock2(torch.nn.Module):
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
'''
class BigVGAN(
torch.nn.Module,
):
@ -274,9 +274,8 @@ class BigVGAN(
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
from alias_free_activation.cuda.activation1d import \
Activation1d as CudaActivation1d
Activation1d = CudaActivation1d
else:

View File

@ -4,14 +4,13 @@
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
import indextts.BigVGAN.activations as activations
from indextts.BigVGAN.utils import init_weights, get_padding
from indextts.BigVGAN.alias_free_torch import *
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
from indextts.BigVGAN.utils import get_padding, init_weights
LRELU_SLOPE = 0.1
@ -41,19 +40,19 @@ class AMPBlock1(torch.nn.Module):
])
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
for _ in range(self.num_layers)
])
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
@ -89,25 +88,25 @@ class AMPBlock2(torch.nn.Module):
])
self.convs.apply(init_weights)
self.num_layers = len(self.convs) # total number of conv layers
self.num_layers = len(self.convs) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
for _ in range(self.num_layers)
])
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
def forward(self, x):
for c, a in zip (self.convs, self.activations):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
@ -154,10 +153,10 @@ class BigVGAN(torch.nn.Module):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
@ -180,7 +179,6 @@ class BigVGAN(torch.nn.Module):
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, x, mel_ref, lens=None):
speaker_embedding = self.speaker_encoder(mel_ref, lens)
n_batch = x.size(0)
@ -190,7 +188,7 @@ class BigVGAN(torch.nn.Module):
contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp())
speaker_embedding = speaker_embedding[:n_batch, :, :]
speaker_embedding = speaker_embedding.transpose(1,2)
speaker_embedding = speaker_embedding.transpose(1, 2)
# upsample feat
if self.feat_upsample:
@ -265,20 +263,20 @@ class DiscriminatorP(torch.nn.Module):
self.d_mult = h.discriminator_channel_mult
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
@ -338,11 +336,11 @@ class DiscriminatorR(nn.Module):
self.d_mult = cfg.mrd_channel_mult
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))),
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))),
])
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
@ -367,7 +365,7 @@ class DiscriminatorR(nn.Module):
x = x.squeeze(1)
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
x = torch.view_as_real(x) # [B, F, TT, 2]
mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
return mag
@ -376,9 +374,9 @@ class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False):
super().__init__()
self.resolutions = cfg.resolutions
assert len(self.resolutions) == 3,\
assert len(self.resolutions) == 3, \
"MRD requires list of list with len=3, each element having a list with len=3. got {}".\
format(self.resolutions)
format(self.resolutions)
self.discriminators = nn.ModuleList(
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
)
@ -406,7 +404,7 @@ def feature_loss(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss*2
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
@ -414,7 +412,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1-dr)**2)
r_loss = torch.mean((1 - dr)**2)
g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
@ -427,9 +425,8 @@ def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1-dg)**2)
l = torch.mean((1 - dg)**2)
gen_losses.append(l)
loss += l
return loss, gen_losses

View File

@ -19,6 +19,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torchaudio
class SincConv(nn.Module):
"""This function implements SincConv (SincNet).

View File

View File

@ -3,13 +3,14 @@
import glob
import os
import matplotlib
import matplotlib.pylab as plt
import torch
from scipy.io.wavfile import write
from torch.nn.utils import weight_norm
matplotlib.use("Agg")
import matplotlib.pylab as plt
from scipy.io.wavfile import write
MAX_WAV_VALUE = 32768.0

0
indextts/__init__.py Normal file
View File

View File

@ -21,6 +21,7 @@ from typing import Tuple, Union
import torch
import torch.nn.functional as F
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.

View File

@ -3,11 +3,18 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from gpt.conformer.subsampling import Conv2dSubsampling4, Conv2dSubsampling6, \
Conv2dSubsampling8, LinearNoSubsampling, Conv2dSubsampling2
from gpt.conformer.embedding import PositionalEncoding, RelPositionalEncoding, NoPositionalEncoding
from gpt.conformer.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
from utils.utils import make_pad_mask
from indextts.gpt.conformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from indextts.gpt.conformer.embedding import (NoPositionalEncoding,
PositionalEncoding,
RelPositionalEncoding)
from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2,
Conv2dSubsampling4,
Conv2dSubsampling6,
Conv2dSubsampling8,
LinearNoSubsampling)
from indextts.utils.utils import make_pad_mask
class PositionwiseFeedForward(torch.nn.Module):
@ -22,6 +29,7 @@ class PositionwiseFeedForward(torch.nn.Module):
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def __init__(self,
idim: int,
hidden_units: int,
@ -47,6 +55,7 @@ class PositionwiseFeedForward(torch.nn.Module):
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int = 15,
@ -181,6 +190,7 @@ class ConformerEncoderLayer(nn.Module):
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(
self,
size: int,
@ -428,6 +438,7 @@ class BaseEncoder(torch.nn.Module):
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""
def __init__(
self,
input_size: int,
@ -507,4 +518,3 @@ class ConformerEncoder(BaseEncoder):
concat_after,
) for _ in range(num_blocks)
])

View File

@ -5,11 +5,13 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from gpt.perceiver import PerceiverResampler
from gpt.conformer_encoder import ConformerEncoder
from transformers.utils.model_parallel_utils import (assert_device_map,
get_device_map)
from indextts.gpt.conformer_encoder import ConformerEncoder
from indextts.gpt.perceiver import PerceiverResampler
from indextts.utils.arch_util import AttentionBlock
from utils.typical_sampling import TypicalLogitsWarper
from indextts.utils.typical_sampling import TypicalLogitsWarper
def null_position_embeddings(range, dim):
@ -20,14 +22,15 @@ class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan),
nn.GroupNorm(chan // 8, chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan)
nn.GroupNorm(chan // 8, chan)
)
def forward(self, x):
@ -229,7 +232,7 @@ class ConditioningEncoder(nn.Module):
return h.mean(dim=2)
else:
return h
#return h[:, :, 0]
# return h[:, :, 0]
class LearnedPositionEmbeddings(nn.Module):
@ -253,8 +256,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
"""
from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len+max_text_seq_len,
n_ctx=max_mel_seq_len+max_text_seq_len,
n_positions=max_mel_seq_len + max_text_seq_len,
n_ctx=max_mel_seq_len + max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
@ -266,7 +269,7 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del gpt.wte
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
None, None
@ -274,14 +277,14 @@ class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//16, channels//2),
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 16, channels // 2),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//8, channels),
nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 8, channels),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
)
@ -493,7 +496,7 @@ class UnifiedVoice(nn.Module):
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
cond_mel_lengths) # (b, s, d), (b, 1, s)
if self.condition_type == "conformer_perceiver":
#conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
# conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
conds_mask = self.cond_mask_pad(mask.squeeze(1))
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
elif self.condition_type == "gst":
@ -536,7 +539,7 @@ class UnifiedVoice(nn.Module):
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
# Types are expressed by expanding the text embedding space.
if types is not None:
text_inputs = text_inputs * (1+types).unsqueeze(-1)
text_inputs = text_inputs * (1 + types).unsqueeze(-1)
if clip_inputs:
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
@ -546,10 +549,10 @@ class UnifiedVoice(nn.Module):
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = mel_codes[:, :max_mel_len]
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
raw_mels = raw_mels[:, :, :max_mel_len * 4]
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
#mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
# mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
@ -569,7 +572,7 @@ class UnifiedVoice(nn.Module):
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
if text_first:
#print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
# print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
@ -598,7 +601,7 @@ class UnifiedVoice(nn.Module):
self.inference_model.store_mel_emb(emb)
# +1 for the start_audio_token
fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long,
fake_inputs = torch.full((emb.shape[0], emb.shape[1] + 1,), fill_value=1, dtype=torch.long,
device=text_inputs.device)
fake_inputs[:, -1] = self.start_mel_token
@ -619,7 +622,3 @@ class UnifiedVoice(nn.Module):
max_length=max_length, logits_processor=logits_processor,
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
return gen[:, trunc_index:]

View File

@ -1,16 +1,18 @@
import os
import re
import sys
import sentencepiece as spm
import torch
import torchaudio
from omegaconf import OmegaConf
import sentencepiece as spm
from utils.utils import tokenize_by_CJK_char
from utils.feature_extractors import MelSpectrogramFeatures
from indextts.vqvae.xtts_dvae import DiscreteVAE
from indextts.utils.checkpoint import load_checkpoint
from indextts.gpt.model import UnifiedVoice
from indextts.BigVGAN.models import BigVGAN as Generator
from indextts.gpt.model import UnifiedVoice
from indextts.utils.checkpoint import load_checkpoint
from indextts.utils.feature_extractors import MelSpectrogramFeatures
from indextts.utils.utils import tokenize_by_CJK_char
from indextts.vqvae.xtts_dvae import DiscreteVAE
class IndexTTS:
@ -107,18 +109,18 @@ class IndexTTS:
print(text_len)
with torch.no_grad():
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
device=text_tokens.device),
# text_lengths=text_len,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=autoregressive_batch_size,
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
device=text_tokens.device),
# text_lengths=text_len,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=autoregressive_batch_size,
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
print(codes)
print(f"codes shape: {codes.shape}")
codes = codes[:, :-2]
@ -126,10 +128,10 @@ class IndexTTS:
# latent, text_lens_out, code_lens_out = \
latent = \
self.gpt(auto_conditioning, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
latent = latent.transpose(1, 2)
'''
latent_list = []
@ -155,4 +157,4 @@ class IndexTTS:
if __name__ == "__main__":
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
tts.infer(audio_prompt='test_data/input.wav', text='大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了',output_path="gen.wav")
tts.infer(audio_prompt='test_data/input.wav', text='大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了', output_path="gen.wav")

View File

View File

@ -1,6 +1,8 @@
import math
import torch
import torch.nn as nn
import math
from indextts.utils.xtransformers import RelativePositionBias

View File

@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import logging
import os
import re
import yaml
import torch
from collections import OrderedDict
import datetime
import torch
import yaml
def load_checkpoint(model: torch.nn.Module, model_pth: str) -> dict:

View File

@ -1,7 +1,8 @@
import torch
import torchaudio
from torch import nn
from utils import safe_log
from indextts.utils.utils import safe_log
class FeatureExtractor(nn.Module):
@ -23,7 +24,7 @@ class FeatureExtractor(nn.Module):
class MelSpectrogramFeatures(FeatureExtractor):
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, win_length=None,
n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"):
n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")

View File

@ -1,6 +1,7 @@
import os
import re
import random
import re
import torch
import torchaudio

View File

@ -6,7 +6,7 @@ from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn, einsum
from torch import einsum, nn
DEFAULT_DIM_HEAD = 64

View File

@ -14,8 +14,8 @@ matplotlib==3.8.2
opencv-python==4.9.0.80
vocos==0.1.0
accelerate==0.25.0
omegaconf==2.0.6
tensorboard==2.9.1
omegaconf
sentencepiece
pypinyin
librosa

View File

@ -1,16 +1,18 @@
import os
import shutil
import sys
import threading
import time
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, "indextts"))
import gradio as gr
from utils.webui_utils import next_page, prev_page
from indextts.infer import IndexTTS
from tools.i18n.i18n import I18nAuto
from utils.webui_utils import next_page, prev_page
i18n = I18nAuto(language="zh_CN")
MODE = 'local'