单句推理:RTF性能至少提升 10%

This commit is contained in:
sunnyboxs 2025-04-20 14:12:38 +08:00
parent a26894de71
commit 3fc7b31e10
4 changed files with 10 additions and 3 deletions

1
.gitignore vendored
View File

@ -3,6 +3,7 @@ __pycache__
*.egg-info
*.DS_Store
.idea/
.vscode/
checkpoints/*.pth
checkpoints/*.vocab
checkpoints/*.model

View File

@ -19,6 +19,7 @@ gpt:
use_mel_codes_as_input: true
mel_length_compression: 1024
layers: 20
activation_function: "gelu_pytorch_tanh"
number_text_tokens: 12000
number_mel_codes: 8194
start_mel_token: 8192

View File

@ -250,7 +250,7 @@ class LearnedPositionEmbeddings(nn.Module):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, activation_function):
"""
GPT-2 implemented by the HuggingFace library.
"""
@ -261,6 +261,7 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
n_embd=model_dim,
n_layer=layers,
n_head=heads,
activation_function=activation_function or "gelu_new",
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
gpt = GPT2Model(gpt_config)
@ -301,7 +302,7 @@ class UnifiedVoice(nn.Module):
mel_length_compression=1024, number_text_tokens=256,
start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, types=1,
checkpointing=True, types=1, activation_function=None,
condition_num_latent=32, condition_type="perceiver", condition_module=None):
"""
Args:
@ -365,7 +366,7 @@ class UnifiedVoice(nn.Module):
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
self.max_text_tokens + 2, checkpointing)
self.max_text_tokens + 2, checkpointing, activation_function)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)

View File

@ -11,6 +11,10 @@ from torch.nn.utils.rnn import pad_sequence
from omegaconf import OmegaConf
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
from indextts.BigVGAN.models import BigVGAN as Generator
from indextts.gpt.model import UnifiedVoice
from indextts.utils.checkpoint import load_checkpoint