Add startup parameters for cuda_kernel (#302)

* Update webui.py

* Update infer_v2.py
This commit is contained in:
十字鱼 2025-09-09 20:31:05 +08:00 committed by GitHub
parent 32f111d906
commit 055a23a12b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 2 deletions

View File

@ -144,7 +144,7 @@ class IndexTTS2:
print(">> campplus_model weights restored from:", campplus_ckpt_path)
bigvgan_name = self.cfg.vocoder.name
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=False)
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=True if self.use_cuda_kernel else False)
self.bigvgan = self.bigvgan.to(self.device)
self.bigvgan.remove_weight_norm()
self.bigvgan.eval()

View File

@ -25,6 +25,7 @@ parser.add_argument("--port", type=int, default=7860, help="Port to run the web
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on")
parser.add_argument("--model_dir", type=str, default="./checkpoints", help="Model checkpoints directory")
parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
parser.add_argument("--cuda_kernel", action="store_true", default=False, help="Use cuda kernel for inference if available")
parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment")
cmd_args = parser.parse_args()
@ -50,7 +51,12 @@ from tools.i18n.i18n import I18nAuto
i18n = I18nAuto(language="Auto")
MODE = 'local'
tts = IndexTTS2(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),use_fp16=cmd_args.fp16)
tts = IndexTTS2(
model_dir=cmd_args.model_dir,
cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
use_fp16=cmd_args.fp16,
use_cuda_kernel=cmd_args.cuda_kernel,
)
# 支持的语言列表
LANGUAGES = {