fix packages
This commit is contained in:
parent
4c086f954b
commit
2fe6a73ada
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
venv/
|
||||||
|
__pycache__
|
||||||
@ -12,6 +12,7 @@ from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d
|
|||||||
from indextts.BigVGAN.nnet.linear import Linear
|
from indextts.BigVGAN.nnet.linear import Linear
|
||||||
from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d
|
from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d
|
||||||
|
|
||||||
|
|
||||||
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
||||||
"""Creates a binary mask for each sequence.
|
"""Creates a binary mask for each sequence.
|
||||||
|
|
||||||
|
|||||||
0
indextts/BigVGAN/__init__.py
Normal file
0
indextts/BigVGAN/__init__.py
Normal file
@ -2,7 +2,7 @@
|
|||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, sin, pow
|
from torch import nn, pow, sin
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
|
||||||
@ -22,6 +22,7 @@ class Snake(nn.Module):
|
|||||||
>>> x = torch.randn(256)
|
>>> x = torch.randn(256)
|
||||||
>>> x = a1(x)
|
>>> x = a1(x)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
'''
|
'''
|
||||||
Initialization.
|
Initialization.
|
||||||
@ -36,9 +37,9 @@ class Snake(nn.Module):
|
|||||||
|
|
||||||
# initialize alpha
|
# initialize alpha
|
||||||
self.alpha_logscale = alpha_logscale
|
self.alpha_logscale = alpha_logscale
|
||||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||||
else: # linear scale alphas initialized to ones
|
else: # linear scale alphas initialized to ones
|
||||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
self.alpha.requires_grad = alpha_trainable
|
self.alpha.requires_grad = alpha_trainable
|
||||||
@ -51,7 +52,7 @@ class Snake(nn.Module):
|
|||||||
Applies the function to the input elementwise.
|
Applies the function to the input elementwise.
|
||||||
Snake ∶= x + 1/a * sin^2 (xa)
|
Snake ∶= x + 1/a * sin^2 (xa)
|
||||||
'''
|
'''
|
||||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||||
if self.alpha_logscale:
|
if self.alpha_logscale:
|
||||||
alpha = torch.exp(alpha)
|
alpha = torch.exp(alpha)
|
||||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
@ -76,6 +77,7 @@ class SnakeBeta(nn.Module):
|
|||||||
>>> x = torch.randn(256)
|
>>> x = torch.randn(256)
|
||||||
>>> x = a1(x)
|
>>> x = a1(x)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
'''
|
'''
|
||||||
Initialization.
|
Initialization.
|
||||||
@ -92,10 +94,10 @@ class SnakeBeta(nn.Module):
|
|||||||
|
|
||||||
# initialize alpha
|
# initialize alpha
|
||||||
self.alpha_logscale = alpha_logscale
|
self.alpha_logscale = alpha_logscale
|
||||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||||
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
||||||
else: # linear scale alphas initialized to ones
|
else: # linear scale alphas initialized to ones
|
||||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||||
self.beta = Parameter(torch.ones(in_features) * alpha)
|
self.beta = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
@ -110,11 +112,11 @@ class SnakeBeta(nn.Module):
|
|||||||
Applies the function to the input elementwise.
|
Applies the function to the input elementwise.
|
||||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||||
'''
|
'''
|
||||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||||
if self.alpha_logscale:
|
if self.alpha_logscale:
|
||||||
alpha = torch.exp(alpha)
|
alpha = torch.exp(alpha)
|
||||||
beta = torch.exp(beta)
|
beta = torch.exp(beta)
|
||||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -3,10 +3,9 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
|
||||||
|
|
||||||
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
||||||
from alias_free_activation.cuda import load
|
from alias_free_activation.cuda import load
|
||||||
|
from alias_free_activation.torch.resample import DownSample1d, UpSample1d
|
||||||
|
|
||||||
anti_alias_activation_cuda = load.load()
|
anti_alias_activation_cuda = load.load()
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
from .act import *
|
||||||
from .filter import *
|
from .filter import *
|
||||||
from .resample import *
|
from .resample import *
|
||||||
from .act import *
|
|
||||||
|
|||||||
@ -2,7 +2,8 @@
|
|||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .resample import UpSample1d, DownSample1d
|
|
||||||
|
from .resample import DownSample1d, UpSample1d
|
||||||
|
|
||||||
|
|
||||||
class Activation1d(nn.Module):
|
class Activation1d(nn.Module):
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
|
||||||
|
|
||||||
if "sinc" in dir(torch):
|
if "sinc" in dir(torch):
|
||||||
sinc = torch.sinc
|
sinc = torch.sinc
|
||||||
|
|||||||
@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from .filter import LowPassFilter1d
|
|
||||||
from .filter import kaiser_sinc_filter1d
|
from .filter import LowPassFilter1d, kaiser_sinc_filter1d
|
||||||
|
|
||||||
|
|
||||||
class UpSample1d(nn.Module):
|
class UpSample1d(nn.Module):
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
from .act import *
|
||||||
from .filter import *
|
from .filter import *
|
||||||
from .resample import *
|
from .resample import *
|
||||||
from .act import *
|
|
||||||
@ -2,7 +2,8 @@
|
|||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .resample import UpSample1d, DownSample1d
|
|
||||||
|
from .resample import DownSample1d, UpSample1d
|
||||||
|
|
||||||
|
|
||||||
class Activation1d(nn.Module):
|
class Activation1d(nn.Module):
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
|
||||||
|
|
||||||
if 'sinc' in dir(torch):
|
if 'sinc' in dir(torch):
|
||||||
sinc = torch.sinc
|
sinc = torch.sinc
|
||||||
|
|||||||
@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from .filter import LowPassFilter1d
|
|
||||||
from .filter import kaiser_sinc_filter1d
|
from .filter import LowPassFilter1d, kaiser_sinc_filter1d
|
||||||
|
|
||||||
|
|
||||||
class UpSample1d(nn.Module):
|
class UpSample1d(nn.Module):
|
||||||
|
|||||||
@ -4,23 +4,23 @@
|
|||||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Dict
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||||
from torch.nn import Conv1d, ConvTranspose1d
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||||
|
|
||||||
import indextts.BigVGAN.activations as activations
|
import indextts.BigVGAN.activations as activations
|
||||||
from indextts.BigVGAN.utils import init_weights, get_padding
|
from indextts.BigVGAN.alias_free_activation.torch.act import \
|
||||||
from indextts.BigVGAN.alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
Activation1d as TorchActivation1d
|
||||||
from indextts.BigVGAN.env import AttrDict
|
|
||||||
|
|
||||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
|
||||||
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
||||||
|
from indextts.BigVGAN.env import AttrDict
|
||||||
|
from indextts.BigVGAN.utils import get_padding, init_weights
|
||||||
|
|
||||||
|
|
||||||
def load_hparams_from_json(path) -> AttrDict:
|
def load_hparams_from_json(path) -> AttrDict:
|
||||||
@ -51,7 +51,7 @@ class AMPBlock1(torch.nn.Module):
|
|||||||
activation: str = None,
|
activation: str = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.h = h
|
self.h = h
|
||||||
|
|
||||||
self.convs1 = nn.ModuleList(
|
self.convs1 = nn.ModuleList(
|
||||||
@ -94,9 +94,8 @@ class AMPBlock1(torch.nn.Module):
|
|||||||
|
|
||||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||||
if self.h.get("use_cuda_kernel", False):
|
if self.h.get("use_cuda_kernel", False):
|
||||||
from alias_free_activation.cuda.activation1d import (
|
from alias_free_activation.cuda.activation1d import \
|
||||||
Activation1d as CudaActivation1d,
|
Activation1d as CudaActivation1d
|
||||||
)
|
|
||||||
|
|
||||||
Activation1d = CudaActivation1d
|
Activation1d = CudaActivation1d
|
||||||
else:
|
else:
|
||||||
@ -170,7 +169,7 @@ class AMPBlock2(torch.nn.Module):
|
|||||||
activation: str = None,
|
activation: str = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.h = h
|
self.h = h
|
||||||
|
|
||||||
self.convs = nn.ModuleList(
|
self.convs = nn.ModuleList(
|
||||||
@ -194,9 +193,8 @@ class AMPBlock2(torch.nn.Module):
|
|||||||
|
|
||||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||||
if self.h.get("use_cuda_kernel", False):
|
if self.h.get("use_cuda_kernel", False):
|
||||||
from alias_free_activation.cuda.activation1d import (
|
from alias_free_activation.cuda.activation1d import \
|
||||||
Activation1d as CudaActivation1d,
|
Activation1d as CudaActivation1d
|
||||||
)
|
|
||||||
|
|
||||||
Activation1d = CudaActivation1d
|
Activation1d = CudaActivation1d
|
||||||
else:
|
else:
|
||||||
@ -241,6 +239,7 @@ class AMPBlock2(torch.nn.Module):
|
|||||||
for l in self.convs:
|
for l in self.convs:
|
||||||
remove_weight_norm(l)
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
PyTorchModelHubMixin,
|
PyTorchModelHubMixin,
|
||||||
library_name="bigvgan",
|
library_name="bigvgan",
|
||||||
@ -251,6 +250,7 @@ class AMPBlock2(torch.nn.Module):
|
|||||||
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
class BigVGAN(
|
class BigVGAN(
|
||||||
torch.nn.Module,
|
torch.nn.Module,
|
||||||
):
|
):
|
||||||
@ -274,9 +274,8 @@ class BigVGAN(
|
|||||||
|
|
||||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||||
if self.h.get("use_cuda_kernel", False):
|
if self.h.get("use_cuda_kernel", False):
|
||||||
from alias_free_activation.cuda.activation1d import (
|
from alias_free_activation.cuda.activation1d import \
|
||||||
Activation1d as CudaActivation1d,
|
Activation1d as CudaActivation1d
|
||||||
)
|
|
||||||
|
|
||||||
Activation1d = CudaActivation1d
|
Activation1d = CudaActivation1d
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,17 +1,16 @@
|
|||||||
# Copyright (c) 2022 NVIDIA CORPORATION.
|
# Copyright (c) 2022 NVIDIA CORPORATION.
|
||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||||
|
|
||||||
import indextts.BigVGAN.activations as activations
|
import indextts.BigVGAN.activations as activations
|
||||||
from indextts.BigVGAN.utils import init_weights, get_padding
|
|
||||||
from indextts.BigVGAN.alias_free_torch import *
|
from indextts.BigVGAN.alias_free_torch import *
|
||||||
|
|
||||||
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
||||||
|
from indextts.BigVGAN.utils import get_padding, init_weights
|
||||||
|
|
||||||
LRELU_SLOPE = 0.1
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
@ -41,19 +40,19 @@ class AMPBlock1(torch.nn.Module):
|
|||||||
])
|
])
|
||||||
self.convs2.apply(init_weights)
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||||
|
|
||||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||||
self.activations = nn.ModuleList([
|
self.activations = nn.ModuleList([
|
||||||
Activation1d(
|
Activation1d(
|
||||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
for _ in range(self.num_layers)
|
for _ in range(self.num_layers)
|
||||||
])
|
])
|
||||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||||
self.activations = nn.ModuleList([
|
self.activations = nn.ModuleList([
|
||||||
Activation1d(
|
Activation1d(
|
||||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
for _ in range(self.num_layers)
|
for _ in range(self.num_layers)
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
|
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
|
||||||
@ -89,25 +88,25 @@ class AMPBlock2(torch.nn.Module):
|
|||||||
])
|
])
|
||||||
self.convs.apply(init_weights)
|
self.convs.apply(init_weights)
|
||||||
|
|
||||||
self.num_layers = len(self.convs) # total number of conv layers
|
self.num_layers = len(self.convs) # total number of conv layers
|
||||||
|
|
||||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||||
self.activations = nn.ModuleList([
|
self.activations = nn.ModuleList([
|
||||||
Activation1d(
|
Activation1d(
|
||||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
for _ in range(self.num_layers)
|
for _ in range(self.num_layers)
|
||||||
])
|
])
|
||||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||||
self.activations = nn.ModuleList([
|
self.activations = nn.ModuleList([
|
||||||
Activation1d(
|
Activation1d(
|
||||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
for _ in range(self.num_layers)
|
for _ in range(self.num_layers)
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
|
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for c, a in zip (self.convs, self.activations):
|
for c, a in zip(self.convs, self.activations):
|
||||||
xt = a(x)
|
xt = a(x)
|
||||||
xt = c(xt)
|
xt = c(xt)
|
||||||
x = xt + x
|
x = xt + x
|
||||||
@ -127,7 +126,7 @@ class BigVGAN(torch.nn.Module):
|
|||||||
|
|
||||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(h.upsample_rates)
|
self.num_upsamples = len(h.upsample_rates)
|
||||||
|
|
||||||
self.feat_upsample = h.feat_upsample
|
self.feat_upsample = h.feat_upsample
|
||||||
self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer
|
self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer
|
||||||
|
|
||||||
@ -154,10 +153,10 @@ class BigVGAN(torch.nn.Module):
|
|||||||
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||||
|
|
||||||
# post conv
|
# post conv
|
||||||
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||||
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||||
self.activation_post = Activation1d(activation=activation_post)
|
self.activation_post = Activation1d(activation=activation_post)
|
||||||
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||||
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||||
self.activation_post = Activation1d(activation=activation_post)
|
self.activation_post = Activation1d(activation=activation_post)
|
||||||
else:
|
else:
|
||||||
@ -180,7 +179,6 @@ class BigVGAN(torch.nn.Module):
|
|||||||
|
|
||||||
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, mel_ref, lens=None):
|
def forward(self, x, mel_ref, lens=None):
|
||||||
speaker_embedding = self.speaker_encoder(mel_ref, lens)
|
speaker_embedding = self.speaker_encoder(mel_ref, lens)
|
||||||
n_batch = x.size(0)
|
n_batch = x.size(0)
|
||||||
@ -190,7 +188,7 @@ class BigVGAN(torch.nn.Module):
|
|||||||
contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp())
|
contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp())
|
||||||
|
|
||||||
speaker_embedding = speaker_embedding[:n_batch, :, :]
|
speaker_embedding = speaker_embedding[:n_batch, :, :]
|
||||||
speaker_embedding = speaker_embedding.transpose(1,2)
|
speaker_embedding = speaker_embedding.transpose(1, 2)
|
||||||
|
|
||||||
# upsample feat
|
# upsample feat
|
||||||
if self.feat_upsample:
|
if self.feat_upsample:
|
||||||
@ -265,20 +263,20 @@ class DiscriminatorP(torch.nn.Module):
|
|||||||
self.d_mult = h.discriminator_channel_mult
|
self.d_mult = h.discriminator_channel_mult
|
||||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
self.convs = nn.ModuleList([
|
self.convs = nn.ModuleList([
|
||||||
norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
|
norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
|
||||||
])
|
])
|
||||||
self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
fmap = []
|
fmap = []
|
||||||
|
|
||||||
# 1d to 2d
|
# 1d to 2d
|
||||||
b, c, t = x.shape
|
b, c, t = x.shape
|
||||||
if t % self.period != 0: # pad first
|
if t % self.period != 0: # pad first
|
||||||
n_pad = self.period - (t % self.period)
|
n_pad = self.period - (t % self.period)
|
||||||
x = F.pad(x, (0, n_pad), "reflect")
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
t = t + n_pad
|
t = t + n_pad
|
||||||
@ -338,11 +336,11 @@ class DiscriminatorR(nn.Module):
|
|||||||
self.d_mult = cfg.mrd_channel_mult
|
self.d_mult = cfg.mrd_channel_mult
|
||||||
|
|
||||||
self.convs = nn.ModuleList([
|
self.convs = nn.ModuleList([
|
||||||
norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))),
|
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
|
||||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))),
|
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))),
|
||||||
])
|
])
|
||||||
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
||||||
|
|
||||||
@ -367,7 +365,7 @@ class DiscriminatorR(nn.Module):
|
|||||||
x = x.squeeze(1)
|
x = x.squeeze(1)
|
||||||
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
|
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
|
||||||
x = torch.view_as_real(x) # [B, F, TT, 2]
|
x = torch.view_as_real(x) # [B, F, TT, 2]
|
||||||
mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
|
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
|
||||||
|
|
||||||
return mag
|
return mag
|
||||||
|
|
||||||
@ -376,9 +374,9 @@ class MultiResolutionDiscriminator(nn.Module):
|
|||||||
def __init__(self, cfg, debug=False):
|
def __init__(self, cfg, debug=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.resolutions = cfg.resolutions
|
self.resolutions = cfg.resolutions
|
||||||
assert len(self.resolutions) == 3,\
|
assert len(self.resolutions) == 3, \
|
||||||
"MRD requires list of list with len=3, each element having a list with len=3. got {}".\
|
"MRD requires list of list with len=3, each element having a list with len=3. got {}".\
|
||||||
format(self.resolutions)
|
format(self.resolutions)
|
||||||
self.discriminators = nn.ModuleList(
|
self.discriminators = nn.ModuleList(
|
||||||
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
|
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
|
||||||
)
|
)
|
||||||
@ -406,7 +404,7 @@ def feature_loss(fmap_r, fmap_g):
|
|||||||
for rl, gl in zip(dr, dg):
|
for rl, gl in zip(dr, dg):
|
||||||
loss += torch.mean(torch.abs(rl - gl))
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
|
||||||
return loss*2
|
return loss * 2
|
||||||
|
|
||||||
|
|
||||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||||
@ -414,7 +412,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
|||||||
r_losses = []
|
r_losses = []
|
||||||
g_losses = []
|
g_losses = []
|
||||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||||
r_loss = torch.mean((1-dr)**2)
|
r_loss = torch.mean((1 - dr)**2)
|
||||||
g_loss = torch.mean(dg**2)
|
g_loss = torch.mean(dg**2)
|
||||||
loss += (r_loss + g_loss)
|
loss += (r_loss + g_loss)
|
||||||
r_losses.append(r_loss.item())
|
r_losses.append(r_loss.item())
|
||||||
@ -427,9 +425,8 @@ def generator_loss(disc_outputs):
|
|||||||
loss = 0
|
loss = 0
|
||||||
gen_losses = []
|
gen_losses = []
|
||||||
for dg in disc_outputs:
|
for dg in disc_outputs:
|
||||||
l = torch.mean((1-dg)**2)
|
l = torch.mean((1 - dg)**2)
|
||||||
gen_losses.append(l)
|
gen_losses.append(l)
|
||||||
loss += l
|
loss += l
|
||||||
|
|
||||||
return loss, gen_losses
|
return loss, gen_losses
|
||||||
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
class SincConv(nn.Module):
|
class SincConv(nn.Module):
|
||||||
"""This function implements SincConv (SincNet).
|
"""This function implements SincConv (SincNet).
|
||||||
|
|
||||||
|
|||||||
0
indextts/BigVGAN/nnet/__init__.py
Normal file
0
indextts/BigVGAN/nnet/__init__.py
Normal file
@ -3,13 +3,14 @@
|
|||||||
|
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
import matplotlib.pylab as plt
|
||||||
import torch
|
import torch
|
||||||
|
from scipy.io.wavfile import write
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
import matplotlib.pylab as plt
|
|
||||||
from scipy.io.wavfile import write
|
|
||||||
|
|
||||||
MAX_WAV_VALUE = 32768.0
|
MAX_WAV_VALUE = 32768.0
|
||||||
|
|
||||||
|
|||||||
0
indextts/__init__.py
Normal file
0
indextts/__init__.py
Normal file
@ -21,6 +21,7 @@ from typing import Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(torch.nn.Module):
|
class PositionalEncoding(torch.nn.Module):
|
||||||
"""Positional encoding.
|
"""Positional encoding.
|
||||||
|
|
||||||
|
|||||||
@ -3,11 +3,18 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from gpt.conformer.subsampling import Conv2dSubsampling4, Conv2dSubsampling6, \
|
|
||||||
Conv2dSubsampling8, LinearNoSubsampling, Conv2dSubsampling2
|
from indextts.gpt.conformer.attention import (MultiHeadedAttention,
|
||||||
from gpt.conformer.embedding import PositionalEncoding, RelPositionalEncoding, NoPositionalEncoding
|
RelPositionMultiHeadedAttention)
|
||||||
from gpt.conformer.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
|
from indextts.gpt.conformer.embedding import (NoPositionalEncoding,
|
||||||
from utils.utils import make_pad_mask
|
PositionalEncoding,
|
||||||
|
RelPositionalEncoding)
|
||||||
|
from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2,
|
||||||
|
Conv2dSubsampling4,
|
||||||
|
Conv2dSubsampling6,
|
||||||
|
Conv2dSubsampling8,
|
||||||
|
LinearNoSubsampling)
|
||||||
|
from indextts.utils.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class PositionwiseFeedForward(torch.nn.Module):
|
class PositionwiseFeedForward(torch.nn.Module):
|
||||||
@ -22,6 +29,7 @@ class PositionwiseFeedForward(torch.nn.Module):
|
|||||||
dropout_rate (float): Dropout rate.
|
dropout_rate (float): Dropout rate.
|
||||||
activation (torch.nn.Module): Activation function
|
activation (torch.nn.Module): Activation function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
idim: int,
|
idim: int,
|
||||||
hidden_units: int,
|
hidden_units: int,
|
||||||
@ -47,6 +55,7 @@ class PositionwiseFeedForward(torch.nn.Module):
|
|||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule in Conformer model."""
|
"""ConvolutionModule in Conformer model."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
channels: int,
|
channels: int,
|
||||||
kernel_size: int = 15,
|
kernel_size: int = 15,
|
||||||
@ -181,6 +190,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
True: x -> x + linear(concat(x, att(x)))
|
True: x -> x + linear(concat(x, att(x)))
|
||||||
False: x -> x + att(x)
|
False: x -> x + att(x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
@ -428,6 +438,7 @@ class BaseEncoder(torch.nn.Module):
|
|||||||
|
|
||||||
class ConformerEncoder(BaseEncoder):
|
class ConformerEncoder(BaseEncoder):
|
||||||
"""Conformer encoder module."""
|
"""Conformer encoder module."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_size: int,
|
input_size: int,
|
||||||
@ -507,4 +518,3 @@ class ConformerEncoder(BaseEncoder):
|
|||||||
concat_after,
|
concat_after,
|
||||||
) for _ in range(num_blocks)
|
) for _ in range(num_blocks)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|||||||
@ -5,11 +5,13 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
from transformers.utils.model_parallel_utils import (assert_device_map,
|
||||||
from gpt.perceiver import PerceiverResampler
|
get_device_map)
|
||||||
from gpt.conformer_encoder import ConformerEncoder
|
|
||||||
|
from indextts.gpt.conformer_encoder import ConformerEncoder
|
||||||
|
from indextts.gpt.perceiver import PerceiverResampler
|
||||||
from indextts.utils.arch_util import AttentionBlock
|
from indextts.utils.arch_util import AttentionBlock
|
||||||
from utils.typical_sampling import TypicalLogitsWarper
|
from indextts.utils.typical_sampling import TypicalLogitsWarper
|
||||||
|
|
||||||
|
|
||||||
def null_position_embeddings(range, dim):
|
def null_position_embeddings(range, dim):
|
||||||
@ -20,14 +22,15 @@ class ResBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Basic residual convolutional block that uses GroupNorm.
|
Basic residual convolutional block that uses GroupNorm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chan):
|
def __init__(self, chan):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||||
nn.GroupNorm(chan//8, chan),
|
nn.GroupNorm(chan // 8, chan),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||||
nn.GroupNorm(chan//8, chan)
|
nn.GroupNorm(chan // 8, chan)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -229,7 +232,7 @@ class ConditioningEncoder(nn.Module):
|
|||||||
return h.mean(dim=2)
|
return h.mean(dim=2)
|
||||||
else:
|
else:
|
||||||
return h
|
return h
|
||||||
#return h[:, :, 0]
|
# return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
class LearnedPositionEmbeddings(nn.Module):
|
class LearnedPositionEmbeddings(nn.Module):
|
||||||
@ -253,8 +256,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
|||||||
"""
|
"""
|
||||||
from transformers import GPT2Config, GPT2Model
|
from transformers import GPT2Config, GPT2Model
|
||||||
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
||||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
n_positions=max_mel_seq_len + max_text_seq_len,
|
||||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
n_ctx=max_mel_seq_len + max_text_seq_len,
|
||||||
n_embd=model_dim,
|
n_embd=model_dim,
|
||||||
n_layer=layers,
|
n_layer=layers,
|
||||||
n_head=heads,
|
n_head=heads,
|
||||||
@ -266,7 +269,7 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
|||||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
# Built-in token embeddings are unused.
|
# Built-in token embeddings are unused.
|
||||||
del gpt.wte
|
del gpt.wte
|
||||||
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
|
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
|
||||||
None, None
|
None, None
|
||||||
|
|
||||||
|
|
||||||
@ -274,14 +277,14 @@ class MelEncoder(nn.Module):
|
|||||||
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
|
||||||
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
|
||||||
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
|
||||||
nn.GroupNorm(channels//16, channels//2),
|
nn.GroupNorm(channels // 16, channels // 2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
|
||||||
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
|
||||||
nn.GroupNorm(channels//8, channels),
|
nn.GroupNorm(channels // 8, channels),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||||
)
|
)
|
||||||
@ -493,7 +496,7 @@ class UnifiedVoice(nn.Module):
|
|||||||
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
|
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
|
||||||
cond_mel_lengths) # (b, s, d), (b, 1, s)
|
cond_mel_lengths) # (b, s, d), (b, 1, s)
|
||||||
if self.condition_type == "conformer_perceiver":
|
if self.condition_type == "conformer_perceiver":
|
||||||
#conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
|
# conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
|
||||||
conds_mask = self.cond_mask_pad(mask.squeeze(1))
|
conds_mask = self.cond_mask_pad(mask.squeeze(1))
|
||||||
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
|
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
|
||||||
elif self.condition_type == "gst":
|
elif self.condition_type == "gst":
|
||||||
@ -536,7 +539,7 @@ class UnifiedVoice(nn.Module):
|
|||||||
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
|
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
|
||||||
# Types are expressed by expanding the text embedding space.
|
# Types are expressed by expanding the text embedding space.
|
||||||
if types is not None:
|
if types is not None:
|
||||||
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
text_inputs = text_inputs * (1 + types).unsqueeze(-1)
|
||||||
|
|
||||||
if clip_inputs:
|
if clip_inputs:
|
||||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
@ -546,10 +549,10 @@ class UnifiedVoice(nn.Module):
|
|||||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||||
mel_codes = mel_codes[:, :max_mel_len]
|
mel_codes = mel_codes[:, :max_mel_len]
|
||||||
if raw_mels is not None:
|
if raw_mels is not None:
|
||||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
raw_mels = raw_mels[:, :, :max_mel_len * 4]
|
||||||
|
|
||||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||||
#mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
|
# mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
|
||||||
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
|
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
|
||||||
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
|
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
|
||||||
|
|
||||||
@ -569,7 +572,7 @@ class UnifiedVoice(nn.Module):
|
|||||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||||
|
|
||||||
if text_first:
|
if text_first:
|
||||||
#print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
|
# print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
|
||||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
||||||
if return_latent:
|
if return_latent:
|
||||||
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||||
@ -598,7 +601,7 @@ class UnifiedVoice(nn.Module):
|
|||||||
self.inference_model.store_mel_emb(emb)
|
self.inference_model.store_mel_emb(emb)
|
||||||
|
|
||||||
# +1 for the start_audio_token
|
# +1 for the start_audio_token
|
||||||
fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long,
|
fake_inputs = torch.full((emb.shape[0], emb.shape[1] + 1,), fill_value=1, dtype=torch.long,
|
||||||
device=text_inputs.device)
|
device=text_inputs.device)
|
||||||
|
|
||||||
fake_inputs[:, -1] = self.start_mel_token
|
fake_inputs[:, -1] = self.start_mel_token
|
||||||
@ -619,7 +622,3 @@ class UnifiedVoice(nn.Module):
|
|||||||
max_length=max_length, logits_processor=logits_processor,
|
max_length=max_length, logits_processor=logits_processor,
|
||||||
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
||||||
return gen[:, trunc_index:]
|
return gen[:, trunc_index:]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,16 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
import sentencepiece as spm
|
|
||||||
from utils.utils import tokenize_by_CJK_char
|
|
||||||
from utils.feature_extractors import MelSpectrogramFeatures
|
|
||||||
from indextts.vqvae.xtts_dvae import DiscreteVAE
|
|
||||||
from indextts.utils.checkpoint import load_checkpoint
|
|
||||||
from indextts.gpt.model import UnifiedVoice
|
|
||||||
from indextts.BigVGAN.models import BigVGAN as Generator
|
from indextts.BigVGAN.models import BigVGAN as Generator
|
||||||
|
from indextts.gpt.model import UnifiedVoice
|
||||||
|
from indextts.utils.checkpoint import load_checkpoint
|
||||||
|
from indextts.utils.feature_extractors import MelSpectrogramFeatures
|
||||||
|
from indextts.utils.utils import tokenize_by_CJK_char
|
||||||
|
from indextts.vqvae.xtts_dvae import DiscreteVAE
|
||||||
|
|
||||||
|
|
||||||
class IndexTTS:
|
class IndexTTS:
|
||||||
@ -107,18 +109,18 @@ class IndexTTS:
|
|||||||
print(text_len)
|
print(text_len)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
|
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
|
||||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
||||||
device=text_tokens.device),
|
device=text_tokens.device),
|
||||||
# text_lengths=text_len,
|
# text_lengths=text_len,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
num_return_sequences=autoregressive_batch_size,
|
num_return_sequences=autoregressive_batch_size,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
max_generate_length=max_mel_tokens)
|
max_generate_length=max_mel_tokens)
|
||||||
print(codes)
|
print(codes)
|
||||||
print(f"codes shape: {codes.shape}")
|
print(f"codes shape: {codes.shape}")
|
||||||
codes = codes[:, :-2]
|
codes = codes[:, :-2]
|
||||||
@ -126,10 +128,10 @@ class IndexTTS:
|
|||||||
# latent, text_lens_out, code_lens_out = \
|
# latent, text_lens_out, code_lens_out = \
|
||||||
latent = \
|
latent = \
|
||||||
self.gpt(auto_conditioning, text_tokens,
|
self.gpt(auto_conditioning, text_tokens,
|
||||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||||
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
|
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device),
|
||||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device),
|
||||||
return_latent=True, clip_inputs=False)
|
return_latent=True, clip_inputs=False)
|
||||||
latent = latent.transpose(1, 2)
|
latent = latent.transpose(1, 2)
|
||||||
'''
|
'''
|
||||||
latent_list = []
|
latent_list = []
|
||||||
@ -155,4 +157,4 @@ class IndexTTS:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
|
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints")
|
||||||
tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!',output_path="gen.wav")
|
tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!', output_path="gen.wav")
|
||||||
|
|||||||
0
indextts/utils/__init__.py
Normal file
0
indextts/utils/__init__.py
Normal file
@ -1,6 +1,8 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import math
|
|
||||||
from indextts.utils.xtransformers import RelativePositionBias
|
from indextts.utils.xtransformers import RelativePositionBias
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -12,15 +12,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import yaml
|
|
||||||
import torch
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import datetime
|
import torch
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model: torch.nn.Module, model_pth: str) -> dict:
|
def load_checkpoint(model: torch.nn.Module, model_pth: str) -> dict:
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from utils import safe_log
|
|
||||||
|
from indextts.utils.utils import safe_log
|
||||||
|
|
||||||
|
|
||||||
class FeatureExtractor(nn.Module):
|
class FeatureExtractor(nn.Module):
|
||||||
@ -23,7 +24,7 @@ class FeatureExtractor(nn.Module):
|
|||||||
|
|
||||||
class MelSpectrogramFeatures(FeatureExtractor):
|
class MelSpectrogramFeatures(FeatureExtractor):
|
||||||
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, win_length=None,
|
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, win_length=None,
|
||||||
n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"):
|
n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if padding not in ["center", "same"]:
|
if padding not in ["center", "same"]:
|
||||||
raise ValueError("Padding must be 'center' or 'same'.")
|
raise ValueError("Padding must be 'center' or 'same'.")
|
||||||
@ -47,4 +48,4 @@ class MelSpectrogramFeatures(FeatureExtractor):
|
|||||||
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
|
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
|
||||||
mel = self.mel_spec(audio)
|
mel = self.mel_spec(audio)
|
||||||
mel = safe_log(mel)
|
mel = safe_log(mel)
|
||||||
return mel
|
return mel
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from inspect import isfunction
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from torch import nn, einsum
|
from torch import einsum, nn
|
||||||
|
|
||||||
DEFAULT_DIM_HEAD = 64
|
DEFAULT_DIM_HEAD = 64
|
||||||
|
|
||||||
|
|||||||
@ -14,8 +14,8 @@ matplotlib==3.8.2
|
|||||||
opencv-python==4.9.0.80
|
opencv-python==4.9.0.80
|
||||||
vocos==0.1.0
|
vocos==0.1.0
|
||||||
accelerate==0.25.0
|
accelerate==0.25.0
|
||||||
omegaconf==2.0.6
|
|
||||||
tensorboard==2.9.1
|
tensorboard==2.9.1
|
||||||
|
omegaconf
|
||||||
sentencepiece
|
sentencepiece
|
||||||
pypinyin
|
pypinyin
|
||||||
librosa
|
librosa
|
||||||
|
|||||||
6
webui.py
6
webui.py
@ -1,16 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import sys
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(current_dir)
|
sys.path.append(current_dir)
|
||||||
sys.path.append(os.path.join(current_dir, "indextts"))
|
sys.path.append(os.path.join(current_dir, "indextts"))
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from utils.webui_utils import next_page, prev_page
|
||||||
|
|
||||||
from indextts.infer import IndexTTS
|
from indextts.infer import IndexTTS
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
from utils.webui_utils import next_page, prev_page
|
|
||||||
|
|
||||||
i18n = I18nAuto(language="zh_CN")
|
i18n = I18nAuto(language="zh_CN")
|
||||||
MODE = 'local'
|
MODE = 'local'
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user