diff --git a/indextts/accel/__init__.py b/indextts/accel/__init__.py new file mode 100644 index 0000000..8f6df0a --- /dev/null +++ b/indextts/accel/__init__.py @@ -0,0 +1,9 @@ +from .accel_engine import AccelInferenceEngine # noqa: F401 +from .attention import ( # noqa: F401 + Attention, + get_forward_context, + reset_forward_context, + set_forward_context, +) +from .gpt2_accel import GPT2AccelAttention, GPT2AccelModel # noqa: F401 +from .kv_manager import KVCacheManager, Seq # noqa: F401 diff --git a/indextts/accel/accel_engine.py b/indextts/accel/accel_engine.py new file mode 100644 index 0000000..2addd77 --- /dev/null +++ b/indextts/accel/accel_engine.py @@ -0,0 +1,609 @@ +import sys +from typing import List, Optional + +import torch +from torch import nn + +from .attention import ( + ForwardContext, + get_forward_context, + reset_forward_context, + set_forward_context, +) +from .kv_manager import KVCacheManager, Seq + + +class Sampler(nn.Module): + def __init__(self): + super().__init__() + + @torch.compile + def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): + logits = logits.float().div_(temperatures.unsqueeze(dim=1)) + probs = torch.softmax(logits, dim=-1) + sample_tokens = probs.div_( + torch.empty_like(probs).exponential_(1).clamp_min_(1e-10) + ).argmax(dim=-1) + return sample_tokens + + +class AccelInferenceEngine: + def __init__( + self, + model, + lm_head, + num_layers: int, + num_heads: int, + head_dim: int, + block_size: int = 256, + num_blocks: int = 128, + use_cuda_graph: bool = True, + ): + """ + Args: + model: The GPT transformer model (should have accel attention) + lm_head: Language model head for generating logits + num_layers: Number of transformer layers + num_heads: Number of attention heads + head_dim: Dimension per head + block_size: KV cache block size + num_blocks: Total number of KV cache blocks + use_cuda_graph: Whether to use CUDA Graph for decode optimization + """ + self.model = model + self.lm_head = lm_head + self.block_size = block_size + self.num_blocks = num_blocks + self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available() + model_dtype = next(model.parameters()).dtype + self.hidden_size = ( + model.config.hidden_size + if hasattr(model, "config") + else head_dim * num_heads + ) + self.kv_manager = KVCacheManager( + num_layers=num_layers, + num_heads=num_heads, + head_dim=head_dim, + block_size=block_size, + num_blocks=num_blocks, + dtype=torch.float16, # Force fp16 for FlashAttention + ) + self.kv_manager.wire_kv_cache_to_model(model) + self.sampler = Sampler() + self.current_sequences = [] + self.graphs = {} + self.graph_vars = None + self.graph_pool = None + self.graph_captured = False + + def _prepare_prefill(self, requests: List[Seq]): + input_ids = [] + positions = [] + cu_seqlens_q = [0] + cu_seqlens_k = [0] + max_seqlen_q = 0 + max_seqlen_k = 0 + slot_mapping = [] + + for req in requests: + seqlen = len(req) + input_ids.extend(req[req.num_cached_tokens :]) + positions.extend(list(range(req.num_cached_tokens, seqlen))) + seqlen_q = seqlen - req.num_cached_tokens + seqlen_k = seqlen + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) + cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) + max_seqlen_q = max(seqlen_q, max_seqlen_q) + max_seqlen_k = max(seqlen_k, max_seqlen_k) + + if req.block_table: + for i in range(req.num_cached_blocks, req.num_blocks): + block_id = req.block_table[i] + start = block_id * self.block_size + if i != req.num_blocks - 1: + end = start + self.block_size + else: + end = start + req.last_block_num_tokens + slot_mapping.extend(list(range(start, end))) + + input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda( + non_blocking=True + ) + positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda( + non_blocking=True + ) + cu_seqlens_q = torch.tensor( + cu_seqlens_q, dtype=torch.int32, pin_memory=True + ).cuda(non_blocking=True) + cu_seqlens_k = torch.tensor( + cu_seqlens_k, dtype=torch.int32, pin_memory=True + ).cuda(non_blocking=True) + slot_mapping = torch.tensor( + slot_mapping, dtype=torch.int32, pin_memory=True + ).cuda(non_blocking=True) + + block_tables = None + if cu_seqlens_k[-1] > cu_seqlens_q[-1]: + max_len = max(len(req.block_table) for req in requests) + block_tables_list = [] + for req in requests: + table = req.block_table + [-1] * (max_len - len(req.block_table)) + block_tables_list.append(table) + block_tables = torch.tensor( + block_tables_list, dtype=torch.int32, pin_memory=True + ).cuda(non_blocking=True) + + set_forward_context( + True, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + slot_mapping, + None, + block_tables, + ) + + return input_ids, positions + + def _prepare_decode(self, requests: List[Seq]): + if not requests: + raise RuntimeError("FATAL: No requests provided to _prepare_decode!") + + input_ids = [] + positions = [] + slot_mapping = [] + context_lens = [] + + for req in requests: + input_ids.append(req.last_token) + + pos = len(req) - 1 + if hasattr(self, "_tts_mode") and self._tts_mode: + pos = pos - (self._tts_prompt_len - 1) + positions.append(pos) + + context_lens.append(len(req)) + slot_mapping.append( + req.block_table[-1] * self.block_size + req.last_block_num_tokens - 1 + ) + + input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda( + non_blocking=True + ) + positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda( + non_blocking=True + ) + slot_mapping = torch.tensor( + slot_mapping, dtype=torch.int32, pin_memory=True + ).cuda(non_blocking=True) + context_lens = torch.tensor( + context_lens, dtype=torch.int32, pin_memory=True + ).cuda(non_blocking=True) + + max_len = max(len(req.block_table) for req in requests) + block_tables_list = [] + for req in requests: + table = req.block_table + [-1] * (max_len - len(req.block_table)) + block_tables_list.append(table) + block_tables = torch.tensor( + block_tables_list, dtype=torch.int32, pin_memory=True + ).cuda(non_blocking=True) + + assert block_tables.dim() == 2, ( + f"block_tables must be 2D, got shape {block_tables.shape}" + ) + assert block_tables.size(0) == len(requests), ( + f"block_tables batch size mismatch: {block_tables.size(0)} vs {len(requests)}" + ) + + set_forward_context( + False, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + ) + + return input_ids, positions + + def _prepare_sample(self, requests: List[Seq], temperature: float): + temperatures = [temperature] * len(requests) + temperatures = torch.tensor( + temperatures, dtype=torch.float32, pin_memory=True + ).cuda(non_blocking=True) + return temperatures + + @torch.inference_mode() + def _capture_cuda_graphs(self, tts_mel_embedding=None, tts_text_pos_embedding=None): + print("Capturing CUDA graphs for decode optimization...") + max_bs = 8 # Support up to batch size 8 + max_num_blocks = (2048 + self.block_size - 1) // self.block_size + model_dtype = next(self.model.parameters()).dtype + input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda") * 8192 + positions = torch.ones(max_bs, dtype=torch.int64, device="cuda") + slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cuda") + context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cuda") + block_tables = torch.zeros( + max_bs, max_num_blocks, dtype=torch.int32, device="cuda" + ) + outputs = torch.zeros( + max_bs, self.hidden_size, dtype=model_dtype, device="cuda" + ) + inputs_embeds_buffer = torch.zeros( + max_bs, self.hidden_size, dtype=model_dtype, device="cuda" + ) + + self.graph_bs = [1] + + use_tts = tts_mel_embedding is not None and tts_text_pos_embedding is not None + + for bs in reversed(self.graph_bs): + graph = torch.cuda.CUDAGraph() + + slot_mapping[:bs] = torch.arange(bs, dtype=torch.int32, device="cuda") + context_lens[:bs] = bs + 1 + block_tables[:bs, 0] = 0 + + set_forward_context( + False, + slot_mapping=slot_mapping[:bs], + context_lens=context_lens[:bs], + block_tables=block_tables[:bs], + ) + + # warmup + if use_tts: + assert tts_mel_embedding is not None + assert tts_text_pos_embedding is not None + emb = tts_mel_embedding(input_ids[:bs]) + pos_clamped = torch.clamp(positions[:bs], min=0) + pos_emb = tts_text_pos_embedding.emb(pos_clamped) + inputs_embeds_buffer[:bs] = emb + pos_emb + out = self.model( + inputs_embeds=inputs_embeds_buffer[:bs].unsqueeze(1), + return_dict=True, + ).last_hidden_state + else: + out = self.model( + input_ids=input_ids[:bs].unsqueeze(1), return_dict=True + ).last_hidden_state + outputs[:bs] = out.squeeze(1) if out.dim() == 3 else out + + with torch.cuda.graph(graph, self.graph_pool): + if use_tts: + assert tts_mel_embedding is not None + assert tts_text_pos_embedding is not None + emb = tts_mel_embedding(input_ids[:bs]) + pos_clamped = torch.clamp(positions[:bs], min=0) + pos_emb = tts_text_pos_embedding.emb(pos_clamped) + inputs_embeds_buffer[:bs] = emb + pos_emb + out = self.model( + inputs_embeds=inputs_embeds_buffer[:bs].unsqueeze(1), + return_dict=True, + ).last_hidden_state + else: + out = self.model( + input_ids=input_ids[:bs].unsqueeze(1), return_dict=True + ).last_hidden_state + outputs[:bs] = out.squeeze(1) if out.dim() == 3 else out + + if self.graph_pool is None: + self.graph_pool = graph.pool() + + self.graphs[bs] = graph + torch.cuda.synchronize() + reset_forward_context() + + self.graph_vars = { + "input_ids": input_ids, + "positions": positions, + "slot_mapping": slot_mapping, + "context_lens": context_lens, + "block_tables": block_tables, + "outputs": outputs, + "inputs_embeds": inputs_embeds_buffer, + } + print(f"CUDA graphs captured for batch sizes: {self.graph_bs}") + + @torch.inference_mode() + def _run_decode_with_graph( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + context: ForwardContext, + tts_mel_embedding: Optional[torch.nn.Module] = None, + tts_text_pos_embedding: Optional[torch.nn.Module] = None, + ) -> torch.Tensor: + bs = input_ids.size(0) + use_tts_embedding = hasattr(self, "_tts_mode") and self._tts_mode + + if not self.use_cuda_graph or not self.graphs: + if use_tts_embedding: + assert tts_mel_embedding is not None + assert tts_text_pos_embedding is not None + inputs_embeds = tts_mel_embedding(input_ids) + pos_clamped = torch.clamp(positions, min=0) + pos_emb = tts_text_pos_embedding.emb(pos_clamped) + inputs_embeds = inputs_embeds + pos_emb + out = self.model( + inputs_embeds=inputs_embeds.unsqueeze(1), return_dict=True + ).last_hidden_state + else: + out = self.model( + input_ids=input_ids.unsqueeze(1), return_dict=True + ).last_hidden_state + return out.squeeze(1) if out.dim() == 3 else out + + graph_bs = next((x for x in self.graph_bs if x >= bs), None) + if graph_bs is None: + if use_tts_embedding: + assert tts_mel_embedding is not None + assert tts_text_pos_embedding is not None + inputs_embeds = tts_mel_embedding(input_ids) + pos_clamped = torch.clamp(positions, min=0) + pos_emb = tts_text_pos_embedding.emb(pos_clamped) + inputs_embeds = inputs_embeds + pos_emb + out = self.model( + inputs_embeds=inputs_embeds.unsqueeze(1), return_dict=True + ).last_hidden_state + else: + out = self.model( + input_ids=input_ids.unsqueeze(1), return_dict=True + ).last_hidden_state + return out.squeeze(1) if out.dim() == 3 else out + + graph = self.graphs[graph_bs] + graph_vars = self.graph_vars + + if graph_vars is None: + raise RuntimeError("Graph variables not initialized") + + set_forward_context( + False, + slot_mapping=graph_vars["slot_mapping"][:graph_bs], + context_lens=graph_vars["context_lens"][:graph_bs], + block_tables=graph_vars["block_tables"][:graph_bs], + ) + + graph_vars["input_ids"][:bs] = input_ids + graph_vars["positions"][:bs] = positions + graph_vars["slot_mapping"].fill_(-1) + graph_vars["slot_mapping"][:bs] = context.slot_mapping + graph_vars["context_lens"].zero_() + graph_vars["context_lens"][:bs] = context.context_lens + graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = ( + context.block_tables + ) + graph.replay() + + return graph_vars["outputs"][:bs] + + @torch.inference_mode() + def generate( + self, + input_ids: torch.Tensor, + max_new_tokens: int = 100, + temperature: float = 1.0, + top_k: int = 50, + top_p: float = 1.0, + stop_tokens: Optional[List[int]] = None, + attention_mask: Optional[torch.Tensor] = None, + tts_embeddings: Optional[ + torch.Tensor + ] = None, # TTS: [pad][cond][text] embeddings (87 tokens, NO start_mel) + tts_mel_embedding: Optional[torch.nn.Module] = None, # TTS: mel_embedding layer + tts_text_pos_embedding: Optional[ + torch.nn.Module + ] = None, # TTS: text_pos_embedding layer + ) -> torch.Tensor: + """ + Generate tokens. + + Args: + input_ids: Input token IDs [batch_size, seq_len] + max_new_tokens: Maximum number of tokens to generate + temperature: Sampling temperature + top_k: Top-k sampling + top_p: Nucleus sampling threshold + stop_tokens: List of token IDs that stop generation + + Returns: + Generated token IDs [batch_size, total_len] + """ + batch_size = input_ids.size(0) + device = input_ids.device + + self._tts_mode = tts_embeddings is not None + self._tts_prompt_len = input_ids.size(1) if self._tts_mode else 0 + + if self.use_cuda_graph and not self.graph_captured: + print( + f"[CAPTURE] use_cuda_graph={self.use_cuda_graph}, graph_captured={self.graph_captured}", + file=sys.stderr, + flush=True, + ) + self._capture_cuda_graphs( + tts_mel_embedding=tts_mel_embedding, + tts_text_pos_embedding=tts_text_pos_embedding, + ) + self.graph_captured = True + print( + f"[CAPTURE] Completed! graphs={list(self.graphs.keys())}", + file=sys.stderr, + flush=True, + ) + + if tts_embeddings is not None: + actual_seq_len = tts_embeddings.size(1) + 1 # embeddings + start_mel_token + pass + else: + actual_seq_len = input_ids.size(1) + + sequences = [] + for i in range(batch_size): + token_ids = [1] * actual_seq_len + if tts_embeddings is not None and actual_seq_len > 0: + token_ids[-1] = input_ids[i, -1].item() if input_ids.size(1) > 0 else 1 + else: + token_ids = input_ids[i].tolist() + req = Seq(token_ids) + self.kv_manager.allocate(req) + sequences.append(req) + + self.current_sequences = sequences + + # Prefill phase + prefill_ids, prefill_pos = self._prepare_prefill(sequences) + + if prefill_ids.dim() == 1: + prefill_ids = prefill_ids.unsqueeze( + 0 + ) # [total_tokens] -> [1, total_tokens] + if prefill_pos.dim() == 1: + prefill_pos = prefill_pos.unsqueeze( + 0 + ) # [total_tokens] -> [1, total_tokens] + + if ( + tts_embeddings is not None + and tts_mel_embedding is not None + and tts_text_pos_embedding is not None + ): + start_token_id = input_ids[0, -1] if input_ids.size(1) > 0 else 8192 + + start_emb = tts_mel_embedding( + torch.tensor([[start_token_id]], device="cuda") + ) # [1, 1, hidden_dim] + + start_emb = start_emb + tts_text_pos_embedding(start_emb) + + full_embeddings = torch.cat( + [tts_embeddings, start_emb], dim=1 + ) # [1, 88, hidden_dim] + + model_dtype = next(self.model.parameters()).dtype + if full_embeddings.dtype != model_dtype: + full_embeddings = full_embeddings.to(model_dtype) + + hidden_states = self.model( + inputs_embeds=full_embeddings, return_dict=True + ).last_hidden_state + + else: + hidden_states = self.model( + input_ids=input_ids, attention_mask=attention_mask, return_dict=True + ).last_hidden_state + + reset_forward_context() + + last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size] + + if self.lm_head is not None: + if last_hidden.dtype != next(self.lm_head.parameters()).dtype: + last_hidden = last_hidden.to(next(self.lm_head.parameters()).dtype) + logits = self.lm_head(last_hidden) # [batch_size, vocab_size] + else: + logits = self.model.compute_logits(last_hidden) # [batch_size, vocab_size] + + temperatures = self._prepare_sample(sequences, temperature) + if temperature > 0: + first_token = self.sampler(logits, temperatures) + else: + first_token = torch.argmax(logits, dim=-1) + + first_token_list = first_token.tolist() + + generated_tokens = [[] for _ in range(batch_size)] + hit_stop_on_first = False + + for i, token_id in enumerate(first_token_list): + if stop_tokens and token_id in stop_tokens: + hit_stop_on_first = True + else: + generated_tokens[i].append(token_id) + + if hit_stop_on_first: + for req in sequences: + self.kv_manager.remove_seq(req) + self.current_sequences = [] + + output_ids = [] + for i in range(batch_size): + full_sequence = input_ids[i].tolist() + generated_tokens[i] + output_ids.append(full_sequence) + + output = torch.tensor(output_ids, dtype=torch.long, device=device) + return output + + if not hit_stop_on_first: + for i, req in enumerate(sequences): + req.append_token(first_token_list[i]) + self.kv_manager.append_to_seq(req) + + remaining_tokens = max_new_tokens - 1 + + for step in range(remaining_tokens): + decode_ids, decode_pos = self._prepare_decode(sequences) + + # Forward pass + if batch_size > 8: + raise RuntimeError( + f"FATAL: batch_size={batch_size} exceeds CUDA Graph limit (8)!" + ) + + context = get_forward_context() + hidden_states = self._run_decode_with_graph( + decode_ids, + decode_pos, + context, + tts_mel_embedding=tts_mel_embedding, + tts_text_pos_embedding=tts_text_pos_embedding, + ) + + # Get logits + if self.lm_head is not None: + logits = self.lm_head(hidden_states) # [batch_size, vocab_size] + else: + logits = self.model.compute_logits( + hidden_states + ) # [batch_size, vocab_size] + + reset_forward_context() + + temperatures = self._prepare_sample(sequences, temperature) + if temperature > 0: + next_token = self.sampler(logits, temperatures) + else: + next_token = torch.argmax(logits, dim=-1) + next_token_list = next_token.tolist() + + should_stop = False + for i, token_id in enumerate(next_token_list): + if stop_tokens and token_id in stop_tokens: + should_stop = True + else: + sequences[i].append_token(token_id) + self.kv_manager.append_to_seq(sequences[i]) + generated_tokens[i].append(token_id) + + if should_stop: + break + + for req in sequences: + self.kv_manager.remove_seq(req) + self.current_sequences = [] + + output_ids = [] + for i in range(batch_size): + initial_tokens = sequences[i].token_ids[: sequences[i].num_prompt_tokens] + full_sequence = initial_tokens + generated_tokens[i] + output_ids.append(full_sequence) + + output = torch.tensor(output_ids, dtype=torch.long, device=device) + + assert output.size(0) == batch_size, ( + f"Output batch size mismatch: {output.size(0)} != {batch_size}" + ) + + return output diff --git a/indextts/accel/attention.py b/indextts/accel/attention.py new file mode 100644 index 0000000..2946c5d --- /dev/null +++ b/indextts/accel/attention.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass + +import torch +import triton +import triton.language as tl +from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from torch import nn + + +@dataclass +class ForwardContext: + is_prefill: bool = False + cu_seqlens_q: torch.Tensor | None = None + cu_seqlens_k: torch.Tensor | None = None + max_seqlen_q: int = 0 + max_seqlen_k: int = 0 + slot_mapping: torch.Tensor | None = None + context_lens: torch.Tensor | None = None + block_tables: torch.Tensor | None = None + + +_FORWARD_CONTEXT = ForwardContext() + + +def get_forward_context(): + return _FORWARD_CONTEXT + + +def set_forward_context( + is_prefill, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=0, + max_seqlen_k=0, + slot_mapping=None, + context_lens=None, + block_tables=None, +): + global _FORWARD_CONTEXT + _FORWARD_CONTEXT = ForwardContext( + is_prefill, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + slot_mapping, + context_lens, + block_tables, + ) + + +def reset_forward_context(): + global _FORWARD_CONTEXT + _FORWARD_CONTEXT = ForwardContext() + + +@triton.jit +def store_kvcache_kernel( + key_ptr, + key_stride, + value_ptr, + value_stride, + k_cache_ptr, + v_cache_ptr, + slot_mapping_ptr, + D: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 2048 + idx = tl.program_id(0) + slot = tl.load(slot_mapping_ptr + idx) + if slot == -1: + return + d_offset = 0 + while d_offset < D: + cur_block_size = min(BLOCK_SIZE, D - d_offset) + key_offsets = idx * key_stride + d_offset + tl.arange(0, BLOCK_SIZE) + value_offsets = idx * value_stride + d_offset + tl.arange(0, BLOCK_SIZE) + cache_offsets = slot * D + d_offset + tl.arange(0, BLOCK_SIZE) + + mask = tl.arange(0, BLOCK_SIZE) < cur_block_size + key = tl.load(key_ptr + key_offsets, mask=mask, other=0.0) + value = tl.load(value_ptr + value_offsets, mask=mask, other=0.0) + tl.store(k_cache_ptr + cache_offsets, key, mask=mask) + tl.store(v_cache_ptr + cache_offsets, value, mask=mask) + + d_offset += BLOCK_SIZE + + +def store_kvcache( + key: torch.Tensor, + value: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + slot_mapping: torch.Tensor, +): + N, num_heads, head_dim = key.shape + D = num_heads * head_dim + assert key.stride(-1) == 1 and value.stride(-1) == 1 + assert key.stride(1) == head_dim and value.stride(1) == head_dim + assert k_cache.stride(1) == D and v_cache.stride(1) == D + assert slot_mapping.numel() == N + store_kvcache_kernel[(N,)]( + key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D + ) + + +class Attention(nn.Module): + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float, + num_kv_heads: int, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.k_cache = self.v_cache = torch.tensor([]) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + context = get_forward_context() + k_cache, v_cache = self.k_cache, self.v_cache + + if k_cache.numel() and v_cache.numel() and context.slot_mapping is not None: + store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) + + if context.is_prefill: + if context.block_tables is not None: + k, v = k_cache, v_cache + o = flash_attn_varlen_func( + q, + k, + v, + max_seqlen_q=context.max_seqlen_q, + cu_seqlens_q=context.cu_seqlens_q, + max_seqlen_k=context.max_seqlen_k, + cu_seqlens_k=context.cu_seqlens_k, + softmax_scale=self.scale, + causal=True, + block_table=context.block_tables, + ) + else: + o = flash_attn_with_kvcache( + q.unsqueeze(1), + k_cache, + v_cache, + cache_seqlens=context.context_lens, + block_table=context.block_tables, + softmax_scale=self.scale, + causal=True, + ) + return o diff --git a/indextts/accel/gpt2_accel.py b/indextts/accel/gpt2_accel.py new file mode 100644 index 0000000..f96800e --- /dev/null +++ b/indextts/accel/gpt2_accel.py @@ -0,0 +1,181 @@ +import torch +import torch.nn as nn +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2Block, GPT2Model + +from .attention import Attention + + +class GPT2AccelAttention(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril( + torch.ones((max_positions, max_positions), dtype=torch.bool) + ).view(1, 1, max_positions, max_positions), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + scale = (self.head_dim**-0.5) if self.scale_attn_weights else 1.0 + self.accel_attn = Attention( + self.num_heads, self.head_dim, scale, self.num_heads + ) + + def forward( + self, + hidden_states: torch.Tensor, + layer_past=None, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=False, + output_attentions=False, + past_key_value=None, + **kwargs, + ): + if encoder_hidden_states is not None: + raise NotImplementedError("Cross attention not supported in accel mode") + + qkv = self.c_attn(hidden_states) + query, key, value = qkv.split(self.split_size, dim=2) + + # [B, T, H*D] -> [B, H, T, D] + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + # flatten to [B*T, H, D] + bsz, num_heads, seq_len, head_dim = query.shape + q_flat = query.transpose(1, 2).contiguous().view(-1, num_heads, head_dim) + k_flat = key.transpose(1, 2).contiguous().view(-1, num_heads, head_dim) + v_flat = value.transpose(1, 2).contiguous().view(-1, num_heads, head_dim) + + # ensure fp16 + if q_flat.device.type == "cuda" and q_flat.dtype != torch.float16: + orig_dtype = q_flat.dtype + q_flat = q_flat.to(torch.float16) + k_flat = k_flat.to(torch.float16) + v_flat = v_flat.to(torch.float16) + else: + orig_dtype = q_flat.dtype + + o_flat = self.accel_attn(q_flat, k_flat, v_flat) # [B*T, H, D] + + if o_flat.dtype != orig_dtype: + o_flat = o_flat.to(orig_dtype) + + # Reshape back: [B*T, H, D] -> [B, H, T, D] + attn_output = o_flat.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, None) + if output_attentions: + outputs += (None,) + + return outputs + + def _split_heads(self, tensor, num_heads, head_dim): + new_shape = tensor.size()[:-1] + (num_heads, head_dim) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, head_dim): + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * head_dim,) + return tensor.view(new_shape) + + +class GPT2AccelBlock(GPT2Block): + def __init__(self, config, layer_idx=None): + super().__init__(config, layer_idx) + self.attn = GPT2AccelAttention(config, layer_idx) + + +class GPT2AccelModel(GPT2Model): + def __init__(self, config): + super().__init__(config) + self.h = nn.ModuleList( + [ + GPT2AccelBlock(config, layer_idx=i) + for i in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + if inputs_embeds is not None: + hidden_states = inputs_embeds + + for block in self.h: + hidden_states = block(hidden_states)[0] + + hidden_states = self.ln_f(hidden_states) + + if return_dict: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + return (hidden_states,) + else: + return super().forward( + input_ids=input_ids, + past_key_values=None, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) diff --git a/indextts/accel/kv_manager.py b/indextts/accel/kv_manager.py new file mode 100644 index 0000000..68232b9 --- /dev/null +++ b/indextts/accel/kv_manager.py @@ -0,0 +1,209 @@ +import hashlib +import pickle +from collections import deque +from copy import copy +from typing import Dict, List, Optional, Set + +import torch + + +class KVCacheBlock: + def __init__(self, block_id: int): + self.block_id = block_id + self.ref_cnt = 0 + self._block_hash = None + self.token_ids = [] + + @property + def block_hash(self) -> Optional[bytes]: + return self._block_hash + + def update(self, block_hash: bytes, token_ids: List[int]): + self._block_hash = block_hash + self.token_ids = token_ids + + def reset(self): + self.ref_cnt = 1 + self._block_hash = None + self.token_ids = [] + + +class Seq: + def __init__(self, token_ids: List[int], block_size: int = 256): + self.token_ids = copy(token_ids) + self.last_token = token_ids[-1] if token_ids else 0 + self.num_tokens = len(self.token_ids) + self.num_prompt_tokens = len(token_ids) + self.num_cached_tokens = 0 + self.block_table: List[int] = [] + self.block_size = block_size + + def __len__(self): + return self.num_tokens + + def __getitem__(self, key): + return self.token_ids[key] + + @property + def num_blocks(self): + return (self.num_tokens + self.block_size - 1) // self.block_size + + @property + def num_cached_blocks(self): + return self.num_cached_tokens // self.block_size + + @property + def last_block_num_tokens(self): + return self.num_tokens - (self.num_blocks - 1) * self.block_size + + def get_block_tokens(self, block_idx: int) -> List[int]: + assert 0 <= block_idx < self.num_blocks + start = block_idx * self.block_size + end = start + self.block_size + return self.token_ids[start:end] + + def append_token(self, token_id: int): + self.token_ids.append(token_id) + self.last_token = token_id + self.num_tokens += 1 + + +class KVCacheManager: + def __init__( + self, + num_layers: int, + num_heads: int, + head_dim: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + ): + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + self.block_size = block_size + self.num_blocks = num_blocks + self.dtype = dtype + + self.blocks: List[KVCacheBlock] = [KVCacheBlock(i) for i in range(num_blocks)] + self.block_hash_to_id: Dict[bytes, int] = {} + self.free_block_ids: deque = deque(range(num_blocks)) + self.used_block_ids: Set[int] = set() + + device = "cuda" if torch.cuda.is_available() else "cpu" + cache_dtype = torch.float16 if device == "cuda" else dtype + self.kv_cache = torch.empty( + 2, + num_layers, + num_blocks, + block_size, + num_heads, + head_dim, + dtype=cache_dtype, + device=device, + ) + + @classmethod + def compute_block_hash( + cls, token_ids: List[int], parent_hash: Optional[bytes] = None + ) -> bytes: + hash_input = [] + if parent_hash is not None: + hash_input.append(parent_hash) + hash_input.extend(token_ids) + input_bytes = pickle.dumps(tuple(hash_input), protocol=pickle.HIGHEST_PROTOCOL) + return hashlib.sha256(input_bytes).digest() + + def _allocate_block(self, block_id: int) -> KVCacheBlock: + block = self.blocks[block_id] + assert block.ref_cnt == 0 + block.reset() + self.free_block_ids.remove(block_id) + self.used_block_ids.add(block_id) + return block + + def _deallocate_block(self, block_id: int): + assert self.blocks[block_id].ref_cnt == 0 + self.used_block_ids.remove(block_id) + self.free_block_ids.append(block_id) + + def allocate(self, sequence: Seq): + assert not sequence.block_table, "Sequence already has allocated blocks" + + parent_hash = None + cache_miss = False + + for i in range(sequence.num_blocks): + token_ids = sequence.get_block_tokens(i) + block_hash = ( + self.compute_block_hash(token_ids, parent_hash) + if len(token_ids) == self.block_size + else None + ) + block_id = self.block_hash_to_id.get(block_hash) if block_hash else None + + if block_id is None or self.blocks[block_id].token_ids != token_ids: + cache_miss = True + + if cache_miss: + block_id = self.free_block_ids[0] + block = self._allocate_block(block_id) + else: + sequence.num_cached_tokens += self.block_size + if block_id is not None and block_id in self.used_block_ids: + block = self.blocks[block_id] + block.ref_cnt += 1 + else: + block_id = self.free_block_ids[0] + block = self._allocate_block(block_id) + + if block_hash is not None: + block.update(block_hash, token_ids) + self.block_hash_to_id[block_hash] = block_id + parent_hash = block_hash + + sequence.block_table.append(block_id) + + def deallocate(self, sequence: Seq): + for block_id in reversed(sequence.block_table): + block = self.blocks[block_id] + block.ref_cnt -= 1 + if block.ref_cnt == 0: + self._deallocate_block(block_id) + + sequence.num_cached_tokens = 0 + sequence.block_table.clear() + + def append_to_seq(self, sequence: Seq): + block_table = sequence.block_table + last_block = self.blocks[block_table[-1]] + + if len(sequence) % self.block_size == 1: + assert last_block.block_hash is not None + block_id = self.free_block_ids[0] + self._allocate_block(block_id) + block_table.append(block_id) + elif len(sequence) % self.block_size == 0: + assert last_block.block_hash is None + token_ids = sequence.get_block_tokens(sequence.num_blocks - 1) + parent_hash = ( + self.blocks[block_table[-2]].block_hash + if len(block_table) > 1 + else None + ) + block_hash = self.compute_block_hash(token_ids, parent_hash) + last_block.update(block_hash, token_ids) + self.block_hash_to_id[block_hash] = last_block.block_id + else: + assert last_block.block_hash is None + + def remove_seq(self, sequence: Seq): + self.deallocate(sequence) + + def wire_kv_cache_to_model(self, model): + layer_id = 0 + for module in model.modules(): + if hasattr(module, "k_cache") and hasattr(module, "v_cache"): + module.k_cache = self.kv_cache[0, layer_id] + module.v_cache = self.kv_cache[1, layer_id] + layer_id += 1 diff --git a/indextts/gpt/model_v2.py b/indextts/gpt/model_v2.py index 3f39bec..3de3ac6 100644 --- a/indextts/gpt/model_v2.py +++ b/indextts/gpt/model_v2.py @@ -307,7 +307,7 @@ class UnifiedVoice(nn.Module): start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, checkpointing=True, types=1, - condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None): + condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None, use_accel=False): """ Args: layers: Number of layers in transformer stack. @@ -409,6 +409,9 @@ class UnifiedVoice(nn.Module): for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) + self.use_accel = use_accel + self.accel_engine = None # Will be initialized in post_init_gpt2_config + def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False): seq_length = self.max_mel_tokens + self.max_text_tokens + 2 gpt_config = GPT2Config( @@ -421,6 +424,38 @@ class UnifiedVoice(nn.Module): gradient_checkpointing=False, use_cache=True, ) + + if self.use_accel and torch.cuda.is_available(): + # Check if flash attention is available + try: + import flash_attn + except ImportError: + raise ImportError("flash_attn is required for acceleration but not installed. Please install from https://github.com/Dao-AILab/flash-attention/releases/") + + from indextts.accel import GPT2AccelModel, AccelInferenceEngine + + # Create accel model + accel_gpt = GPT2AccelModel(gpt_config) + accel_gpt.load_state_dict(self.gpt.state_dict(), strict=False) + + if half: + accel_gpt = accel_gpt.half().cuda() + else: + accel_gpt = accel_gpt.cuda() + accel_gpt.eval() + + lm_head_with_norm = nn.Sequential(self.final_norm, self.mel_head) + self.accel_engine = AccelInferenceEngine( + model=accel_gpt, + lm_head=lm_head_with_norm, + num_layers=self.layers, + num_heads=self.heads, + head_dim=self.model_dim // self.heads, + block_size=256, + num_blocks=16, # Reduce to save memory (16*256 = 4096 tokens capacity) + use_cuda_graph=True, + ) + print("acceleration engine initialized") self.inference_model = GPT2InferenceModel( gpt_config, self.gpt, @@ -721,12 +756,26 @@ class UnifiedVoice(nn.Module): min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1 logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep)) max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length - output = self.inference_model.generate(inputs, - bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, - eos_token_id=self.stop_mel_token, attention_mask=attention_mask, - max_length=max_length, logits_processor=logits_processor, - num_return_sequences=num_return_sequences, - **hf_generate_kwargs) + + # Use accel engine if available (single sequence only) + if self.accel_engine is not None and num_return_sequences == 1: + output = self.accel_engine.generate( + inputs, # fake input_ids (all 1s + start_mel_token) + max_new_tokens=max_length - trunc_index, + attention_mask=attention_mask, + temperature=hf_generate_kwargs.get('temperature', 1), + stop_tokens=[self.stop_mel_token], + tts_embeddings=inputs_embeds, # [pad][cond][text] embeddings (87 tokens, NO start_mel_token) + tts_mel_embedding=self.inference_model.embeddings, # mel_embedding layer + tts_text_pos_embedding=self.inference_model.text_pos_embedding, # text_pos_embedding layer + ) + else: + output = self.inference_model.generate(inputs, + bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, attention_mask=attention_mask, + max_length=max_length, logits_processor=logits_processor, + num_return_sequences=num_return_sequences, + **hf_generate_kwargs) if isinstance(output, torch.Tensor): return output[:, trunc_index:], speech_conditioning_latent # GenerateOutput diff --git a/indextts/infer_v2.py b/indextts/infer_v2.py index a4dde5b..c506bba 100644 --- a/indextts/infer_v2.py +++ b/indextts/infer_v2.py @@ -38,7 +38,7 @@ import torch.nn.functional as F class IndexTTS2: def __init__( self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None, - use_cuda_kernel=None,use_deepspeed=False + use_cuda_kernel=None,use_deepspeed=False, use_accel=False ): """ Args: @@ -48,6 +48,7 @@ class IndexTTS2: device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS. use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device. use_deepspeed (bool): whether to use DeepSpeed or not. + use_accel (bool): whether to use acceleration engine for GPT2 or not. """ if device is not None: self.device = device @@ -75,10 +76,11 @@ class IndexTTS2: self.model_dir = model_dir self.dtype = torch.float16 if self.use_fp16 else None self.stop_mel_token = self.cfg.gpt.stop_mel_token + self.use_accel = use_accel self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path)) - self.gpt = UnifiedVoice(**self.cfg.gpt) + self.gpt = UnifiedVoice(**self.cfg.gpt, use_accel=self.use_accel) self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint) load_checkpoint(self.gpt, self.gpt_path) self.gpt = self.gpt.to(self.device) @@ -453,7 +455,7 @@ class IndexTTS2: ref_mel = self.cache_mel if emo_vector is not None: - weight_vector = torch.tensor(emo_vector).to(self.device) + weight_vector = torch.tensor(emo_vector, device=self.device) if use_random: random_index = [random.randint(0, x - 1) for x in self.emo_num] else: @@ -580,15 +582,16 @@ class IndexTTS2: # print(f"code len: {code_lens}") code_lens = [] + max_code_len = 0 for code in codes: if self.stop_mel_token not in code: - code_lens.append(len(code)) code_len = len(code) else: - len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1 - code_len = len_ - 1 + len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + code_len = len_[0].item() if len_.numel() > 0 else len(code) code_lens.append(code_len) - codes = codes[:, :code_len] + max_code_len = max(max_code_len, code_len) + codes = codes[:, :max_code_len] code_lens = torch.LongTensor(code_lens) code_lens = code_lens.to(self.device) if verbose: