添加警告提示:生成停止因超出 max_mel_tokens 限制

This commit is contained in:
yrom 2025-05-18 16:29:17 +08:00
parent 7e52976bd1
commit 96d3b75708

View File

@ -399,9 +399,17 @@ class IndexTTS:
self._set_gr_progress(0.5, "gpt inference latents...")
all_idxs = []
all_latents = []
has_warned = False
for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences):
for i in range(batch_codes.shape[0]):
codes = batch_codes[i] # [x]
if not has_warned and codes[-1] != self.stop_mel_token:
warnings.warn(
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
category=RuntimeWarning
)
has_warned = True
codes = codes.unsqueeze(0) # [x] -> [1, x]
if verbose:
print("codes:", codes.shape)
@ -538,6 +546,7 @@ class IndexTTS:
gpt_forward_time = 0
bigvgan_time = 0
progress = 0
has_warned = False
for sent in sentences:
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
@ -572,7 +581,15 @@ class IndexTTS:
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
gpt_gen_time += time.perf_counter() - m_start_time
# codes = codes[:, :-2]
if not has_warned and (codes[:, -1] != self.stop_mel_token).any():
warnings.warn(
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
f"Input text tokens: {text_tokens.shape[1]}. "
f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
category=RuntimeWarning
)
has_warned = True
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
if verbose:
print(codes, type(codes))