#!/usr/bin/env python3 import os import sys # 关键:强制离线模式,避免连接 HuggingFace os.environ["HF_HUB_OFFLINE"] = "1" os.environ["TRANSFORMERS_OFFLINE"] = "1" os.environ["HF_DATASETS_OFFLINE"] = "1" import traceback from pathlib import Path def log_error(msg): print(f"[ERROR] {msg}", file=sys.stderr) def log_info(msg): print(f"[INFO] {msg}") try: log_info("Starting API server (OFFLINE MODE)...") log_info(f"Python version: {sys.version}") from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import FileResponse import uvicorn import tempfile log_info("Dependencies imported successfully") sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from indextts.infer_v2 import IndexTTS2 from transformers import SeamlessM4TFeatureExtractor log_info("IndexTTS2 imported successfully") # 检查模型路径 cfg_path = "checkpoints/config.yaml" model_dir = "checkpoints" if not os.path.exists(cfg_path): log_error(f"Config file not found: {cfg_path}") sys.exit(1) log_info("Initializing model...") # DGX Spark 优化:禁用 CUDA 内核编译(因为 sm_121 不支持 compute_70) # 强制使用纯 PyTorch 模式 os.environ["DISABLE_BIGVGAN_CUDA_KERNEL"] = "1" local_w2v_path = "./models/w2v-bert-2.0" if os.path.exists(local_w2v_path): log_info(f"Loading w2v-bert-2.0 from local path: {local_w2v_path}") feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(local_w2v_path) log_info("Feature extractor loaded locally") else: log_info("Local w2v-bert-2.0 not found, will try to load from cache...") # 初始化 IndexTTS2 try: tts = IndexTTS2( cfg_path=cfg_path, model_dir=model_dir, use_fp16=True, # DGX Spark 支持 FP16 use_cuda_kernel=False # 关键:禁用自定义 CUDA 内核,避免 sm_121 编译错误 ) log_info("Model loaded successfully!") except Exception as e: log_error(f"Model initialization failed: {e}") traceback.print_exc() sys.exit(1) app = FastAPI(title="IndexTTS2 API (DGX Spark)") @app.post("/tts") async def text_to_speech( text: str = Form(...), spk_audio: UploadFile = File(...), emo_audio: UploadFile = None, emo_alpha: float = Form(0.8) ): log_info(f"Processing: {text[:50]}...") with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as spk_tmp: content = await spk_audio.read() spk_tmp.write(content) spk_path = spk_tmp.name emo_path = None if emo_audio: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as emo_tmp: content = await emo_audio.read() emo_tmp.write(content) emo_path = emo_tmp.name output_path = tempfile.mktemp(suffix=".wav") try: tts.infer( spk_audio_prompt=spk_path, text=text, output_path=output_path, emo_audio_prompt=emo_path, emo_alpha=emo_alpha ) return FileResponse( output_path, media_type="audio/wav", filename="generated.wav" ) finally: if os.path.exists(spk_path): os.unlink(spk_path) if emo_path and os.path.exists(emo_path): os.unlink(emo_path) @app.get("/health") def health_check(): return {"status": "ok", "model": "IndexTTS2", "device": "cuda"} @app.get("/") def root(): return {"message": "IndexTTS2 API for DGX Spark", "docs": "/docs"} log_info("Starting server on port 11996...") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=11996) except Exception as e: log_error(f"Fatal error: {e}") traceback.print_exc() sys.exit(1)