padding_test.py support model dir for test

This commit is contained in:
yrom 2025-05-18 19:57:11 +08:00
parent 76e7645a8d
commit c178198ed7

View File

@ -5,10 +5,23 @@ from indextts.utils.feature_extractors import MelSpectrogramFeatures
from torch.nn import functional as F
if __name__ == "__main__":
"""
Test the padding of text tokens in inference.
```
python tests/padding_test.py checkpoints
python tests/padding_test.py IndexTTS-1.5
```
"""
import transformers
transformers.set_seed(42)
import sys
sys.path.append("..")
if len(sys.argv) > 1:
model_dir = sys.argv[1]
else:
model_dir = "checkpoints"
audio_prompt="tests/sample_prompt.wav"
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False)
tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False)
text = "晕 XUAN4 是 一 种 not very good GAN3 觉"
text_tokens = tts.tokenizer.encode(text)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L]