From 055a23a12b969edecb3d515abef4232f77eca44f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=81=E5=AD=97=E9=B1=BC?= <52977964+gluttony-10@users.noreply.github.com> Date: Tue, 9 Sep 2025 20:31:05 +0800 Subject: [PATCH] Add startup parameters for cuda_kernel (#302) * Update webui.py * Update infer_v2.py --- indextts/infer_v2.py | 2 +- webui.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/indextts/infer_v2.py b/indextts/infer_v2.py index a52b5cb..842f50e 100644 --- a/indextts/infer_v2.py +++ b/indextts/infer_v2.py @@ -144,7 +144,7 @@ class IndexTTS2: print(">> campplus_model weights restored from:", campplus_ckpt_path) bigvgan_name = self.cfg.vocoder.name - self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=False) + self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=True if self.use_cuda_kernel else False) self.bigvgan = self.bigvgan.to(self.device) self.bigvgan.remove_weight_norm() self.bigvgan.eval() diff --git a/webui.py b/webui.py index 657a298..705b126 100644 --- a/webui.py +++ b/webui.py @@ -25,6 +25,7 @@ parser.add_argument("--port", type=int, default=7860, help="Port to run the web parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on") parser.add_argument("--model_dir", type=str, default="./checkpoints", help="Model checkpoints directory") parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available") +parser.add_argument("--cuda_kernel", action="store_true", default=False, help="Use cuda kernel for inference if available") parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment") cmd_args = parser.parse_args() @@ -50,7 +51,12 @@ from tools.i18n.i18n import I18nAuto i18n = I18nAuto(language="Auto") MODE = 'local' -tts = IndexTTS2(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),use_fp16=cmd_args.fp16) +tts = IndexTTS2( + model_dir=cmd_args.model_dir, + cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"), + use_fp16=cmd_args.fp16, + use_cuda_kernel=cmd_args.cuda_kernel, +) # 支持的语言列表 LANGUAGES = {