2026-02-18 02:08:51 +08:00

134 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import os
import sys
# 关键:强制离线模式,避免连接 HuggingFace
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
import traceback
from pathlib import Path
def log_error(msg):
print(f"[ERROR] {msg}", file=sys.stderr)
def log_info(msg):
print(f"[INFO] {msg}")
try:
log_info("Starting API server (OFFLINE MODE)...")
log_info(f"Python version: {sys.version}")
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import FileResponse
import uvicorn
import tempfile
log_info("Dependencies imported successfully")
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from indextts.infer_v2 import IndexTTS2
from transformers import SeamlessM4TFeatureExtractor
log_info("IndexTTS2 imported successfully")
# 检查模型路径
cfg_path = "checkpoints/config.yaml"
model_dir = "checkpoints"
if not os.path.exists(cfg_path):
log_error(f"Config file not found: {cfg_path}")
sys.exit(1)
log_info("Initializing model...")
# DGX Spark 优化:禁用 CUDA 内核编译(因为 sm_121 不支持 compute_70
# 强制使用纯 PyTorch 模式
os.environ["DISABLE_BIGVGAN_CUDA_KERNEL"] = "1"
local_w2v_path = "./models/w2v-bert-2.0"
if os.path.exists(local_w2v_path):
log_info(f"Loading w2v-bert-2.0 from local path: {local_w2v_path}")
feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(local_w2v_path)
log_info("Feature extractor loaded locally")
else:
log_info("Local w2v-bert-2.0 not found, will try to load from cache...")
# 初始化 IndexTTS2
try:
tts = IndexTTS2(
cfg_path=cfg_path,
model_dir=model_dir,
use_fp16=True, # DGX Spark 支持 FP16
use_cuda_kernel=False # 关键:禁用自定义 CUDA 内核,避免 sm_121 编译错误
)
log_info("Model loaded successfully!")
except Exception as e:
log_error(f"Model initialization failed: {e}")
traceback.print_exc()
sys.exit(1)
app = FastAPI(title="IndexTTS2 API (DGX Spark)")
@app.post("/tts")
async def text_to_speech(
text: str = Form(...),
spk_audio: UploadFile = File(...),
emo_audio: UploadFile = None,
emo_alpha: float = Form(0.8)
):
log_info(f"Processing: {text[:50]}...")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as spk_tmp:
content = await spk_audio.read()
spk_tmp.write(content)
spk_path = spk_tmp.name
emo_path = None
if emo_audio:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as emo_tmp:
content = await emo_audio.read()
emo_tmp.write(content)
emo_path = emo_tmp.name
output_path = tempfile.mktemp(suffix=".wav")
try:
tts.infer(
spk_audio_prompt=spk_path,
text=text,
output_path=output_path,
emo_audio_prompt=emo_path,
emo_alpha=emo_alpha
)
return FileResponse(
output_path,
media_type="audio/wav",
filename="generated.wav"
)
finally:
if os.path.exists(spk_path):
os.unlink(spk_path)
if emo_path and os.path.exists(emo_path):
os.unlink(emo_path)
@app.get("/health")
def health_check():
return {"status": "ok", "model": "IndexTTS2", "device": "cuda"}
@app.get("/")
def root():
return {"message": "IndexTTS2 API for DGX Spark", "docs": "/docs"}
log_info("Starting server on port 11996...")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=11996)
except Exception as e:
log_error(f"Fatal error: {e}")
traceback.print_exc()
sys.exit(1)