Merge branch 'main' into gpt2_accel
This commit is contained in:
commit
5d67f6271b
@ -38,7 +38,7 @@ import torch.nn.functional as F
|
||||
class IndexTTS2:
|
||||
def __init__(
|
||||
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None,
|
||||
use_cuda_kernel=None,use_deepspeed=False, use_accel=False
|
||||
use_cuda_kernel=None,use_deepspeed=False, use_accel=False, use_torch_compile=False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -49,6 +49,7 @@ class IndexTTS2:
|
||||
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
|
||||
use_deepspeed (bool): whether to use DeepSpeed or not.
|
||||
use_accel (bool): whether to use acceleration engine for GPT2 or not.
|
||||
use_torch_compile (bool): whether to use torch.compile for optimization or not.
|
||||
"""
|
||||
if device is not None:
|
||||
self.device = device
|
||||
@ -77,6 +78,7 @@ class IndexTTS2:
|
||||
self.dtype = torch.float16 if self.use_fp16 else None
|
||||
self.stop_mel_token = self.cfg.gpt.stop_mel_token
|
||||
self.use_accel = use_accel
|
||||
self.use_torch_compile = use_torch_compile
|
||||
|
||||
self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
|
||||
|
||||
@ -137,6 +139,13 @@ class IndexTTS2:
|
||||
)
|
||||
self.s2mel = s2mel.to(self.device)
|
||||
self.s2mel.models['cfm'].estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
|
||||
|
||||
# Enable torch.compile optimization if requested
|
||||
if self.use_torch_compile:
|
||||
print(">> Enabling torch.compile optimization")
|
||||
self.s2mel.enable_torch_compile()
|
||||
print(">> torch.compile optimization enabled successfully")
|
||||
|
||||
self.s2mel.eval()
|
||||
print(">> s2mel weights restored from:", s2mel_path)
|
||||
|
||||
@ -818,6 +827,19 @@ class QwenEmotion:
|
||||
if __name__ == "__main__":
|
||||
prompt_wav = "examples/voice_01.wav"
|
||||
text = '欢迎大家来体验indextts2,并给予我们意见与反馈,谢谢大家。'
|
||||
|
||||
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False)
|
||||
tts = IndexTTS2(
|
||||
cfg_path="checkpoints/config.yaml",
|
||||
model_dir="checkpoints",
|
||||
use_cuda_kernel=False,
|
||||
use_torch_compile=True
|
||||
)
|
||||
tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
|
||||
char_size = 5
|
||||
import string
|
||||
time_buckets = []
|
||||
for i in range(10):
|
||||
text = ''.join(random.choices(string.ascii_letters, k=char_size))
|
||||
start_time = time.time()
|
||||
tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
|
||||
time_buckets.append(time.time() - start_time)
|
||||
print(time_buckets)
|
||||
|
||||
@ -133,8 +133,10 @@ def subsequent_mask(length):
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
# use torch.split to avoid dynamic slicing
|
||||
t_act_part, s_act_part = torch.split(in_act, n_channels_int, dim=1)
|
||||
t_act = torch.tanh(t_act_part)
|
||||
s_act = torch.sigmoid(s_act_part)
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
@ -437,6 +439,15 @@ class MyModel(nn.Module):
|
||||
x = self.models['gpt_layer'](x)
|
||||
return x
|
||||
|
||||
def enable_torch_compile(self):
|
||||
"""Enable torch.compile optimization.
|
||||
|
||||
This method applies torch.compile to the model for significant
|
||||
performance improvements during inference.
|
||||
"""
|
||||
if 'cfm' in self.models:
|
||||
self.models['cfm'].enable_torch_compile()
|
||||
|
||||
|
||||
|
||||
def build_model(args, stage="DiT"):
|
||||
|
||||
@ -233,7 +233,7 @@ class DiT(torch.nn.Module):
|
||||
if self.time_as_token: # False
|
||||
x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
|
||||
|
||||
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
|
||||
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token, max_length=x_in.size(1)).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
|
||||
input_pos = self.input_pos[:x_in.size(1)] # (T,) range(0,1863)
|
||||
x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863]
|
||||
x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512]
|
||||
|
||||
@ -169,3 +169,18 @@ class CFM(BASECFM):
|
||||
self.estimator = DiT(args)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
|
||||
|
||||
def enable_torch_compile(self):
|
||||
"""Enable torch.compile optimization for the estimator model.
|
||||
|
||||
This method applies torch.compile to the estimator (DiT model) for significant
|
||||
performance improvements during inference. It also configures distributed
|
||||
training optimizations if applicable.
|
||||
"""
|
||||
if torch.distributed.is_initialized():
|
||||
torch._inductor.config.reorder_for_compute_comm_overlap = True
|
||||
self.estimator = torch.compile(
|
||||
self.estimator,
|
||||
fullgraph=True,
|
||||
dynamic=True,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user