diff --git a/tests/padding_test.py b/tests/padding_test.py index 9fe418b..fcb67d0 100644 --- a/tests/padding_test.py +++ b/tests/padding_test.py @@ -5,10 +5,23 @@ from indextts.utils.feature_extractors import MelSpectrogramFeatures from torch.nn import functional as F if __name__ == "__main__": + """ + Test the padding of text tokens in inference. + ``` + python tests/padding_test.py checkpoints + python tests/padding_test.py IndexTTS-1.5 + ``` + """ import transformers transformers.set_seed(42) + import sys + sys.path.append("..") + if len(sys.argv) > 1: + model_dir = sys.argv[1] + else: + model_dir = "checkpoints" audio_prompt="tests/sample_prompt.wav" - tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False) + tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False) text = "晕 XUAN4 是 一 种 not very good GAN3 觉" text_tokens = tts.tokenizer.encode(text) text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L]