diff --git a/indextts/accel/accel_engine.py b/indextts/accel/accel_engine.py index 2addd77..51a062d 100644 --- a/indextts/accel/accel_engine.py +++ b/indextts/accel/accel_engine.py @@ -12,21 +12,6 @@ from .attention import ( ) from .kv_manager import KVCacheManager, Seq - -class Sampler(nn.Module): - def __init__(self): - super().__init__() - - @torch.compile - def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): - logits = logits.float().div_(temperatures.unsqueeze(dim=1)) - probs = torch.softmax(logits, dim=-1) - sample_tokens = probs.div_( - torch.empty_like(probs).exponential_(1).clamp_min_(1e-10) - ).argmax(dim=-1) - return sample_tokens - - class AccelInferenceEngine: def __init__( self, @@ -55,7 +40,6 @@ class AccelInferenceEngine: self.block_size = block_size self.num_blocks = num_blocks self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available() - model_dtype = next(model.parameters()).dtype self.hidden_size = ( model.config.hidden_size if hasattr(model, "config") @@ -98,14 +82,15 @@ class AccelInferenceEngine: max_seqlen_k = max(seqlen_k, max_seqlen_k) if req.block_table: - for i in range(req.num_cached_blocks, req.num_blocks): - block_id = req.block_table[i] - start = block_id * self.block_size - if i != req.num_blocks - 1: - end = start + self.block_size - else: - end = start + req.last_block_num_tokens - slot_mapping.extend(list(range(start, end))) + num_cached = req.num_cached_tokens + num_total = len(req) + + for token_idx in range(num_cached, num_total): + block_idx = token_idx // self.block_size + block_offset = token_idx % self.block_size + block_id = req.block_table[block_idx] + slot_idx = block_id * self.block_size + block_offset + slot_mapping.append(slot_idx) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda( non_blocking=True @@ -214,13 +199,12 @@ class AccelInferenceEngine: ).cuda(non_blocking=True) return temperatures - @torch.inference_mode() def _capture_cuda_graphs(self, tts_mel_embedding=None, tts_text_pos_embedding=None): print("Capturing CUDA graphs for decode optimization...") max_bs = 8 # Support up to batch size 8 max_num_blocks = (2048 + self.block_size - 1) // self.block_size model_dtype = next(self.model.parameters()).dtype - input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda") * 8192 + input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda") positions = torch.ones(max_bs, dtype=torch.int64, device="cuda") slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cuda") context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cuda") @@ -234,7 +218,7 @@ class AccelInferenceEngine: max_bs, self.hidden_size, dtype=model_dtype, device="cuda" ) - self.graph_bs = [1] + self.graph_bs = [1, 2, 4, 8] use_tts = tts_mel_embedding is not None and tts_text_pos_embedding is not None @@ -243,7 +227,7 @@ class AccelInferenceEngine: slot_mapping[:bs] = torch.arange(bs, dtype=torch.int32, device="cuda") context_lens[:bs] = bs + 1 - block_tables[:bs, 0] = 0 + block_tables[:bs, :] = 0 set_forward_context( False, @@ -306,7 +290,6 @@ class AccelInferenceEngine: } print(f"CUDA graphs captured for batch sizes: {self.graph_bs}") - @torch.inference_mode() def _run_decode_with_graph( self, input_ids: torch.Tensor, @@ -359,19 +342,13 @@ class AccelInferenceEngine: if graph_vars is None: raise RuntimeError("Graph variables not initialized") - set_forward_context( - False, - slot_mapping=graph_vars["slot_mapping"][:graph_bs], - context_lens=graph_vars["context_lens"][:graph_bs], - block_tables=graph_vars["block_tables"][:graph_bs], - ) - graph_vars["input_ids"][:bs] = input_ids graph_vars["positions"][:bs] = positions graph_vars["slot_mapping"].fill_(-1) graph_vars["slot_mapping"][:bs] = context.slot_mapping graph_vars["context_lens"].zero_() graph_vars["context_lens"][:bs] = context.context_lens + graph_vars["block_tables"][:bs, :].fill_(-1) graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = ( context.block_tables ) @@ -379,7 +356,6 @@ class AccelInferenceEngine: return graph_vars["outputs"][:bs] - @torch.inference_mode() def generate( self, input_ids: torch.Tensor, @@ -436,14 +412,26 @@ class AccelInferenceEngine: if tts_embeddings is not None: actual_seq_len = tts_embeddings.size(1) + 1 # embeddings + start_mel_token - pass else: actual_seq_len = input_ids.size(1) + is_varlen_batch = ( + tts_embeddings is not None + and attention_mask is not None + and batch_size > 1 + and (attention_mask.sum(dim=1) != attention_mask.size(1)).any() + ) + + if is_varlen_batch: + seq_lens = [attention_mask[i].sum().item() for i in range(batch_size)] + else: + seq_lens = [actual_seq_len] * batch_size + sequences = [] for i in range(batch_size): - token_ids = [1] * actual_seq_len - if tts_embeddings is not None and actual_seq_len > 0: + seq_len = seq_lens[i] + token_ids = [1] * seq_len + if tts_embeddings is not None and seq_len > 0: token_ids[-1] = input_ids[i, -1].item() if input_ids.size(1) > 0 else 1 else: token_ids = input_ids[i].tolist() @@ -453,18 +441,8 @@ class AccelInferenceEngine: self.current_sequences = sequences - # Prefill phase prefill_ids, prefill_pos = self._prepare_prefill(sequences) - if prefill_ids.dim() == 1: - prefill_ids = prefill_ids.unsqueeze( - 0 - ) # [total_tokens] -> [1, total_tokens] - if prefill_pos.dim() == 1: - prefill_pos = prefill_pos.unsqueeze( - 0 - ) # [total_tokens] -> [1, total_tokens] - if ( tts_embeddings is not None and tts_mel_embedding is not None @@ -476,11 +454,31 @@ class AccelInferenceEngine: torch.tensor([[start_token_id]], device="cuda") ) # [1, 1, hidden_dim] - start_emb = start_emb + tts_text_pos_embedding(start_emb) + start_pos = torch.tensor( + [[tts_embeddings.size(1)]], device="cuda", dtype=torch.long + ) + pos_emb = tts_text_pos_embedding.emb(start_pos) + start_emb = start_emb + pos_emb + start_emb = start_emb.repeat(batch_size, 1, 1) - full_embeddings = torch.cat( - [tts_embeddings, start_emb], dim=1 - ) # [1, 88, hidden_dim] + if is_varlen_batch: + valid_embeddings = [] + for i in range(batch_size): + emb_len = seq_lens[i] - 1 + padding_len = tts_embeddings.size(1) - emb_len + valid_emb = tts_embeddings[i, padding_len:].unsqueeze( + 0 + ) # [1, emb_len, hidden_dim] + valid_embeddings.append( + torch.cat([valid_emb, start_emb[i : i + 1]], dim=1) + ) + full_embeddings = torch.cat( + valid_embeddings, dim=1 + ) # [1, total_tokens, hidden_dim] + else: + full_embeddings = torch.cat( + [tts_embeddings, start_emb], dim=1 + ) # [batch_size, seq_len, hidden_dim] model_dtype = next(self.model.parameters()).dtype if full_embeddings.dtype != model_dtype: @@ -495,9 +493,16 @@ class AccelInferenceEngine: input_ids=input_ids, attention_mask=attention_mask, return_dict=True ).last_hidden_state - reset_forward_context() + if is_varlen_batch: + context = get_forward_context() + cu_seqlens = context.cu_seqlens_q.cpu().tolist() + last_hidden = torch.stack( + [hidden_states[0, cu_seqlens[i + 1] - 1] for i in range(batch_size)] + ) + else: + last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size] - last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size] + reset_forward_context() if self.lm_head is not None: if last_hidden.dtype != next(self.lm_head.parameters()).dtype: @@ -515,15 +520,17 @@ class AccelInferenceEngine: first_token_list = first_token.tolist() generated_tokens = [[] for _ in range(batch_size)] - hit_stop_on_first = False + is_finished = [False] * batch_size for i, token_id in enumerate(first_token_list): if stop_tokens and token_id in stop_tokens: - hit_stop_on_first = True + is_finished[i] = True else: generated_tokens[i].append(token_id) + sequences[i].append_token(token_id) + self.kv_manager.append_to_seq(sequences[i]) - if hit_stop_on_first: + if all(is_finished): for req in sequences: self.kv_manager.remove_seq(req) self.current_sequences = [] @@ -536,22 +543,11 @@ class AccelInferenceEngine: output = torch.tensor(output_ids, dtype=torch.long, device=device) return output - if not hit_stop_on_first: - for i, req in enumerate(sequences): - req.append_token(first_token_list[i]) - self.kv_manager.append_to_seq(req) - remaining_tokens = max_new_tokens - 1 for step in range(remaining_tokens): decode_ids, decode_pos = self._prepare_decode(sequences) - # Forward pass - if batch_size > 8: - raise RuntimeError( - f"FATAL: batch_size={batch_size} exceeds CUDA Graph limit (8)!" - ) - context = get_forward_context() hidden_states = self._run_decode_with_graph( decode_ids, @@ -578,32 +574,67 @@ class AccelInferenceEngine: next_token = torch.argmax(logits, dim=-1) next_token_list = next_token.tolist() - should_stop = False for i, token_id in enumerate(next_token_list): - if stop_tokens and token_id in stop_tokens: - should_stop = True + if is_finished[i]: + continue + elif stop_tokens and token_id in stop_tokens: + is_finished[i] = True else: sequences[i].append_token(token_id) self.kv_manager.append_to_seq(sequences[i]) generated_tokens[i].append(token_id) - if should_stop: + if all(is_finished): break for req in sequences: self.kv_manager.remove_seq(req) self.current_sequences = [] - output_ids = [] - for i in range(batch_size): - initial_tokens = sequences[i].token_ids[: sequences[i].num_prompt_tokens] - full_sequence = initial_tokens + generated_tokens[i] - output_ids.append(full_sequence) + pad_token = stop_tokens[0] if stop_tokens else 0 - output = torch.tensor(output_ids, dtype=torch.long, device=device) + if is_varlen_batch: + max_prompt_len = attention_mask.size(1) + output_ids = [] + + for i in range(batch_size): + padding_len = max_prompt_len - seq_lens[i] + initial_tokens = sequences[i].token_ids[ + : sequences[i].num_prompt_tokens + ] + padded_prompt = [pad_token] * padding_len + initial_tokens + full_sequence = padded_prompt + generated_tokens[i] + output_ids.append(full_sequence) + else: + output_ids = [ + sequences[i].token_ids[: sequences[i].num_prompt_tokens] + + generated_tokens[i] + for i in range(batch_size) + ] + + max_length = max(len(seq) for seq in output_ids) + padded_output_ids = [ + seq + [pad_token] * (max_length - len(seq)) for seq in output_ids + ] + + output = torch.tensor(padded_output_ids, dtype=torch.long, device=device) assert output.size(0) == batch_size, ( f"Output batch size mismatch: {output.size(0)} != {batch_size}" ) return output + + +class Sampler(nn.Module): + def __init__(self): + super().__init__() + + @torch.compile + def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): + logits = logits.float().div_(temperatures.unsqueeze(dim=1)) + probs = torch.softmax(logits, dim=-1) + sample_tokens = probs.div_( + torch.empty_like(probs).exponential_(1).clamp_min_(1e-10) + ).argmax(dim=-1) + return sample_tokens \ No newline at end of file