优化文本掩码填充逻辑,改进句子桶化处理

This commit is contained in:
yrom 2025-05-17 20:59:07 +08:00
parent 4de7611bda
commit a50cb8c287
2 changed files with 82 additions and 21 deletions

View File

@ -556,7 +556,6 @@ class UnifiedVoice(nn.Module):
# mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
text_inputs = self.set_text_padding(text_inputs, text_lengths)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
@ -588,14 +587,48 @@ class UnifiedVoice(nn.Module):
loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
@staticmethod
def pad_text_masks(text_masks):
"""
pad masks to [b, t+2]. e.g.:
```
[
[1, 1, 1, 1, 1] -> [1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 0] -> [1, 1, 1, 1, 1, 1, 0]
[0, 0, 0, 0, 0] -> [0, 0, 0, 0, 0, 1, 1]
[1, 0, 0, 0, 0] -> [1, 1, 1, 0, 0, 0, 0]
[0, 0, 1, 1, 1] -> [0, 0, 1, 1, 1, 1, 1]
]
```
text_masks: [b, t]
padded_masks: [b, t+2]
"""
b, t = text_masks.size()
padded = torch.nn.functional.pad(text_masks, (1, 1), value=0) # [b, t+2]
padded = torch.ones_like(padded, dtype=text_masks.dtype, device=text_masks.device)
# Find the first and last non-zero index
# and set the values before the first index and after the last index to 0
nonzero_mask = text_masks != 0
has_nonzero = nonzero_mask.any(dim=1)
first_idx = nonzero_mask.float().argmax(dim=1)
rev_mask = nonzero_mask.flip(dims=[1])
last_idx = t - rev_mask.float().argmax(dim=1)
first_idx = first_idx.reshape(b, 1)
last_idx = last_idx.reshape(b, 1)
col_idx = torch.arange(t+2, device=text_masks.device).unsqueeze(0).expand(b, -1) # [b, t_2]
padded[col_idx < first_idx] = 0
padded[col_idx > last_idx+1] = 0
# all zeros
padded[~has_nonzero, :-2] = 0
return padded
def inference_speech(self, speech_conditioning_latent, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
text_masks = ((text_inputs != self.stop_text_token) & (text_inputs != self.start_text_token)).long()
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_masks = F.pad(text_masks, (1, 1), value=1) # (-1, +1)
text_masks = UnifiedVoice.pad_text_masks(text_masks) # [b, t+2]
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
conds = speech_conditioning_latent

View File

@ -188,38 +188,59 @@ class IndexTTS:
Sentence data bucketing.
if ``bucket_max_size=1``, return all sentences in one bucket.
"""
outputs: List = []
outputs: List[Dict] = []
for idx, sent in enumerate(sentences):
outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
if len(outputs) > bucket_max_size:
# split sentences into buckets by sentence length
buckets = []
buckets: List[List[Dict]] = []
factor = 1.5
last_bucket = None
last_bucket_sent_len_median = 0
for sent in sorted(outputs, key=lambda x: x["len"]):
current_sent_len = sent["len"]
if current_sent_len == 0:
print(">> skip empty sentence")
continue
if last_bucket_sent_len_median == 0 \
or current_sent_len > last_bucket_sent_len_median * factor \
or len(buckets[-1]) > bucket_max_size:
if last_bucket is None \
or current_sent_len >= int(last_bucket_sent_len_median * factor) \
or len(last_bucket) >= bucket_max_size:
# new bucket
buckets.append([sent])
last_bucket = buckets[-1]
last_bucket_sent_len_median = current_sent_len
else:
# current bucket can hold more sentences
buckets[-1].append(sent) # sorted
mid = len(buckets[-1]) // 2
last_bucket_sent_len_median = buckets[-1][mid]["len"]
return buckets
last_bucket.append(sent) # sorted
mid = len(last_bucket) // 2
last_bucket_sent_len_median = last_bucket[mid]["len"]
last_bucket=None
# merge all buckets with size 1
out_buckets: List[List[Dict]] = []
only_ones: List[Dict] = []
for b in buckets:
if len(b) == 1:
only_ones.append(b[0])
else:
out_buckets.append(b)
if len(only_ones) > 0:
# merge into previous buckets if possible
# print("only_ones:", [(o["idx"], o["len"]) for o in only_ones])
for i in range(len(out_buckets)):
b = out_buckets[i]
if len(b) < bucket_max_size:
b.append(only_ones.pop(0))
if len(only_ones) == 0:
break
# combined all remaining sized 1 buckets
if len(only_ones) > 0:
out_buckets.extend([only_ones[i:i+bucket_max_size] for i in range(0, len(only_ones), bucket_max_size)])
return out_buckets
return [outputs]
def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor:
if len(tokens) <= 1:
return tokens[-1]
if self.model_version and self.model_version >= 1.5:
# 1.5版本以上直接使用stop_text_token 右侧填充
# [1, N] -> [N,]
@ -324,7 +345,9 @@ class IndexTTS:
all_sentences = self.bucket_sentences(sentences, bucket_max_size=bucket_max_size)
bucket_count = len(all_sentences)
if verbose:
print(">> sentences bucket_count:", bucket_count, [len(s) for s in all_sentences])
print(">> sentences bucket_count:", bucket_count,
"bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_sentences],
"bucket_max_size:", bucket_max_size)
for sentences in all_sentences:
temp_tokens: List[torch.Tensor] = []
all_text_tokens.append(temp_tokens)
@ -346,9 +369,14 @@ class IndexTTS:
all_batch_codes = []
for item_tokens in all_text_tokens:
batch_num = len(item_tokens)
batch_text_tokens = self.pad_tokens_cat(item_tokens)
batch_cond_mel_lengths = cond_mel_lengths.expand(batch_num) # [batch_num]
batch_auto_conditioning = auto_conditioning.expand(batch_num, -1, -1) # [batch_num, n_mels, L]
if batch_num > 1:
batch_text_tokens = self.pad_tokens_cat(item_tokens)
batch_cond_mel_lengths = cond_mel_lengths.expand(batch_num) # [batch_num]
batch_auto_conditioning = auto_conditioning.expand(batch_num, -1, -1) # [batch_num, n_mels, L]
else:
batch_text_tokens = item_tokens[0]
batch_cond_mel_lengths = cond_mel_lengths
batch_auto_conditioning = auto_conditioning
all_batch_num += batch_num
# gpt speech
@ -363,9 +391,9 @@ class IndexTTS:
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=autoregressive_batch_size,
num_return_sequences=autoregressive_batch_size if batch_num == 1 else 1,
length_penalty=length_penalty,
num_beams=num_beams,
num_beams=num_beams if batch_num == 1 else 1,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
all_batch_codes.append(temp_codes)
@ -406,7 +434,7 @@ class IndexTTS:
all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
if verbose:
print(">> all_latents:", len(all_latents))
print(*[l.shape for l in all_latents], sep=", ")
print(" latents length:", [l.shape[1] for l in all_latents])
chunk_latents = [all_latents[i : i + chunk_size] for i in range(0, len(all_latents), chunk_size)]
chunk_length = len(chunk_latents)
latent_length = len(all_latents)