134 lines
4.1 KiB
Python
134 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
||
import os
|
||
import sys
|
||
|
||
# 关键:强制离线模式,避免连接 HuggingFace
|
||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
||
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
||
|
||
import traceback
|
||
from pathlib import Path
|
||
|
||
def log_error(msg):
|
||
print(f"[ERROR] {msg}", file=sys.stderr)
|
||
|
||
def log_info(msg):
|
||
print(f"[INFO] {msg}")
|
||
|
||
try:
|
||
log_info("Starting API server (OFFLINE MODE)...")
|
||
log_info(f"Python version: {sys.version}")
|
||
|
||
from fastapi import FastAPI, UploadFile, File, Form
|
||
from fastapi.responses import FileResponse
|
||
import uvicorn
|
||
import tempfile
|
||
|
||
log_info("Dependencies imported successfully")
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
from indextts.infer_v2 import IndexTTS2
|
||
from transformers import SeamlessM4TFeatureExtractor
|
||
|
||
log_info("IndexTTS2 imported successfully")
|
||
|
||
# 检查模型路径
|
||
cfg_path = "checkpoints/config.yaml"
|
||
model_dir = "checkpoints"
|
||
|
||
if not os.path.exists(cfg_path):
|
||
log_error(f"Config file not found: {cfg_path}")
|
||
sys.exit(1)
|
||
|
||
log_info("Initializing model...")
|
||
|
||
# DGX Spark 优化:禁用 CUDA 内核编译(因为 sm_121 不支持 compute_70)
|
||
# 强制使用纯 PyTorch 模式
|
||
os.environ["DISABLE_BIGVGAN_CUDA_KERNEL"] = "1"
|
||
|
||
local_w2v_path = "./models/w2v-bert-2.0"
|
||
|
||
if os.path.exists(local_w2v_path):
|
||
log_info(f"Loading w2v-bert-2.0 from local path: {local_w2v_path}")
|
||
feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(local_w2v_path)
|
||
log_info("Feature extractor loaded locally")
|
||
else:
|
||
log_info("Local w2v-bert-2.0 not found, will try to load from cache...")
|
||
|
||
# 初始化 IndexTTS2
|
||
try:
|
||
tts = IndexTTS2(
|
||
cfg_path=cfg_path,
|
||
model_dir=model_dir,
|
||
use_fp16=True, # DGX Spark 支持 FP16
|
||
use_cuda_kernel=False # 关键:禁用自定义 CUDA 内核,避免 sm_121 编译错误
|
||
)
|
||
log_info("Model loaded successfully!")
|
||
except Exception as e:
|
||
log_error(f"Model initialization failed: {e}")
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
app = FastAPI(title="IndexTTS2 API (DGX Spark)")
|
||
|
||
@app.post("/tts")
|
||
async def text_to_speech(
|
||
text: str = Form(...),
|
||
spk_audio: UploadFile = File(...),
|
||
emo_audio: UploadFile = None,
|
||
emo_alpha: float = Form(0.8)
|
||
):
|
||
log_info(f"Processing: {text[:50]}...")
|
||
|
||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as spk_tmp:
|
||
content = await spk_audio.read()
|
||
spk_tmp.write(content)
|
||
spk_path = spk_tmp.name
|
||
|
||
emo_path = None
|
||
if emo_audio:
|
||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as emo_tmp:
|
||
content = await emo_audio.read()
|
||
emo_tmp.write(content)
|
||
emo_path = emo_tmp.name
|
||
|
||
output_path = tempfile.mktemp(suffix=".wav")
|
||
|
||
try:
|
||
tts.infer(
|
||
spk_audio_prompt=spk_path,
|
||
text=text,
|
||
output_path=output_path,
|
||
emo_audio_prompt=emo_path,
|
||
emo_alpha=emo_alpha
|
||
)
|
||
|
||
return FileResponse(
|
||
output_path,
|
||
media_type="audio/wav",
|
||
filename="generated.wav"
|
||
)
|
||
finally:
|
||
if os.path.exists(spk_path):
|
||
os.unlink(spk_path)
|
||
if emo_path and os.path.exists(emo_path):
|
||
os.unlink(emo_path)
|
||
|
||
@app.get("/health")
|
||
def health_check():
|
||
return {"status": "ok", "model": "IndexTTS2", "device": "cuda"}
|
||
|
||
@app.get("/")
|
||
def root():
|
||
return {"message": "IndexTTS2 API for DGX Spark", "docs": "/docs"}
|
||
|
||
log_info("Starting server on port 11996...")
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=11996)
|
||
|
||
except Exception as e:
|
||
log_error(f"Fatal error: {e}")
|
||
traceback.print_exc()
|
||
sys.exit(1) |