Fix split_sentences_by_token
This commit is contained in:
parent
475fb12574
commit
d3bd7eb8b2
@ -282,14 +282,12 @@ class TextTokenizer:
|
||||
return vocab
|
||||
|
||||
@overload
|
||||
def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
|
||||
def convert_ids_to_tokens(self, ids: int) -> str: ...
|
||||
|
||||
@overload
|
||||
def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: ...
|
||||
def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: ...
|
||||
|
||||
def convert_ids_to_tokens(self, ids: Union[List[int], int]):
|
||||
if isinstance(ids, int):
|
||||
ids = [ids]
|
||||
return self.sp_model.IdToPiece(ids)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens: Union[List[str], str]) -> List[int]:
|
||||
@ -301,6 +299,10 @@ class TextTokenizer:
|
||||
return self.encode(text, out_type=str)
|
||||
|
||||
def encode(self, text: str, **kwargs):
|
||||
if len(text) == 0:
|
||||
return []
|
||||
if len(text.strip()) == 1:
|
||||
return self.sp_model.Encode(text, out_type=kwargs.pop("out_type", int), **kwargs)
|
||||
# 预处理
|
||||
if self.normalizer:
|
||||
text = self.normalizer.normalize(text)
|
||||
@ -358,7 +360,7 @@ class TextTokenizer:
|
||||
sub_sentences = TextTokenizer.split_sentences_by_token(
|
||||
current_sentence, [",", "▁,"], max_tokens_per_sentence=max_tokens_per_sentence
|
||||
)
|
||||
elif "-" in current_sentence or "":
|
||||
elif "-" in current_sentence:
|
||||
# 没有,,则按-分割
|
||||
sub_sentences = TextTokenizer.split_sentences_by_token(
|
||||
current_sentence, ["-"], max_tokens_per_sentence=max_tokens_per_sentence
|
||||
@ -411,7 +413,7 @@ if __name__ == "__main__":
|
||||
# 测试程序
|
||||
|
||||
text_normalizer = TextNormalizer()
|
||||
text_normalizer.load()
|
||||
|
||||
cases = [
|
||||
"IndexTTS 正式发布1.0版本了,效果666",
|
||||
"晕XUAN4是一种GAN3觉",
|
||||
@ -468,9 +470,11 @@ if __name__ == "__main__":
|
||||
tokens = tokenizer.convert_tokens_to_ids(t)
|
||||
if tokenizer.unk_token_id in tokens:
|
||||
print(f"Warning: {t} is unknown token")
|
||||
print(t, "->", tokens)
|
||||
print(f"`{t}`", "->", tokens, "->", tokenizer.convert_ids_to_tokens(tokens))
|
||||
for ch in set(tokenizer.normalizer.zh_char_rep_map.values()):
|
||||
print(ch, "->", tokenizer.encode(ch, out_type=str))
|
||||
# 测试 normalize后的字符能被分词器识别
|
||||
print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str))
|
||||
print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str))
|
||||
for i in range(len(cases)):
|
||||
print(f"原始文本: {cases[i]}")
|
||||
print(f"Normalized: {text_normalizer.normalize(cases[i])}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user