Add cli mode for inference
This commit is contained in:
parent
eff6eb8f43
commit
471a45435c
8
.gitignore
vendored
8
.gitignore
vendored
@ -1,4 +1,10 @@
|
||||
venv/
|
||||
__pycache__
|
||||
*.egg-info
|
||||
*.DS_Store
|
||||
.idea/
|
||||
.idea/
|
||||
checkpoints/*.pth
|
||||
checkpoints/*.vocab
|
||||
checkpoints/*.model
|
||||
checkpoints/.cache
|
||||
outputs/
|
||||
|
||||
35
README.md
35
README.md
@ -103,7 +103,21 @@ conda activate index-tts
|
||||
pip install -r requirements.txt
|
||||
apt-get install ffmpeg
|
||||
```
|
||||
|
||||
3. Download models:
|
||||
|
||||
Download by `huggingface-cli`:
|
||||
|
||||
```bash
|
||||
# 如果下载速度慢,可以使用官方的镜像
|
||||
export HF_ENDPOINT="https://hf-mirror.com"
|
||||
huggingface-cli download IndexTeam/Index-TTS \
|
||||
bigvgan_discriminator.pth bigvgan_generator.pth bpe.model dvae.pth gpt.pth unigram_12000.vocab \
|
||||
--local-dir checkpoints
|
||||
```
|
||||
|
||||
Or by `wget`:
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/bigvgan_discriminator.pth -P checkpoints
|
||||
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/bigvgan_generator.pth -P checkpoints
|
||||
@ -112,11 +126,32 @@ wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/dvae.pth -P checkpo
|
||||
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/gpt.pth -P checkpoints
|
||||
wget https://huggingface.co/IndexTeam/Index-TTS/resolve/main/unigram_12000.vocab -P checkpoints
|
||||
```
|
||||
|
||||
4. Run test script:
|
||||
|
||||
|
||||
```bash
|
||||
# Please put your prompt audio in 'test_data' and rename it to 'input.wav'
|
||||
PYTHONPATH=. python indextts/infer.py
|
||||
```
|
||||
|
||||
5. Use as command line tool:
|
||||
|
||||
```bash
|
||||
# Make sure pytorch has been installed before running this command
|
||||
pip install -e .
|
||||
indextts "大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!" \
|
||||
--voice reference_voice.wav \
|
||||
--model_dir checkpoints \
|
||||
--config checkpoints/config.yaml \
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
Use `--help` to see more options.
|
||||
```bash
|
||||
indextts --help
|
||||
```
|
||||
|
||||
#### Web Demo
|
||||
```bash
|
||||
python webui.py
|
||||
|
||||
51
indextts/cli.py
Normal file
51
indextts/cli.py
Normal file
@ -0,0 +1,51 @@
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
# Suppress warnings from tensorflow and other libraries
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="IndexTTS Command Line")
|
||||
parser.add_argument("text", type=str, help="Text to be synthesized")
|
||||
parser.add_argument("-v", "--voice", type=str, required=True, help="Path to the audio prompt file (wav format)")
|
||||
parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
|
||||
parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
|
||||
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
|
||||
parser.add_argument("--fp16", action="store_true", default=True, help="Use FP16 for inference if available")
|
||||
parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.voice):
|
||||
print(f"Audio prompt file {args.voice} does not exist.")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
if not os.path.exists(args.config):
|
||||
print(f"Config file {args.config} does not exist.")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
output_path = args.output_path
|
||||
if os.path.exists(output_path):
|
||||
if not args.force:
|
||||
print(f"ERROR: Output file {output_path} already exists. Use --force to overwrite.")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
else:
|
||||
os.remove(output_path)
|
||||
|
||||
try:
|
||||
import torch
|
||||
if not torch.cuda.is_available():
|
||||
print("WARNING: CUDA is not available. Running in CPU mode.")
|
||||
except ImportError:
|
||||
print("ERROR: PyTorch is not installed. Please install it first.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
from indextts.infer import IndexTTS
|
||||
tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, is_fp16=args.fp16)
|
||||
tts.infer(audio_prompt=args.voice, text=args.text, output_path=output_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -17,9 +17,9 @@ accelerate==0.25.0
|
||||
tensorboard==2.9.1
|
||||
omegaconf
|
||||
sentencepiece
|
||||
pypinyin
|
||||
librosa
|
||||
gradio
|
||||
tqdm
|
||||
WeTextProcessing # arm机器如果安装失败,请注释此行
|
||||
wetext
|
||||
|
||||
WeTextProcessing; platform_machine != "Darwin"
|
||||
wetext; platform_system == "Darwin"
|
||||
50
setup.py
Normal file
50
setup.py
Normal file
@ -0,0 +1,50 @@
|
||||
|
||||
import platform
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
|
||||
setup(
|
||||
name="indextts",
|
||||
version="0.1.0",
|
||||
author="Index SpeechTeam",
|
||||
author_email="xuanwu@bilibili.com",
|
||||
long_description=open("README.md", encoding="utf8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
description="An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System",
|
||||
url="https://github.com/index-tts/index-tts",
|
||||
packages=find_packages(),
|
||||
include_package_data=True,
|
||||
install_requires=[
|
||||
"torch==2.6.0",
|
||||
"torchaudio",
|
||||
"transformers==4.36.2",
|
||||
"accelerate",
|
||||
"tokenizers==0.15.0",
|
||||
"einops==0.8.1",
|
||||
"matplotlib==3.8.2",
|
||||
"omegaconf",
|
||||
"sentencepiece",
|
||||
"librosa",
|
||||
"numpy",
|
||||
"wetext" if platform.system() == "Darwin" else "WeTextProcessing",
|
||||
],
|
||||
extras_require={
|
||||
"webui": ["gradio"],
|
||||
},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"indextts = indextts.cli:main",
|
||||
]
|
||||
},
|
||||
license="Apache-2.0",
|
||||
python_requires=">=3.10",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"Operating System :: OS Independent",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Topic :: Scientific/Engineering",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user