From d3bd7eb8b218fa7eb5b1ea94fba8c10e6b8a866a Mon Sep 17 00:00:00 2001 From: yrom Date: Thu, 24 Apr 2025 23:40:49 +0800 Subject: [PATCH] Fix split_sentences_by_token --- indextts/utils/front.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/indextts/utils/front.py b/indextts/utils/front.py index 5b45a40..d7ff8c7 100644 --- a/indextts/utils/front.py +++ b/indextts/utils/front.py @@ -282,14 +282,12 @@ class TextTokenizer: return vocab @overload - def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ... + def convert_ids_to_tokens(self, ids: int) -> str: ... @overload - def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: ... + def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: ... def convert_ids_to_tokens(self, ids: Union[List[int], int]): - if isinstance(ids, int): - ids = [ids] return self.sp_model.IdToPiece(ids) def convert_tokens_to_ids(self, tokens: Union[List[str], str]) -> List[int]: @@ -301,6 +299,10 @@ class TextTokenizer: return self.encode(text, out_type=str) def encode(self, text: str, **kwargs): + if len(text) == 0: + return [] + if len(text.strip()) == 1: + return self.sp_model.Encode(text, out_type=kwargs.pop("out_type", int), **kwargs) # 预处理 if self.normalizer: text = self.normalizer.normalize(text) @@ -358,7 +360,7 @@ class TextTokenizer: sub_sentences = TextTokenizer.split_sentences_by_token( current_sentence, [",", "▁,"], max_tokens_per_sentence=max_tokens_per_sentence ) - elif "-" in current_sentence or "": + elif "-" in current_sentence: # 没有,,则按-分割 sub_sentences = TextTokenizer.split_sentences_by_token( current_sentence, ["-"], max_tokens_per_sentence=max_tokens_per_sentence @@ -411,7 +413,7 @@ if __name__ == "__main__": # 测试程序 text_normalizer = TextNormalizer() - text_normalizer.load() + cases = [ "IndexTTS 正式发布1.0版本了,效果666", "晕XUAN4是一种GAN3觉", @@ -468,9 +470,11 @@ if __name__ == "__main__": tokens = tokenizer.convert_tokens_to_ids(t) if tokenizer.unk_token_id in tokens: print(f"Warning: {t} is unknown token") - print(t, "->", tokens) + print(f"`{t}`", "->", tokens, "->", tokenizer.convert_ids_to_tokens(tokens)) for ch in set(tokenizer.normalizer.zh_char_rep_map.values()): - print(ch, "->", tokenizer.encode(ch, out_type=str)) + # 测试 normalize后的字符能被分词器识别 + print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str)) + print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str)) for i in range(len(cases)): print(f"原始文本: {cases[i]}") print(f"Normalized: {text_normalizer.normalize(cases[i])}")