Fix split_sentences_by_token

This commit is contained in:
yrom 2025-04-24 23:40:49 +08:00
parent 475fb12574
commit d3bd7eb8b2

View File

@ -282,14 +282,12 @@ class TextTokenizer:
return vocab
@overload
def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
def convert_ids_to_tokens(self, ids: int) -> str: ...
@overload
def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: ...
def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: ...
def convert_ids_to_tokens(self, ids: Union[List[int], int]):
if isinstance(ids, int):
ids = [ids]
return self.sp_model.IdToPiece(ids)
def convert_tokens_to_ids(self, tokens: Union[List[str], str]) -> List[int]:
@ -301,6 +299,10 @@ class TextTokenizer:
return self.encode(text, out_type=str)
def encode(self, text: str, **kwargs):
if len(text) == 0:
return []
if len(text.strip()) == 1:
return self.sp_model.Encode(text, out_type=kwargs.pop("out_type", int), **kwargs)
# 预处理
if self.normalizer:
text = self.normalizer.normalize(text)
@ -358,7 +360,7 @@ class TextTokenizer:
sub_sentences = TextTokenizer.split_sentences_by_token(
current_sentence, [",", "▁,"], max_tokens_per_sentence=max_tokens_per_sentence
)
elif "-" in current_sentence or "":
elif "-" in current_sentence:
# 没有,,则按-分割
sub_sentences = TextTokenizer.split_sentences_by_token(
current_sentence, ["-"], max_tokens_per_sentence=max_tokens_per_sentence
@ -411,7 +413,7 @@ if __name__ == "__main__":
# 测试程序
text_normalizer = TextNormalizer()
text_normalizer.load()
cases = [
"IndexTTS 正式发布1.0版本了效果666",
"晕XUAN4是一种GAN3觉",
@ -468,9 +470,11 @@ if __name__ == "__main__":
tokens = tokenizer.convert_tokens_to_ids(t)
if tokenizer.unk_token_id in tokens:
print(f"Warning: {t} is unknown token")
print(t, "->", tokens)
print(f"`{t}`", "->", tokens, "->", tokenizer.convert_ids_to_tokens(tokens))
for ch in set(tokenizer.normalizer.zh_char_rep_map.values()):
print(ch, "->", tokenizer.encode(ch, out_type=str))
# 测试 normalize后的字符能被分词器识别
print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str))
print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str))
for i in range(len(cases)):
print(f"原始文本: {cases[i]}")
print(f"Normalized: {text_normalizer.normalize(cases[i])}")