feat: gumbel_softmax_sampler

Signed-off-by: PAN <1162953505@qq.com>
This commit is contained in:
PAN 2025-11-06 13:12:11 +08:00
parent 1d5d079aaa
commit a3b884ff6f

View File

@ -19,12 +19,16 @@ class Sampler(nn.Module):
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(
torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)
).argmax(dim=-1)
return sample_tokens
temperatures = temperatures.to(logits.device).clamp(min=1e-8)
greedy_mask = temperatures < 1e-5
temp_for_scaling = torch.where(greedy_mask, 1.0, temperatures)
scaled_logits = logits / temp_for_scaling.unsqueeze(-1)
probs = torch.softmax(scaled_logits, dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
sampled_tokens = probs.div_(q).argmax(dim=-1)
greedy_tokens = logits.argmax(dim=-1)
return torch.where(greedy_mask, greedy_tokens, sampled_tokens)
class AccelInferenceEngine: