padding_test.py support model dir for test
This commit is contained in:
parent
76e7645a8d
commit
c178198ed7
@ -5,10 +5,23 @@ from indextts.utils.feature_extractors import MelSpectrogramFeatures
|
||||
from torch.nn import functional as F
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Test the padding of text tokens in inference.
|
||||
```
|
||||
python tests/padding_test.py checkpoints
|
||||
python tests/padding_test.py IndexTTS-1.5
|
||||
```
|
||||
"""
|
||||
import transformers
|
||||
transformers.set_seed(42)
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
if len(sys.argv) > 1:
|
||||
model_dir = sys.argv[1]
|
||||
else:
|
||||
model_dir = "checkpoints"
|
||||
audio_prompt="tests/sample_prompt.wav"
|
||||
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False)
|
||||
tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False)
|
||||
text = "晕 XUAN4 是 一 种 not very good GAN3 觉"
|
||||
text_tokens = tts.tokenizer.encode(text)
|
||||
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user