diff --git a/indextts/accel/accel_engine.py b/indextts/accel/accel_engine.py index 51a062d..5aa7533 100644 --- a/indextts/accel/accel_engine.py +++ b/indextts/accel/accel_engine.py @@ -12,6 +12,25 @@ from .attention import ( ) from .kv_manager import KVCacheManager, Seq + +class Sampler(nn.Module): + def __init__(self): + super().__init__() + + @torch.compile + def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): + temperatures = temperatures.to(logits.device).clamp(min=1e-8) + greedy_mask = temperatures < 1e-5 + temp_for_scaling = torch.where(greedy_mask, 1.0, temperatures) + scaled_logits = logits / temp_for_scaling.unsqueeze(-1) + probs = torch.softmax(scaled_logits, dim=-1, dtype=torch.float32) + q = torch.empty_like(probs) + q.exponential_() + sampled_tokens = probs.div_(q).argmax(dim=-1) + greedy_tokens = logits.argmax(dim=-1) + return torch.where(greedy_mask, greedy_tokens, sampled_tokens) + + class AccelInferenceEngine: def __init__( self,