From 31e7e855e21779c5ad0cd6b7b0ae4329c0c91ec2 Mon Sep 17 00:00:00 2001 From: storyicon Date: Fri, 24 Oct 2025 07:30:20 +0000 Subject: [PATCH] feat: optimize s2mel stage Signed-off-by: storyicon --- indextts/infer_v2.py | 28 +++++++++++++++++-- indextts/s2mel/modules/commons.py | 15 ++++++++-- .../s2mel/modules/diffusion_transformer.py | 2 +- indextts/s2mel/modules/flow_matching.py | 15 ++++++++++ 4 files changed, 54 insertions(+), 6 deletions(-) diff --git a/indextts/infer_v2.py b/indextts/infer_v2.py index a4dde5b..bf5e8cc 100644 --- a/indextts/infer_v2.py +++ b/indextts/infer_v2.py @@ -38,7 +38,7 @@ import torch.nn.functional as F class IndexTTS2: def __init__( self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None, - use_cuda_kernel=None,use_deepspeed=False + use_cuda_kernel=None,use_deepspeed=False, use_torch_compile=False ): """ Args: @@ -48,6 +48,7 @@ class IndexTTS2: device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS. use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device. use_deepspeed (bool): whether to use DeepSpeed or not. + use_torch_compile (bool): whether to use torch.compile for optimization or not. """ if device is not None: self.device = device @@ -75,6 +76,7 @@ class IndexTTS2: self.model_dir = model_dir self.dtype = torch.float16 if self.use_fp16 else None self.stop_mel_token = self.cfg.gpt.stop_mel_token + self.use_torch_compile = use_torch_compile self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path)) @@ -135,6 +137,13 @@ class IndexTTS2: ) self.s2mel = s2mel.to(self.device) self.s2mel.models['cfm'].estimator.setup_caches(max_batch_size=1, max_seq_length=8192) + + # Enable torch.compile optimization if requested + if self.use_torch_compile: + print(">> Enabling torch.compile optimization") + self.s2mel.enable_torch_compile() + print(">> torch.compile optimization enabled successfully") + self.s2mel.eval() print(">> s2mel weights restored from:", s2mel_path) @@ -815,6 +824,19 @@ class QwenEmotion: if __name__ == "__main__": prompt_wav = "examples/voice_01.wav" text = '欢迎大家来体验indextts2,并给予我们意见与反馈,谢谢大家。' - - tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False) + tts = IndexTTS2( + cfg_path="checkpoints/config.yaml", + model_dir="checkpoints", + use_cuda_kernel=False, + use_torch_compile=True + ) tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True) + char_size = 5 + import string + time_buckets = [] + for i in range(10): + text = ''.join(random.choices(string.ascii_letters, k=char_size)) + start_time = time.time() + tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True) + time_buckets.append(time.time() - start_time) + print(time_buckets) diff --git a/indextts/s2mel/modules/commons.py b/indextts/s2mel/modules/commons.py index 35fced4..63c14b3 100644 --- a/indextts/s2mel/modules/commons.py +++ b/indextts/s2mel/modules/commons.py @@ -133,8 +133,10 @@ def subsequent_mask(length): def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): n_channels_int = n_channels[0] in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + # use torch.split to avoid dynamic slicing + t_act_part, s_act_part = torch.split(in_act, n_channels_int, dim=1) + t_act = torch.tanh(t_act_part) + s_act = torch.sigmoid(s_act_part) acts = t_act * s_act return acts @@ -437,6 +439,15 @@ class MyModel(nn.Module): x = self.models['gpt_layer'](x) return x + def enable_torch_compile(self): + """Enable torch.compile optimization. + + This method applies torch.compile to the model for significant + performance improvements during inference. + """ + if 'cfm' in self.models: + self.models['cfm'].enable_torch_compile() + def build_model(args, stage="DiT"): diff --git a/indextts/s2mel/modules/diffusion_transformer.py b/indextts/s2mel/modules/diffusion_transformer.py index 23d6912..1606a40 100644 --- a/indextts/s2mel/modules/diffusion_transformer.py +++ b/indextts/s2mel/modules/diffusion_transformer.py @@ -233,7 +233,7 @@ class DiT(torch.nn.Module): if self.time_as_token: # False x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1) - x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True + x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token, max_length=x_in.size(1)).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True input_pos = self.input_pos[:x_in.size(1)] # (T,) range(0,1863) x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863] x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512] diff --git a/indextts/s2mel/modules/flow_matching.py b/indextts/s2mel/modules/flow_matching.py index c396695..245c71f 100644 --- a/indextts/s2mel/modules/flow_matching.py +++ b/indextts/s2mel/modules/flow_matching.py @@ -169,3 +169,18 @@ class CFM(BASECFM): self.estimator = DiT(args) else: raise NotImplementedError(f"Unknown diffusion type {args.dit_type}") + + def enable_torch_compile(self): + """Enable torch.compile optimization for the estimator model. + + This method applies torch.compile to the estimator (DiT model) for significant + performance improvements during inference. It also configures distributed + training optimizations if applicable. + """ + if torch.distributed.is_initialized(): + torch._inductor.config.reorder_for_compute_comm_overlap = True + self.estimator = torch.compile( + self.estimator, + fullgraph=True, + dynamic=True, + )