Merge pull request #160 from yrom/fix-infer

适配1.5版本模型,优化Webui,适配新版本transformers
This commit is contained in:
index-tts 2025-05-18 22:14:32 +08:00 committed by GitHub
commit c0c17fe387
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 578 additions and 145 deletions

View File

@ -149,6 +149,7 @@ wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bpe.model -P che
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/dvae.pth -P checkpoints
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/gpt.pth -P checkpoints
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/unigram_12000.vocab -P checkpoints
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/config.yaml -P checkpoints
```
4. Run test script:
@ -180,6 +181,9 @@ indextts --help
```bash
pip install -e ".[webui]"
python webui.py
# use another model version:
python webui.py --model_dir IndexTTS-1.5
```
Open your browser and visit `http://127.0.0.1:7860` to see the demo.

View File

@ -40,6 +40,7 @@ class ResBlock(nn.Module):
class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
super().__init__(config)
# Note: the argument named `text_pos_emb` here actually represents the mel position embedding
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.embeddings = embeddings
@ -97,7 +98,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids.masked_fill_(attention_mask == 0, 0)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
@ -134,7 +135,6 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Create embedding
mel_len = self.cached_mel_emb.shape[1]
if input_ids.shape[1] != 1:
@ -388,7 +388,7 @@ class UnifiedVoice(nn.Module):
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
vocab_size=self.number_mel_codes,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
@ -556,7 +556,6 @@ class UnifiedVoice(nn.Module):
# mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
text_inputs = self.set_text_padding(text_inputs, text_lengths)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
@ -589,37 +588,121 @@ class UnifiedVoice(nn.Module):
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
def inference_speech(self, speech_conditioning_latent, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
conds = speech_conditioning_latent
emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_mel_emb(emb)
# +1 for the start_audio_token
fake_inputs = torch.full((emb.shape[0], emb.shape[1] + 1,), fill_value=1, dtype=torch.long,
device=text_inputs.device)
def prepare_gpt_inputs(
self,
conditional_latents: torch.Tensor,
text_inputs: torch.Tensor,
):
"""
Prepare the inputs for the GPT2InferenceModel to generate.
Args:
conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()`
text_inputs: (b, L)
Returns:
input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate()
inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward()
attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate()
"""
b, L = text_inputs.shape[:2]
device = text_inputs.device
single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1
if not single_cond:
assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}"
batched_mel_emb = []
attention_masks = []
target_len = conditional_latents.shape[1] + L + 2
for i in range(b):
valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token)
text_input = text_inputs[i][valid_mask]
text_input = F.pad(text_input, (1, 0), value=self.start_text_token)
text_input = F.pad(text_input, (0, 1), value=self.stop_text_token)
text_input_pos = torch.arange(0, text_input.size(-1), device=device)
text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos)
# concatenate [conditional latents][text embeddings]
conds_text_emb = [
conditional_latents.squeeze(0) if single_cond else conditional_latents[i],
text_emb,
]
# +1 for the start_mel_token
attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device)
# check this text input is padded
padding: int = L + 2 - text_input.size(-1)
# pad left of [cond][text] -> [pad][cond][text]
if padding > 0:
pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim]
conds_text_emb.insert(0, pad)
attention_mask[:padding] = 0
mel_emb = torch.cat(conds_text_emb) #[s, dim]
assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}"
batched_mel_emb.append(mel_emb)
attention_masks.append(attention_mask)
# [b, s, dim]
batched_mel_emb = torch.stack(batched_mel_emb, dim=0)
# [b, s+1]
attention_mask = torch.stack(attention_masks, dim=0)
# [b, s+1]
fake_inputs = torch.ones(
(
batched_mel_emb.shape[0],
batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token
),
dtype=torch.long,
device=device,
)
fake_inputs[:, -1] = self.start_mel_token
trunc_index = fake_inputs.shape[1]
return fake_inputs, batched_mel_emb, attention_mask
def inference_speech(self, speech_conditioning_mel, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
"""
Args:
speech_conditioning_mel: (b, n_mels, frames) or (n_mels, frames)
text_inputs: (b, L)
cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,)
input_tokens: additional tokens for generation in shape (b, s) or (s,)
max_generate_length: limit the number of generated tokens
hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)`
"""
if speech_conditioning_mel.ndim == 2:
speech_conditioning_mel = speech_conditioning_mel.unsqueeze(0)
if cond_mel_lengths is None:
cond_mel_lengths = torch.tensor([speech_conditioning_mel.shape[-1]], device=speech_conditioning_mel.device)
conds_latent = self.get_conditioning(speech_conditioning_mel, cond_mel_lengths)
input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs)
self.inference_model.store_mel_emb(inputs_embeds)
if input_tokens is None:
inputs = fake_inputs
inputs = input_ids
else:
assert num_return_sequences % input_tokens.shape[
0] == 0, "The number of return sequences must be divisible by the number of input sequences"
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
if input_tokens.ndim == 1:
input_tokens = input_tokens.unsqueeze(0)
assert num_return_sequences % input_tokens.shape[0] == 0, \
"The num_return_sequences must be divisible by the batch number of input_tokens"
assert num_return_sequences % text_inputs.shape[0] == 0, \
"The num_return_sequences must be divisible by the batch number of text_inputs"
b = num_return_sequences // input_ids.shape[0]
if b > 1:
input_ids = input_ids.repeat(b, 1)
attention_mask = attention_mask.repeat(b, 1)
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
eos_token_id=self.stop_mel_token,
inputs = torch.cat([input_ids, input_tokens], dim=1)
attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1)
trunc_index = inputs.shape[1]
logits_processor = LogitsProcessorList()
if typical_sampling:
# employ custom typical sampling
if not (typical_mass > 0.0 and typical_mass < 1.0):
raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}")
min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
output = self.inference_model.generate(inputs,
bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
max_length=max_length, logits_processor=logits_processor,
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
return gen[:, trunc_index:]
num_return_sequences=num_return_sequences,
**hf_generate_kwargs)
if isinstance(output, torch.Tensor):
return output[:, trunc_index:]
# GenerateOutput
output.sequences = output.sequences[:, trunc_index:]
return output

View File

@ -2,7 +2,7 @@ import os
import re
import time
from subprocess import CalledProcessError
from typing import List
from typing import Dict, List, Tuple
import numpy as np
import sentencepiece as spm
@ -125,8 +125,13 @@ class IndexTTS:
self.cache_cond_mel = None
# 进度引用显示(可选)
self.gr_progress = None
self.model_version = self.cfg.version if hasattr(self.cfg, "version") else None
def remove_long_silence(self, codes: torch.Tensor, silent_token=52, max_consecutive=30):
"""
Shrink special tokens (silent_token and stop_mel_token) in codes
codes: [B, T]
"""
code_lens = []
codes_list = []
device = codes.device
@ -134,59 +139,113 @@ class IndexTTS:
isfix = False
for i in range(0, codes.shape[0]):
code = codes[i]
if self.cfg.gpt.stop_mel_token not in code:
code_lens.append(len(code))
len_ = len(code)
if not torch.any(code == self.stop_mel_token).item():
len_ = code.size(0)
else:
# len_ = code.cpu().tolist().index(8193)+1
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1
len_ = len_ - 2
stop_mel_idx = (code == self.stop_mel_token).nonzero(as_tuple=False)
len_ = stop_mel_idx[0].item() if len(stop_mel_idx) > 0 else code.size(0)
count = torch.sum(code == silent_token).item()
if count > max_consecutive:
code = code.cpu().tolist()
ncode = []
# code = code.cpu().tolist()
ncode_idx = []
n = 0
for k in range(0, len_):
for k in range(len_):
assert code[k] != self.stop_mel_token, f"stop_mel_token {self.stop_mel_token} should be shrinked here"
if code[k] != silent_token:
ncode.append(code[k])
ncode_idx.append(k)
n = 0
elif code[k] == silent_token and n < 10:
ncode.append(code[k])
ncode_idx.append(k)
n += 1
# if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52):
# n += 1
len_ = len(ncode)
ncode = torch.LongTensor(ncode)
codes_list.append(ncode.to(device, dtype=dtype))
# new code
len_ = len(ncode_idx)
codes_list.append(code[ncode_idx])
isfix = True
# codes[i] = self.stop_mel_token
# codes[i, 0:len_] = ncode
else:
codes_list.append(codes[i])
# shrink to len_
codes_list.append(code[:len_])
code_lens.append(len_)
codes = pad_sequence(codes_list, batch_first=True) if isfix else codes[:, :-2]
code_lens = torch.LongTensor(code_lens).to(device, dtype=dtype)
if isfix:
if len(codes_list) > 1:
codes = pad_sequence(codes_list, batch_first=True, padding_value=self.stop_mel_token)
else:
codes = codes_list[0].unsqueeze(0)
else:
# unchanged
pass
# clip codes to max length
max_len = max(code_lens)
if max_len < codes.shape[1]:
codes = codes[:, :max_len]
code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
return codes, code_lens
def bucket_sentences(self, sentences, enable=False):
def bucket_sentences(self, sentences, bucket_max_size=4) -> List[List[Dict]]:
"""
Sentence data bucketing
Sentence data bucketing.
if ``bucket_max_size=1``, return all sentences in one bucket.
"""
max_len = max(len(s) for s in sentences)
half = max_len // 2
outputs = [[], []]
outputs: List[Dict] = []
for idx, sent in enumerate(sentences):
if enable is False or len(sent) <= half:
outputs[0].append({"idx": idx, "sent": sent})
else:
outputs[1].append({"idx": idx, "sent": sent})
return [item for item in outputs if item]
outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
if len(outputs) > bucket_max_size:
# split sentences into buckets by sentence length
buckets: List[List[Dict]] = []
factor = 1.5
last_bucket = None
last_bucket_sent_len_median = 0
def pad_tokens_cat(self, tokens: List[torch.Tensor]):
if len(tokens) <= 1:
return tokens[-1]
for sent in sorted(outputs, key=lambda x: x["len"]):
current_sent_len = sent["len"]
if current_sent_len == 0:
print(">> skip empty sentence")
continue
if last_bucket is None \
or current_sent_len >= int(last_bucket_sent_len_median * factor) \
or len(last_bucket) >= bucket_max_size:
# new bucket
buckets.append([sent])
last_bucket = buckets[-1]
last_bucket_sent_len_median = current_sent_len
else:
# current bucket can hold more sentences
last_bucket.append(sent) # sorted
mid = len(last_bucket) // 2
last_bucket_sent_len_median = last_bucket[mid]["len"]
last_bucket=None
# merge all buckets with size 1
out_buckets: List[List[Dict]] = []
only_ones: List[Dict] = []
for b in buckets:
if len(b) == 1:
only_ones.append(b[0])
else:
out_buckets.append(b)
if len(only_ones) > 0:
# merge into previous buckets if possible
# print("only_ones:", [(o["idx"], o["len"]) for o in only_ones])
for i in range(len(out_buckets)):
b = out_buckets[i]
if len(b) < bucket_max_size:
b.append(only_ones.pop(0))
if len(only_ones) == 0:
break
# combined all remaining sized 1 buckets
if len(only_ones) > 0:
out_buckets.extend([only_ones[i:i+bucket_max_size] for i in range(0, len(only_ones), bucket_max_size)])
return out_buckets
return [outputs]
def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor:
if self.model_version and self.model_version >= 1.5:
# 1.5版本以上直接使用stop_text_token 右侧填充,填充到最大长度
# [1, N] -> [N,]
tokens = [t.squeeze(0) for t in tokens]
return pad_sequence(tokens, batch_first=True, padding_value=self.cfg.gpt.stop_text_token, padding_side="right")
max_len = max(t.size(1) for t in tokens)
outputs = []
for tensor in tokens:
@ -214,8 +273,18 @@ class IndexTTS:
self.gr_progress(value, desc=desc)
# 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ First modified by sunnyboxs 2025-04-16
def infer_fast(self, audio_prompt, text, output_path, verbose=False):
def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=100, sentences_bucket_max_size=4, **generation_kwargs):
"""
Args:
``max_text_tokens_per_sentence``: 分句的最大token数默认``100``可以根据GPU硬件情况调整
- 越小batch 越多推理速度越**占用内存更多可能影响质量
- 越大batch 越少推理速度越**占用内存和质量更接近于非快速推理
``sentences_bucket_max_size``: 分句分桶的最大容量默认``4``可以根据GPU内存调整
- 越大bucket数量越少batch越多推理速度越**占用内存更多可能影响质量
- 越小bucket数量越多batch越少推理速度越**占用内存和质量更接近于非快速推理
"""
print(">> start fast inference...")
self._set_gr_progress(0, "start fast inference...")
if verbose:
print(f"origin text:{text}")
@ -245,20 +314,22 @@ class IndexTTS:
# text_tokens
text_tokens_list = self.tokenizer.tokenize(text)
sentences = self.tokenizer.split_sentences(text_tokens_list)
if verbose:
print("text token count:", len(text_tokens_list))
print("sentences count:", len(sentences))
print(*sentences, sep="\n")
top_p = 0.8
top_k = 30
temperature = 1.0
sentences = self.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=max_text_tokens_per_sentence)
if verbose:
print(">> text token count:", len(text_tokens_list))
print(" splited sentences count:", len(sentences))
print(" max_text_tokens_per_sentence:", max_text_tokens_per_sentence)
print(*sentences, sep="\n")
do_sample = generation_kwargs.pop("do_sample", True)
top_p = generation_kwargs.pop("top_p", 0.8)
top_k = generation_kwargs.pop("top_k", 30)
temperature = generation_kwargs.pop("temperature", 1.0)
autoregressive_batch_size = 1
length_penalty = 0.0
num_beams = 3
repetition_penalty = 10.0
max_mel_tokens = 600
length_penalty = generation_kwargs.pop("length_penalty", 0.0)
num_beams = generation_kwargs.pop("num_beams", 3)
repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600)
sampling_rate = 24000
# lang = "EN"
# lang = "ZH"
@ -270,8 +341,13 @@ class IndexTTS:
# text processing
all_text_tokens: List[List[torch.Tensor]] = []
self._set_gr_progress(0.1, "text processing...")
bucket_enable = True # 预分桶开关,优先保证质量=True。优先保证速度=False。
all_sentences = self.bucket_sentences(sentences, enable=bucket_enable)
bucket_max_size = sentences_bucket_max_size if self.device != "cpu" else 1
all_sentences = self.bucket_sentences(sentences, bucket_max_size=bucket_max_size)
bucket_count = len(all_sentences)
if verbose:
print(">> sentences bucket_count:", bucket_count,
"bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_sentences],
"bucket_max_size:", bucket_max_size)
for sentences in all_sentences:
temp_tokens: List[torch.Tensor] = []
all_text_tokens.append(temp_tokens)
@ -289,24 +365,25 @@ class IndexTTS:
# Sequential processing of bucketing data
all_batch_num = 0
all_batch_num = sum(len(s) for s in all_sentences)
all_batch_codes = []
processed_num = 0
for item_tokens in all_text_tokens:
batch_num = len(item_tokens)
batch_text_tokens = self.pad_tokens_cat(item_tokens)
batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0)
batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0)
all_batch_num += batch_num
if batch_num > 1:
batch_text_tokens = self.pad_tokens_cat(item_tokens)
else:
batch_text_tokens = item_tokens[0]
processed_num += batch_num
# gpt speech
self._set_gr_progress(0.2, "gpt inference speech...")
self._set_gr_progress(0.2 + 0.3 * processed_num/all_batch_num, f"gpt inference speech... {processed_num}/{all_batch_num}")
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
temp_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens,
cond_mel_lengths=batch_cond_mel_lengths,
temp_codes = self.gpt.inference_speech(auto_conditioning, batch_text_tokens,
cond_mel_lengths=cond_mel_lengths,
# text_lengths=text_len,
do_sample=True,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
@ -314,7 +391,8 @@ class IndexTTS:
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
max_generate_length=max_mel_tokens,
**generation_kwargs)
all_batch_codes.append(temp_codes)
gpt_gen_time += time.perf_counter() - m_start_time
@ -322,14 +400,26 @@ class IndexTTS:
self._set_gr_progress(0.5, "gpt inference latents...")
all_idxs = []
all_latents = []
has_warned = False
for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences):
for i in range(batch_codes.shape[0]):
codes = batch_codes[i] # [x]
codes = codes[codes != self.cfg.gpt.stop_mel_token]
codes, _ = torch.unique_consecutive(codes, return_inverse=True)
if not has_warned and codes[-1] != self.stop_mel_token:
warnings.warn(
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
category=RuntimeWarning
)
has_warned = True
codes = codes.unsqueeze(0) # [x] -> [1, x]
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
if verbose:
print("codes:", codes.shape)
print(codes)
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
if verbose:
print("fix codes:", codes.shape)
print(codes)
print("code_lens:", code_lens)
text_tokens = batch_tokens[i]
all_idxs.append(batch_sentences[i]["idx"])
m_start_time = time.perf_counter()
@ -343,14 +433,16 @@ class IndexTTS:
return_latent=True, clip_inputs=False)
gpt_forward_time += time.perf_counter() - m_start_time
all_latents.append(latent)
del all_batch_codes, all_text_tokens, all_sentences
# bigvgan chunk
chunk_size = 2
all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
if verbose:
print(">> all_latents:", len(all_latents))
print(" latents length:", [l.shape[1] for l in all_latents])
chunk_latents = [all_latents[i : i + chunk_size] for i in range(0, len(all_latents), chunk_size)]
chunk_length = len(chunk_latents)
latent_length = len(all_latents)
all_latents = None
# bigvgan chunk decode
self._set_gr_progress(0.7, "bigvgan decode...")
@ -370,7 +462,7 @@ class IndexTTS:
# clear cache
tqdm_progress.close() # 确保进度条被关闭
chunk_latents.clear()
del all_latents, chunk_latents
end_time = time.perf_counter()
self.torch_empty_cache()
@ -385,7 +477,7 @@ class IndexTTS:
print(f">> Total fast inference time: {end_time - start_time:.2f} seconds")
print(f">> Generated audio length: {wav_length:.2f} seconds")
print(f">> [fast] bigvgan chunk_length: {chunk_length}")
print(f">> [fast] batch_num: {all_batch_num} bucket_enable: {bucket_enable}")
print(f">> [fast] batch_num: {all_batch_num} bucket_max_size: {bucket_max_size}", f"bucket_count: {bucket_count}" if bucket_max_size > 1 else "")
print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}")
# save audio
@ -403,7 +495,7 @@ class IndexTTS:
return (sampling_rate, wav_data)
# 原始推理模式
def infer(self, audio_prompt, text, output_path, verbose=False):
def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=120, **generation_kwargs):
print(">> start inference...")
self._set_gr_progress(0, "start inference...")
if verbose:
@ -429,21 +521,24 @@ class IndexTTS:
cond_mel_frame = cond_mel.shape[-1]
pass
self._set_gr_progress(0.1, "text processing...")
auto_conditioning = cond_mel
text_tokens_list = self.tokenizer.tokenize(text)
sentences = self.tokenizer.split_sentences(text_tokens_list)
sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence)
if verbose:
print("text token count:", len(text_tokens_list))
print("sentences count:", len(sentences))
print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence)
print(*sentences, sep="\n")
top_p = 0.8
top_k = 30
temperature = 1.0
do_sample = generation_kwargs.pop("do_sample", True)
top_p = generation_kwargs.pop("top_p", 0.8)
top_k = generation_kwargs.pop("top_k", 30)
temperature = generation_kwargs.pop("temperature", 1.0)
autoregressive_batch_size = 1
length_penalty = 0.0
num_beams = 3
repetition_penalty = 10.0
max_mel_tokens = 600
length_penalty = generation_kwargs.pop("length_penalty", 0.0)
num_beams = generation_kwargs.pop("num_beams", 3)
repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600)
sampling_rate = 24000
# lang = "EN"
# lang = "ZH"
@ -451,7 +546,8 @@ class IndexTTS:
gpt_gen_time = 0
gpt_forward_time = 0
bigvgan_time = 0
progress = 0
has_warned = False
for sent in sentences:
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
@ -467,7 +563,8 @@ class IndexTTS:
# text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
# print(text_len)
progress += 1
self._set_gr_progress(0.2 + 0.4 * (progress-1) / len(sentences), f"gpt inference latent... {progress}/{len(sentences)}")
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
@ -475,7 +572,7 @@ class IndexTTS:
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
device=text_tokens.device),
# text_lengths=text_len,
do_sample=True,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
@ -483,9 +580,18 @@ class IndexTTS:
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
max_generate_length=max_mel_tokens,
**generation_kwargs)
gpt_gen_time += time.perf_counter() - m_start_time
# codes = codes[:, :-2]
if not has_warned and (codes[:, -1] != self.stop_mel_token).any():
warnings.warn(
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
f"Input text tokens: {text_tokens.shape[1]}. "
f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
category=RuntimeWarning
)
has_warned = True
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
if verbose:
print(codes, type(codes))
@ -499,7 +605,7 @@ class IndexTTS:
print(codes, type(codes))
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
print(f"code len: {code_lens}")
self._set_gr_progress(0.2 + 0.4 * progress / len(sentences), f"gpt inference speech... {progress}/{len(sentences)}")
m_start_time = time.perf_counter()
# latent, text_lens_out, code_lens_out = \
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
@ -517,11 +623,12 @@ class IndexTTS:
wav = wav.squeeze(1)
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
if verbose:
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
# wavs.append(wav[:, :-512])
wavs.append(wav.cpu()) # to cpu before saving
end_time = time.perf_counter()
self._set_gr_progress(0.9, "save audio...")
wav = torch.cat(wavs, dim=1)
wav_length = wav.shape[-1] / sampling_rate
print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")

View File

@ -84,7 +84,8 @@ class TextNormalizer:
# print(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
# sys.path.append(model_dir)
import platform
if self.zh_normalizer is not None and self.en_normalizer is not None:
return
if platform.system() == "Darwin":
from wetext import Normalizer
@ -355,22 +356,24 @@ class TextTokenizer:
sentences.append(current_sentence)
else:
# 如果当前tokens的长度超过最大限制
if "," in current_sentence or "▁," in current_sentence:
if not ("," in split_tokens or "▁," in split_tokens ) and ("," in current_sentence or "▁," in current_sentence):
# 如果当前tokens中有,,则按,分割
sub_sentences = TextTokenizer.split_sentences_by_token(
current_sentence, [",", "▁,"], max_tokens_per_sentence=max_tokens_per_sentence
)
elif "-" in current_sentence:
elif not ("-" in split_tokens ) and "-" in current_sentence:
# 没有,,则按-分割
sub_sentences = TextTokenizer.split_sentences_by_token(
current_sentence, ["-"], max_tokens_per_sentence=max_tokens_per_sentence
)
else:
# 按照长度分割
sub_sentences = [
current_sentence[:max_tokens_per_sentence],
current_sentence[max_tokens_per_sentence:],
]
sub_sentences = []
for j in range(0, len(current_sentence), max_tokens_per_sentence):
if j + max_tokens_per_sentence < len(current_sentence):
sub_sentences.append(current_sentence[j : j + max_tokens_per_sentence])
else:
sub_sentences.append(current_sentence[j:])
warnings.warn(
f"The tokens length of sentence exceeds limit: {max_tokens_per_sentence}, "
f"Tokens in sentence: {current_sentence}."
@ -448,6 +451,7 @@ if __name__ == "__main__":
"蒂莫西·唐纳德·库克英文名Timothy Donald Cook通称蒂姆·库克Tim Cook美国商业经理、工业工程师和工业开发商现任苹果公司首席执行官。",
# 长句子
"《盗梦空间》是由美国华纳兄弟影片公司出品的电影,由克里斯托弗·诺兰执导并编剧,莱昂纳多·迪卡普里奥、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演2010年7月16日在美国上映2010年9月1日在中国内地上映2020年8月28日在中国内地重映。影片剧情游走于梦境与现实之间被定义为“发生在意识结构内的当代动作科幻片”讲述了由莱昂纳多·迪卡普里奥扮演的造梦师带领特工团队进入他人梦境从他人的潜意识中盗取机密并重塑他人梦境的故事。",
"清晨拉开窗帘阳光洒在窗台的Bloomixy花艺礼盒上——薰衣草香薰蜡烛唤醒嗅觉永生花束折射出晨露般光泽。设计师将“自然绽放美学”融入每个细节手工陶瓷花瓶可作首饰收纳香薰精油含依兰依兰舒缓配方。限量款附赠《365天插花灵感手册》让每个平凡日子都有花开仪式感。\n宴会厅灯光暗下的刹那Glimmeria星月系列耳坠开始发光——瑞士冷珐琅工艺让蓝宝石如银河流动钛合金骨架仅3.2g无负重感。设计师秘密内置微型重力感应器随步伐产生0.01mm振幅,打造“行走的星光”。七夕限定礼盒含星座定制铭牌,让爱意如星辰永恒闪耀。",
]
# 测试分词器
tokenizer = TextTokenizer(

View File

@ -1,12 +1,9 @@
import torch
from transformers import LogitsWarper
from transformers import TypicalLogitsWarper as BaseTypicalLogitsWarper
class TypicalLogitsWarper(LogitsWarper):
class TypicalLogitsWarper(BaseTypicalLogitsWarper):
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
self.filter_value = filter_value
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep
super().__init__(mass=mass, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy

8
tests/cases.jsonl Normal file
View File

@ -0,0 +1,8 @@
{"prompt_audio":"sample_prompt.wav","text":"IndexTTS 正式发布1.0版本了效果666","infer_mode":0}
{"prompt_audio":"sample_prompt.wav","text":"大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了","infer_mode":0}
{"prompt_audio":"sample_prompt.wav","text":"晕XUAN4是一种GAN3觉","infer_mode":0}
{"prompt_audio":"sample_prompt.wav","text":"最zhong4要的是不要chong2蹈覆辙","infer_mode":0}
{"prompt_audio":"sample_prompt.wav","text":"Matt Hougan, chief investment officer at Bitwise, predicts Bitcoin (BTC) will reach $200,000 by the end of 2025 due to a supply shock from heightened institutional demand. In an interview with Cointelegraph at Consensus 2025 in Toronto, the executive said that Bitwise's Bitcoin price prediction model is driven exclusively by supply and demand metrics. \"I think eventually that will exhaust sellers at the $100,000 level where we have been stuck, and I think the next stopping point above that is $200,000,\" the executive said.","infer_mode":1}
{"prompt_audio":"sample_prompt.wav","text":"《盗梦空间》英语Inception是由美国华纳兄弟影片公司出品的电影由克里斯托弗·诺兰Christopher Edward Nolan执导并编剧莱昂纳多·迪卡普里奥Leonardo Wilhelm DiCaprio、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演2010年7月16日在美国上映2010年9月1日在中国内地上映2020年8月28日在中国内地重映。豆瓣评分9.4IMDB 8.8。影片剧情游走于梦境与现实之间,被定义为“发生在意识结构内的当代动作科幻片”,讲述了由 Leonardo 扮演的造梦师,带领特工团队进入他人梦境,从他人的潜意识中盗取机密,并重塑他人梦境的故事。","infer_mode":1}
{"prompt_audio":"sample_prompt.wav","text":"清晨拉开窗帘阳光洒在窗台的Bloomixy花艺礼盒上——薰衣草香薰蜡烛唤醒嗅觉永生花束折射出晨露般光泽。设计师将“自然绽放美学”融入每个细节手工陶瓷花瓶可作首饰收纳香薰精油含依兰依兰舒缓配方。限量款附赠《365天插花灵感手册》让每个平凡日子都有花开仪式感。宴会厅灯光暗下的刹那Glimmeria星月系列耳坠开始发光——瑞士冷珐琅工艺让蓝宝石如银河流动钛合金骨架仅3.2g无负重感。设计师秘密内置微型重力感应器随步伐产生0.01mm振幅,打造“行走的星光”。七夕限定礼盒含星座定制铭牌,让爱意如星辰永恒闪耀。","infer_mode":1}
{"prompt_audio":"sample_prompt.wav","text":"当地时间15日随着特朗普与阿联酋敲定2000亿美元协议特朗普的中东之行正式收官。特朗普已宣布获得沙特6000亿美元和卡塔尔2430亿美元投资承诺。商业协议成为特朗普重返白宫后首次外访的核心成果。香港英文媒体《南华早报》South China Morning Post特朗普访问期间提出了以经济合作为驱动的中东及南亚和平计划。分析人士称该战略包含多项旨在遏制中国在这些地区影响力的措施。英国伦敦国王学院安全研究教授安德烈亚斯·克里格(Dr. Andreas Krieg)对半岛电视台表示特朗普访问海湾地区的主要目标有三个其一以军工投资和能源合作的形式获得海湾国家的切实承诺其二加强与“让美国再次伟大”运动结盟的外交伙伴关系维持美国外交影响力其三将海湾国家重新定位为美国在从加沙到伊朗等地区危机前线的调解人这样就可以不用增强军事部署。安德烈亚斯·克里格直言“海湾国家不会为美国牺牲与中国的关系他们的战略自主性远超特朗普想象。”美国有线电视新闻网CNN报道称特朗普到访的三个能源富国每个国家都对美国有着长长的诉求清单。尽管这些国家豪掷重金但美国并未实现所有诉求。","infer_mode":1}

99
tests/padding_test.py Normal file
View File

@ -0,0 +1,99 @@
import torch
import torchaudio
from indextts.infer import IndexTTS
from indextts.utils.feature_extractors import MelSpectrogramFeatures
from torch.nn import functional as F
if __name__ == "__main__":
"""
Test the padding of text tokens in inference.
```
python tests/padding_test.py checkpoints
python tests/padding_test.py IndexTTS-1.5
```
"""
import transformers
transformers.set_seed(42)
import sys
sys.path.append("..")
if len(sys.argv) > 1:
model_dir = sys.argv[1]
else:
model_dir = "checkpoints"
audio_prompt="tests/sample_prompt.wav"
tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False)
text = "晕 XUAN4 是 一 种 not very good GAN3 觉"
text_tokens = tts.tokenizer.encode(text)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L]
audio, sr = torchaudio.load(audio_prompt)
audio = torch.mean(audio, dim=0, keepdim=True)
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
auto_conditioning = MelSpectrogramFeatures()(audio).to(tts.device)
cond_mel_lengths = torch.tensor([auto_conditioning.shape[-1]]).to(tts.device)
with torch.no_grad():
kwargs = {
"cond_mel_lengths": cond_mel_lengths,
"do_sample": False,
"top_p": 0.8,
"top_k": None,
"temperature": 1.0,
"num_return_sequences": 1,
"length_penalty": 0.0,
"num_beams": 1,
"repetition_penalty": 10.0,
"max_generate_length": 100,
}
# baseline for non-pad
baseline = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs)
baseline = baseline.squeeze(0)
print("Inference padded text tokens...")
pad_text_tokens = [
F.pad(text_tokens, (8, 0), value=0), # left bos
F.pad(text_tokens, (0, 8), value=1), # right eos
F.pad(F.pad(text_tokens, (4, 0), value=0), (0, 4), value=1), # both side
F.pad(F.pad(text_tokens, (6, 0), value=0), (0, 2), value=1),
F.pad(F.pad(text_tokens, (0, 4), value=0), (0, 4), value=1),
]
output_for_padded = []
for t in pad_text_tokens:
# test for each padded text
out = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs)
output_for_padded.append(out.squeeze(0))
# batched inference
print("Inference padded text tokens as one batch...")
batched_text_tokens = torch.cat(pad_text_tokens, dim=0).to(tts.device)
assert len(pad_text_tokens) == batched_text_tokens.shape[0] and batched_text_tokens.ndim == 2
batch_output = tts.gpt.inference_speech(auto_conditioning, batched_text_tokens, **kwargs)
del pad_text_tokens
mismatch_idx = []
print("baseline:", baseline.shape, baseline)
print("--"*10)
print("baseline vs padded output:")
for i in range(len(output_for_padded)):
if not baseline.equal(output_for_padded[i]):
mismatch_idx.append(i)
if len(mismatch_idx) > 0:
print("mismatch:", mismatch_idx)
for i in mismatch_idx:
print(f"[{i}]: {output_for_padded[i]}")
else:
print("all matched")
del output_for_padded
print("--"*10)
print("baseline vs batched output:")
mismatch_idx = []
for i in range(batch_output.shape[0]):
if not baseline.equal(batch_output[i]):
mismatch_idx.append(i)
if len(mismatch_idx) > 0:
print("mismatch:", mismatch_idx)
for i in mismatch_idx:
print(f"[{i}]: {batch_output[i]}")
else:
print("all matched")
print("Test finished.")

161
webui.py
View File

@ -1,5 +1,5 @@
import json
import os
import shutil
import sys
import threading
import time
@ -12,70 +12,201 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, "indextts"))
import argparse
parser = argparse.ArgumentParser(description="IndexTTS WebUI")
parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to run the web UI on")
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
cmd_args = parser.parse_args()
if not os.path.exists(cmd_args.model_dir):
print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
sys.exit(1)
for file in [
"bigvgan_generator.pth",
"bpe.model",
"gpt.pth",
"config.yaml",
]:
file_path = os.path.join(cmd_args.model_dir, file)
if not os.path.exists(file_path):
print(f"Required file {file_path} does not exist. Please download it.")
sys.exit(1)
import gradio as gr
from indextts.utils.webui_utils import next_page, prev_page
from indextts.infer import IndexTTS
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto(language="zh_CN")
MODE = 'local'
tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml")
tts = IndexTTS(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),)
os.makedirs("outputs/tasks",exist_ok=True)
os.makedirs("prompts",exist_ok=True)
with open("tests/cases.jsonl", "r", encoding="utf-8") as f:
example_cases = []
for line in f:
line = line.strip()
if not line:
continue
example = json.loads(line)
example_cases.append([os.path.join("tests", example.get("prompt_audio", "sample_prompt.wav")),
example.get("text"), ["普通推理", "批次推理"][example.get("infer_mode", 0)]])
def gen_single(prompt, text, infer_mode, progress=gr.Progress()):
def gen_single(prompt, text, infer_mode, max_text_tokens_per_sentence=120, sentences_bucket_max_size=4,
*args, progress=gr.Progress()):
output_path = None
if not output_path:
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
# set gradio progress
tts.gr_progress = progress
do_sample, top_p, top_k, temperature, \
length_penalty, num_beams, repetition_penalty, max_mel_tokens = args
kwargs = {
"do_sample": bool(do_sample),
"top_p": float(top_p),
"top_k": int(top_k) if int(top_k) > 0 else None,
"temperature": float(temperature),
"length_penalty": float(length_penalty),
"num_beams": num_beams,
"repetition_penalty": float(repetition_penalty),
"max_mel_tokens": int(max_mel_tokens),
# "typical_sampling": bool(typical_sampling),
# "typical_mass": float(typical_mass),
}
if infer_mode == "普通推理":
output = tts.infer(prompt, text, output_path) # 普通推理
output = tts.infer(prompt, text, output_path, verbose=cmd_args.verbose,
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
**kwargs)
else:
output = tts.infer_fast(prompt, text, output_path) # 批次推理
# 批次推理
output = tts.infer_fast(prompt, text, output_path, verbose=cmd_args.verbose,
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
sentences_bucket_max_size=(sentences_bucket_max_size),
**kwargs)
return gr.update(value=output,visible=True)
def update_prompt_audio():
update_button = gr.update(interactive=True)
return update_button
with gr.Blocks() as demo:
with gr.Blocks(title="IndexTTS Demo") as demo:
mutex = threading.Lock()
gr.HTML('''
<h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
<h2><center>(一款工业级可控且高效的零样本文本转语音系统)</h2>
<p align="center">
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
</p>
''')
with gr.Tab("音频生成"):
with gr.Row():
os.makedirs("prompts",exist_ok=True)
prompt_audio = gr.Audio(label="请上传参考音频",key="prompt_audio",
prompt_audio = gr.Audio(label="参考音频",key="prompt_audio",
sources=["upload","microphone"],type="filepath")
prompt_list = os.listdir("prompts")
default = ''
if prompt_list:
default = prompt_list[0]
with gr.Column():
input_text_single = gr.TextArea(label="请输入目标文本",key="input_text_single")
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="选择推理模式(批次推理:更适合长句,性能翻倍)",value="普通推理")
gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
input_text_single = gr.TextArea(label="文本",key="input_text_single", placeholder="请输入目标文本", info="当前模型版本{}".format(tts.model_version or "1.0"))
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="推理模式",info="批次推理:更适合长句,性能翻倍",value="普通推理")
gen_button = gr.Button("生成语音", key="gen_button",interactive=True)
output_audio = gr.Audio(label="生成结果", visible=True,key="output_audio")
with gr.Accordion("高级生成参数设置", open=False):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**GPT2 采样设置** _参数会影响音频多样性和生成速度详见[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_")
with gr.Row():
do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样")
temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
with gr.Row():
top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1)
num_beams = gr.Slider(label="num_beams", value=3, minimum=1, maximum=10, step=1)
with gr.Row():
repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
max_mel_tokens = gr.Slider(label="max_mel_tokens", value=600, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量过小导致音频被截断", key="max_mel_tokens")
# with gr.Row():
# typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
# typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
with gr.Column(scale=2):
gr.Markdown("**分句设置** _参数会影响音频质量和生成速度_")
with gr.Row():
max_text_tokens_per_sentence = gr.Slider(
label="分句最大Token数", value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence",
info="建议80~200之间值越大分句越长值越小分句越碎过小过大都可能导致音频质量不高",
)
sentences_bucket_max_size = gr.Slider(
label="分句分桶的最大容量(批次推理生效)", value=4, minimum=1, maximum=16, step=1, key="sentences_bucket_max_size",
info="建议2-8之间值越大一批次推理包含的分句数越多过大可能导致内存溢出",
)
with gr.Accordion("预览分句结果", open=True) as sentences_settings:
sentences_preview = gr.Dataframe(
headers=["序号", "分句内容", "Token数"],
key="sentences_preview",
wrap=True,
)
advanced_params = [
do_sample, top_p, top_k, temperature,
length_penalty, num_beams, repetition_penalty, max_mel_tokens,
# typical_sampling, typical_mass,
]
if len(example_cases) > 0:
gr.Examples(
examples=example_cases,
inputs=[prompt_audio, input_text_single, infer_mode],
)
def on_input_text_change(text, max_tokens_per_sentence):
if text and len(text) > 0:
text_tokens_list = tts.tokenizer.tokenize(text)
sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence))
data = []
for i, s in enumerate(sentences):
sentence_str = ''.join(s)
tokens_count = len(s)
data.append([i, sentence_str, tokens_count])
return {
sentences_preview: gr.update(value=data, visible=True, type="array"),
}
else:
df = pd.DataFrame([], columns=["序号", "分句内容", "Token数"])
return {
sentences_preview: gr.update(value=df)
}
input_text_single.change(
on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_sentence],
outputs=[sentences_preview]
)
max_text_tokens_per_sentence.change(
on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_sentence],
outputs=[sentences_preview]
)
prompt_audio.upload(update_prompt_audio,
inputs=[],
outputs=[gen_button])
gen_button.click(gen_single,
inputs=[prompt_audio, input_text_single, infer_mode],
inputs=[prompt_audio, input_text_single, infer_mode,
max_text_tokens_per_sentence, sentences_bucket_max_size,
*advanced_params,
],
outputs=[output_audio])
if __name__ == "__main__":
demo.queue(20)
demo.launch(server_name="127.0.0.1")
demo.launch(server_name=cmd_args.host, server_port=cmd_args.port)