适配新版本transformers

This commit is contained in:
yrom 2025-05-18 16:19:28 +08:00
parent 22eeb7625f
commit 1b7529cacd
2 changed files with 11 additions and 8 deletions

View File

@ -388,7 +388,7 @@ class UnifiedVoice(nn.Module):
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
vocab_size=self.number_mel_codes,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
@ -687,7 +687,13 @@ class UnifiedVoice(nn.Module):
inputs = torch.cat([input_ids, input_tokens], dim=1)
attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1)
trunc_index = inputs.shape[1]
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
logits_processor = LogitsProcessorList()
if typical_sampling:
# employ custom typical sampling
if not (typical_mass > 0.0 and typical_mass < 1.0):
raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}")
min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
output = self.inference_model.generate(inputs,
bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,

View File

@ -1,12 +1,9 @@
import torch
from transformers import LogitsWarper
from transformers import TypicalLogitsWarper as BaseTypicalLogitsWarper
class TypicalLogitsWarper(LogitsWarper):
class TypicalLogitsWarper(BaseTypicalLogitsWarper):
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
self.filter_value = filter_value
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep
super().__init__(mass=mass, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy