From a3b884ff6ff401923e6242fb87630be0a6d0f221 Mon Sep 17 00:00:00 2001 From: PAN <1162953505@qq.com> Date: Thu, 6 Nov 2025 13:12:11 +0800 Subject: [PATCH] feat: gumbel_softmax_sampler Signed-off-by: PAN <1162953505@qq.com> --- indextts/accel/accel_engine.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/indextts/accel/accel_engine.py b/indextts/accel/accel_engine.py index 2addd77..843fd24 100644 --- a/indextts/accel/accel_engine.py +++ b/indextts/accel/accel_engine.py @@ -19,12 +19,16 @@ class Sampler(nn.Module): @torch.compile def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): - logits = logits.float().div_(temperatures.unsqueeze(dim=1)) - probs = torch.softmax(logits, dim=-1) - sample_tokens = probs.div_( - torch.empty_like(probs).exponential_(1).clamp_min_(1e-10) - ).argmax(dim=-1) - return sample_tokens + temperatures = temperatures.to(logits.device).clamp(min=1e-8) + greedy_mask = temperatures < 1e-5 + temp_for_scaling = torch.where(greedy_mask, 1.0, temperatures) + scaled_logits = logits / temp_for_scaling.unsqueeze(-1) + probs = torch.softmax(scaled_logits, dim=-1, dtype=torch.float32) + q = torch.empty_like(probs) + q.exponential_() + sampled_tokens = probs.div_(q).argmax(dim=-1) + greedy_tokens = logits.argmax(dim=-1) + return torch.where(greedy_mask, greedy_tokens, sampled_tokens) class AccelInferenceEngine: