Add cli mode for inference

This commit is contained in:
Yrom 2025-04-11 14:57:20 +08:00
parent eff6eb8f43
commit 471a45435c
No known key found for this signature in database
6 changed files with 147 additions and 5 deletions

8
.gitignore vendored
View File

@ -1,4 +1,10 @@
venv/
__pycache__
*.egg-info
*.DS_Store
.idea/
.idea/
checkpoints/*.pth
checkpoints/*.vocab
checkpoints/*.model
checkpoints/.cache
outputs/

View File

@ -103,7 +103,21 @@ conda activate index-tts
pip install -r requirements.txt
apt-get install ffmpeg
```
3. Download models:
Download by `huggingface-cli`:
```bash
# 如果下载速度慢,可以使用官方的镜像
export HF_ENDPOINT="https://hf-mirror.com"
huggingface-cli download IndexTeam/Index-TTS \
bigvgan_discriminator.pth bigvgan_generator.pth bpe.model dvae.pth gpt.pth unigram_12000.vocab \
--local-dir checkpoints
```
Or by `wget`:
```bash
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/bigvgan_discriminator.pth -P checkpoints
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/bigvgan_generator.pth -P checkpoints
@ -112,11 +126,32 @@ wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/dvae.pth -P checkpo
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/gpt.pth -P checkpoints
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/unigram_12000.vocab -P checkpoints
```
4. Run test script:
```bash
# Please put your prompt audio in 'test_data' and rename it to 'input.wav'
PYTHONPATH=. python indextts/infer.py
```
5. Use as command line tool:
```bash
# Make sure pytorch has been installed before running this command
pip install -e .
indextts "大家好我现在正在bilibili 体验 ai 科技说实话来之前我绝对想不到AI技术已经发展到这样匪夷所思的地步了" \
--voice reference_voice.wav \
--model_dir checkpoints \
--config checkpoints/config.yaml \
--output output.wav
```
Use `--help` to see more options.
```bash
indextts --help
```
#### Web Demo
```bash
python webui.py

51
indextts/cli.py Normal file
View File

@ -0,0 +1,51 @@
import os
import sys
import warnings
# Suppress warnings from tensorflow and other libraries
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
def main():
import argparse
parser = argparse.ArgumentParser(description="IndexTTS Command Line")
parser.add_argument("text", type=str, help="Text to be synthesized")
parser.add_argument("-v", "--voice", type=str, required=True, help="Path to the audio prompt file (wav format)")
parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
parser.add_argument("--fp16", action="store_true", default=True, help="Use FP16 for inference if available")
parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
args = parser.parse_args()
if not os.path.exists(args.voice):
print(f"Audio prompt file {args.voice} does not exist.")
parser.print_help()
sys.exit(1)
if not os.path.exists(args.config):
print(f"Config file {args.config} does not exist.")
parser.print_help()
sys.exit(1)
output_path = args.output_path
if os.path.exists(output_path):
if not args.force:
print(f"ERROR: Output file {output_path} already exists. Use --force to overwrite.")
parser.print_help()
sys.exit(1)
else:
os.remove(output_path)
try:
import torch
if not torch.cuda.is_available():
print("WARNING: CUDA is not available. Running in CPU mode.")
except ImportError:
print("ERROR: PyTorch is not installed. Please install it first.")
sys.exit(1)
from indextts.infer import IndexTTS
tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, is_fp16=args.fp16)
tts.infer(audio_prompt=args.voice, text=args.text, output_path=output_path)
if __name__ == "__main__":
main()

View File

@ -17,9 +17,9 @@ accelerate==0.25.0
tensorboard==2.9.1
omegaconf
sentencepiece
pypinyin
librosa
gradio
tqdm
WeTextProcessing # arm机器如果安装失败请注释此行
wetext
WeTextProcessing; platform_machine != "Darwin"
wetext; platform_system == "Darwin"

50
setup.py Normal file
View File

@ -0,0 +1,50 @@
import platform
from setuptools import find_packages, setup
setup(
name="indextts",
version="0.1.0",
author="Index SpeechTeam",
author_email="xuanwu@bilibili.com",
long_description=open("README.md", encoding="utf8").read(),
long_description_content_type="text/markdown",
description="An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System",
url="https://github.com/index-tts/index-tts",
packages=find_packages(),
include_package_data=True,
install_requires=[
"torch==2.6.0",
"torchaudio",
"transformers==4.36.2",
"accelerate",
"tokenizers==0.15.0",
"einops==0.8.1",
"matplotlib==3.8.2",
"omegaconf",
"sentencepiece",
"librosa",
"numpy",
"wetext" if platform.system() == "Darwin" else "WeTextProcessing",
],
extras_require={
"webui": ["gradio"],
},
entry_points={
"console_scripts": [
"indextts = indextts.cli:main",
]
},
license="Apache-2.0",
python_requires=">=3.10",
classifiers=[
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
"License :: OSI Approved :: Apache Software License",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)

View File

@ -68,4 +68,4 @@ with gr.Blocks() as demo:
if __name__ == "__main__":
demo.queue(20)
demo.launch(server_name="0.0.0.0")
demo.launch(server_name="127.0.0.1")