diff --git a/indextts/gpt/model.py b/indextts/gpt/model.py index 6f98a45..4e64660 100644 --- a/indextts/gpt/model.py +++ b/indextts/gpt/model.py @@ -388,7 +388,7 @@ class UnifiedVoice(nn.Module): def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False): seq_length = self.max_mel_tokens + self.max_text_tokens + 2 gpt_config = GPT2Config( - vocab_size=self.max_mel_tokens, + vocab_size=self.number_mel_codes, n_positions=seq_length, n_ctx=seq_length, n_embd=self.model_dim, @@ -687,7 +687,13 @@ class UnifiedVoice(nn.Module): inputs = torch.cat([input_ids, input_tokens], dim=1) attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1) trunc_index = inputs.shape[1] - logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() + logits_processor = LogitsProcessorList() + if typical_sampling: + # employ custom typical sampling + if not (typical_mass > 0.0 and typical_mass < 1.0): + raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}") + min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1 + logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep)) max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length output = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, diff --git a/indextts/utils/typical_sampling.py b/indextts/utils/typical_sampling.py index c982463..0b225e9 100644 --- a/indextts/utils/typical_sampling.py +++ b/indextts/utils/typical_sampling.py @@ -1,12 +1,9 @@ import torch -from transformers import LogitsWarper +from transformers import TypicalLogitsWarper as BaseTypicalLogitsWarper - -class TypicalLogitsWarper(LogitsWarper): +class TypicalLogitsWarper(BaseTypicalLogitsWarper): def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): - self.filter_value = filter_value - self.mass = mass - self.min_tokens_to_keep = min_tokens_to_keep + super().__init__(mass=mass, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # calculate entropy