优化文本掩码填充逻辑,改进句子桶化处理
This commit is contained in:
parent
4de7611bda
commit
a50cb8c287
@ -556,7 +556,6 @@ class UnifiedVoice(nn.Module):
|
||||
# mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
|
||||
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
|
||||
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
|
||||
|
||||
text_inputs = self.set_text_padding(text_inputs, text_lengths)
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
|
||||
@ -588,14 +587,48 @@ class UnifiedVoice(nn.Module):
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
@staticmethod
|
||||
def pad_text_masks(text_masks):
|
||||
"""
|
||||
pad masks to [b, t+2]. e.g.:
|
||||
```
|
||||
[
|
||||
[1, 1, 1, 1, 1] -> [1, 1, 1, 1, 1, 1, 1]
|
||||
[1, 1, 1, 1, 0] -> [1, 1, 1, 1, 1, 1, 0]
|
||||
[0, 0, 0, 0, 0] -> [0, 0, 0, 0, 0, 1, 1]
|
||||
[1, 0, 0, 0, 0] -> [1, 1, 1, 0, 0, 0, 0]
|
||||
[0, 0, 1, 1, 1] -> [0, 0, 1, 1, 1, 1, 1]
|
||||
]
|
||||
```
|
||||
text_masks: [b, t]
|
||||
padded_masks: [b, t+2]
|
||||
"""
|
||||
b, t = text_masks.size()
|
||||
padded = torch.nn.functional.pad(text_masks, (1, 1), value=0) # [b, t+2]
|
||||
padded = torch.ones_like(padded, dtype=text_masks.dtype, device=text_masks.device)
|
||||
|
||||
# Find the first and last non-zero index
|
||||
# and set the values before the first index and after the last index to 0
|
||||
nonzero_mask = text_masks != 0
|
||||
has_nonzero = nonzero_mask.any(dim=1)
|
||||
first_idx = nonzero_mask.float().argmax(dim=1)
|
||||
rev_mask = nonzero_mask.flip(dims=[1])
|
||||
last_idx = t - rev_mask.float().argmax(dim=1)
|
||||
first_idx = first_idx.reshape(b, 1)
|
||||
last_idx = last_idx.reshape(b, 1)
|
||||
col_idx = torch.arange(t+2, device=text_masks.device).unsqueeze(0).expand(b, -1) # [b, t_2]
|
||||
padded[col_idx < first_idx] = 0
|
||||
padded[col_idx > last_idx+1] = 0
|
||||
# all zeros
|
||||
padded[~has_nonzero, :-2] = 0
|
||||
return padded
|
||||
def inference_speech(self, speech_conditioning_latent, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
|
||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
||||
|
||||
text_masks = ((text_inputs != self.stop_text_token) & (text_inputs != self.start_text_token)).long()
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_masks = F.pad(text_masks, (1, 1), value=1) # (-1, +1)
|
||||
text_masks = UnifiedVoice.pad_text_masks(text_masks) # [b, t+2]
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
|
||||
conds = speech_conditioning_latent
|
||||
|
||||
@ -188,38 +188,59 @@ class IndexTTS:
|
||||
Sentence data bucketing.
|
||||
if ``bucket_max_size=1``, return all sentences in one bucket.
|
||||
"""
|
||||
outputs: List = []
|
||||
outputs: List[Dict] = []
|
||||
for idx, sent in enumerate(sentences):
|
||||
outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
|
||||
|
||||
if len(outputs) > bucket_max_size:
|
||||
# split sentences into buckets by sentence length
|
||||
buckets = []
|
||||
buckets: List[List[Dict]] = []
|
||||
factor = 1.5
|
||||
last_bucket = None
|
||||
last_bucket_sent_len_median = 0
|
||||
|
||||
for sent in sorted(outputs, key=lambda x: x["len"]):
|
||||
current_sent_len = sent["len"]
|
||||
if current_sent_len == 0:
|
||||
print(">> skip empty sentence")
|
||||
continue
|
||||
if last_bucket_sent_len_median == 0 \
|
||||
or current_sent_len > last_bucket_sent_len_median * factor \
|
||||
or len(buckets[-1]) > bucket_max_size:
|
||||
if last_bucket is None \
|
||||
or current_sent_len >= int(last_bucket_sent_len_median * factor) \
|
||||
or len(last_bucket) >= bucket_max_size:
|
||||
# new bucket
|
||||
buckets.append([sent])
|
||||
last_bucket = buckets[-1]
|
||||
last_bucket_sent_len_median = current_sent_len
|
||||
else:
|
||||
# current bucket can hold more sentences
|
||||
buckets[-1].append(sent) # sorted
|
||||
mid = len(buckets[-1]) // 2
|
||||
last_bucket_sent_len_median = buckets[-1][mid]["len"]
|
||||
return buckets
|
||||
last_bucket.append(sent) # sorted
|
||||
mid = len(last_bucket) // 2
|
||||
last_bucket_sent_len_median = last_bucket[mid]["len"]
|
||||
last_bucket=None
|
||||
# merge all buckets with size 1
|
||||
out_buckets: List[List[Dict]] = []
|
||||
only_ones: List[Dict] = []
|
||||
for b in buckets:
|
||||
if len(b) == 1:
|
||||
only_ones.append(b[0])
|
||||
else:
|
||||
out_buckets.append(b)
|
||||
if len(only_ones) > 0:
|
||||
# merge into previous buckets if possible
|
||||
# print("only_ones:", [(o["idx"], o["len"]) for o in only_ones])
|
||||
for i in range(len(out_buckets)):
|
||||
b = out_buckets[i]
|
||||
if len(b) < bucket_max_size:
|
||||
b.append(only_ones.pop(0))
|
||||
if len(only_ones) == 0:
|
||||
break
|
||||
# combined all remaining sized 1 buckets
|
||||
if len(only_ones) > 0:
|
||||
out_buckets.extend([only_ones[i:i+bucket_max_size] for i in range(0, len(only_ones), bucket_max_size)])
|
||||
return out_buckets
|
||||
return [outputs]
|
||||
|
||||
|
||||
def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor:
|
||||
if len(tokens) <= 1:
|
||||
return tokens[-1]
|
||||
if self.model_version and self.model_version >= 1.5:
|
||||
# 1.5版本以上,直接使用stop_text_token 右侧填充
|
||||
# [1, N] -> [N,]
|
||||
@ -324,7 +345,9 @@ class IndexTTS:
|
||||
all_sentences = self.bucket_sentences(sentences, bucket_max_size=bucket_max_size)
|
||||
bucket_count = len(all_sentences)
|
||||
if verbose:
|
||||
print(">> sentences bucket_count:", bucket_count, [len(s) for s in all_sentences])
|
||||
print(">> sentences bucket_count:", bucket_count,
|
||||
"bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_sentences],
|
||||
"bucket_max_size:", bucket_max_size)
|
||||
for sentences in all_sentences:
|
||||
temp_tokens: List[torch.Tensor] = []
|
||||
all_text_tokens.append(temp_tokens)
|
||||
@ -346,9 +369,14 @@ class IndexTTS:
|
||||
all_batch_codes = []
|
||||
for item_tokens in all_text_tokens:
|
||||
batch_num = len(item_tokens)
|
||||
batch_text_tokens = self.pad_tokens_cat(item_tokens)
|
||||
batch_cond_mel_lengths = cond_mel_lengths.expand(batch_num) # [batch_num]
|
||||
batch_auto_conditioning = auto_conditioning.expand(batch_num, -1, -1) # [batch_num, n_mels, L]
|
||||
if batch_num > 1:
|
||||
batch_text_tokens = self.pad_tokens_cat(item_tokens)
|
||||
batch_cond_mel_lengths = cond_mel_lengths.expand(batch_num) # [batch_num]
|
||||
batch_auto_conditioning = auto_conditioning.expand(batch_num, -1, -1) # [batch_num, n_mels, L]
|
||||
else:
|
||||
batch_text_tokens = item_tokens[0]
|
||||
batch_cond_mel_lengths = cond_mel_lengths
|
||||
batch_auto_conditioning = auto_conditioning
|
||||
all_batch_num += batch_num
|
||||
|
||||
# gpt speech
|
||||
@ -363,9 +391,9 @@ class IndexTTS:
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=autoregressive_batch_size,
|
||||
num_return_sequences=autoregressive_batch_size if batch_num == 1 else 1,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
num_beams=num_beams if batch_num == 1 else 1,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens)
|
||||
all_batch_codes.append(temp_codes)
|
||||
@ -406,7 +434,7 @@ class IndexTTS:
|
||||
all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
|
||||
if verbose:
|
||||
print(">> all_latents:", len(all_latents))
|
||||
print(*[l.shape for l in all_latents], sep=", ")
|
||||
print(" latents length:", [l.shape[1] for l in all_latents])
|
||||
chunk_latents = [all_latents[i : i + chunk_size] for i in range(0, len(all_latents), chunk_size)]
|
||||
chunk_length = len(chunk_latents)
|
||||
latent_length = len(all_latents)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user