Merge pull request #39 from index-tts/feature/kemurin

推理时加载bpe model使用相对于模型根目录的路径
This commit is contained in:
index-tts 2025-04-02 17:42:39 +08:00 committed by GitHub
commit 397fef2f14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 113 additions and 3 deletions

View File

@ -105,11 +105,9 @@ apt-get install ffmpeg
```
3. Download models:
```bash
mkdir checkpoints
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/bigvgan_discriminator.pth -P checkpoints
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/bigvgan_generator.pth -P checkpoints
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/bpe.model -P checkpoints
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/config.yaml -P checkpoints
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/dvae.pth -P checkpoints
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/gpt.pth -P checkpoints
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/unigram_12000.vocab -P checkpoints

112
checkpoints/config.yaml Normal file
View File

@ -0,0 +1,112 @@
dataset:
bpe_model: bpe.model
sample_rate: 24000
squeeze: false
mel:
sample_rate: 24000
n_fft: 1024
hop_length: 256
win_length: 1024
n_mels: 100
mel_fmin: 0
normalize: false
gpt:
model_dim: 1024
max_mel_tokens: 605
max_text_tokens: 402
heads: 16
use_mel_codes_as_input: true
mel_length_compression: 1024
layers: 20
number_text_tokens: 12000
number_mel_codes: 8194
start_mel_token: 8192
stop_mel_token: 8193
start_text_token: 0
stop_text_token: 1
train_solo_embeddings: false
condition_type: "conformer_perceiver"
condition_module:
output_size: 512
linear_units: 2048
attention_heads: 8
num_blocks: 6
input_layer: "conv2d2"
perceiver_mult: 2
vqvae:
channels: 100
num_tokens: 8192
hidden_dim: 512
num_resnet_blocks: 3
codebook_dim: 512
num_layers: 2
positional_dims: 1
kernel_size: 3
smooth_l1_loss: true
use_transposed_convs: false
bigvgan:
adam_b1: 0.8
adam_b2: 0.99
lr_decay: 0.999998
seed: 1234
resblock: "1"
upsample_rates: [4,4,4,4,2,2]
upsample_kernel_sizes: [8,8,4,4,4,4]
upsample_initial_channel: 1536
resblock_kernel_sizes: [3,7,11]
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
feat_upsample: false
speaker_embedding_dim: 512
cond_d_vector_in_each_upsampling_layer: true
gpt_dim: 1024
activation: "snakebeta"
snake_logscale: true
use_cqtd_instead_of_mrd: true
cqtd_filters: 128
cqtd_max_filters: 1024
cqtd_filters_scale: 1
cqtd_dilations: [1, 2, 4]
cqtd_hop_lengths: [512, 256, 256]
cqtd_n_octaves: [9, 9, 9]
cqtd_bins_per_octaves: [24, 36, 48]
resolutions: [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]]
mpd_reshapes: [2, 3, 5, 7, 11]
use_spectral_norm: false
discriminator_channel_mult: 1
use_multiscale_melloss: true
lambda_melloss: 15
clip_grad_norm: 1000
segment_size: 16384
num_mels: 100
num_freq: 1025
n_fft: 1024
hop_size: 256
win_size: 1024
sampling_rate: 24000
fmin: 0
fmax: null
fmax_for_loss: null
mel_type: "pytorch"
num_workers: 2
dist_config:
dist_backend: "nccl"
dist_url: "tcp://localhost:54321"
world_size: 1
dvae_checkpoint: dvae.pth
gpt_checkpoint: gpt.pth
bigvgan_checkpoint: bigvgan_generator.pth

View File

@ -74,7 +74,7 @@ class IndexTTS:
auto_conditioning = cond_mel
tokenizer = spm.SentencePieceProcessor()
tokenizer.load(self.cfg.dataset['bpe_model'])
tokenizer.load(os.path.join(self.model_dir,self.cfg.dataset['bpe_model']))
punctuation = ["!", "?", ".", ";", "", "", "", ""]
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))