From 3fc7b31e10f0166d244279053b27b2259c76453a Mon Sep 17 00:00:00 2001 From: sunnyboxs Date: Sun, 20 Apr 2025 14:12:38 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8D=95=E5=8F=A5=E6=8E=A8=E7=90=86=EF=BC=9ART?= =?UTF-8?q?F=E6=80=A7=E8=83=BD=E8=87=B3=E5=B0=91=E6=8F=90=E5=8D=87=2010%?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + checkpoints/config.yaml | 1 + indextts/gpt/model.py | 7 ++++--- indextts/infer.py | 4 ++++ 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 23e0d7b..42f0240 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__ *.egg-info *.DS_Store .idea/ +.vscode/ checkpoints/*.pth checkpoints/*.vocab checkpoints/*.model diff --git a/checkpoints/config.yaml b/checkpoints/config.yaml index e24336d..ca2d2cc 100644 --- a/checkpoints/config.yaml +++ b/checkpoints/config.yaml @@ -19,6 +19,7 @@ gpt: use_mel_codes_as_input: true mel_length_compression: 1024 layers: 20 + activation_function: "gelu_pytorch_tanh" number_text_tokens: 12000 number_mel_codes: 8194 start_mel_token: 8192 diff --git a/indextts/gpt/model.py b/indextts/gpt/model.py index a605f7e..2906b31 100644 --- a/indextts/gpt/model.py +++ b/indextts/gpt/model.py @@ -250,7 +250,7 @@ class LearnedPositionEmbeddings(nn.Module): return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) -def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): +def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, activation_function): """ GPT-2 implemented by the HuggingFace library. """ @@ -261,6 +261,7 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text n_embd=model_dim, n_layer=layers, n_head=heads, + activation_function=activation_function or "gelu_new", gradient_checkpointing=checkpointing, use_cache=not checkpointing) gpt = GPT2Model(gpt_config) @@ -301,7 +302,7 @@ class UnifiedVoice(nn.Module): mel_length_compression=1024, number_text_tokens=256, start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, - checkpointing=True, types=1, + checkpointing=True, types=1, activation_function=None, condition_num_latent=32, condition_type="perceiver", condition_module=None): """ Args: @@ -365,7 +366,7 @@ class UnifiedVoice(nn.Module): self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs, - self.max_text_tokens + 2, checkpointing) + self.max_text_tokens + 2, checkpointing, activation_function) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) diff --git a/indextts/infer.py b/indextts/infer.py index 6c7f5f0..b2233b4 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -11,6 +11,10 @@ from torch.nn.utils.rnn import pad_sequence from omegaconf import OmegaConf from tqdm import tqdm +import warnings +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + from indextts.BigVGAN.models import BigVGAN as Generator from indextts.gpt.model import UnifiedVoice from indextts.utils.checkpoint import load_checkpoint