Merge pull request #516 from storyicon/s2mel_accel

feat: achieve inference acceleration for the s2mel stage (1.61×)
This commit is contained in:
Vanka0051 2025-10-30 15:55:25 +08:00 committed by GitHub
commit e42480ced8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 54 additions and 6 deletions

View File

@ -38,7 +38,7 @@ import torch.nn.functional as F
class IndexTTS2:
def __init__(
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None,
use_cuda_kernel=None,use_deepspeed=False
use_cuda_kernel=None,use_deepspeed=False, use_torch_compile=False
):
"""
Args:
@ -48,6 +48,7 @@ class IndexTTS2:
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
use_deepspeed (bool): whether to use DeepSpeed or not.
use_torch_compile (bool): whether to use torch.compile for optimization or not.
"""
if device is not None:
self.device = device
@ -75,6 +76,7 @@ class IndexTTS2:
self.model_dir = model_dir
self.dtype = torch.float16 if self.use_fp16 else None
self.stop_mel_token = self.cfg.gpt.stop_mel_token
self.use_torch_compile = use_torch_compile
self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
@ -135,6 +137,13 @@ class IndexTTS2:
)
self.s2mel = s2mel.to(self.device)
self.s2mel.models['cfm'].estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
# Enable torch.compile optimization if requested
if self.use_torch_compile:
print(">> Enabling torch.compile optimization")
self.s2mel.enable_torch_compile()
print(">> torch.compile optimization enabled successfully")
self.s2mel.eval()
print(">> s2mel weights restored from:", s2mel_path)
@ -815,6 +824,19 @@ class QwenEmotion:
if __name__ == "__main__":
prompt_wav = "examples/voice_01.wav"
text = '欢迎大家来体验indextts2并给予我们意见与反馈谢谢大家。'
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False)
tts = IndexTTS2(
cfg_path="checkpoints/config.yaml",
model_dir="checkpoints",
use_cuda_kernel=False,
use_torch_compile=True
)
tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
char_size = 5
import string
time_buckets = []
for i in range(10):
text = ''.join(random.choices(string.ascii_letters, k=char_size))
start_time = time.time()
tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
time_buckets.append(time.time() - start_time)
print(time_buckets)

View File

@ -133,8 +133,10 @@ def subsequent_mask(length):
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
# use torch.split to avoid dynamic slicing
t_act_part, s_act_part = torch.split(in_act, n_channels_int, dim=1)
t_act = torch.tanh(t_act_part)
s_act = torch.sigmoid(s_act_part)
acts = t_act * s_act
return acts
@ -437,6 +439,15 @@ class MyModel(nn.Module):
x = self.models['gpt_layer'](x)
return x
def enable_torch_compile(self):
"""Enable torch.compile optimization.
This method applies torch.compile to the model for significant
performance improvements during inference.
"""
if 'cfm' in self.models:
self.models['cfm'].enable_torch_compile()
def build_model(args, stage="DiT"):

View File

@ -233,7 +233,7 @@ class DiT(torch.nn.Module):
if self.time_as_token: # False
x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token, max_length=x_in.size(1)).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
input_pos = self.input_pos[:x_in.size(1)] # (T,) range01863
x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863]
x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512]

View File

@ -169,3 +169,18 @@ class CFM(BASECFM):
self.estimator = DiT(args)
else:
raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
def enable_torch_compile(self):
"""Enable torch.compile optimization for the estimator model.
This method applies torch.compile to the estimator (DiT model) for significant
performance improvements during inference. It also configures distributed
training optimizations if applicable.
"""
if torch.distributed.is_initialized():
torch._inductor.config.reorder_for_compute_comm_overlap = True
self.estimator = torch.compile(
self.estimator,
fullgraph=True,
dynamic=True,
)