diff --git a/api_server.py b/api_server.py new file mode 100644 index 0000000..51a4e0a --- /dev/null +++ b/api_server.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +import os +import sys + +# 关键:强制离线模式,避免连接 HuggingFace +os.environ["HF_HUB_OFFLINE"] = "1" +os.environ["TRANSFORMERS_OFFLINE"] = "1" +os.environ["HF_DATASETS_OFFLINE"] = "1" + +import traceback +from pathlib import Path + +def log_error(msg): + print(f"[ERROR] {msg}", file=sys.stderr) + +def log_info(msg): + print(f"[INFO] {msg}") + +try: + log_info("Starting API server (OFFLINE MODE)...") + log_info(f"Python version: {sys.version}") + + from fastapi import FastAPI, UploadFile, File, Form + from fastapi.responses import FileResponse + import uvicorn + import tempfile + + log_info("Dependencies imported successfully") + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from indextts.infer_v2 import IndexTTS2 + from transformers import SeamlessM4TFeatureExtractor + + log_info("IndexTTS2 imported successfully") + + # 检查模型路径 + cfg_path = "checkpoints/config.yaml" + model_dir = "checkpoints" + + if not os.path.exists(cfg_path): + log_error(f"Config file not found: {cfg_path}") + sys.exit(1) + + log_info("Initializing model...") + + # DGX Spark 优化:禁用 CUDA 内核编译(因为 sm_121 不支持 compute_70) + # 强制使用纯 PyTorch 模式 + os.environ["DISABLE_BIGVGAN_CUDA_KERNEL"] = "1" + + local_w2v_path = "./models/w2v-bert-2.0" + + if os.path.exists(local_w2v_path): + log_info(f"Loading w2v-bert-2.0 from local path: {local_w2v_path}") + feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(local_w2v_path) + log_info("Feature extractor loaded locally") + else: + log_info("Local w2v-bert-2.0 not found, will try to load from cache...") + + # 初始化 IndexTTS2 + try: + tts = IndexTTS2( + cfg_path=cfg_path, + model_dir=model_dir, + use_fp16=True, # DGX Spark 支持 FP16 + use_cuda_kernel=False # 关键:禁用自定义 CUDA 内核,避免 sm_121 编译错误 + ) + log_info("Model loaded successfully!") + except Exception as e: + log_error(f"Model initialization failed: {e}") + traceback.print_exc() + sys.exit(1) + + app = FastAPI(title="IndexTTS2 API (DGX Spark)") + + @app.post("/tts") + async def text_to_speech( + text: str = Form(...), + spk_audio: UploadFile = File(...), + emo_audio: UploadFile = None, + emo_alpha: float = Form(0.8) + ): + log_info(f"Processing: {text[:50]}...") + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as spk_tmp: + content = await spk_audio.read() + spk_tmp.write(content) + spk_path = spk_tmp.name + + emo_path = None + if emo_audio: + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as emo_tmp: + content = await emo_audio.read() + emo_tmp.write(content) + emo_path = emo_tmp.name + + output_path = tempfile.mktemp(suffix=".wav") + + try: + tts.infer( + spk_audio_prompt=spk_path, + text=text, + output_path=output_path, + emo_audio_prompt=emo_path, + emo_alpha=emo_alpha + ) + + return FileResponse( + output_path, + media_type="audio/wav", + filename="generated.wav" + ) + finally: + if os.path.exists(spk_path): + os.unlink(spk_path) + if emo_path and os.path.exists(emo_path): + os.unlink(emo_path) + + @app.get("/health") + def health_check(): + return {"status": "ok", "model": "IndexTTS2", "device": "cuda"} + + @app.get("/") + def root(): + return {"message": "IndexTTS2 API for DGX Spark", "docs": "/docs"} + + log_info("Starting server on port 11996...") + + if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=11996) + +except Exception as e: + log_error(f"Fatal error: {e}") + traceback.print_exc() + sys.exit(1) \ No newline at end of file diff --git a/indextts/infer_v2.py b/indextts/infer_v2.py index c5b85fb..c5415d6 100644 --- a/indextts/infer_v2.py +++ b/indextts/infer_v2.py @@ -113,7 +113,7 @@ class IndexTTS2: print(f"{e!r}") self.use_cuda_kernel = False - self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") + self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("./models/w2v-bert-2.0") self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model( os.path.join(self.model_dir, self.cfg.w2v_stat)) self.semantic_model = self.semantic_model.to(self.device) @@ -122,7 +122,8 @@ class IndexTTS2: self.semantic_std = self.semantic_std.to(self.device) semantic_codec = build_semantic_codec(self.cfg.semantic_codec) - semantic_code_ckpt = hf_hub_download("amphion/MaskGCT", filename="semantic_codec/model.safetensors") + # semantic_code_ckpt = hf_hub_download("./models/MaskGCT", filename="./models/MaskGCT/semantic_codec/model.safetensors") + semantic_code_ckpt = "./models/MaskGCT/semantic_codec/model.safetensors" safetensors.torch.load_model(semantic_codec, semantic_code_ckpt) self.semantic_codec = semantic_codec.to(self.device) self.semantic_codec.eval() @@ -151,9 +152,8 @@ class IndexTTS2: print(">> s2mel weights restored from:", s2mel_path) # load campplus_model - campplus_ckpt_path = hf_hub_download( - "funasr/campplus", filename="campplus_cn_common.bin" - ) + # campplus_ckpt_path = hf_hub_download("funasr/campplus", filename="campplus_cn_common.bin") + campplus_ckpt_path = "./models/campplus/campplus_cn_common.bin" campplus_model = CAMPPlus(feat_dim=80, embedding_size=192) campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu")) self.campplus_model = campplus_model.to(self.device) diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps index aae7506..e024e2d 100644 Binary files a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps and b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps differ diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log index 773d44e..b1be289 100644 --- a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log @@ -1,2 +1,3 @@ # ninja log v7 0 8799 1771051743097851885 anti_alias_activation.o 6d4acaa78fc0c336 +0 9806 1771307045391159629 anti_alias_activation.o 6d4acaa78fc0c336 diff --git a/indextts/utils/maskgct_utils.py b/indextts/utils/maskgct_utils.py index 40b9cb0..99c0e07 100644 --- a/indextts/utils/maskgct_utils.py +++ b/indextts/utils/maskgct_utils.py @@ -85,7 +85,7 @@ class JsonHParams: def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'): - semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0") + semantic_model = Wav2Vec2BertModel.from_pretrained("./models/w2v-bert-2.0") semantic_model.eval() stat_mean_var = torch.load(path_) semantic_mean = stat_mean_var["mean"]