From cb6c73d391cdc413fe1e8c4130b2d6f9860035ac Mon Sep 17 00:00:00 2001 From: yrom Date: Sat, 17 May 2025 11:16:54 +0800 Subject: [PATCH 01/11] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=96=87=E6=9C=AC?= =?UTF-8?q?=E5=BD=92=E4=B8=80=E5=8C=96=E5=92=8C=E5=88=86=E5=8F=A5=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复可能的递归问题 (Fixes #124) --- indextts/utils/front.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/indextts/utils/front.py b/indextts/utils/front.py index d7ff8c7..e83a3c2 100644 --- a/indextts/utils/front.py +++ b/indextts/utils/front.py @@ -84,7 +84,8 @@ class TextNormalizer: # print(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) # sys.path.append(model_dir) import platform - + if self.zh_normalizer is not None and self.en_normalizer is not None: + return if platform.system() == "Darwin": from wetext import Normalizer @@ -355,22 +356,24 @@ class TextTokenizer: sentences.append(current_sentence) else: # 如果当前tokens的长度超过最大限制 - if "," in current_sentence or "▁," in current_sentence: + if not ("," in split_tokens or "▁," in split_tokens ) and ("," in current_sentence or "▁," in current_sentence): # 如果当前tokens中有,,则按,分割 sub_sentences = TextTokenizer.split_sentences_by_token( current_sentence, [",", "▁,"], max_tokens_per_sentence=max_tokens_per_sentence ) - elif "-" in current_sentence: + elif not ("-" in split_tokens ) and "-" in current_sentence: # 没有,,则按-分割 sub_sentences = TextTokenizer.split_sentences_by_token( current_sentence, ["-"], max_tokens_per_sentence=max_tokens_per_sentence ) else: # 按照长度分割 - sub_sentences = [ - current_sentence[:max_tokens_per_sentence], - current_sentence[max_tokens_per_sentence:], - ] + sub_sentences = [] + for j in range(0, len(current_sentence), max_tokens_per_sentence): + if j + max_tokens_per_sentence < len(current_sentence): + sub_sentences.append(current_sentence[j : j + max_tokens_per_sentence]) + else: + sub_sentences.append(current_sentence[j:]) warnings.warn( f"The tokens length of sentence exceeds limit: {max_tokens_per_sentence}, " f"Tokens in sentence: {current_sentence}." @@ -448,6 +451,7 @@ if __name__ == "__main__": "蒂莫西·唐纳德·库克(英文名:Timothy Donald Cook),通称蒂姆·库克(Tim Cook),美国商业经理、工业工程师和工业开发商,现任苹果公司首席执行官。", # 长句子 "《盗梦空间》是由美国华纳兄弟影片公司出品的电影,由克里斯托弗·诺兰执导并编剧,莱昂纳多·迪卡普里奥、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演,2010年7月16日在美国上映,2010年9月1日在中国内地上映,2020年8月28日在中国内地重映。影片剧情游走于梦境与现实之间,被定义为“发生在意识结构内的当代动作科幻片”,讲述了由莱昂纳多·迪卡普里奥扮演的造梦师,带领特工团队进入他人梦境,从他人的潜意识中盗取机密,并重塑他人梦境的故事。", + "清晨拉开窗帘,阳光洒在窗台的Bloomixy花艺礼盒上——薰衣草香薰蜡烛唤醒嗅觉,永生花束折射出晨露般光泽。设计师将“自然绽放美学”融入每个细节:手工陶瓷花瓶可作首饰收纳,香薰精油含依兰依兰舒缓配方。限量款附赠《365天插花灵感手册》,让每个平凡日子都有花开仪式感。\n宴会厅灯光暗下的刹那,Glimmeria星月系列耳坠开始发光——瑞士冷珐琅工艺让蓝宝石如银河流动,钛合金骨架仅3.2g无负重感。设计师秘密:内置微型重力感应器,随步伐产生0.01mm振幅,打造“行走的星光”。七夕限定礼盒含星座定制铭牌,让爱意如星辰永恒闪耀。", ] # 测试分词器 tokenizer = TextTokenizer( From 8f7c1f3e93f485b5bc4b90c245643aff7cf981c8 Mon Sep 17 00:00:00 2001 From: yrom Date: Sat, 17 May 2025 14:38:01 +0800 Subject: [PATCH 02/11] =?UTF-8?q?=E4=BC=98=E5=8C=96inference=20attention?= =?UTF-8?q?=20mask?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- indextts/gpt/model.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/indextts/gpt/model.py b/indextts/gpt/model.py index 2906b31..5852650 100644 --- a/indextts/gpt/model.py +++ b/indextts/gpt/model.py @@ -592,10 +592,11 @@ class UnifiedVoice(nn.Module): def inference_speech(self, speech_conditioning_latent, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1, max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): + text_masks = ((text_inputs != self.stop_text_token) & (text_inputs != self.start_text_token)).long() text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) + text_masks = F.pad(text_masks, (1, 1), value=1) # (-1, +1) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths) conds = speech_conditioning_latent emb = torch.cat([conds, text_emb], dim=1) @@ -604,7 +605,11 @@ class UnifiedVoice(nn.Module): # +1 for the start_audio_token fake_inputs = torch.full((emb.shape[0], emb.shape[1] + 1,), fill_value=1, dtype=torch.long, device=text_inputs.device) - + attention_mask = torch.cat([ + torch.ones((conds.shape[0], conds.shape[1]), dtype=torch.long, device=text_inputs.device), + text_masks, + torch.ones((conds.shape[0], 1), dtype=torch.long, device=text_inputs.device), + ], dim=1) fake_inputs[:, -1] = self.start_mel_token trunc_index = fake_inputs.shape[1] if input_tokens is None: @@ -614,12 +619,15 @@ class UnifiedVoice(nn.Module): 0] == 0, "The number of return sequences must be divisible by the number of input sequences" fake_inputs = fake_inputs.repeat(num_return_sequences, 1) input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) + input_tokens_mask = ((input_tokens != self.stop_text_token) & (input_tokens != self.start_text_token)).long() inputs = torch.cat([fake_inputs, input_tokens], dim=1) + attention_mask = torch.cat([attention_mask.repeat(num_return_sequences, 1), input_tokens_mask], dim=1) + logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, - eos_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, attention_mask=attention_mask, max_length=max_length, logits_processor=logits_processor, num_return_sequences=num_return_sequences, **hf_generate_kwargs) return gen[:, trunc_index:] From 4de7611bda5e2f8a10f63829ad9d6706fa94858a Mon Sep 17 00:00:00 2001 From: yrom Date: Sat, 17 May 2025 14:40:01 +0800 Subject: [PATCH 03/11] =?UTF-8?q?fix=20=E6=89=B9=E9=87=8F=E6=8E=A8?= =?UTF-8?q?=E7=90=861.5=E7=89=88=E6=9C=AC=E6=A8=A1=E5=9E=8B=E9=97=AE?= =?UTF-8?q?=E9=A2=98=EF=BC=8C=E8=B0=83=E6=95=B4=E5=88=86=E5=8F=A5=E9=80=BB?= =?UTF-8?q?=E8=BE=91=E5=92=8C=E5=8F=82=E6=95=B0=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将pad 改为全 eos token - 优化bucket_sentences 算法 --- indextts/infer.py | 158 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 109 insertions(+), 49 deletions(-) diff --git a/indextts/infer.py b/indextts/infer.py index 7bacde8..59e5c90 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -2,7 +2,7 @@ import os import re import time from subprocess import CalledProcessError -from typing import List +from typing import Dict, List, Tuple import numpy as np import sentencepiece as spm @@ -125,8 +125,13 @@ class IndexTTS: self.cache_cond_mel = None # 进度引用显示(可选) self.gr_progress = None + self.model_version = self.cfg.version if hasattr(self.cfg, "version") else None def remove_long_silence(self, codes: torch.Tensor, silent_token=52, max_consecutive=30): + """ + Shrink special tokens (silent_token and stop_mel_token) in codes + codes: [B, T] + """ code_lens = [] codes_list = [] device = codes.device @@ -134,59 +139,92 @@ class IndexTTS: isfix = False for i in range(0, codes.shape[0]): code = codes[i] - if self.cfg.gpt.stop_mel_token not in code: - code_lens.append(len(code)) - len_ = len(code) + if not torch.any(code == self.stop_mel_token).item(): + len_ = code.size(0) else: - # len_ = code.cpu().tolist().index(8193)+1 - len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1 - len_ = len_ - 2 + stop_mel_idx = (code == self.stop_mel_token).nonzero(as_tuple=False) + len_ = stop_mel_idx[0].item() if len(stop_mel_idx) > 0 else code.size(0) count = torch.sum(code == silent_token).item() if count > max_consecutive: - code = code.cpu().tolist() - ncode = [] + # code = code.cpu().tolist() + ncode_idx = [] n = 0 - for k in range(0, len_): + for k in range(len_): + assert code[k] != self.stop_mel_token, f"stop_mel_token {self.stop_mel_token} should be shrinked here" if code[k] != silent_token: - ncode.append(code[k]) + ncode_idx.append(k) n = 0 elif code[k] == silent_token and n < 10: - ncode.append(code[k]) + ncode_idx.append(k) n += 1 # if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52): # n += 1 - len_ = len(ncode) - ncode = torch.LongTensor(ncode) - codes_list.append(ncode.to(device, dtype=dtype)) + # new code + len_ = len(ncode_idx) + codes_list.append(code[ncode_idx]) isfix = True - # codes[i] = self.stop_mel_token - # codes[i, 0:len_] = ncode else: - codes_list.append(codes[i]) + # shrink to len_ + codes_list.append(code[:len_]) code_lens.append(len_) - - codes = pad_sequence(codes_list, batch_first=True) if isfix else codes[:, :-2] - code_lens = torch.LongTensor(code_lens).to(device, dtype=dtype) + if isfix: + if len(codes_list) > 1: + codes = pad_sequence(codes_list, batch_first=True, padding_value=self.stop_mel_token) + else: + codes = codes_list[0].unsqueeze(0) + else: + # unchanged + pass + # clip codes to max length + max_len = max(code_lens) + if max_len < codes.shape[1]: + codes = codes[:, :max_len] + code_lens = torch.tensor(code_lens, dtype=torch.long, device=device) return codes, code_lens - def bucket_sentences(self, sentences, enable=False): + def bucket_sentences(self, sentences, bucket_max_size=4) -> List[List[Dict]]: """ - Sentence data bucketing + Sentence data bucketing. + if ``bucket_max_size=1``, return all sentences in one bucket. """ - max_len = max(len(s) for s in sentences) - half = max_len // 2 - outputs = [[], []] + outputs: List = [] for idx, sent in enumerate(sentences): - if enable is False or len(sent) <= half: - outputs[0].append({"idx": idx, "sent": sent}) - else: - outputs[1].append({"idx": idx, "sent": sent}) - return [item for item in outputs if item] + outputs.append({"idx": idx, "sent": sent, "len": len(sent)}) + + if len(outputs) > bucket_max_size: + # split sentences into buckets by sentence length + buckets = [] + factor = 1.5 + last_bucket_sent_len_median = 0 + for sent in sorted(outputs, key=lambda x: x["len"]): + current_sent_len = sent["len"] + if current_sent_len == 0: + print(">> skip empty sentence") + continue + if last_bucket_sent_len_median == 0 \ + or current_sent_len > last_bucket_sent_len_median * factor \ + or len(buckets[-1]) > bucket_max_size: + # new bucket + buckets.append([sent]) + last_bucket_sent_len_median = current_sent_len + else: + # current bucket can hold more sentences + buckets[-1].append(sent) # sorted + mid = len(buckets[-1]) // 2 + last_bucket_sent_len_median = buckets[-1][mid]["len"] + return buckets + return [outputs] - def pad_tokens_cat(self, tokens: List[torch.Tensor]): + + def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor: if len(tokens) <= 1: return tokens[-1] + if self.model_version and self.model_version >= 1.5: + # 1.5版本以上,直接使用stop_text_token 右侧填充 + # [1, N] -> [N,] + tokens = [t.squeeze(0) for t in tokens] + return pad_sequence(tokens, batch_first=True, padding_value=self.cfg.gpt.stop_text_token, padding_side="right") max_len = max(t.size(1) for t in tokens) outputs = [] for tensor in tokens: @@ -214,8 +252,18 @@ class IndexTTS: self.gr_progress(value, desc=desc) # 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16) - def infer_fast(self, audio_prompt, text, output_path, verbose=False): + def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=100, sentences_bucket_max_size=4): + """ + Args: + ``max_text_tokens_per_sentence``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整 + - 越小,batch 越多,推理速度越*快*,占用内存更多,可能影响质量 + - 越大,batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理 + ``sentences_bucket_max_size``: 分句分桶的最大容量,默认``4``,可以根据GPU内存调整 + - 越大,bucket数量越少,batch越多,推理速度越*快*,占用内存更多,可能影响质量 + - 越小,bucket数量越多,batch越少,推理速度越*慢*,占用内存和质量更接近于非快速推理 + """ print(">> start fast inference...") + self._set_gr_progress(0, "start fast inference...") if verbose: print(f"origin text:{text}") @@ -245,10 +293,12 @@ class IndexTTS: # text_tokens text_tokens_list = self.tokenizer.tokenize(text) - sentences = self.tokenizer.split_sentences(text_tokens_list) + + sentences = self.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=max_text_tokens_per_sentence) if verbose: - print("text token count:", len(text_tokens_list)) - print("sentences count:", len(sentences)) + print(">> text token count:", len(text_tokens_list)) + print(" splited sentences count:", len(sentences)) + print(" max_text_tokens_per_sentence:", max_text_tokens_per_sentence) print(*sentences, sep="\n") top_p = 0.8 @@ -270,8 +320,11 @@ class IndexTTS: # text processing all_text_tokens: List[List[torch.Tensor]] = [] self._set_gr_progress(0.1, "text processing...") - bucket_enable = True # 预分桶开关,优先保证质量=True。优先保证速度=False。 - all_sentences = self.bucket_sentences(sentences, enable=bucket_enable) + bucket_max_size = sentences_bucket_max_size if self.device != "cpu" else 1 + all_sentences = self.bucket_sentences(sentences, bucket_max_size=bucket_max_size) + bucket_count = len(all_sentences) + if verbose: + print(">> sentences bucket_count:", bucket_count, [len(s) for s in all_sentences]) for sentences in all_sentences: temp_tokens: List[torch.Tensor] = [] all_text_tokens.append(temp_tokens) @@ -294,8 +347,8 @@ class IndexTTS: for item_tokens in all_text_tokens: batch_num = len(item_tokens) batch_text_tokens = self.pad_tokens_cat(item_tokens) - batch_cond_mel_lengths = torch.cat([cond_mel_lengths] * batch_num, dim=0) - batch_auto_conditioning = torch.cat([auto_conditioning] * batch_num, dim=0) + batch_cond_mel_lengths = cond_mel_lengths.expand(batch_num) # [batch_num] + batch_auto_conditioning = auto_conditioning.expand(batch_num, -1, -1) # [batch_num, n_mels, L] all_batch_num += batch_num # gpt speech @@ -325,11 +378,15 @@ class IndexTTS: for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences): for i in range(batch_codes.shape[0]): codes = batch_codes[i] # [x] - codes = codes[codes != self.cfg.gpt.stop_mel_token] - codes, _ = torch.unique_consecutive(codes, return_inverse=True) codes = codes.unsqueeze(0) # [x] -> [1, x] - code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype) + if verbose: + print("codes:", codes.shape) + print(codes) codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30) + if verbose: + print("fix codes:", codes.shape) + print(codes) + print("code_lens:", code_lens) text_tokens = batch_tokens[i] all_idxs.append(batch_sentences[i]["idx"]) m_start_time = time.perf_counter() @@ -343,14 +400,16 @@ class IndexTTS: return_latent=True, clip_inputs=False) gpt_forward_time += time.perf_counter() - m_start_time all_latents.append(latent) - + del all_batch_codes, all_text_tokens, all_sentences # bigvgan chunk chunk_size = 2 all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))] + if verbose: + print(">> all_latents:", len(all_latents)) + print(*[l.shape for l in all_latents], sep=", ") chunk_latents = [all_latents[i : i + chunk_size] for i in range(0, len(all_latents), chunk_size)] chunk_length = len(chunk_latents) latent_length = len(all_latents) - all_latents = None # bigvgan chunk decode self._set_gr_progress(0.7, "bigvgan decode...") @@ -370,7 +429,7 @@ class IndexTTS: # clear cache tqdm_progress.close() # 确保进度条被关闭 - chunk_latents.clear() + del all_latents, chunk_latents end_time = time.perf_counter() self.torch_empty_cache() @@ -385,7 +444,7 @@ class IndexTTS: print(f">> Total fast inference time: {end_time - start_time:.2f} seconds") print(f">> Generated audio length: {wav_length:.2f} seconds") print(f">> [fast] bigvgan chunk_length: {chunk_length}") - print(f">> [fast] batch_num: {all_batch_num} bucket_enable: {bucket_enable}") + print(f">> [fast] batch_num: {all_batch_num} bucket_max_size: {bucket_max_size}", f"bucket_count: {bucket_count}" if bucket_max_size > 1 else "") print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}") # save audio @@ -403,7 +462,7 @@ class IndexTTS: return (sampling_rate, wav_data) # 原始推理模式 - def infer(self, audio_prompt, text, output_path, verbose=False): + def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=120): print(">> start inference...") self._set_gr_progress(0, "start inference...") if verbose: @@ -431,10 +490,11 @@ class IndexTTS: auto_conditioning = cond_mel text_tokens_list = self.tokenizer.tokenize(text) - sentences = self.tokenizer.split_sentences(text_tokens_list) + sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence) if verbose: print("text token count:", len(text_tokens_list)) print("sentences count:", len(sentences)) + print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence) print(*sentences, sep="\n") top_p = 0.8 top_k = 30 From a50cb8c2872fa965e5fdf180752939c027736242 Mon Sep 17 00:00:00 2001 From: yrom Date: Sat, 17 May 2025 20:59:07 +0800 Subject: [PATCH 04/11] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=96=87=E6=9C=AC?= =?UTF-8?q?=E6=8E=A9=E7=A0=81=E5=A1=AB=E5=85=85=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E6=94=B9=E8=BF=9B=E5=8F=A5=E5=AD=90=E6=A1=B6=E5=8C=96=E5=A4=84?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- indextts/gpt/model.py | 37 ++++++++++++++++++++++-- indextts/infer.py | 66 ++++++++++++++++++++++++++++++------------- 2 files changed, 82 insertions(+), 21 deletions(-) diff --git a/indextts/gpt/model.py b/indextts/gpt/model.py index 5852650..e432c00 100644 --- a/indextts/gpt/model.py +++ b/indextts/gpt/model.py @@ -556,7 +556,6 @@ class UnifiedVoice(nn.Module): # mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc') mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1 mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths) - text_inputs = self.set_text_padding(text_inputs, text_lengths) text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token) @@ -588,14 +587,48 @@ class UnifiedVoice(nn.Module): loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_text.mean(), loss_mel.mean(), mel_logits + @staticmethod + def pad_text_masks(text_masks): + """ + pad masks to [b, t+2]. e.g.: + ``` + [ + [1, 1, 1, 1, 1] -> [1, 1, 1, 1, 1, 1, 1] + [1, 1, 1, 1, 0] -> [1, 1, 1, 1, 1, 1, 0] + [0, 0, 0, 0, 0] -> [0, 0, 0, 0, 0, 1, 1] + [1, 0, 0, 0, 0] -> [1, 1, 1, 0, 0, 0, 0] + [0, 0, 1, 1, 1] -> [0, 0, 1, 1, 1, 1, 1] + ] + ``` + text_masks: [b, t] + padded_masks: [b, t+2] + """ + b, t = text_masks.size() + padded = torch.nn.functional.pad(text_masks, (1, 1), value=0) # [b, t+2] + padded = torch.ones_like(padded, dtype=text_masks.dtype, device=text_masks.device) + # Find the first and last non-zero index + # and set the values before the first index and after the last index to 0 + nonzero_mask = text_masks != 0 + has_nonzero = nonzero_mask.any(dim=1) + first_idx = nonzero_mask.float().argmax(dim=1) + rev_mask = nonzero_mask.flip(dims=[1]) + last_idx = t - rev_mask.float().argmax(dim=1) + first_idx = first_idx.reshape(b, 1) + last_idx = last_idx.reshape(b, 1) + col_idx = torch.arange(t+2, device=text_masks.device).unsqueeze(0).expand(b, -1) # [b, t_2] + padded[col_idx < first_idx] = 0 + padded[col_idx > last_idx+1] = 0 + # all zeros + padded[~has_nonzero, :-2] = 0 + return padded def inference_speech(self, speech_conditioning_latent, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1, max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): text_masks = ((text_inputs != self.stop_text_token) & (text_inputs != self.start_text_token)).long() text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_masks = F.pad(text_masks, (1, 1), value=1) # (-1, +1) + text_masks = UnifiedVoice.pad_text_masks(text_masks) # [b, t+2] text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths) conds = speech_conditioning_latent diff --git a/indextts/infer.py b/indextts/infer.py index 59e5c90..37bda7f 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -188,38 +188,59 @@ class IndexTTS: Sentence data bucketing. if ``bucket_max_size=1``, return all sentences in one bucket. """ - outputs: List = [] + outputs: List[Dict] = [] for idx, sent in enumerate(sentences): outputs.append({"idx": idx, "sent": sent, "len": len(sent)}) if len(outputs) > bucket_max_size: # split sentences into buckets by sentence length - buckets = [] + buckets: List[List[Dict]] = [] factor = 1.5 + last_bucket = None last_bucket_sent_len_median = 0 + for sent in sorted(outputs, key=lambda x: x["len"]): current_sent_len = sent["len"] if current_sent_len == 0: print(">> skip empty sentence") continue - if last_bucket_sent_len_median == 0 \ - or current_sent_len > last_bucket_sent_len_median * factor \ - or len(buckets[-1]) > bucket_max_size: + if last_bucket is None \ + or current_sent_len >= int(last_bucket_sent_len_median * factor) \ + or len(last_bucket) >= bucket_max_size: # new bucket buckets.append([sent]) + last_bucket = buckets[-1] last_bucket_sent_len_median = current_sent_len else: # current bucket can hold more sentences - buckets[-1].append(sent) # sorted - mid = len(buckets[-1]) // 2 - last_bucket_sent_len_median = buckets[-1][mid]["len"] - return buckets + last_bucket.append(sent) # sorted + mid = len(last_bucket) // 2 + last_bucket_sent_len_median = last_bucket[mid]["len"] + last_bucket=None + # merge all buckets with size 1 + out_buckets: List[List[Dict]] = [] + only_ones: List[Dict] = [] + for b in buckets: + if len(b) == 1: + only_ones.append(b[0]) + else: + out_buckets.append(b) + if len(only_ones) > 0: + # merge into previous buckets if possible + # print("only_ones:", [(o["idx"], o["len"]) for o in only_ones]) + for i in range(len(out_buckets)): + b = out_buckets[i] + if len(b) < bucket_max_size: + b.append(only_ones.pop(0)) + if len(only_ones) == 0: + break + # combined all remaining sized 1 buckets + if len(only_ones) > 0: + out_buckets.extend([only_ones[i:i+bucket_max_size] for i in range(0, len(only_ones), bucket_max_size)]) + return out_buckets return [outputs] - def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor: - if len(tokens) <= 1: - return tokens[-1] if self.model_version and self.model_version >= 1.5: # 1.5版本以上,直接使用stop_text_token 右侧填充 # [1, N] -> [N,] @@ -324,7 +345,9 @@ class IndexTTS: all_sentences = self.bucket_sentences(sentences, bucket_max_size=bucket_max_size) bucket_count = len(all_sentences) if verbose: - print(">> sentences bucket_count:", bucket_count, [len(s) for s in all_sentences]) + print(">> sentences bucket_count:", bucket_count, + "bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_sentences], + "bucket_max_size:", bucket_max_size) for sentences in all_sentences: temp_tokens: List[torch.Tensor] = [] all_text_tokens.append(temp_tokens) @@ -346,9 +369,14 @@ class IndexTTS: all_batch_codes = [] for item_tokens in all_text_tokens: batch_num = len(item_tokens) - batch_text_tokens = self.pad_tokens_cat(item_tokens) - batch_cond_mel_lengths = cond_mel_lengths.expand(batch_num) # [batch_num] - batch_auto_conditioning = auto_conditioning.expand(batch_num, -1, -1) # [batch_num, n_mels, L] + if batch_num > 1: + batch_text_tokens = self.pad_tokens_cat(item_tokens) + batch_cond_mel_lengths = cond_mel_lengths.expand(batch_num) # [batch_num] + batch_auto_conditioning = auto_conditioning.expand(batch_num, -1, -1) # [batch_num, n_mels, L] + else: + batch_text_tokens = item_tokens[0] + batch_cond_mel_lengths = cond_mel_lengths + batch_auto_conditioning = auto_conditioning all_batch_num += batch_num # gpt speech @@ -363,9 +391,9 @@ class IndexTTS: top_p=top_p, top_k=top_k, temperature=temperature, - num_return_sequences=autoregressive_batch_size, + num_return_sequences=autoregressive_batch_size if batch_num == 1 else 1, length_penalty=length_penalty, - num_beams=num_beams, + num_beams=num_beams if batch_num == 1 else 1, repetition_penalty=repetition_penalty, max_generate_length=max_mel_tokens) all_batch_codes.append(temp_codes) @@ -406,7 +434,7 @@ class IndexTTS: all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))] if verbose: print(">> all_latents:", len(all_latents)) - print(*[l.shape for l in all_latents], sep=", ") + print(" latents length:", [l.shape[1] for l in all_latents]) chunk_latents = [all_latents[i : i + chunk_size] for i in range(0, len(all_latents), chunk_size)] chunk_length = len(chunk_latents) latent_length = len(all_latents) From 22eeb7625f179413e5765ae940ea6bec48221993 Mon Sep 17 00:00:00 2001 From: yrom Date: Sun, 18 May 2025 15:27:53 +0800 Subject: [PATCH 05/11] =?UTF-8?q?=E4=BF=AE=E6=AD=A3attention=20mask?= =?UTF-8?q?=E5=92=8Cpositional=20embeddings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将之前只有text右侧填充改为cond+text 整体左侧填充 - 添加填充测试用例 --- indextts/gpt/model.py | 176 +++++++++++++++++++++++++----------------- indextts/infer.py | 76 +++++++++--------- tests/padding_test.py | 86 +++++++++++++++++++++ 3 files changed, 230 insertions(+), 108 deletions(-) create mode 100644 tests/padding_test.py diff --git a/indextts/gpt/model.py b/indextts/gpt/model.py index e432c00..6f98a45 100644 --- a/indextts/gpt/model.py +++ b/indextts/gpt/model.py @@ -40,6 +40,7 @@ class ResBlock(nn.Module): class GPT2InferenceModel(GPT2PreTrainedModel): def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False): super().__init__(config) + # Note: the argument named `text_pos_emb` here actually represents the mel position embedding self.transformer = gpt self.text_pos_embedding = text_pos_emb self.embeddings = embeddings @@ -97,7 +98,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel): if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + position_ids.masked_fill_(attention_mask == 0, 0) if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) else: @@ -134,7 +135,6 @@ class GPT2InferenceModel(GPT2PreTrainedModel): return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) - # Create embedding mel_len = self.cached_mel_emb.shape[1] if input_ids.shape[1] != 1: @@ -587,80 +587,116 @@ class UnifiedVoice(nn.Module): loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_text.mean(), loss_mel.mean(), mel_logits - @staticmethod - def pad_text_masks(text_masks): + + def prepare_gpt_inputs( + self, + conditional_latents: torch.Tensor, + text_inputs: torch.Tensor, + ): + """ - pad masks to [b, t+2]. e.g.: - ``` - [ - [1, 1, 1, 1, 1] -> [1, 1, 1, 1, 1, 1, 1] - [1, 1, 1, 1, 0] -> [1, 1, 1, 1, 1, 1, 0] - [0, 0, 0, 0, 0] -> [0, 0, 0, 0, 0, 1, 1] - [1, 0, 0, 0, 0] -> [1, 1, 1, 0, 0, 0, 0] - [0, 0, 1, 1, 1] -> [0, 0, 1, 1, 1, 1, 1] - ] - ``` - text_masks: [b, t] - padded_masks: [b, t+2] + Prepare the inputs for the GPT2InferenceModel to generate. + Args: + conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()` + text_inputs: (b, L) + Returns: + input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate() + inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward() + attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate() """ - b, t = text_masks.size() - padded = torch.nn.functional.pad(text_masks, (1, 1), value=0) # [b, t+2] - padded = torch.ones_like(padded, dtype=text_masks.dtype, device=text_masks.device) - - # Find the first and last non-zero index - # and set the values before the first index and after the last index to 0 - nonzero_mask = text_masks != 0 - has_nonzero = nonzero_mask.any(dim=1) - first_idx = nonzero_mask.float().argmax(dim=1) - rev_mask = nonzero_mask.flip(dims=[1]) - last_idx = t - rev_mask.float().argmax(dim=1) - first_idx = first_idx.reshape(b, 1) - last_idx = last_idx.reshape(b, 1) - col_idx = torch.arange(t+2, device=text_masks.device).unsqueeze(0).expand(b, -1) # [b, t_2] - padded[col_idx < first_idx] = 0 - padded[col_idx > last_idx+1] = 0 - # all zeros - padded[~has_nonzero, :-2] = 0 - return padded - def inference_speech(self, speech_conditioning_latent, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1, - max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): - - text_masks = ((text_inputs != self.stop_text_token) & (text_inputs != self.start_text_token)).long() - text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_masks = UnifiedVoice.pad_text_masks(text_masks) # [b, t+2] - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths) - conds = speech_conditioning_latent - emb = torch.cat([conds, text_emb], dim=1) - self.inference_model.store_mel_emb(emb) - - # +1 for the start_audio_token - fake_inputs = torch.full((emb.shape[0], emb.shape[1] + 1,), fill_value=1, dtype=torch.long, - device=text_inputs.device) - attention_mask = torch.cat([ - torch.ones((conds.shape[0], conds.shape[1]), dtype=torch.long, device=text_inputs.device), - text_masks, - torch.ones((conds.shape[0], 1), dtype=torch.long, device=text_inputs.device), - ], dim=1) + b, L = text_inputs.shape[:2] + device = text_inputs.device + single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1 + if not single_cond: + assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}" + batched_mel_emb = [] + attention_masks = [] + target_len = conditional_latents.shape[1] + L + 2 + for i in range(b): + valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token) + text_input = text_inputs[i][valid_mask] + text_input = F.pad(text_input, (1, 0), value=self.start_text_token) + text_input = F.pad(text_input, (0, 1), value=self.stop_text_token) + text_input_pos = torch.arange(0, text_input.size(-1), device=device) + text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos) + # concatenate [conditional latents][text embeddings] + conds_text_emb = [ + conditional_latents.squeeze(0) if single_cond else conditional_latents[i], + text_emb, + ] + # +1 for the start_mel_token + attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device) + # check this text input is padded + padding: int = L + 2 - text_input.size(-1) + # pad left of [cond][text] -> [pad][cond][text] + if padding > 0: + pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim] + conds_text_emb.insert(0, pad) + attention_mask[:padding] = 0 + mel_emb = torch.cat(conds_text_emb) #[s, dim] + assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}" + batched_mel_emb.append(mel_emb) + attention_masks.append(attention_mask) + # [b, s, dim] + batched_mel_emb = torch.stack(batched_mel_emb, dim=0) + # [b, s+1] + attention_mask = torch.stack(attention_masks, dim=0) + # [b, s+1] + fake_inputs = torch.ones( + ( + batched_mel_emb.shape[0], + batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token + ), + dtype=torch.long, + device=device, + ) fake_inputs[:, -1] = self.start_mel_token - trunc_index = fake_inputs.shape[1] + return fake_inputs, batched_mel_emb, attention_mask + def inference_speech(self, speech_conditioning_mel, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1, + max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): + """ + Args: + speech_conditioning_mel: (b, n_mels, frames) or (n_mels, frames) + text_inputs: (b, L) + cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,) + input_tokens: additional tokens for generation in shape (b, s) or (s,) + max_generate_length: limit the number of generated tokens + hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)` + """ + if speech_conditioning_mel.ndim == 2: + speech_conditioning_mel = speech_conditioning_mel.unsqueeze(0) + if cond_mel_lengths is None: + cond_mel_lengths = torch.tensor([speech_conditioning_mel.shape[-1]], device=speech_conditioning_mel.device) + conds_latent = self.get_conditioning(speech_conditioning_mel, cond_mel_lengths) + input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs) + self.inference_model.store_mel_emb(inputs_embeds) if input_tokens is None: - inputs = fake_inputs + inputs = input_ids else: - assert num_return_sequences % input_tokens.shape[ - 0] == 0, "The number of return sequences must be divisible by the number of input sequences" - fake_inputs = fake_inputs.repeat(num_return_sequences, 1) + if input_tokens.ndim == 1: + input_tokens = input_tokens.unsqueeze(0) + assert num_return_sequences % input_tokens.shape[0] == 0, \ + "The num_return_sequences must be divisible by the batch number of input_tokens" + assert num_return_sequences % text_inputs.shape[0] == 0, \ + "The num_return_sequences must be divisible by the batch number of text_inputs" + b = num_return_sequences // input_ids.shape[0] + if b > 1: + input_ids = input_ids.repeat(b, 1) + attention_mask = attention_mask.repeat(b, 1) input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) - input_tokens_mask = ((input_tokens != self.stop_text_token) & (input_tokens != self.start_text_token)).long() - inputs = torch.cat([fake_inputs, input_tokens], dim=1) - attention_mask = torch.cat([attention_mask.repeat(num_return_sequences, 1), input_tokens_mask], dim=1) - - + inputs = torch.cat([input_ids, input_tokens], dim=1) + attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1) + trunc_index = inputs.shape[1] logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() - max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length - gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, + max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length + output = self.inference_model.generate(inputs, + bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, attention_mask=attention_mask, max_length=max_length, logits_processor=logits_processor, - num_return_sequences=num_return_sequences, **hf_generate_kwargs) - return gen[:, trunc_index:] + num_return_sequences=num_return_sequences, + **hf_generate_kwargs) + if isinstance(output, torch.Tensor): + return output[:, trunc_index:] + # GenerateOutput + output.sequences = output.sequences[:, trunc_index:] + return output diff --git a/indextts/infer.py b/indextts/infer.py index 37bda7f..3d07dfe 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -242,7 +242,7 @@ class IndexTTS: def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor: if self.model_version and self.model_version >= 1.5: - # 1.5版本以上,直接使用stop_text_token 右侧填充 + # 1.5版本以上,直接使用stop_text_token 右侧填充,填充到最大长度 # [1, N] -> [N,] tokens = [t.squeeze(0) for t in tokens] return pad_sequence(tokens, batch_first=True, padding_value=self.cfg.gpt.stop_text_token, padding_side="right") @@ -273,7 +273,7 @@ class IndexTTS: self.gr_progress(value, desc=desc) # 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16) - def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=100, sentences_bucket_max_size=4): + def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=100, sentences_bucket_max_size=4, **sample_kwargs): """ Args: ``max_text_tokens_per_sentence``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整 @@ -321,15 +321,15 @@ class IndexTTS: print(" splited sentences count:", len(sentences)) print(" max_text_tokens_per_sentence:", max_text_tokens_per_sentence) print(*sentences, sep="\n") - - top_p = 0.8 - top_k = 30 - temperature = 1.0 + do_sample = sample_kwargs.pop("do_sample", True) + top_p = sample_kwargs.pop("top_p", 0.8) + top_k = sample_kwargs.pop("top_k", 30) + temperature = sample_kwargs.pop("temperature", 1.0) autoregressive_batch_size = 1 - length_penalty = 0.0 - num_beams = 3 - repetition_penalty = 10.0 - max_mel_tokens = 600 + length_penalty = sample_kwargs.pop("length_penalty", 0.0) + num_beams = sample_kwargs.pop("num_beams", 3) + repetition_penalty = sample_kwargs.pop("repetition_penalty", 10.0) + max_mel_tokens = sample_kwargs.pop("max_mel_tokens", 600) sampling_rate = 24000 # lang = "EN" # lang = "ZH" @@ -365,35 +365,31 @@ class IndexTTS: # Sequential processing of bucketing data - all_batch_num = 0 + all_batch_num = sum(len(s) for s in all_sentences) all_batch_codes = [] + processed_num = 0 for item_tokens in all_text_tokens: batch_num = len(item_tokens) if batch_num > 1: batch_text_tokens = self.pad_tokens_cat(item_tokens) - batch_cond_mel_lengths = cond_mel_lengths.expand(batch_num) # [batch_num] - batch_auto_conditioning = auto_conditioning.expand(batch_num, -1, -1) # [batch_num, n_mels, L] else: - batch_text_tokens = item_tokens[0] - batch_cond_mel_lengths = cond_mel_lengths - batch_auto_conditioning = auto_conditioning - all_batch_num += batch_num - + batch_text_tokens = torch.nn.functional.pad(item_tokens[0], (8, 0), value=self.cfg.gpt.start_text_token) + processed_num += batch_num # gpt speech - self._set_gr_progress(0.2, "gpt inference speech...") + self._set_gr_progress(0.2 + 0.3 * processed_num/all_batch_num, f"gpt inference speech... {processed_num}/{all_batch_num}") m_start_time = time.perf_counter() with torch.no_grad(): with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): - temp_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens, - cond_mel_lengths=batch_cond_mel_lengths, + temp_codes = self.gpt.inference_speech(auto_conditioning, batch_text_tokens, + cond_mel_lengths=cond_mel_lengths, # text_lengths=text_len, - do_sample=True, + do_sample=do_sample, top_p=top_p, top_k=top_k, temperature=temperature, - num_return_sequences=autoregressive_batch_size if batch_num == 1 else 1, + num_return_sequences=autoregressive_batch_size, length_penalty=length_penalty, - num_beams=num_beams if batch_num == 1 else 1, + num_beams=num_beams, repetition_penalty=repetition_penalty, max_generate_length=max_mel_tokens) all_batch_codes.append(temp_codes) @@ -490,7 +486,7 @@ class IndexTTS: return (sampling_rate, wav_data) # 原始推理模式 - def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=120): + def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=120, **sample_kwargs): print(">> start inference...") self._set_gr_progress(0, "start inference...") if verbose: @@ -516,6 +512,7 @@ class IndexTTS: cond_mel_frame = cond_mel.shape[-1] pass + self._set_gr_progress(0.1, "text processing...") auto_conditioning = cond_mel text_tokens_list = self.tokenizer.tokenize(text) sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence) @@ -524,14 +521,15 @@ class IndexTTS: print("sentences count:", len(sentences)) print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence) print(*sentences, sep="\n") - top_p = 0.8 - top_k = 30 - temperature = 1.0 + do_sample = sample_kwargs.pop("do_sample", True) + top_p = sample_kwargs.pop("top_p", 0.8) + top_k = sample_kwargs.pop("top_k", 30) + temperature = sample_kwargs.pop("temperature", 1.0) autoregressive_batch_size = 1 - length_penalty = 0.0 - num_beams = 3 - repetition_penalty = 10.0 - max_mel_tokens = 600 + length_penalty = sample_kwargs.pop("length_penalty", 0.0) + num_beams = sample_kwargs.pop("num_beams", 3) + repetition_penalty = sample_kwargs.pop("repetition_penalty", 10.0) + max_mel_tokens = sample_kwargs.pop("max_mel_tokens", 600) sampling_rate = 24000 # lang = "EN" # lang = "ZH" @@ -539,7 +537,7 @@ class IndexTTS: gpt_gen_time = 0 gpt_forward_time = 0 bigvgan_time = 0 - + progress = 0 for sent in sentences: text_tokens = self.tokenizer.convert_tokens_to_ids(sent) text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0) @@ -555,7 +553,8 @@ class IndexTTS: # text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device) # print(text_len) - + progress += 1 + self._set_gr_progress(0.2 + 0.4 * (progress-1) / len(sentences), f"gpt inference latent... {progress}/{len(sentences)}") m_start_time = time.perf_counter() with torch.no_grad(): with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): @@ -563,7 +562,7 @@ class IndexTTS: cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), # text_lengths=text_len, - do_sample=True, + do_sample=do_sample, top_p=top_p, top_k=top_k, temperature=temperature, @@ -587,7 +586,7 @@ class IndexTTS: print(codes, type(codes)) print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}") print(f"code len: {code_lens}") - + self._set_gr_progress(0.2 + 0.4 * progress / len(sentences), f"gpt inference speech... {progress}/{len(sentences)}") m_start_time = time.perf_counter() # latent, text_lens_out, code_lens_out = \ with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): @@ -605,11 +604,12 @@ class IndexTTS: wav = wav.squeeze(1) wav = torch.clamp(32767 * wav, -32767.0, 32767.0) - print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max()) + if verbose: + print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max()) # wavs.append(wav[:, :-512]) wavs.append(wav.cpu()) # to cpu before saving end_time = time.perf_counter() - + self._set_gr_progress(0.9, "save audio...") wav = torch.cat(wavs, dim=1) wav_length = wav.shape[-1] / sampling_rate print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds") diff --git a/tests/padding_test.py b/tests/padding_test.py new file mode 100644 index 0000000..9fe418b --- /dev/null +++ b/tests/padding_test.py @@ -0,0 +1,86 @@ +import torch +import torchaudio +from indextts.infer import IndexTTS +from indextts.utils.feature_extractors import MelSpectrogramFeatures +from torch.nn import functional as F + +if __name__ == "__main__": + import transformers + transformers.set_seed(42) + audio_prompt="tests/sample_prompt.wav" + tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False) + text = "晕 XUAN4 是 一 种 not very good GAN3 觉" + text_tokens = tts.tokenizer.encode(text) + text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L] + + audio, sr = torchaudio.load(audio_prompt) + audio = torch.mean(audio, dim=0, keepdim=True) + audio = torchaudio.transforms.Resample(sr, 24000)(audio) + auto_conditioning = MelSpectrogramFeatures()(audio).to(tts.device) + cond_mel_lengths = torch.tensor([auto_conditioning.shape[-1]]).to(tts.device) + with torch.no_grad(): + kwargs = { + "cond_mel_lengths": cond_mel_lengths, + "do_sample": False, + "top_p": 0.8, + "top_k": None, + "temperature": 1.0, + "num_return_sequences": 1, + "length_penalty": 0.0, + "num_beams": 1, + "repetition_penalty": 10.0, + "max_generate_length": 100, + } + # baseline for non-pad + baseline = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs) + baseline = baseline.squeeze(0) + print("Inference padded text tokens...") + pad_text_tokens = [ + F.pad(text_tokens, (8, 0), value=0), # left bos + F.pad(text_tokens, (0, 8), value=1), # right eos + F.pad(F.pad(text_tokens, (4, 0), value=0), (0, 4), value=1), # both side + F.pad(F.pad(text_tokens, (6, 0), value=0), (0, 2), value=1), + F.pad(F.pad(text_tokens, (0, 4), value=0), (0, 4), value=1), + ] + output_for_padded = [] + for t in pad_text_tokens: + # test for each padded text + out = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs) + output_for_padded.append(out.squeeze(0)) + # batched inference + print("Inference padded text tokens as one batch...") + batched_text_tokens = torch.cat(pad_text_tokens, dim=0).to(tts.device) + assert len(pad_text_tokens) == batched_text_tokens.shape[0] and batched_text_tokens.ndim == 2 + batch_output = tts.gpt.inference_speech(auto_conditioning, batched_text_tokens, **kwargs) + del pad_text_tokens + mismatch_idx = [] + print("baseline:", baseline.shape, baseline) + print("--"*10) + print("baseline vs padded output:") + for i in range(len(output_for_padded)): + if not baseline.equal(output_for_padded[i]): + mismatch_idx.append(i) + + if len(mismatch_idx) > 0: + print("mismatch:", mismatch_idx) + for i in mismatch_idx: + print(f"[{i}]: {output_for_padded[i]}") + else: + print("all matched") + + del output_for_padded + print("--"*10) + print("baseline vs batched output:") + mismatch_idx = [] + for i in range(batch_output.shape[0]): + if not baseline.equal(batch_output[i]): + mismatch_idx.append(i) + if len(mismatch_idx) > 0: + print("mismatch:", mismatch_idx) + for i in mismatch_idx: + print(f"[{i}]: {batch_output[i]}") + + else: + print("all matched") + + print("Test finished.") \ No newline at end of file From 1b7529cacd997bf3ebe21d7c48a7481d9cb6addb Mon Sep 17 00:00:00 2001 From: yrom Date: Sun, 18 May 2025 16:19:28 +0800 Subject: [PATCH 06/11] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=96=B0=E7=89=88?= =?UTF-8?q?=E6=9C=ACtransformers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- indextts/gpt/model.py | 10 ++++++++-- indextts/utils/typical_sampling.py | 9 +++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/indextts/gpt/model.py b/indextts/gpt/model.py index 6f98a45..4e64660 100644 --- a/indextts/gpt/model.py +++ b/indextts/gpt/model.py @@ -388,7 +388,7 @@ class UnifiedVoice(nn.Module): def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False): seq_length = self.max_mel_tokens + self.max_text_tokens + 2 gpt_config = GPT2Config( - vocab_size=self.max_mel_tokens, + vocab_size=self.number_mel_codes, n_positions=seq_length, n_ctx=seq_length, n_embd=self.model_dim, @@ -687,7 +687,13 @@ class UnifiedVoice(nn.Module): inputs = torch.cat([input_ids, input_tokens], dim=1) attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1) trunc_index = inputs.shape[1] - logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() + logits_processor = LogitsProcessorList() + if typical_sampling: + # employ custom typical sampling + if not (typical_mass > 0.0 and typical_mass < 1.0): + raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}") + min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1 + logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep)) max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length output = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, diff --git a/indextts/utils/typical_sampling.py b/indextts/utils/typical_sampling.py index c982463..0b225e9 100644 --- a/indextts/utils/typical_sampling.py +++ b/indextts/utils/typical_sampling.py @@ -1,12 +1,9 @@ import torch -from transformers import LogitsWarper +from transformers import TypicalLogitsWarper as BaseTypicalLogitsWarper - -class TypicalLogitsWarper(LogitsWarper): +class TypicalLogitsWarper(BaseTypicalLogitsWarper): def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): - self.filter_value = filter_value - self.mass = mass - self.min_tokens_to_keep = min_tokens_to_keep + super().__init__(mass=mass, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # calculate entropy From 7e52976bd12e7e9eea5f07ecfb7cb52bc9316edd Mon Sep 17 00:00:00 2001 From: yrom Date: Sun, 18 May 2025 16:28:54 +0800 Subject: [PATCH 07/11] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- indextts/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/indextts/infer.py b/indextts/infer.py index 3d07dfe..01c774d 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -373,7 +373,7 @@ class IndexTTS: if batch_num > 1: batch_text_tokens = self.pad_tokens_cat(item_tokens) else: - batch_text_tokens = torch.nn.functional.pad(item_tokens[0], (8, 0), value=self.cfg.gpt.start_text_token) + batch_text_tokens = item_tokens[0] processed_num += batch_num # gpt speech self._set_gr_progress(0.2 + 0.3 * processed_num/all_batch_num, f"gpt inference speech... {processed_num}/{all_batch_num}") From 96d3b757086347f963789142a33a99971a9f7d26 Mon Sep 17 00:00:00 2001 From: yrom Date: Sun, 18 May 2025 16:29:17 +0800 Subject: [PATCH 08/11] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=AD=A6=E5=91=8A?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=EF=BC=9A=E7=94=9F=E6=88=90=E5=81=9C=E6=AD=A2?= =?UTF-8?q?=E5=9B=A0=E8=B6=85=E5=87=BA=20`max=5Fmel=5Ftokens`=20=E9=99=90?= =?UTF-8?q?=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- indextts/infer.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/indextts/infer.py b/indextts/infer.py index 01c774d..bac1b83 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -399,9 +399,17 @@ class IndexTTS: self._set_gr_progress(0.5, "gpt inference latents...") all_idxs = [] all_latents = [] + has_warned = False for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences): for i in range(batch_codes.shape[0]): codes = batch_codes[i] # [x] + if not has_warned and codes[-1] != self.stop_mel_token: + warnings.warn( + f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). " + f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.", + category=RuntimeWarning + ) + has_warned = True codes = codes.unsqueeze(0) # [x] -> [1, x] if verbose: print("codes:", codes.shape) @@ -538,6 +546,7 @@ class IndexTTS: gpt_forward_time = 0 bigvgan_time = 0 progress = 0 + has_warned = False for sent in sentences: text_tokens = self.tokenizer.convert_tokens_to_ids(sent) text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0) @@ -572,7 +581,15 @@ class IndexTTS: repetition_penalty=repetition_penalty, max_generate_length=max_mel_tokens) gpt_gen_time += time.perf_counter() - m_start_time - # codes = codes[:, :-2] + if not has_warned and (codes[:, -1] != self.stop_mel_token).any(): + warnings.warn( + f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). " + f"Input text tokens: {text_tokens.shape[1]}. " + f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.", + category=RuntimeWarning + ) + has_warned = True + code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype) if verbose: print(codes, type(codes)) From 60a2238eac08bfc2a71c9f473b75ddeaabd107f4 Mon Sep 17 00:00:00 2001 From: yrom Date: Sun, 18 May 2025 16:43:44 +0800 Subject: [PATCH 09/11] =?UTF-8?q?=E5=B0=86=20`sample=5Fkwargs`=20=E6=9B=BF?= =?UTF-8?q?=E6=8D=A2=E4=B8=BA=20`generation=5Fkwargs`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- indextts/infer.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/indextts/infer.py b/indextts/infer.py index bac1b83..3afe945 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -273,7 +273,7 @@ class IndexTTS: self.gr_progress(value, desc=desc) # 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16) - def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=100, sentences_bucket_max_size=4, **sample_kwargs): + def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=100, sentences_bucket_max_size=4, **generation_kwargs): """ Args: ``max_text_tokens_per_sentence``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整 @@ -321,15 +321,15 @@ class IndexTTS: print(" splited sentences count:", len(sentences)) print(" max_text_tokens_per_sentence:", max_text_tokens_per_sentence) print(*sentences, sep="\n") - do_sample = sample_kwargs.pop("do_sample", True) - top_p = sample_kwargs.pop("top_p", 0.8) - top_k = sample_kwargs.pop("top_k", 30) - temperature = sample_kwargs.pop("temperature", 1.0) + do_sample = generation_kwargs.pop("do_sample", True) + top_p = generation_kwargs.pop("top_p", 0.8) + top_k = generation_kwargs.pop("top_k", 30) + temperature = generation_kwargs.pop("temperature", 1.0) autoregressive_batch_size = 1 - length_penalty = sample_kwargs.pop("length_penalty", 0.0) - num_beams = sample_kwargs.pop("num_beams", 3) - repetition_penalty = sample_kwargs.pop("repetition_penalty", 10.0) - max_mel_tokens = sample_kwargs.pop("max_mel_tokens", 600) + length_penalty = generation_kwargs.pop("length_penalty", 0.0) + num_beams = generation_kwargs.pop("num_beams", 3) + repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0) + max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600) sampling_rate = 24000 # lang = "EN" # lang = "ZH" @@ -391,7 +391,8 @@ class IndexTTS: length_penalty=length_penalty, num_beams=num_beams, repetition_penalty=repetition_penalty, - max_generate_length=max_mel_tokens) + max_generate_length=max_mel_tokens, + **generation_kwargs) all_batch_codes.append(temp_codes) gpt_gen_time += time.perf_counter() - m_start_time @@ -494,7 +495,7 @@ class IndexTTS: return (sampling_rate, wav_data) # 原始推理模式 - def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=120, **sample_kwargs): + def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=120, **generation_kwargs): print(">> start inference...") self._set_gr_progress(0, "start inference...") if verbose: @@ -529,15 +530,15 @@ class IndexTTS: print("sentences count:", len(sentences)) print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence) print(*sentences, sep="\n") - do_sample = sample_kwargs.pop("do_sample", True) - top_p = sample_kwargs.pop("top_p", 0.8) - top_k = sample_kwargs.pop("top_k", 30) - temperature = sample_kwargs.pop("temperature", 1.0) + do_sample = generation_kwargs.pop("do_sample", True) + top_p = generation_kwargs.pop("top_p", 0.8) + top_k = generation_kwargs.pop("top_k", 30) + temperature = generation_kwargs.pop("temperature", 1.0) autoregressive_batch_size = 1 - length_penalty = sample_kwargs.pop("length_penalty", 0.0) - num_beams = sample_kwargs.pop("num_beams", 3) - repetition_penalty = sample_kwargs.pop("repetition_penalty", 10.0) - max_mel_tokens = sample_kwargs.pop("max_mel_tokens", 600) + length_penalty = generation_kwargs.pop("length_penalty", 0.0) + num_beams = generation_kwargs.pop("num_beams", 3) + repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0) + max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600) sampling_rate = 24000 # lang = "EN" # lang = "ZH" @@ -579,7 +580,8 @@ class IndexTTS: length_penalty=length_penalty, num_beams=num_beams, repetition_penalty=repetition_penalty, - max_generate_length=max_mel_tokens) + max_generate_length=max_mel_tokens, + **generation_kwargs) gpt_gen_time += time.perf_counter() - m_start_time if not has_warned and (codes[:, -1] != self.stop_mel_token).any(): warnings.warn( From 76e7645a8d787707d2f9e2e1e5db52c827269172 Mon Sep 17 00:00:00 2001 From: yrom Date: Sun, 18 May 2025 16:48:07 +0800 Subject: [PATCH 10/11] =?UTF-8?q?=E6=9B=B4=E6=96=B0WebUI=EF=BC=8C=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=A8=A1=E5=9E=8B=E7=9B=AE=E5=BD=95=E6=A3=80=E6=9F=A5?= =?UTF-8?q?=E5=92=8C=E5=BF=85=E8=A6=81=E6=96=87=E4=BB=B6=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增示例 - 新增模型版本提示 - 新增生成参数设置 - 新增分句预览 --- README.md | 4 ++ tests/cases.jsonl | 8 +++ webui.py | 161 +++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 158 insertions(+), 15 deletions(-) create mode 100644 tests/cases.jsonl diff --git a/README.md b/README.md index f9c88d9..ffbecb4 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,7 @@ wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bpe.model -P che wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/dvae.pth -P checkpoints wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/gpt.pth -P checkpoints wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/unigram_12000.vocab -P checkpoints +wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/config.yaml -P checkpoints ``` 4. Run test script: @@ -180,6 +181,9 @@ indextts --help ```bash pip install -e ".[webui]" python webui.py + +# use another model version: +python webui.py --model_dir IndexTTS-1.5 ``` Open your browser and visit `http://127.0.0.1:7860` to see the demo. diff --git a/tests/cases.jsonl b/tests/cases.jsonl new file mode 100644 index 0000000..6420880 --- /dev/null +++ b/tests/cases.jsonl @@ -0,0 +1,8 @@ +{"prompt_audio":"sample_prompt.wav","text":"IndexTTS 正式发布1.0版本了,效果666","infer_mode":0} +{"prompt_audio":"sample_prompt.wav","text":"大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!","infer_mode":0} +{"prompt_audio":"sample_prompt.wav","text":"晕XUAN4是一种GAN3觉","infer_mode":0} +{"prompt_audio":"sample_prompt.wav","text":"最zhong4要的是:不要chong2蹈覆辙","infer_mode":0} +{"prompt_audio":"sample_prompt.wav","text":"Matt Hougan, chief investment officer at Bitwise, predicts Bitcoin (BTC) will reach $200,000 by the end of 2025 due to a supply shock from heightened institutional demand. In an interview with Cointelegraph at Consensus 2025 in Toronto, the executive said that Bitwise's Bitcoin price prediction model is driven exclusively by supply and demand metrics. \"I think eventually that will exhaust sellers at the $100,000 level where we have been stuck, and I think the next stopping point above that is $200,000,\" the executive said.","infer_mode":1} +{"prompt_audio":"sample_prompt.wav","text":"《盗梦空间》(英语:Inception)是由美国华纳兄弟影片公司出品的电影,由克里斯托弗·诺兰(Christopher Edward Nolan)执导并编剧,莱昂纳多·迪卡普里奥(Leonardo Wilhelm DiCaprio)、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演,2010年7月16日在美国上映,2010年9月1日在中国内地上映,2020年8月28日在中国内地重映。豆瓣评分:9.4,IMDB 8.8。影片剧情游走于梦境与现实之间,被定义为“发生在意识结构内的当代动作科幻片”,讲述了由 Leonardo 扮演的造梦师,带领特工团队进入他人梦境,从他人的潜意识中盗取机密,并重塑他人梦境的故事。","infer_mode":1} +{"prompt_audio":"sample_prompt.wav","text":"清晨拉开窗帘,阳光洒在窗台的Bloomixy花艺礼盒上——薰衣草香薰蜡烛唤醒嗅觉,永生花束折射出晨露般光泽。设计师将“自然绽放美学”融入每个细节:手工陶瓷花瓶可作首饰收纳,香薰精油含依兰依兰舒缓配方。限量款附赠《365天插花灵感手册》,让每个平凡日子都有花开仪式感。宴会厅灯光暗下的刹那,Glimmeria星月系列耳坠开始发光——瑞士冷珐琅工艺让蓝宝石如银河流动,钛合金骨架仅3.2g无负重感。设计师秘密:内置微型重力感应器,随步伐产生0.01mm振幅,打造“行走的星光”。七夕限定礼盒含星座定制铭牌,让爱意如星辰永恒闪耀。","infer_mode":1} +{"prompt_audio":"sample_prompt.wav","text":"当地时间15日,随着特朗普与阿联酋敲定2000亿美元协议,特朗普的中东之行正式收官。特朗普已宣布获得沙特6000亿美元和卡塔尔2430亿美元投资承诺。商业协议成为特朗普重返白宫后首次外访的核心成果。香港英文媒体《南华早报》(South China Morning Post)称,特朗普访问期间提出了以经济合作为驱动的中东及南亚和平计划。分析人士称,该战略包含多项旨在遏制中国在这些地区影响力的措施。英国伦敦国王学院安全研究教授安德烈亚斯·克里格(Dr. Andreas Krieg)对半岛电视台表示,特朗普访问海湾地区的主要目标有三个:其一,以军工投资和能源合作的形式获得海湾国家的切实承诺;其二,加强与“让美国再次伟大”运动结盟的外交伙伴关系,维持美国外交影响力;其三:将海湾国家重新定位为美国在从加沙到伊朗等地区危机前线的调解人,这样就可以不用增强军事部署。安德烈亚斯·克里格直言:“海湾国家不会为美国牺牲与中国的关系,他们的战略自主性远超特朗普想象。”美国有线电视新闻网(CNN)报道称,特朗普到访的三个能源富国,每个国家都对美国有着长长的诉求清单。尽管这些国家豪掷重金,但美国并未实现所有诉求。","infer_mode":1} diff --git a/webui.py b/webui.py index 29d44e8..94b2ab4 100644 --- a/webui.py +++ b/webui.py @@ -1,5 +1,5 @@ +import json import os -import shutil import sys import threading import time @@ -12,70 +12,201 @@ current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_dir) sys.path.append(os.path.join(current_dir, "indextts")) +import argparse +parser = argparse.ArgumentParser(description="IndexTTS WebUI") +parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode") +parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on") +parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to run the web UI on") +parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory") +cmd_args = parser.parse_args() + +if not os.path.exists(cmd_args.model_dir): + print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.") + sys.exit(1) + +for file in [ + "bigvgan_generator.pth", + "bpe.model", + "gpt.pth", + "config.yaml", +]: + file_path = os.path.join(cmd_args.model_dir, file) + if not os.path.exists(file_path): + print(f"Required file {file_path} does not exist. Please download it.") + sys.exit(1) + import gradio as gr -from indextts.utils.webui_utils import next_page, prev_page from indextts.infer import IndexTTS from tools.i18n.i18n import I18nAuto i18n = I18nAuto(language="zh_CN") MODE = 'local' -tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml") +tts = IndexTTS(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),) + os.makedirs("outputs/tasks",exist_ok=True) os.makedirs("prompts",exist_ok=True) +with open("tests/cases.jsonl", "r", encoding="utf-8") as f: + example_cases = [] + for line in f: + line = line.strip() + if not line: + continue + example = json.loads(line) + example_cases.append([os.path.join("tests", example.get("prompt_audio", "sample_prompt.wav")), + example.get("text"), ["普通推理", "批次推理"][example.get("infer_mode", 0)]]) -def gen_single(prompt, text, infer_mode, progress=gr.Progress()): +def gen_single(prompt, text, infer_mode, max_text_tokens_per_sentence=120, sentences_bucket_max_size=4, + *args, progress=gr.Progress()): output_path = None if not output_path: output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav") # set gradio progress tts.gr_progress = progress + do_sample, top_p, top_k, temperature, \ + length_penalty, num_beams, repetition_penalty, max_mel_tokens = args + kwargs = { + "do_sample": bool(do_sample), + "top_p": float(top_p), + "top_k": int(top_k) if int(top_k) > 0 else None, + "temperature": float(temperature), + "length_penalty": float(length_penalty), + "num_beams": num_beams, + "repetition_penalty": float(repetition_penalty), + "max_mel_tokens": int(max_mel_tokens), + # "typical_sampling": bool(typical_sampling), + # "typical_mass": float(typical_mass), + } if infer_mode == "普通推理": - output = tts.infer(prompt, text, output_path) # 普通推理 + output = tts.infer(prompt, text, output_path, verbose=cmd_args.verbose, + max_text_tokens_per_sentence=int(max_text_tokens_per_sentence), + **kwargs) else: - output = tts.infer_fast(prompt, text, output_path) # 批次推理 + # 批次推理 + output = tts.infer_fast(prompt, text, output_path, verbose=cmd_args.verbose, + max_text_tokens_per_sentence=int(max_text_tokens_per_sentence), + sentences_bucket_max_size=(sentences_bucket_max_size), + **kwargs) return gr.update(value=output,visible=True) def update_prompt_audio(): update_button = gr.update(interactive=True) return update_button - -with gr.Blocks() as demo: +with gr.Blocks(title="IndexTTS Demo") as demo: mutex = threading.Lock() gr.HTML('''

IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System

(一款工业级可控且高效的零样本文本转语音系统)

-

+

''') with gr.Tab("音频生成"): with gr.Row(): os.makedirs("prompts",exist_ok=True) - prompt_audio = gr.Audio(label="请上传参考音频",key="prompt_audio", + prompt_audio = gr.Audio(label="参考音频",key="prompt_audio", sources=["upload","microphone"],type="filepath") prompt_list = os.listdir("prompts") default = '' if prompt_list: default = prompt_list[0] with gr.Column(): - input_text_single = gr.TextArea(label="请输入目标文本",key="input_text_single") - infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="选择推理模式(批次推理:更适合长句,性能翻倍)",value="普通推理") - gen_button = gr.Button("生成语音",key="gen_button",interactive=True) + input_text_single = gr.TextArea(label="文本",key="input_text_single", placeholder="请输入目标文本", info="当前模型版本{}".format(tts.model_version or "1.0")) + infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="推理模式",info="批次推理:更适合长句,性能翻倍",value="普通推理") + gen_button = gr.Button("生成语音", key="gen_button",interactive=True) output_audio = gr.Audio(label="生成结果", visible=True,key="output_audio") + with gr.Accordion("高级生成参数设置", open=False): + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("**GPT2 采样设置** _参数会影响音频多样性和生成速度详见[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_") + with gr.Row(): + do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样") + temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1) + with gr.Row(): + top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01) + top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1) + num_beams = gr.Slider(label="num_beams", value=3, minimum=1, maximum=10, step=1) + with gr.Row(): + repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1) + length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1) + max_mel_tokens = gr.Slider(label="max_mel_tokens", value=600, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量,过小导致音频被截断", key="max_mel_tokens") + # with gr.Row(): + # typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用") + # typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1) + with gr.Column(scale=2): + gr.Markdown("**分句设置** _参数会影响音频质量和生成速度_") + with gr.Row(): + max_text_tokens_per_sentence = gr.Slider( + label="分句最大Token数", value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence", + info="建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高", + ) + sentences_bucket_max_size = gr.Slider( + label="分句分桶的最大容量(批次推理生效)", value=4, minimum=1, maximum=16, step=1, key="sentences_bucket_max_size", + info="建议2-8之间,值越大,一批次推理包含的分句数越多,过大可能导致内存溢出", + ) + with gr.Accordion("预览分句结果", open=True) as sentences_settings: + sentences_preview = gr.Dataframe( + headers=["序号", "分句内容", "Token数"], + key="sentences_preview", + wrap=True, + ) + advanced_params = [ + do_sample, top_p, top_k, temperature, + length_penalty, num_beams, repetition_penalty, max_mel_tokens, + # typical_sampling, typical_mass, + ] + + if len(example_cases) > 0: + gr.Examples( + examples=example_cases, + inputs=[prompt_audio, input_text_single, infer_mode], + ) + def on_input_text_change(text, max_tokens_per_sentence): + if text and len(text) > 0: + text_tokens_list = tts.tokenizer.tokenize(text) + + sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence)) + data = [] + for i, s in enumerate(sentences): + sentence_str = ''.join(s) + tokens_count = len(s) + data.append([i, sentence_str, tokens_count]) + + return { + sentences_preview: gr.update(value=data, visible=True, type="array"), + } + else: + df = pd.DataFrame([], columns=["序号", "分句内容", "Token数"]) + return { + sentences_preview: gr.update(value=df) + } + + input_text_single.change( + on_input_text_change, + inputs=[input_text_single, max_text_tokens_per_sentence], + outputs=[sentences_preview] + ) + max_text_tokens_per_sentence.change( + on_input_text_change, + inputs=[input_text_single, max_text_tokens_per_sentence], + outputs=[sentences_preview] + ) prompt_audio.upload(update_prompt_audio, inputs=[], outputs=[gen_button]) gen_button.click(gen_single, - inputs=[prompt_audio, input_text_single, infer_mode], + inputs=[prompt_audio, input_text_single, infer_mode, + max_text_tokens_per_sentence, sentences_bucket_max_size, + *advanced_params, + ], outputs=[output_audio]) if __name__ == "__main__": demo.queue(20) - demo.launch(server_name="127.0.0.1") + demo.launch(server_name=cmd_args.host, server_port=cmd_args.port) From c178198ed7b36f88219051e000830f8cbed45d34 Mon Sep 17 00:00:00 2001 From: yrom Date: Sun, 18 May 2025 19:57:11 +0800 Subject: [PATCH 11/11] padding_test.py support model dir for test --- tests/padding_test.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/padding_test.py b/tests/padding_test.py index 9fe418b..fcb67d0 100644 --- a/tests/padding_test.py +++ b/tests/padding_test.py @@ -5,10 +5,23 @@ from indextts.utils.feature_extractors import MelSpectrogramFeatures from torch.nn import functional as F if __name__ == "__main__": + """ + Test the padding of text tokens in inference. + ``` + python tests/padding_test.py checkpoints + python tests/padding_test.py IndexTTS-1.5 + ``` + """ import transformers transformers.set_seed(42) + import sys + sys.path.append("..") + if len(sys.argv) > 1: + model_dir = sys.argv[1] + else: + model_dir = "checkpoints" audio_prompt="tests/sample_prompt.wav" - tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False) + tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False) text = "晕 XUAN4 是 一 种 not very good GAN3 觉" text_tokens = tts.tokenizer.encode(text) text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L]