单句推理:RTF性能至少提升 10%
This commit is contained in:
parent
a26894de71
commit
3fc7b31e10
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,6 +3,7 @@ __pycache__
|
||||
*.egg-info
|
||||
*.DS_Store
|
||||
.idea/
|
||||
.vscode/
|
||||
checkpoints/*.pth
|
||||
checkpoints/*.vocab
|
||||
checkpoints/*.model
|
||||
|
||||
@ -19,6 +19,7 @@ gpt:
|
||||
use_mel_codes_as_input: true
|
||||
mel_length_compression: 1024
|
||||
layers: 20
|
||||
activation_function: "gelu_pytorch_tanh"
|
||||
number_text_tokens: 12000
|
||||
number_mel_codes: 8194
|
||||
start_mel_token: 8192
|
||||
|
||||
@ -250,7 +250,7 @@ class LearnedPositionEmbeddings(nn.Module):
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, activation_function):
|
||||
"""
|
||||
GPT-2 implemented by the HuggingFace library.
|
||||
"""
|
||||
@ -261,6 +261,7 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
activation_function=activation_function or "gelu_new",
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
@ -301,7 +302,7 @@ class UnifiedVoice(nn.Module):
|
||||
mel_length_compression=1024, number_text_tokens=256,
|
||||
start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
|
||||
train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||
checkpointing=True, types=1,
|
||||
checkpointing=True, types=1, activation_function=None,
|
||||
condition_num_latent=32, condition_type="perceiver", condition_module=None):
|
||||
"""
|
||||
Args:
|
||||
@ -365,7 +366,7 @@ class UnifiedVoice(nn.Module):
|
||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
|
||||
self.max_text_tokens + 2, checkpointing)
|
||||
self.max_text_tokens + 2, checkpointing, activation_function)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
|
||||
@ -11,6 +11,10 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
from indextts.BigVGAN.models import BigVGAN as Generator
|
||||
from indextts.gpt.model import UnifiedVoice
|
||||
from indextts.utils.checkpoint import load_checkpoint
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user