Merge pull request #160 from yrom/fix-infer
适配1.5版本模型,优化Webui,适配新版本transformers
This commit is contained in:
commit
c0c17fe387
@ -149,6 +149,7 @@ wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bpe.model -P che
|
||||
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/dvae.pth -P checkpoints
|
||||
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/gpt.pth -P checkpoints
|
||||
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/unigram_12000.vocab -P checkpoints
|
||||
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/config.yaml -P checkpoints
|
||||
```
|
||||
|
||||
4. Run test script:
|
||||
@ -180,6 +181,9 @@ indextts --help
|
||||
```bash
|
||||
pip install -e ".[webui]"
|
||||
python webui.py
|
||||
|
||||
# use another model version:
|
||||
python webui.py --model_dir IndexTTS-1.5
|
||||
```
|
||||
Open your browser and visit `http://127.0.0.1:7860` to see the demo.
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ class ResBlock(nn.Module):
|
||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
|
||||
super().__init__(config)
|
||||
# Note: the argument named `text_pos_emb` here actually represents the mel position embedding
|
||||
self.transformer = gpt
|
||||
self.text_pos_embedding = text_pos_emb
|
||||
self.embeddings = embeddings
|
||||
@ -97,7 +98,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
position_ids.masked_fill_(attention_mask == 0, 0)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
@ -134,7 +135,6 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# Create embedding
|
||||
mel_len = self.cached_mel_emb.shape[1]
|
||||
if input_ids.shape[1] != 1:
|
||||
@ -388,7 +388,7 @@ class UnifiedVoice(nn.Module):
|
||||
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
|
||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=self.max_mel_tokens,
|
||||
vocab_size=self.number_mel_codes,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.model_dim,
|
||||
@ -556,7 +556,6 @@ class UnifiedVoice(nn.Module):
|
||||
# mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
|
||||
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
|
||||
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
|
||||
|
||||
text_inputs = self.set_text_padding(text_inputs, text_lengths)
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
|
||||
@ -589,37 +588,121 @@ class UnifiedVoice(nn.Module):
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
|
||||
def inference_speech(self, speech_conditioning_latent, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
|
||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
|
||||
conds = speech_conditioning_latent
|
||||
emb = torch.cat([conds, text_emb], dim=1)
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
|
||||
# +1 for the start_audio_token
|
||||
fake_inputs = torch.full((emb.shape[0], emb.shape[1] + 1,), fill_value=1, dtype=torch.long,
|
||||
device=text_inputs.device)
|
||||
|
||||
def prepare_gpt_inputs(
|
||||
self,
|
||||
conditional_latents: torch.Tensor,
|
||||
text_inputs: torch.Tensor,
|
||||
):
|
||||
|
||||
"""
|
||||
Prepare the inputs for the GPT2InferenceModel to generate.
|
||||
Args:
|
||||
conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()`
|
||||
text_inputs: (b, L)
|
||||
Returns:
|
||||
input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate()
|
||||
inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward()
|
||||
attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate()
|
||||
"""
|
||||
b, L = text_inputs.shape[:2]
|
||||
device = text_inputs.device
|
||||
single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1
|
||||
if not single_cond:
|
||||
assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}"
|
||||
batched_mel_emb = []
|
||||
attention_masks = []
|
||||
target_len = conditional_latents.shape[1] + L + 2
|
||||
for i in range(b):
|
||||
valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token)
|
||||
text_input = text_inputs[i][valid_mask]
|
||||
text_input = F.pad(text_input, (1, 0), value=self.start_text_token)
|
||||
text_input = F.pad(text_input, (0, 1), value=self.stop_text_token)
|
||||
text_input_pos = torch.arange(0, text_input.size(-1), device=device)
|
||||
text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos)
|
||||
# concatenate [conditional latents][text embeddings]
|
||||
conds_text_emb = [
|
||||
conditional_latents.squeeze(0) if single_cond else conditional_latents[i],
|
||||
text_emb,
|
||||
]
|
||||
# +1 for the start_mel_token
|
||||
attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device)
|
||||
# check this text input is padded
|
||||
padding: int = L + 2 - text_input.size(-1)
|
||||
# pad left of [cond][text] -> [pad][cond][text]
|
||||
if padding > 0:
|
||||
pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim]
|
||||
conds_text_emb.insert(0, pad)
|
||||
attention_mask[:padding] = 0
|
||||
mel_emb = torch.cat(conds_text_emb) #[s, dim]
|
||||
assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}"
|
||||
batched_mel_emb.append(mel_emb)
|
||||
attention_masks.append(attention_mask)
|
||||
# [b, s, dim]
|
||||
batched_mel_emb = torch.stack(batched_mel_emb, dim=0)
|
||||
# [b, s+1]
|
||||
attention_mask = torch.stack(attention_masks, dim=0)
|
||||
# [b, s+1]
|
||||
fake_inputs = torch.ones(
|
||||
(
|
||||
batched_mel_emb.shape[0],
|
||||
batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token
|
||||
),
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
trunc_index = fake_inputs.shape[1]
|
||||
return fake_inputs, batched_mel_emb, attention_mask
|
||||
def inference_speech(self, speech_conditioning_mel, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
|
||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
||||
"""
|
||||
Args:
|
||||
speech_conditioning_mel: (b, n_mels, frames) or (n_mels, frames)
|
||||
text_inputs: (b, L)
|
||||
cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,)
|
||||
input_tokens: additional tokens for generation in shape (b, s) or (s,)
|
||||
max_generate_length: limit the number of generated tokens
|
||||
hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)`
|
||||
"""
|
||||
if speech_conditioning_mel.ndim == 2:
|
||||
speech_conditioning_mel = speech_conditioning_mel.unsqueeze(0)
|
||||
if cond_mel_lengths is None:
|
||||
cond_mel_lengths = torch.tensor([speech_conditioning_mel.shape[-1]], device=speech_conditioning_mel.device)
|
||||
conds_latent = self.get_conditioning(speech_conditioning_mel, cond_mel_lengths)
|
||||
input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs)
|
||||
self.inference_model.store_mel_emb(inputs_embeds)
|
||||
if input_tokens is None:
|
||||
inputs = fake_inputs
|
||||
inputs = input_ids
|
||||
else:
|
||||
assert num_return_sequences % input_tokens.shape[
|
||||
0] == 0, "The number of return sequences must be divisible by the number of input sequences"
|
||||
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
|
||||
if input_tokens.ndim == 1:
|
||||
input_tokens = input_tokens.unsqueeze(0)
|
||||
assert num_return_sequences % input_tokens.shape[0] == 0, \
|
||||
"The num_return_sequences must be divisible by the batch number of input_tokens"
|
||||
assert num_return_sequences % text_inputs.shape[0] == 0, \
|
||||
"The num_return_sequences must be divisible by the batch number of text_inputs"
|
||||
b = num_return_sequences // input_ids.shape[0]
|
||||
if b > 1:
|
||||
input_ids = input_ids.repeat(b, 1)
|
||||
attention_mask = attention_mask.repeat(b, 1)
|
||||
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
|
||||
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
||||
|
||||
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
||||
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
||||
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
|
||||
eos_token_id=self.stop_mel_token,
|
||||
inputs = torch.cat([input_ids, input_tokens], dim=1)
|
||||
attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1)
|
||||
trunc_index = inputs.shape[1]
|
||||
logits_processor = LogitsProcessorList()
|
||||
if typical_sampling:
|
||||
# employ custom typical sampling
|
||||
if not (typical_mass > 0.0 and typical_mass < 1.0):
|
||||
raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}")
|
||||
min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
|
||||
logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
|
||||
max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
|
||||
output = self.inference_model.generate(inputs,
|
||||
bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
|
||||
eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
|
||||
max_length=max_length, logits_processor=logits_processor,
|
||||
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
||||
return gen[:, trunc_index:]
|
||||
num_return_sequences=num_return_sequences,
|
||||
**hf_generate_kwargs)
|
||||
if isinstance(output, torch.Tensor):
|
||||
return output[:, trunc_index:]
|
||||
# GenerateOutput
|
||||
output.sequences = output.sequences[:, trunc_index:]
|
||||
return output
|
||||
|
||||
@ -2,7 +2,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from subprocess import CalledProcessError
|
||||
from typing import List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
@ -125,8 +125,13 @@ class IndexTTS:
|
||||
self.cache_cond_mel = None
|
||||
# 进度引用显示(可选)
|
||||
self.gr_progress = None
|
||||
self.model_version = self.cfg.version if hasattr(self.cfg, "version") else None
|
||||
|
||||
def remove_long_silence(self, codes: torch.Tensor, silent_token=52, max_consecutive=30):
|
||||
"""
|
||||
Shrink special tokens (silent_token and stop_mel_token) in codes
|
||||
codes: [B, T]
|
||||
"""
|
||||
code_lens = []
|
||||
codes_list = []
|
||||
device = codes.device
|
||||
@ -134,59 +139,113 @@ class IndexTTS:
|
||||
isfix = False
|
||||
for i in range(0, codes.shape[0]):
|
||||
code = codes[i]
|
||||
if self.cfg.gpt.stop_mel_token not in code:
|
||||
code_lens.append(len(code))
|
||||
len_ = len(code)
|
||||
if not torch.any(code == self.stop_mel_token).item():
|
||||
len_ = code.size(0)
|
||||
else:
|
||||
# len_ = code.cpu().tolist().index(8193)+1
|
||||
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1
|
||||
len_ = len_ - 2
|
||||
stop_mel_idx = (code == self.stop_mel_token).nonzero(as_tuple=False)
|
||||
len_ = stop_mel_idx[0].item() if len(stop_mel_idx) > 0 else code.size(0)
|
||||
|
||||
count = torch.sum(code == silent_token).item()
|
||||
if count > max_consecutive:
|
||||
code = code.cpu().tolist()
|
||||
ncode = []
|
||||
# code = code.cpu().tolist()
|
||||
ncode_idx = []
|
||||
n = 0
|
||||
for k in range(0, len_):
|
||||
for k in range(len_):
|
||||
assert code[k] != self.stop_mel_token, f"stop_mel_token {self.stop_mel_token} should be shrinked here"
|
||||
if code[k] != silent_token:
|
||||
ncode.append(code[k])
|
||||
ncode_idx.append(k)
|
||||
n = 0
|
||||
elif code[k] == silent_token and n < 10:
|
||||
ncode.append(code[k])
|
||||
ncode_idx.append(k)
|
||||
n += 1
|
||||
# if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52):
|
||||
# n += 1
|
||||
len_ = len(ncode)
|
||||
ncode = torch.LongTensor(ncode)
|
||||
codes_list.append(ncode.to(device, dtype=dtype))
|
||||
# new code
|
||||
len_ = len(ncode_idx)
|
||||
codes_list.append(code[ncode_idx])
|
||||
isfix = True
|
||||
# codes[i] = self.stop_mel_token
|
||||
# codes[i, 0:len_] = ncode
|
||||
else:
|
||||
codes_list.append(codes[i])
|
||||
# shrink to len_
|
||||
codes_list.append(code[:len_])
|
||||
code_lens.append(len_)
|
||||
|
||||
codes = pad_sequence(codes_list, batch_first=True) if isfix else codes[:, :-2]
|
||||
code_lens = torch.LongTensor(code_lens).to(device, dtype=dtype)
|
||||
if isfix:
|
||||
if len(codes_list) > 1:
|
||||
codes = pad_sequence(codes_list, batch_first=True, padding_value=self.stop_mel_token)
|
||||
else:
|
||||
codes = codes_list[0].unsqueeze(0)
|
||||
else:
|
||||
# unchanged
|
||||
pass
|
||||
# clip codes to max length
|
||||
max_len = max(code_lens)
|
||||
if max_len < codes.shape[1]:
|
||||
codes = codes[:, :max_len]
|
||||
code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
|
||||
return codes, code_lens
|
||||
|
||||
def bucket_sentences(self, sentences, enable=False):
|
||||
def bucket_sentences(self, sentences, bucket_max_size=4) -> List[List[Dict]]:
|
||||
"""
|
||||
Sentence data bucketing
|
||||
Sentence data bucketing.
|
||||
if ``bucket_max_size=1``, return all sentences in one bucket.
|
||||
"""
|
||||
max_len = max(len(s) for s in sentences)
|
||||
half = max_len // 2
|
||||
outputs = [[], []]
|
||||
outputs: List[Dict] = []
|
||||
for idx, sent in enumerate(sentences):
|
||||
if enable is False or len(sent) <= half:
|
||||
outputs[0].append({"idx": idx, "sent": sent})
|
||||
else:
|
||||
outputs[1].append({"idx": idx, "sent": sent})
|
||||
return [item for item in outputs if item]
|
||||
outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
|
||||
|
||||
if len(outputs) > bucket_max_size:
|
||||
# split sentences into buckets by sentence length
|
||||
buckets: List[List[Dict]] = []
|
||||
factor = 1.5
|
||||
last_bucket = None
|
||||
last_bucket_sent_len_median = 0
|
||||
|
||||
def pad_tokens_cat(self, tokens: List[torch.Tensor]):
|
||||
if len(tokens) <= 1:
|
||||
return tokens[-1]
|
||||
for sent in sorted(outputs, key=lambda x: x["len"]):
|
||||
current_sent_len = sent["len"]
|
||||
if current_sent_len == 0:
|
||||
print(">> skip empty sentence")
|
||||
continue
|
||||
if last_bucket is None \
|
||||
or current_sent_len >= int(last_bucket_sent_len_median * factor) \
|
||||
or len(last_bucket) >= bucket_max_size:
|
||||
# new bucket
|
||||
buckets.append([sent])
|
||||
last_bucket = buckets[-1]
|
||||
last_bucket_sent_len_median = current_sent_len
|
||||
else:
|
||||
# current bucket can hold more sentences
|
||||
last_bucket.append(sent) # sorted
|
||||
mid = len(last_bucket) // 2
|
||||
last_bucket_sent_len_median = last_bucket[mid]["len"]
|
||||
last_bucket=None
|
||||
# merge all buckets with size 1
|
||||
out_buckets: List[List[Dict]] = []
|
||||
only_ones: List[Dict] = []
|
||||
for b in buckets:
|
||||
if len(b) == 1:
|
||||
only_ones.append(b[0])
|
||||
else:
|
||||
out_buckets.append(b)
|
||||
if len(only_ones) > 0:
|
||||
# merge into previous buckets if possible
|
||||
# print("only_ones:", [(o["idx"], o["len"]) for o in only_ones])
|
||||
for i in range(len(out_buckets)):
|
||||
b = out_buckets[i]
|
||||
if len(b) < bucket_max_size:
|
||||
b.append(only_ones.pop(0))
|
||||
if len(only_ones) == 0:
|
||||
break
|
||||
# combined all remaining sized 1 buckets
|
||||
if len(only_ones) > 0:
|
||||
out_buckets.extend([only_ones[i:i+bucket_max_size] for i in range(0, len(only_ones), bucket_max_size)])
|
||||
return out_buckets
|
||||
return [outputs]
|
||||
|
||||
def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor:
|
||||
if self.model_version and self.model_version >= 1.5:
|
||||
# 1.5版本以上,直接使用stop_text_token 右侧填充,填充到最大长度
|
||||
# [1, N] -> [N,]
|
||||
tokens = [t.squeeze(0) for t in tokens]
|
||||
return pad_sequence(tokens, batch_first=True, padding_value=self.cfg.gpt.stop_text_token, padding_side="right")
|
||||
max_len = max(t.size(1) for t in tokens)
|
||||
outputs = []
|
||||
for tensor in tokens:
|
||||
@ -214,8 +273,18 @@ class IndexTTS:
|
||||
self.gr_progress(value, desc=desc)
|
||||
|
||||
# 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16)
|
||||
def infer_fast(self, audio_prompt, text, output_path, verbose=False):
|
||||
def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=100, sentences_bucket_max_size=4, **generation_kwargs):
|
||||
"""
|
||||
Args:
|
||||
``max_text_tokens_per_sentence``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整
|
||||
- 越小,batch 越多,推理速度越*快*,占用内存更多,可能影响质量
|
||||
- 越大,batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
|
||||
``sentences_bucket_max_size``: 分句分桶的最大容量,默认``4``,可以根据GPU内存调整
|
||||
- 越大,bucket数量越少,batch越多,推理速度越*快*,占用内存更多,可能影响质量
|
||||
- 越小,bucket数量越多,batch越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
|
||||
"""
|
||||
print(">> start fast inference...")
|
||||
|
||||
self._set_gr_progress(0, "start fast inference...")
|
||||
if verbose:
|
||||
print(f"origin text:{text}")
|
||||
@ -245,20 +314,22 @@ class IndexTTS:
|
||||
|
||||
# text_tokens
|
||||
text_tokens_list = self.tokenizer.tokenize(text)
|
||||
sentences = self.tokenizer.split_sentences(text_tokens_list)
|
||||
if verbose:
|
||||
print("text token count:", len(text_tokens_list))
|
||||
print("sentences count:", len(sentences))
|
||||
print(*sentences, sep="\n")
|
||||
|
||||
top_p = 0.8
|
||||
top_k = 30
|
||||
temperature = 1.0
|
||||
sentences = self.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=max_text_tokens_per_sentence)
|
||||
if verbose:
|
||||
print(">> text token count:", len(text_tokens_list))
|
||||
print(" splited sentences count:", len(sentences))
|
||||
print(" max_text_tokens_per_sentence:", max_text_tokens_per_sentence)
|
||||
print(*sentences, sep="\n")
|
||||
do_sample = generation_kwargs.pop("do_sample", True)
|
||||
top_p = generation_kwargs.pop("top_p", 0.8)
|
||||
top_k = generation_kwargs.pop("top_k", 30)
|
||||
temperature = generation_kwargs.pop("temperature", 1.0)
|
||||
autoregressive_batch_size = 1
|
||||
length_penalty = 0.0
|
||||
num_beams = 3
|
||||
repetition_penalty = 10.0
|
||||
max_mel_tokens = 600
|
||||
length_penalty = generation_kwargs.pop("length_penalty", 0.0)
|
||||
num_beams = generation_kwargs.pop("num_beams", 3)
|
||||
repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
|
||||
max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600)
|
||||
sampling_rate = 24000
|
||||
# lang = "EN"
|
||||
# lang = "ZH"
|
||||
@ -270,8 +341,13 @@ class IndexTTS:
|
||||
# text processing
|
||||
all_text_tokens: List[List[torch.Tensor]] = []
|
||||
self._set_gr_progress(0.1, "text processing...")
|
||||
bucket_enable = True # 预分桶开关,优先保证质量=True。优先保证速度=False。
|
||||
all_sentences = self.bucket_sentences(sentences, enable=bucket_enable)
|
||||
bucket_max_size = sentences_bucket_max_size if self.device != "cpu" else 1
|
||||
all_sentences = self.bucket_sentences(sentences, bucket_max_size=bucket_max_size)
|
||||
bucket_count = len(all_sentences)
|
||||
if verbose:
|
||||
print(">> sentences bucket_count:", bucket_count,
|
||||
"bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_sentences],
|
||||
"bucket_max_size:", bucket_max_size)
|
||||
for sentences in all_sentences:
|
||||
temp_tokens: List[torch.Tensor] = []
|
||||
all_text_tokens.append(temp_tokens)
|
||||
@ -289,24 +365,25 @@ class IndexTTS:
|
||||
|
||||
|
||||
# Sequential processing of bucketing data
|
||||
all_batch_num = 0
|
||||
all_batch_num = sum(len(s) for s in all_sentences)
|
||||
all_batch_codes = []
|
||||
processed_num = 0
|
||||
for item_tokens in all_text_tokens:
|
||||
batch_num = len(item_tokens)
|
||||
batch_text_tokens = self.pad_tokens_cat(item_tokens)
|
||||
batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0)
|
||||
batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0)
|
||||
all_batch_num += batch_num
|
||||
|
||||
if batch_num > 1:
|
||||
batch_text_tokens = self.pad_tokens_cat(item_tokens)
|
||||
else:
|
||||
batch_text_tokens = item_tokens[0]
|
||||
processed_num += batch_num
|
||||
# gpt speech
|
||||
self._set_gr_progress(0.2, "gpt inference speech...")
|
||||
self._set_gr_progress(0.2 + 0.3 * processed_num/all_batch_num, f"gpt inference speech... {processed_num}/{all_batch_num}")
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
temp_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens,
|
||||
cond_mel_lengths=batch_cond_mel_lengths,
|
||||
temp_codes = self.gpt.inference_speech(auto_conditioning, batch_text_tokens,
|
||||
cond_mel_lengths=cond_mel_lengths,
|
||||
# text_lengths=text_len,
|
||||
do_sample=True,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
@ -314,7 +391,8 @@ class IndexTTS:
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
max_generate_length=max_mel_tokens,
|
||||
**generation_kwargs)
|
||||
all_batch_codes.append(temp_codes)
|
||||
gpt_gen_time += time.perf_counter() - m_start_time
|
||||
|
||||
@ -322,14 +400,26 @@ class IndexTTS:
|
||||
self._set_gr_progress(0.5, "gpt inference latents...")
|
||||
all_idxs = []
|
||||
all_latents = []
|
||||
has_warned = False
|
||||
for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences):
|
||||
for i in range(batch_codes.shape[0]):
|
||||
codes = batch_codes[i] # [x]
|
||||
codes = codes[codes != self.cfg.gpt.stop_mel_token]
|
||||
codes, _ = torch.unique_consecutive(codes, return_inverse=True)
|
||||
if not has_warned and codes[-1] != self.stop_mel_token:
|
||||
warnings.warn(
|
||||
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
||||
f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
|
||||
category=RuntimeWarning
|
||||
)
|
||||
has_warned = True
|
||||
codes = codes.unsqueeze(0) # [x] -> [1, x]
|
||||
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
|
||||
if verbose:
|
||||
print("codes:", codes.shape)
|
||||
print(codes)
|
||||
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
|
||||
if verbose:
|
||||
print("fix codes:", codes.shape)
|
||||
print(codes)
|
||||
print("code_lens:", code_lens)
|
||||
text_tokens = batch_tokens[i]
|
||||
all_idxs.append(batch_sentences[i]["idx"])
|
||||
m_start_time = time.perf_counter()
|
||||
@ -343,14 +433,16 @@ class IndexTTS:
|
||||
return_latent=True, clip_inputs=False)
|
||||
gpt_forward_time += time.perf_counter() - m_start_time
|
||||
all_latents.append(latent)
|
||||
|
||||
del all_batch_codes, all_text_tokens, all_sentences
|
||||
# bigvgan chunk
|
||||
chunk_size = 2
|
||||
all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
|
||||
if verbose:
|
||||
print(">> all_latents:", len(all_latents))
|
||||
print(" latents length:", [l.shape[1] for l in all_latents])
|
||||
chunk_latents = [all_latents[i : i + chunk_size] for i in range(0, len(all_latents), chunk_size)]
|
||||
chunk_length = len(chunk_latents)
|
||||
latent_length = len(all_latents)
|
||||
all_latents = None
|
||||
|
||||
# bigvgan chunk decode
|
||||
self._set_gr_progress(0.7, "bigvgan decode...")
|
||||
@ -370,7 +462,7 @@ class IndexTTS:
|
||||
|
||||
# clear cache
|
||||
tqdm_progress.close() # 确保进度条被关闭
|
||||
chunk_latents.clear()
|
||||
del all_latents, chunk_latents
|
||||
end_time = time.perf_counter()
|
||||
self.torch_empty_cache()
|
||||
|
||||
@ -385,7 +477,7 @@ class IndexTTS:
|
||||
print(f">> Total fast inference time: {end_time - start_time:.2f} seconds")
|
||||
print(f">> Generated audio length: {wav_length:.2f} seconds")
|
||||
print(f">> [fast] bigvgan chunk_length: {chunk_length}")
|
||||
print(f">> [fast] batch_num: {all_batch_num} bucket_enable: {bucket_enable}")
|
||||
print(f">> [fast] batch_num: {all_batch_num} bucket_max_size: {bucket_max_size}", f"bucket_count: {bucket_count}" if bucket_max_size > 1 else "")
|
||||
print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}")
|
||||
|
||||
# save audio
|
||||
@ -403,7 +495,7 @@ class IndexTTS:
|
||||
return (sampling_rate, wav_data)
|
||||
|
||||
# 原始推理模式
|
||||
def infer(self, audio_prompt, text, output_path, verbose=False):
|
||||
def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=120, **generation_kwargs):
|
||||
print(">> start inference...")
|
||||
self._set_gr_progress(0, "start inference...")
|
||||
if verbose:
|
||||
@ -429,21 +521,24 @@ class IndexTTS:
|
||||
cond_mel_frame = cond_mel.shape[-1]
|
||||
pass
|
||||
|
||||
self._set_gr_progress(0.1, "text processing...")
|
||||
auto_conditioning = cond_mel
|
||||
text_tokens_list = self.tokenizer.tokenize(text)
|
||||
sentences = self.tokenizer.split_sentences(text_tokens_list)
|
||||
sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence)
|
||||
if verbose:
|
||||
print("text token count:", len(text_tokens_list))
|
||||
print("sentences count:", len(sentences))
|
||||
print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence)
|
||||
print(*sentences, sep="\n")
|
||||
top_p = 0.8
|
||||
top_k = 30
|
||||
temperature = 1.0
|
||||
do_sample = generation_kwargs.pop("do_sample", True)
|
||||
top_p = generation_kwargs.pop("top_p", 0.8)
|
||||
top_k = generation_kwargs.pop("top_k", 30)
|
||||
temperature = generation_kwargs.pop("temperature", 1.0)
|
||||
autoregressive_batch_size = 1
|
||||
length_penalty = 0.0
|
||||
num_beams = 3
|
||||
repetition_penalty = 10.0
|
||||
max_mel_tokens = 600
|
||||
length_penalty = generation_kwargs.pop("length_penalty", 0.0)
|
||||
num_beams = generation_kwargs.pop("num_beams", 3)
|
||||
repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
|
||||
max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600)
|
||||
sampling_rate = 24000
|
||||
# lang = "EN"
|
||||
# lang = "ZH"
|
||||
@ -451,7 +546,8 @@ class IndexTTS:
|
||||
gpt_gen_time = 0
|
||||
gpt_forward_time = 0
|
||||
bigvgan_time = 0
|
||||
|
||||
progress = 0
|
||||
has_warned = False
|
||||
for sent in sentences:
|
||||
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
|
||||
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
|
||||
@ -467,7 +563,8 @@ class IndexTTS:
|
||||
|
||||
# text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
|
||||
# print(text_len)
|
||||
|
||||
progress += 1
|
||||
self._set_gr_progress(0.2 + 0.4 * (progress-1) / len(sentences), f"gpt inference latent... {progress}/{len(sentences)}")
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
@ -475,7 +572,7 @@ class IndexTTS:
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
||||
device=text_tokens.device),
|
||||
# text_lengths=text_len,
|
||||
do_sample=True,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
@ -483,9 +580,18 @@ class IndexTTS:
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
max_generate_length=max_mel_tokens,
|
||||
**generation_kwargs)
|
||||
gpt_gen_time += time.perf_counter() - m_start_time
|
||||
# codes = codes[:, :-2]
|
||||
if not has_warned and (codes[:, -1] != self.stop_mel_token).any():
|
||||
warnings.warn(
|
||||
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
||||
f"Input text tokens: {text_tokens.shape[1]}. "
|
||||
f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
|
||||
category=RuntimeWarning
|
||||
)
|
||||
has_warned = True
|
||||
|
||||
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
|
||||
if verbose:
|
||||
print(codes, type(codes))
|
||||
@ -499,7 +605,7 @@ class IndexTTS:
|
||||
print(codes, type(codes))
|
||||
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
|
||||
print(f"code len: {code_lens}")
|
||||
|
||||
self._set_gr_progress(0.2 + 0.4 * progress / len(sentences), f"gpt inference speech... {progress}/{len(sentences)}")
|
||||
m_start_time = time.perf_counter()
|
||||
# latent, text_lens_out, code_lens_out = \
|
||||
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
@ -517,11 +623,12 @@ class IndexTTS:
|
||||
wav = wav.squeeze(1)
|
||||
|
||||
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
|
||||
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
|
||||
if verbose:
|
||||
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
|
||||
# wavs.append(wav[:, :-512])
|
||||
wavs.append(wav.cpu()) # to cpu before saving
|
||||
end_time = time.perf_counter()
|
||||
|
||||
self._set_gr_progress(0.9, "save audio...")
|
||||
wav = torch.cat(wavs, dim=1)
|
||||
wav_length = wav.shape[-1] / sampling_rate
|
||||
print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
|
||||
|
||||
@ -84,7 +84,8 @@ class TextNormalizer:
|
||||
# print(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
|
||||
# sys.path.append(model_dir)
|
||||
import platform
|
||||
|
||||
if self.zh_normalizer is not None and self.en_normalizer is not None:
|
||||
return
|
||||
if platform.system() == "Darwin":
|
||||
from wetext import Normalizer
|
||||
|
||||
@ -355,22 +356,24 @@ class TextTokenizer:
|
||||
sentences.append(current_sentence)
|
||||
else:
|
||||
# 如果当前tokens的长度超过最大限制
|
||||
if "," in current_sentence or "▁," in current_sentence:
|
||||
if not ("," in split_tokens or "▁," in split_tokens ) and ("," in current_sentence or "▁," in current_sentence):
|
||||
# 如果当前tokens中有,,则按,分割
|
||||
sub_sentences = TextTokenizer.split_sentences_by_token(
|
||||
current_sentence, [",", "▁,"], max_tokens_per_sentence=max_tokens_per_sentence
|
||||
)
|
||||
elif "-" in current_sentence:
|
||||
elif not ("-" in split_tokens ) and "-" in current_sentence:
|
||||
# 没有,,则按-分割
|
||||
sub_sentences = TextTokenizer.split_sentences_by_token(
|
||||
current_sentence, ["-"], max_tokens_per_sentence=max_tokens_per_sentence
|
||||
)
|
||||
else:
|
||||
# 按照长度分割
|
||||
sub_sentences = [
|
||||
current_sentence[:max_tokens_per_sentence],
|
||||
current_sentence[max_tokens_per_sentence:],
|
||||
]
|
||||
sub_sentences = []
|
||||
for j in range(0, len(current_sentence), max_tokens_per_sentence):
|
||||
if j + max_tokens_per_sentence < len(current_sentence):
|
||||
sub_sentences.append(current_sentence[j : j + max_tokens_per_sentence])
|
||||
else:
|
||||
sub_sentences.append(current_sentence[j:])
|
||||
warnings.warn(
|
||||
f"The tokens length of sentence exceeds limit: {max_tokens_per_sentence}, "
|
||||
f"Tokens in sentence: {current_sentence}."
|
||||
@ -448,6 +451,7 @@ if __name__ == "__main__":
|
||||
"蒂莫西·唐纳德·库克(英文名:Timothy Donald Cook),通称蒂姆·库克(Tim Cook),美国商业经理、工业工程师和工业开发商,现任苹果公司首席执行官。",
|
||||
# 长句子
|
||||
"《盗梦空间》是由美国华纳兄弟影片公司出品的电影,由克里斯托弗·诺兰执导并编剧,莱昂纳多·迪卡普里奥、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演,2010年7月16日在美国上映,2010年9月1日在中国内地上映,2020年8月28日在中国内地重映。影片剧情游走于梦境与现实之间,被定义为“发生在意识结构内的当代动作科幻片”,讲述了由莱昂纳多·迪卡普里奥扮演的造梦师,带领特工团队进入他人梦境,从他人的潜意识中盗取机密,并重塑他人梦境的故事。",
|
||||
"清晨拉开窗帘,阳光洒在窗台的Bloomixy花艺礼盒上——薰衣草香薰蜡烛唤醒嗅觉,永生花束折射出晨露般光泽。设计师将“自然绽放美学”融入每个细节:手工陶瓷花瓶可作首饰收纳,香薰精油含依兰依兰舒缓配方。限量款附赠《365天插花灵感手册》,让每个平凡日子都有花开仪式感。\n宴会厅灯光暗下的刹那,Glimmeria星月系列耳坠开始发光——瑞士冷珐琅工艺让蓝宝石如银河流动,钛合金骨架仅3.2g无负重感。设计师秘密:内置微型重力感应器,随步伐产生0.01mm振幅,打造“行走的星光”。七夕限定礼盒含星座定制铭牌,让爱意如星辰永恒闪耀。",
|
||||
]
|
||||
# 测试分词器
|
||||
tokenizer = TextTokenizer(
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
import torch
|
||||
from transformers import LogitsWarper
|
||||
from transformers import TypicalLogitsWarper as BaseTypicalLogitsWarper
|
||||
|
||||
|
||||
class TypicalLogitsWarper(LogitsWarper):
|
||||
class TypicalLogitsWarper(BaseTypicalLogitsWarper):
|
||||
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
self.filter_value = filter_value
|
||||
self.mass = mass
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
super().__init__(mass=mass, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# calculate entropy
|
||||
|
||||
8
tests/cases.jsonl
Normal file
8
tests/cases.jsonl
Normal file
@ -0,0 +1,8 @@
|
||||
{"prompt_audio":"sample_prompt.wav","text":"IndexTTS 正式发布1.0版本了,效果666","infer_mode":0}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!","infer_mode":0}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"晕XUAN4是一种GAN3觉","infer_mode":0}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"最zhong4要的是:不要chong2蹈覆辙","infer_mode":0}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"Matt Hougan, chief investment officer at Bitwise, predicts Bitcoin (BTC) will reach $200,000 by the end of 2025 due to a supply shock from heightened institutional demand. In an interview with Cointelegraph at Consensus 2025 in Toronto, the executive said that Bitwise's Bitcoin price prediction model is driven exclusively by supply and demand metrics. \"I think eventually that will exhaust sellers at the $100,000 level where we have been stuck, and I think the next stopping point above that is $200,000,\" the executive said.","infer_mode":1}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"《盗梦空间》(英语:Inception)是由美国华纳兄弟影片公司出品的电影,由克里斯托弗·诺兰(Christopher Edward Nolan)执导并编剧,莱昂纳多·迪卡普里奥(Leonardo Wilhelm DiCaprio)、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演,2010年7月16日在美国上映,2010年9月1日在中国内地上映,2020年8月28日在中国内地重映。豆瓣评分:9.4,IMDB 8.8。影片剧情游走于梦境与现实之间,被定义为“发生在意识结构内的当代动作科幻片”,讲述了由 Leonardo 扮演的造梦师,带领特工团队进入他人梦境,从他人的潜意识中盗取机密,并重塑他人梦境的故事。","infer_mode":1}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"清晨拉开窗帘,阳光洒在窗台的Bloomixy花艺礼盒上——薰衣草香薰蜡烛唤醒嗅觉,永生花束折射出晨露般光泽。设计师将“自然绽放美学”融入每个细节:手工陶瓷花瓶可作首饰收纳,香薰精油含依兰依兰舒缓配方。限量款附赠《365天插花灵感手册》,让每个平凡日子都有花开仪式感。宴会厅灯光暗下的刹那,Glimmeria星月系列耳坠开始发光——瑞士冷珐琅工艺让蓝宝石如银河流动,钛合金骨架仅3.2g无负重感。设计师秘密:内置微型重力感应器,随步伐产生0.01mm振幅,打造“行走的星光”。七夕限定礼盒含星座定制铭牌,让爱意如星辰永恒闪耀。","infer_mode":1}
|
||||
{"prompt_audio":"sample_prompt.wav","text":"当地时间15日,随着特朗普与阿联酋敲定2000亿美元协议,特朗普的中东之行正式收官。特朗普已宣布获得沙特6000亿美元和卡塔尔2430亿美元投资承诺。商业协议成为特朗普重返白宫后首次外访的核心成果。香港英文媒体《南华早报》(South China Morning Post)称,特朗普访问期间提出了以经济合作为驱动的中东及南亚和平计划。分析人士称,该战略包含多项旨在遏制中国在这些地区影响力的措施。英国伦敦国王学院安全研究教授安德烈亚斯·克里格(Dr. Andreas Krieg)对半岛电视台表示,特朗普访问海湾地区的主要目标有三个:其一,以军工投资和能源合作的形式获得海湾国家的切实承诺;其二,加强与“让美国再次伟大”运动结盟的外交伙伴关系,维持美国外交影响力;其三:将海湾国家重新定位为美国在从加沙到伊朗等地区危机前线的调解人,这样就可以不用增强军事部署。安德烈亚斯·克里格直言:“海湾国家不会为美国牺牲与中国的关系,他们的战略自主性远超特朗普想象。”美国有线电视新闻网(CNN)报道称,特朗普到访的三个能源富国,每个国家都对美国有着长长的诉求清单。尽管这些国家豪掷重金,但美国并未实现所有诉求。","infer_mode":1}
|
||||
99
tests/padding_test.py
Normal file
99
tests/padding_test.py
Normal file
@ -0,0 +1,99 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from indextts.infer import IndexTTS
|
||||
from indextts.utils.feature_extractors import MelSpectrogramFeatures
|
||||
from torch.nn import functional as F
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Test the padding of text tokens in inference.
|
||||
```
|
||||
python tests/padding_test.py checkpoints
|
||||
python tests/padding_test.py IndexTTS-1.5
|
||||
```
|
||||
"""
|
||||
import transformers
|
||||
transformers.set_seed(42)
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
if len(sys.argv) > 1:
|
||||
model_dir = sys.argv[1]
|
||||
else:
|
||||
model_dir = "checkpoints"
|
||||
audio_prompt="tests/sample_prompt.wav"
|
||||
tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False)
|
||||
text = "晕 XUAN4 是 一 种 not very good GAN3 觉"
|
||||
text_tokens = tts.tokenizer.encode(text)
|
||||
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L]
|
||||
|
||||
audio, sr = torchaudio.load(audio_prompt)
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
|
||||
auto_conditioning = MelSpectrogramFeatures()(audio).to(tts.device)
|
||||
cond_mel_lengths = torch.tensor([auto_conditioning.shape[-1]]).to(tts.device)
|
||||
with torch.no_grad():
|
||||
kwargs = {
|
||||
"cond_mel_lengths": cond_mel_lengths,
|
||||
"do_sample": False,
|
||||
"top_p": 0.8,
|
||||
"top_k": None,
|
||||
"temperature": 1.0,
|
||||
"num_return_sequences": 1,
|
||||
"length_penalty": 0.0,
|
||||
"num_beams": 1,
|
||||
"repetition_penalty": 10.0,
|
||||
"max_generate_length": 100,
|
||||
}
|
||||
# baseline for non-pad
|
||||
baseline = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs)
|
||||
baseline = baseline.squeeze(0)
|
||||
print("Inference padded text tokens...")
|
||||
pad_text_tokens = [
|
||||
F.pad(text_tokens, (8, 0), value=0), # left bos
|
||||
F.pad(text_tokens, (0, 8), value=1), # right eos
|
||||
F.pad(F.pad(text_tokens, (4, 0), value=0), (0, 4), value=1), # both side
|
||||
F.pad(F.pad(text_tokens, (6, 0), value=0), (0, 2), value=1),
|
||||
F.pad(F.pad(text_tokens, (0, 4), value=0), (0, 4), value=1),
|
||||
]
|
||||
output_for_padded = []
|
||||
for t in pad_text_tokens:
|
||||
# test for each padded text
|
||||
out = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs)
|
||||
output_for_padded.append(out.squeeze(0))
|
||||
# batched inference
|
||||
print("Inference padded text tokens as one batch...")
|
||||
batched_text_tokens = torch.cat(pad_text_tokens, dim=0).to(tts.device)
|
||||
assert len(pad_text_tokens) == batched_text_tokens.shape[0] and batched_text_tokens.ndim == 2
|
||||
batch_output = tts.gpt.inference_speech(auto_conditioning, batched_text_tokens, **kwargs)
|
||||
del pad_text_tokens
|
||||
mismatch_idx = []
|
||||
print("baseline:", baseline.shape, baseline)
|
||||
print("--"*10)
|
||||
print("baseline vs padded output:")
|
||||
for i in range(len(output_for_padded)):
|
||||
if not baseline.equal(output_for_padded[i]):
|
||||
mismatch_idx.append(i)
|
||||
|
||||
if len(mismatch_idx) > 0:
|
||||
print("mismatch:", mismatch_idx)
|
||||
for i in mismatch_idx:
|
||||
print(f"[{i}]: {output_for_padded[i]}")
|
||||
else:
|
||||
print("all matched")
|
||||
|
||||
del output_for_padded
|
||||
print("--"*10)
|
||||
print("baseline vs batched output:")
|
||||
mismatch_idx = []
|
||||
for i in range(batch_output.shape[0]):
|
||||
if not baseline.equal(batch_output[i]):
|
||||
mismatch_idx.append(i)
|
||||
if len(mismatch_idx) > 0:
|
||||
print("mismatch:", mismatch_idx)
|
||||
for i in mismatch_idx:
|
||||
print(f"[{i}]: {batch_output[i]}")
|
||||
|
||||
else:
|
||||
print("all matched")
|
||||
|
||||
print("Test finished.")
|
||||
161
webui.py
161
webui.py
@ -1,5 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@ -12,70 +12,201 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(current_dir)
|
||||
sys.path.append(os.path.join(current_dir, "indextts"))
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="IndexTTS WebUI")
|
||||
parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
|
||||
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to run the web UI on")
|
||||
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
|
||||
cmd_args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(cmd_args.model_dir):
|
||||
print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
|
||||
sys.exit(1)
|
||||
|
||||
for file in [
|
||||
"bigvgan_generator.pth",
|
||||
"bpe.model",
|
||||
"gpt.pth",
|
||||
"config.yaml",
|
||||
]:
|
||||
file_path = os.path.join(cmd_args.model_dir, file)
|
||||
if not os.path.exists(file_path):
|
||||
print(f"Required file {file_path} does not exist. Please download it.")
|
||||
sys.exit(1)
|
||||
|
||||
import gradio as gr
|
||||
from indextts.utils.webui_utils import next_page, prev_page
|
||||
|
||||
from indextts.infer import IndexTTS
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto(language="zh_CN")
|
||||
MODE = 'local'
|
||||
tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml")
|
||||
tts = IndexTTS(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),)
|
||||
|
||||
|
||||
os.makedirs("outputs/tasks",exist_ok=True)
|
||||
os.makedirs("prompts",exist_ok=True)
|
||||
|
||||
with open("tests/cases.jsonl", "r", encoding="utf-8") as f:
|
||||
example_cases = []
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
example = json.loads(line)
|
||||
example_cases.append([os.path.join("tests", example.get("prompt_audio", "sample_prompt.wav")),
|
||||
example.get("text"), ["普通推理", "批次推理"][example.get("infer_mode", 0)]])
|
||||
|
||||
def gen_single(prompt, text, infer_mode, progress=gr.Progress()):
|
||||
def gen_single(prompt, text, infer_mode, max_text_tokens_per_sentence=120, sentences_bucket_max_size=4,
|
||||
*args, progress=gr.Progress()):
|
||||
output_path = None
|
||||
if not output_path:
|
||||
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
|
||||
# set gradio progress
|
||||
tts.gr_progress = progress
|
||||
do_sample, top_p, top_k, temperature, \
|
||||
length_penalty, num_beams, repetition_penalty, max_mel_tokens = args
|
||||
kwargs = {
|
||||
"do_sample": bool(do_sample),
|
||||
"top_p": float(top_p),
|
||||
"top_k": int(top_k) if int(top_k) > 0 else None,
|
||||
"temperature": float(temperature),
|
||||
"length_penalty": float(length_penalty),
|
||||
"num_beams": num_beams,
|
||||
"repetition_penalty": float(repetition_penalty),
|
||||
"max_mel_tokens": int(max_mel_tokens),
|
||||
# "typical_sampling": bool(typical_sampling),
|
||||
# "typical_mass": float(typical_mass),
|
||||
}
|
||||
if infer_mode == "普通推理":
|
||||
output = tts.infer(prompt, text, output_path) # 普通推理
|
||||
output = tts.infer(prompt, text, output_path, verbose=cmd_args.verbose,
|
||||
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
|
||||
**kwargs)
|
||||
else:
|
||||
output = tts.infer_fast(prompt, text, output_path) # 批次推理
|
||||
# 批次推理
|
||||
output = tts.infer_fast(prompt, text, output_path, verbose=cmd_args.verbose,
|
||||
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
|
||||
sentences_bucket_max_size=(sentences_bucket_max_size),
|
||||
**kwargs)
|
||||
return gr.update(value=output,visible=True)
|
||||
|
||||
def update_prompt_audio():
|
||||
update_button = gr.update(interactive=True)
|
||||
return update_button
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Blocks(title="IndexTTS Demo") as demo:
|
||||
mutex = threading.Lock()
|
||||
gr.HTML('''
|
||||
<h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
|
||||
<h2><center>(一款工业级可控且高效的零样本文本转语音系统)</h2>
|
||||
|
||||
<p align="center">
|
||||
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
|
||||
</p>
|
||||
''')
|
||||
with gr.Tab("音频生成"):
|
||||
with gr.Row():
|
||||
os.makedirs("prompts",exist_ok=True)
|
||||
prompt_audio = gr.Audio(label="请上传参考音频",key="prompt_audio",
|
||||
prompt_audio = gr.Audio(label="参考音频",key="prompt_audio",
|
||||
sources=["upload","microphone"],type="filepath")
|
||||
prompt_list = os.listdir("prompts")
|
||||
default = ''
|
||||
if prompt_list:
|
||||
default = prompt_list[0]
|
||||
with gr.Column():
|
||||
input_text_single = gr.TextArea(label="请输入目标文本",key="input_text_single")
|
||||
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="选择推理模式(批次推理:更适合长句,性能翻倍)",value="普通推理")
|
||||
gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
|
||||
input_text_single = gr.TextArea(label="文本",key="input_text_single", placeholder="请输入目标文本", info="当前模型版本{}".format(tts.model_version or "1.0"))
|
||||
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="推理模式",info="批次推理:更适合长句,性能翻倍",value="普通推理")
|
||||
gen_button = gr.Button("生成语音", key="gen_button",interactive=True)
|
||||
output_audio = gr.Audio(label="生成结果", visible=True,key="output_audio")
|
||||
with gr.Accordion("高级生成参数设置", open=False):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("**GPT2 采样设置** _参数会影响音频多样性和生成速度详见[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_")
|
||||
with gr.Row():
|
||||
do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样")
|
||||
temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
|
||||
with gr.Row():
|
||||
top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
|
||||
top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1)
|
||||
num_beams = gr.Slider(label="num_beams", value=3, minimum=1, maximum=10, step=1)
|
||||
with gr.Row():
|
||||
repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
|
||||
length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
|
||||
max_mel_tokens = gr.Slider(label="max_mel_tokens", value=600, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量,过小导致音频被截断", key="max_mel_tokens")
|
||||
# with gr.Row():
|
||||
# typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
|
||||
# typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
|
||||
with gr.Column(scale=2):
|
||||
gr.Markdown("**分句设置** _参数会影响音频质量和生成速度_")
|
||||
with gr.Row():
|
||||
max_text_tokens_per_sentence = gr.Slider(
|
||||
label="分句最大Token数", value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence",
|
||||
info="建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高",
|
||||
)
|
||||
sentences_bucket_max_size = gr.Slider(
|
||||
label="分句分桶的最大容量(批次推理生效)", value=4, minimum=1, maximum=16, step=1, key="sentences_bucket_max_size",
|
||||
info="建议2-8之间,值越大,一批次推理包含的分句数越多,过大可能导致内存溢出",
|
||||
)
|
||||
with gr.Accordion("预览分句结果", open=True) as sentences_settings:
|
||||
sentences_preview = gr.Dataframe(
|
||||
headers=["序号", "分句内容", "Token数"],
|
||||
key="sentences_preview",
|
||||
wrap=True,
|
||||
)
|
||||
advanced_params = [
|
||||
do_sample, top_p, top_k, temperature,
|
||||
length_penalty, num_beams, repetition_penalty, max_mel_tokens,
|
||||
# typical_sampling, typical_mass,
|
||||
]
|
||||
|
||||
if len(example_cases) > 0:
|
||||
gr.Examples(
|
||||
examples=example_cases,
|
||||
inputs=[prompt_audio, input_text_single, infer_mode],
|
||||
)
|
||||
|
||||
def on_input_text_change(text, max_tokens_per_sentence):
|
||||
if text and len(text) > 0:
|
||||
text_tokens_list = tts.tokenizer.tokenize(text)
|
||||
|
||||
sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence))
|
||||
data = []
|
||||
for i, s in enumerate(sentences):
|
||||
sentence_str = ''.join(s)
|
||||
tokens_count = len(s)
|
||||
data.append([i, sentence_str, tokens_count])
|
||||
|
||||
return {
|
||||
sentences_preview: gr.update(value=data, visible=True, type="array"),
|
||||
}
|
||||
else:
|
||||
df = pd.DataFrame([], columns=["序号", "分句内容", "Token数"])
|
||||
return {
|
||||
sentences_preview: gr.update(value=df)
|
||||
}
|
||||
|
||||
input_text_single.change(
|
||||
on_input_text_change,
|
||||
inputs=[input_text_single, max_text_tokens_per_sentence],
|
||||
outputs=[sentences_preview]
|
||||
)
|
||||
max_text_tokens_per_sentence.change(
|
||||
on_input_text_change,
|
||||
inputs=[input_text_single, max_text_tokens_per_sentence],
|
||||
outputs=[sentences_preview]
|
||||
)
|
||||
prompt_audio.upload(update_prompt_audio,
|
||||
inputs=[],
|
||||
outputs=[gen_button])
|
||||
|
||||
gen_button.click(gen_single,
|
||||
inputs=[prompt_audio, input_text_single, infer_mode],
|
||||
inputs=[prompt_audio, input_text_single, infer_mode,
|
||||
max_text_tokens_per_sentence, sentences_bucket_max_size,
|
||||
*advanced_params,
|
||||
],
|
||||
outputs=[output_audio])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue(20)
|
||||
demo.launch(server_name="127.0.0.1")
|
||||
demo.launch(server_name=cmd_args.host, server_port=cmd_args.port)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user