Merge pull request #537 from pandalee99/perf/gumbel_softmax_sampler

feat(sampler): enhance with greedy sampling mode
This commit is contained in:
Vanka0051 2025-11-07 15:32:31 +08:00 committed by GitHub
commit 42a73394e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,25 @@ from .attention import (
)
from .kv_manager import KVCacheManager, Seq
class Sampler(nn.Module):
def __init__(self):
super().__init__()
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
temperatures = temperatures.to(logits.device).clamp(min=1e-8)
greedy_mask = temperatures < 1e-5
temp_for_scaling = torch.where(greedy_mask, 1.0, temperatures)
scaled_logits = logits / temp_for_scaling.unsqueeze(-1)
probs = torch.softmax(scaled_logits, dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
sampled_tokens = probs.div_(q).argmax(dim=-1)
greedy_tokens = logits.argmax(dim=-1)
return torch.where(greedy_mask, greedy_tokens, sampled_tokens)
class AccelInferenceEngine:
def __init__(
self,