这是常用的一版

This commit is contained in:
cloyir 2026-02-18 02:08:51 +08:00
parent d7ba0be4ed
commit 5038282fe6
5 changed files with 141 additions and 6 deletions

134
api_server.py Normal file
View File

@ -0,0 +1,134 @@
#!/usr/bin/env python3
import os
import sys
# 关键:强制离线模式,避免连接 HuggingFace
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
import traceback
from pathlib import Path
def log_error(msg):
print(f"[ERROR] {msg}", file=sys.stderr)
def log_info(msg):
print(f"[INFO] {msg}")
try:
log_info("Starting API server (OFFLINE MODE)...")
log_info(f"Python version: {sys.version}")
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import FileResponse
import uvicorn
import tempfile
log_info("Dependencies imported successfully")
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from indextts.infer_v2 import IndexTTS2
from transformers import SeamlessM4TFeatureExtractor
log_info("IndexTTS2 imported successfully")
# 检查模型路径
cfg_path = "checkpoints/config.yaml"
model_dir = "checkpoints"
if not os.path.exists(cfg_path):
log_error(f"Config file not found: {cfg_path}")
sys.exit(1)
log_info("Initializing model...")
# DGX Spark 优化:禁用 CUDA 内核编译(因为 sm_121 不支持 compute_70
# 强制使用纯 PyTorch 模式
os.environ["DISABLE_BIGVGAN_CUDA_KERNEL"] = "1"
local_w2v_path = "./models/w2v-bert-2.0"
if os.path.exists(local_w2v_path):
log_info(f"Loading w2v-bert-2.0 from local path: {local_w2v_path}")
feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(local_w2v_path)
log_info("Feature extractor loaded locally")
else:
log_info("Local w2v-bert-2.0 not found, will try to load from cache...")
# 初始化 IndexTTS2
try:
tts = IndexTTS2(
cfg_path=cfg_path,
model_dir=model_dir,
use_fp16=True, # DGX Spark 支持 FP16
use_cuda_kernel=False # 关键:禁用自定义 CUDA 内核,避免 sm_121 编译错误
)
log_info("Model loaded successfully!")
except Exception as e:
log_error(f"Model initialization failed: {e}")
traceback.print_exc()
sys.exit(1)
app = FastAPI(title="IndexTTS2 API (DGX Spark)")
@app.post("/tts")
async def text_to_speech(
text: str = Form(...),
spk_audio: UploadFile = File(...),
emo_audio: UploadFile = None,
emo_alpha: float = Form(0.8)
):
log_info(f"Processing: {text[:50]}...")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as spk_tmp:
content = await spk_audio.read()
spk_tmp.write(content)
spk_path = spk_tmp.name
emo_path = None
if emo_audio:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as emo_tmp:
content = await emo_audio.read()
emo_tmp.write(content)
emo_path = emo_tmp.name
output_path = tempfile.mktemp(suffix=".wav")
try:
tts.infer(
spk_audio_prompt=spk_path,
text=text,
output_path=output_path,
emo_audio_prompt=emo_path,
emo_alpha=emo_alpha
)
return FileResponse(
output_path,
media_type="audio/wav",
filename="generated.wav"
)
finally:
if os.path.exists(spk_path):
os.unlink(spk_path)
if emo_path and os.path.exists(emo_path):
os.unlink(emo_path)
@app.get("/health")
def health_check():
return {"status": "ok", "model": "IndexTTS2", "device": "cuda"}
@app.get("/")
def root():
return {"message": "IndexTTS2 API for DGX Spark", "docs": "/docs"}
log_info("Starting server on port 11996...")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=11996)
except Exception as e:
log_error(f"Fatal error: {e}")
traceback.print_exc()
sys.exit(1)

View File

@ -113,7 +113,7 @@ class IndexTTS2:
print(f"{e!r}") print(f"{e!r}")
self.use_cuda_kernel = False self.use_cuda_kernel = False
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("./models/w2v-bert-2.0")
self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model( self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model(
os.path.join(self.model_dir, self.cfg.w2v_stat)) os.path.join(self.model_dir, self.cfg.w2v_stat))
self.semantic_model = self.semantic_model.to(self.device) self.semantic_model = self.semantic_model.to(self.device)
@ -122,7 +122,8 @@ class IndexTTS2:
self.semantic_std = self.semantic_std.to(self.device) self.semantic_std = self.semantic_std.to(self.device)
semantic_codec = build_semantic_codec(self.cfg.semantic_codec) semantic_codec = build_semantic_codec(self.cfg.semantic_codec)
semantic_code_ckpt = hf_hub_download("amphion/MaskGCT", filename="semantic_codec/model.safetensors") # semantic_code_ckpt = hf_hub_download("./models/MaskGCT", filename="./models/MaskGCT/semantic_codec/model.safetensors")
semantic_code_ckpt = "./models/MaskGCT/semantic_codec/model.safetensors"
safetensors.torch.load_model(semantic_codec, semantic_code_ckpt) safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
self.semantic_codec = semantic_codec.to(self.device) self.semantic_codec = semantic_codec.to(self.device)
self.semantic_codec.eval() self.semantic_codec.eval()
@ -151,9 +152,8 @@ class IndexTTS2:
print(">> s2mel weights restored from:", s2mel_path) print(">> s2mel weights restored from:", s2mel_path)
# load campplus_model # load campplus_model
campplus_ckpt_path = hf_hub_download( # campplus_ckpt_path = hf_hub_download("funasr/campplus", filename="campplus_cn_common.bin")
"funasr/campplus", filename="campplus_cn_common.bin" campplus_ckpt_path = "./models/campplus/campplus_cn_common.bin"
)
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192) campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu")) campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
self.campplus_model = campplus_model.to(self.device) self.campplus_model = campplus_model.to(self.device)

View File

@ -1,2 +1,3 @@
# ninja log v7 # ninja log v7
0 8799 1771051743097851885 anti_alias_activation.o 6d4acaa78fc0c336 0 8799 1771051743097851885 anti_alias_activation.o 6d4acaa78fc0c336
0 9806 1771307045391159629 anti_alias_activation.o 6d4acaa78fc0c336

View File

@ -85,7 +85,7 @@ class JsonHParams:
def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'): def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'):
semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0") semantic_model = Wav2Vec2BertModel.from_pretrained("./models/w2v-bert-2.0")
semantic_model.eval() semantic_model.eval()
stat_mean_var = torch.load(path_) stat_mean_var = torch.load(path_)
semantic_mean = stat_mean_var["mean"] semantic_mean = stat_mean_var["mean"]