Merge pull request #517 from storyicon/gpt2_accel

feat: achieve inference acceleration for the gpt2 stage (3.79×)
This commit is contained in:
Vanka0051 2025-10-30 16:14:46 +08:00 committed by GitHub
commit 1d5d079aaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1228 additions and 14 deletions

View File

@ -0,0 +1,9 @@
from .accel_engine import AccelInferenceEngine # noqa: F401
from .attention import ( # noqa: F401
Attention,
get_forward_context,
reset_forward_context,
set_forward_context,
)
from .gpt2_accel import GPT2AccelAttention, GPT2AccelModel # noqa: F401
from .kv_manager import KVCacheManager, Seq # noqa: F401

View File

@ -0,0 +1,609 @@
import sys
from typing import List, Optional
import torch
from torch import nn
from .attention import (
ForwardContext,
get_forward_context,
reset_forward_context,
set_forward_context,
)
from .kv_manager import KVCacheManager, Seq
class Sampler(nn.Module):
def __init__(self):
super().__init__()
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(
torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)
).argmax(dim=-1)
return sample_tokens
class AccelInferenceEngine:
def __init__(
self,
model,
lm_head,
num_layers: int,
num_heads: int,
head_dim: int,
block_size: int = 256,
num_blocks: int = 128,
use_cuda_graph: bool = True,
):
"""
Args:
model: The GPT transformer model (should have accel attention)
lm_head: Language model head for generating logits
num_layers: Number of transformer layers
num_heads: Number of attention heads
head_dim: Dimension per head
block_size: KV cache block size
num_blocks: Total number of KV cache blocks
use_cuda_graph: Whether to use CUDA Graph for decode optimization
"""
self.model = model
self.lm_head = lm_head
self.block_size = block_size
self.num_blocks = num_blocks
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
model_dtype = next(model.parameters()).dtype
self.hidden_size = (
model.config.hidden_size
if hasattr(model, "config")
else head_dim * num_heads
)
self.kv_manager = KVCacheManager(
num_layers=num_layers,
num_heads=num_heads,
head_dim=head_dim,
block_size=block_size,
num_blocks=num_blocks,
dtype=torch.float16, # Force fp16 for FlashAttention
)
self.kv_manager.wire_kv_cache_to_model(model)
self.sampler = Sampler()
self.current_sequences = []
self.graphs = {}
self.graph_vars = None
self.graph_pool = None
self.graph_captured = False
def _prepare_prefill(self, requests: List[Seq]):
input_ids = []
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
for req in requests:
seqlen = len(req)
input_ids.extend(req[req.num_cached_tokens :])
positions.extend(list(range(req.num_cached_tokens, seqlen)))
seqlen_q = seqlen - req.num_cached_tokens
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if req.block_table:
for i in range(req.num_cached_blocks, req.num_blocks):
block_id = req.block_table[i]
start = block_id * self.block_size
if i != req.num_blocks - 1:
end = start + self.block_size
else:
end = start + req.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
non_blocking=True
)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(
non_blocking=True
)
cu_seqlens_q = torch.tensor(
cu_seqlens_q, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(
cu_seqlens_k, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
slot_mapping = torch.tensor(
slot_mapping, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
block_tables = None
if cu_seqlens_k[-1] > cu_seqlens_q[-1]:
max_len = max(len(req.block_table) for req in requests)
block_tables_list = []
for req in requests:
table = req.block_table + [-1] * (max_len - len(req.block_table))
block_tables_list.append(table)
block_tables = torch.tensor(
block_tables_list, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
set_forward_context(
True,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
slot_mapping,
None,
block_tables,
)
return input_ids, positions
def _prepare_decode(self, requests: List[Seq]):
if not requests:
raise RuntimeError("FATAL: No requests provided to _prepare_decode!")
input_ids = []
positions = []
slot_mapping = []
context_lens = []
for req in requests:
input_ids.append(req.last_token)
pos = len(req) - 1
if hasattr(self, "_tts_mode") and self._tts_mode:
pos = pos - (self._tts_prompt_len - 1)
positions.append(pos)
context_lens.append(len(req))
slot_mapping.append(
req.block_table[-1] * self.block_size + req.last_block_num_tokens - 1
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
non_blocking=True
)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(
non_blocking=True
)
slot_mapping = torch.tensor(
slot_mapping, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
context_lens = torch.tensor(
context_lens, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
max_len = max(len(req.block_table) for req in requests)
block_tables_list = []
for req in requests:
table = req.block_table + [-1] * (max_len - len(req.block_table))
block_tables_list.append(table)
block_tables = torch.tensor(
block_tables_list, dtype=torch.int32, pin_memory=True
).cuda(non_blocking=True)
assert block_tables.dim() == 2, (
f"block_tables must be 2D, got shape {block_tables.shape}"
)
assert block_tables.size(0) == len(requests), (
f"block_tables batch size mismatch: {block_tables.size(0)} vs {len(requests)}"
)
set_forward_context(
False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
)
return input_ids, positions
def _prepare_sample(self, requests: List[Seq], temperature: float):
temperatures = [temperature] * len(requests)
temperatures = torch.tensor(
temperatures, dtype=torch.float32, pin_memory=True
).cuda(non_blocking=True)
return temperatures
@torch.inference_mode()
def _capture_cuda_graphs(self, tts_mel_embedding=None, tts_text_pos_embedding=None):
print("Capturing CUDA graphs for decode optimization...")
max_bs = 8 # Support up to batch size 8
max_num_blocks = (2048 + self.block_size - 1) // self.block_size
model_dtype = next(self.model.parameters()).dtype
input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda") * 8192
positions = torch.ones(max_bs, dtype=torch.int64, device="cuda")
slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
block_tables = torch.zeros(
max_bs, max_num_blocks, dtype=torch.int32, device="cuda"
)
outputs = torch.zeros(
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
)
inputs_embeds_buffer = torch.zeros(
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
)
self.graph_bs = [1]
use_tts = tts_mel_embedding is not None and tts_text_pos_embedding is not None
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
slot_mapping[:bs] = torch.arange(bs, dtype=torch.int32, device="cuda")
context_lens[:bs] = bs + 1
block_tables[:bs, 0] = 0
set_forward_context(
False,
slot_mapping=slot_mapping[:bs],
context_lens=context_lens[:bs],
block_tables=block_tables[:bs],
)
# warmup
if use_tts:
assert tts_mel_embedding is not None
assert tts_text_pos_embedding is not None
emb = tts_mel_embedding(input_ids[:bs])
pos_clamped = torch.clamp(positions[:bs], min=0)
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
inputs_embeds_buffer[:bs] = emb + pos_emb
out = self.model(
inputs_embeds=inputs_embeds_buffer[:bs].unsqueeze(1),
return_dict=True,
).last_hidden_state
else:
out = self.model(
input_ids=input_ids[:bs].unsqueeze(1), return_dict=True
).last_hidden_state
outputs[:bs] = out.squeeze(1) if out.dim() == 3 else out
with torch.cuda.graph(graph, self.graph_pool):
if use_tts:
assert tts_mel_embedding is not None
assert tts_text_pos_embedding is not None
emb = tts_mel_embedding(input_ids[:bs])
pos_clamped = torch.clamp(positions[:bs], min=0)
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
inputs_embeds_buffer[:bs] = emb + pos_emb
out = self.model(
inputs_embeds=inputs_embeds_buffer[:bs].unsqueeze(1),
return_dict=True,
).last_hidden_state
else:
out = self.model(
input_ids=input_ids[:bs].unsqueeze(1), return_dict=True
).last_hidden_state
outputs[:bs] = out.squeeze(1) if out.dim() == 3 else out
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
torch.cuda.synchronize()
reset_forward_context()
self.graph_vars = {
"input_ids": input_ids,
"positions": positions,
"slot_mapping": slot_mapping,
"context_lens": context_lens,
"block_tables": block_tables,
"outputs": outputs,
"inputs_embeds": inputs_embeds_buffer,
}
print(f"CUDA graphs captured for batch sizes: {self.graph_bs}")
@torch.inference_mode()
def _run_decode_with_graph(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
context: ForwardContext,
tts_mel_embedding: Optional[torch.nn.Module] = None,
tts_text_pos_embedding: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
bs = input_ids.size(0)
use_tts_embedding = hasattr(self, "_tts_mode") and self._tts_mode
if not self.use_cuda_graph or not self.graphs:
if use_tts_embedding:
assert tts_mel_embedding is not None
assert tts_text_pos_embedding is not None
inputs_embeds = tts_mel_embedding(input_ids)
pos_clamped = torch.clamp(positions, min=0)
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
inputs_embeds = inputs_embeds + pos_emb
out = self.model(
inputs_embeds=inputs_embeds.unsqueeze(1), return_dict=True
).last_hidden_state
else:
out = self.model(
input_ids=input_ids.unsqueeze(1), return_dict=True
).last_hidden_state
return out.squeeze(1) if out.dim() == 3 else out
graph_bs = next((x for x in self.graph_bs if x >= bs), None)
if graph_bs is None:
if use_tts_embedding:
assert tts_mel_embedding is not None
assert tts_text_pos_embedding is not None
inputs_embeds = tts_mel_embedding(input_ids)
pos_clamped = torch.clamp(positions, min=0)
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
inputs_embeds = inputs_embeds + pos_emb
out = self.model(
inputs_embeds=inputs_embeds.unsqueeze(1), return_dict=True
).last_hidden_state
else:
out = self.model(
input_ids=input_ids.unsqueeze(1), return_dict=True
).last_hidden_state
return out.squeeze(1) if out.dim() == 3 else out
graph = self.graphs[graph_bs]
graph_vars = self.graph_vars
if graph_vars is None:
raise RuntimeError("Graph variables not initialized")
set_forward_context(
False,
slot_mapping=graph_vars["slot_mapping"][:graph_bs],
context_lens=graph_vars["context_lens"][:graph_bs],
block_tables=graph_vars["block_tables"][:graph_bs],
)
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"].fill_(-1)
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"].zero_()
graph_vars["context_lens"][:bs] = context.context_lens
graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = (
context.block_tables
)
graph.replay()
return graph_vars["outputs"][:bs]
@torch.inference_mode()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
stop_tokens: Optional[List[int]] = None,
attention_mask: Optional[torch.Tensor] = None,
tts_embeddings: Optional[
torch.Tensor
] = None, # TTS: [pad][cond][text] embeddings (87 tokens, NO start_mel)
tts_mel_embedding: Optional[torch.nn.Module] = None, # TTS: mel_embedding layer
tts_text_pos_embedding: Optional[
torch.nn.Module
] = None, # TTS: text_pos_embedding layer
) -> torch.Tensor:
"""
Generate tokens.
Args:
input_ids: Input token IDs [batch_size, seq_len]
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_k: Top-k sampling
top_p: Nucleus sampling threshold
stop_tokens: List of token IDs that stop generation
Returns:
Generated token IDs [batch_size, total_len]
"""
batch_size = input_ids.size(0)
device = input_ids.device
self._tts_mode = tts_embeddings is not None
self._tts_prompt_len = input_ids.size(1) if self._tts_mode else 0
if self.use_cuda_graph and not self.graph_captured:
print(
f"[CAPTURE] use_cuda_graph={self.use_cuda_graph}, graph_captured={self.graph_captured}",
file=sys.stderr,
flush=True,
)
self._capture_cuda_graphs(
tts_mel_embedding=tts_mel_embedding,
tts_text_pos_embedding=tts_text_pos_embedding,
)
self.graph_captured = True
print(
f"[CAPTURE] Completed! graphs={list(self.graphs.keys())}",
file=sys.stderr,
flush=True,
)
if tts_embeddings is not None:
actual_seq_len = tts_embeddings.size(1) + 1 # embeddings + start_mel_token
pass
else:
actual_seq_len = input_ids.size(1)
sequences = []
for i in range(batch_size):
token_ids = [1] * actual_seq_len
if tts_embeddings is not None and actual_seq_len > 0:
token_ids[-1] = input_ids[i, -1].item() if input_ids.size(1) > 0 else 1
else:
token_ids = input_ids[i].tolist()
req = Seq(token_ids)
self.kv_manager.allocate(req)
sequences.append(req)
self.current_sequences = sequences
# Prefill phase
prefill_ids, prefill_pos = self._prepare_prefill(sequences)
if prefill_ids.dim() == 1:
prefill_ids = prefill_ids.unsqueeze(
0
) # [total_tokens] -> [1, total_tokens]
if prefill_pos.dim() == 1:
prefill_pos = prefill_pos.unsqueeze(
0
) # [total_tokens] -> [1, total_tokens]
if (
tts_embeddings is not None
and tts_mel_embedding is not None
and tts_text_pos_embedding is not None
):
start_token_id = input_ids[0, -1] if input_ids.size(1) > 0 else 8192
start_emb = tts_mel_embedding(
torch.tensor([[start_token_id]], device="cuda")
) # [1, 1, hidden_dim]
start_emb = start_emb + tts_text_pos_embedding(start_emb)
full_embeddings = torch.cat(
[tts_embeddings, start_emb], dim=1
) # [1, 88, hidden_dim]
model_dtype = next(self.model.parameters()).dtype
if full_embeddings.dtype != model_dtype:
full_embeddings = full_embeddings.to(model_dtype)
hidden_states = self.model(
inputs_embeds=full_embeddings, return_dict=True
).last_hidden_state
else:
hidden_states = self.model(
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
).last_hidden_state
reset_forward_context()
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
if self.lm_head is not None:
if last_hidden.dtype != next(self.lm_head.parameters()).dtype:
last_hidden = last_hidden.to(next(self.lm_head.parameters()).dtype)
logits = self.lm_head(last_hidden) # [batch_size, vocab_size]
else:
logits = self.model.compute_logits(last_hidden) # [batch_size, vocab_size]
temperatures = self._prepare_sample(sequences, temperature)
if temperature > 0:
first_token = self.sampler(logits, temperatures)
else:
first_token = torch.argmax(logits, dim=-1)
first_token_list = first_token.tolist()
generated_tokens = [[] for _ in range(batch_size)]
hit_stop_on_first = False
for i, token_id in enumerate(first_token_list):
if stop_tokens and token_id in stop_tokens:
hit_stop_on_first = True
else:
generated_tokens[i].append(token_id)
if hit_stop_on_first:
for req in sequences:
self.kv_manager.remove_seq(req)
self.current_sequences = []
output_ids = []
for i in range(batch_size):
full_sequence = input_ids[i].tolist() + generated_tokens[i]
output_ids.append(full_sequence)
output = torch.tensor(output_ids, dtype=torch.long, device=device)
return output
if not hit_stop_on_first:
for i, req in enumerate(sequences):
req.append_token(first_token_list[i])
self.kv_manager.append_to_seq(req)
remaining_tokens = max_new_tokens - 1
for step in range(remaining_tokens):
decode_ids, decode_pos = self._prepare_decode(sequences)
# Forward pass
if batch_size > 8:
raise RuntimeError(
f"FATAL: batch_size={batch_size} exceeds CUDA Graph limit (8)!"
)
context = get_forward_context()
hidden_states = self._run_decode_with_graph(
decode_ids,
decode_pos,
context,
tts_mel_embedding=tts_mel_embedding,
tts_text_pos_embedding=tts_text_pos_embedding,
)
# Get logits
if self.lm_head is not None:
logits = self.lm_head(hidden_states) # [batch_size, vocab_size]
else:
logits = self.model.compute_logits(
hidden_states
) # [batch_size, vocab_size]
reset_forward_context()
temperatures = self._prepare_sample(sequences, temperature)
if temperature > 0:
next_token = self.sampler(logits, temperatures)
else:
next_token = torch.argmax(logits, dim=-1)
next_token_list = next_token.tolist()
should_stop = False
for i, token_id in enumerate(next_token_list):
if stop_tokens and token_id in stop_tokens:
should_stop = True
else:
sequences[i].append_token(token_id)
self.kv_manager.append_to_seq(sequences[i])
generated_tokens[i].append(token_id)
if should_stop:
break
for req in sequences:
self.kv_manager.remove_seq(req)
self.current_sequences = []
output_ids = []
for i in range(batch_size):
initial_tokens = sequences[i].token_ids[: sequences[i].num_prompt_tokens]
full_sequence = initial_tokens + generated_tokens[i]
output_ids.append(full_sequence)
output = torch.tensor(output_ids, dtype=torch.long, device=device)
assert output.size(0) == batch_size, (
f"Output batch size mismatch: {output.size(0)} != {batch_size}"
)
return output

154
indextts/accel/attention.py Normal file
View File

@ -0,0 +1,154 @@
from dataclasses import dataclass
import torch
import triton
import triton.language as tl
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from torch import nn
@dataclass
class ForwardContext:
is_prefill: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
slot_mapping: torch.Tensor | None = None
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
_FORWARD_CONTEXT = ForwardContext()
def get_forward_context():
return _FORWARD_CONTEXT
def set_forward_context(
is_prefill,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=0,
max_seqlen_k=0,
slot_mapping=None,
context_lens=None,
block_tables=None,
):
global _FORWARD_CONTEXT
_FORWARD_CONTEXT = ForwardContext(
is_prefill,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
slot_mapping,
context_lens,
block_tables,
)
def reset_forward_context():
global _FORWARD_CONTEXT
_FORWARD_CONTEXT = ForwardContext()
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 2048
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1:
return
d_offset = 0
while d_offset < D:
cur_block_size = min(BLOCK_SIZE, D - d_offset)
key_offsets = idx * key_stride + d_offset + tl.arange(0, BLOCK_SIZE)
value_offsets = idx * value_stride + d_offset + tl.arange(0, BLOCK_SIZE)
cache_offsets = slot * D + d_offset + tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < cur_block_size
key = tl.load(key_ptr + key_offsets, mask=mask, other=0.0)
value = tl.load(value_ptr + value_offsets, mask=mask, other=0.0)
tl.store(k_cache_ptr + cache_offsets, key, mask=mask)
tl.store(v_cache_ptr + cache_offsets, value, mask=mask)
d_offset += BLOCK_SIZE
def store_kvcache(
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](
key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D
)
class Attention(nn.Module):
def __init__(
self,
num_heads: int,
head_dim: int,
scale: float,
num_kv_heads: int,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_forward_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel() and context.slot_mapping is not None:
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.block_tables is not None:
k, v = k_cache, v_cache
o = flash_attn_varlen_func(
q,
k,
v,
max_seqlen_q=context.max_seqlen_q,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k,
cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale,
causal=True,
block_table=context.block_tables,
)
else:
o = flash_attn_with_kvcache(
q.unsqueeze(1),
k_cache,
v_cache,
cache_seqlens=context.context_lens,
block_table=context.block_tables,
softmax_scale=self.scale,
causal=True,
)
return o

View File

@ -0,0 +1,181 @@
import torch
import torch.nn as nn
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2Block, GPT2Model
from .attention import Attention
class GPT2AccelAttention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(
torch.ones((max_positions, max_positions), dtype=torch.bool)
).view(1, 1, max_positions, max_positions),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
scale = (self.head_dim**-0.5) if self.scale_attn_weights else 1.0
self.accel_attn = Attention(
self.num_heads, self.head_dim, scale, self.num_heads
)
def forward(
self,
hidden_states: torch.Tensor,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
past_key_value=None,
**kwargs,
):
if encoder_hidden_states is not None:
raise NotImplementedError("Cross attention not supported in accel mode")
qkv = self.c_attn(hidden_states)
query, key, value = qkv.split(self.split_size, dim=2)
# [B, T, H*D] -> [B, H, T, D]
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# flatten to [B*T, H, D]
bsz, num_heads, seq_len, head_dim = query.shape
q_flat = query.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
k_flat = key.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
v_flat = value.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
# ensure fp16
if q_flat.device.type == "cuda" and q_flat.dtype != torch.float16:
orig_dtype = q_flat.dtype
q_flat = q_flat.to(torch.float16)
k_flat = k_flat.to(torch.float16)
v_flat = v_flat.to(torch.float16)
else:
orig_dtype = q_flat.dtype
o_flat = self.accel_attn(q_flat, k_flat, v_flat) # [B*T, H, D]
if o_flat.dtype != orig_dtype:
o_flat = o_flat.to(orig_dtype)
# Reshape back: [B*T, H, D] -> [B, H, T, D]
attn_output = o_flat.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, None)
if output_attentions:
outputs += (None,)
return outputs
def _split_heads(self, tensor, num_heads, head_dim):
new_shape = tensor.size()[:-1] + (num_heads, head_dim)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, head_dim):
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * head_dim,)
return tensor.view(new_shape)
class GPT2AccelBlock(GPT2Block):
def __init__(self, config, layer_idx=None):
super().__init__(config, layer_idx)
self.attn = GPT2AccelAttention(config, layer_idx)
class GPT2AccelModel(GPT2Model):
def __init__(self, config):
super().__init__(config)
self.h = nn.ModuleList(
[
GPT2AccelBlock(config, layer_idx=i)
for i in range(config.num_hidden_layers)
]
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if inputs_embeds is not None:
hidden_states = inputs_embeds
for block in self.h:
hidden_states = block(hidden_states)[0]
hidden_states = self.ln_f(hidden_states)
if return_dict:
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
)
return (hidden_states,)
else:
return super().forward(
input_ids=input_ids,
past_key_values=None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=False,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

View File

@ -0,0 +1,209 @@
import hashlib
import pickle
from collections import deque
from copy import copy
from typing import Dict, List, Optional, Set
import torch
class KVCacheBlock:
def __init__(self, block_id: int):
self.block_id = block_id
self.ref_cnt = 0
self._block_hash = None
self.token_ids = []
@property
def block_hash(self) -> Optional[bytes]:
return self._block_hash
def update(self, block_hash: bytes, token_ids: List[int]):
self._block_hash = block_hash
self.token_ids = token_ids
def reset(self):
self.ref_cnt = 1
self._block_hash = None
self.token_ids = []
class Seq:
def __init__(self, token_ids: List[int], block_size: int = 256):
self.token_ids = copy(token_ids)
self.last_token = token_ids[-1] if token_ids else 0
self.num_tokens = len(self.token_ids)
self.num_prompt_tokens = len(token_ids)
self.num_cached_tokens = 0
self.block_table: List[int] = []
self.block_size = block_size
def __len__(self):
return self.num_tokens
def __getitem__(self, key):
return self.token_ids[key]
@property
def num_blocks(self):
return (self.num_tokens + self.block_size - 1) // self.block_size
@property
def num_cached_blocks(self):
return self.num_cached_tokens // self.block_size
@property
def last_block_num_tokens(self):
return self.num_tokens - (self.num_blocks - 1) * self.block_size
def get_block_tokens(self, block_idx: int) -> List[int]:
assert 0 <= block_idx < self.num_blocks
start = block_idx * self.block_size
end = start + self.block_size
return self.token_ids[start:end]
def append_token(self, token_id: int):
self.token_ids.append(token_id)
self.last_token = token_id
self.num_tokens += 1
class KVCacheManager:
def __init__(
self,
num_layers: int,
num_heads: int,
head_dim: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
):
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.block_size = block_size
self.num_blocks = num_blocks
self.dtype = dtype
self.blocks: List[KVCacheBlock] = [KVCacheBlock(i) for i in range(num_blocks)]
self.block_hash_to_id: Dict[bytes, int] = {}
self.free_block_ids: deque = deque(range(num_blocks))
self.used_block_ids: Set[int] = set()
device = "cuda" if torch.cuda.is_available() else "cpu"
cache_dtype = torch.float16 if device == "cuda" else dtype
self.kv_cache = torch.empty(
2,
num_layers,
num_blocks,
block_size,
num_heads,
head_dim,
dtype=cache_dtype,
device=device,
)
@classmethod
def compute_block_hash(
cls, token_ids: List[int], parent_hash: Optional[bytes] = None
) -> bytes:
hash_input = []
if parent_hash is not None:
hash_input.append(parent_hash)
hash_input.extend(token_ids)
input_bytes = pickle.dumps(tuple(hash_input), protocol=pickle.HIGHEST_PROTOCOL)
return hashlib.sha256(input_bytes).digest()
def _allocate_block(self, block_id: int) -> KVCacheBlock:
block = self.blocks[block_id]
assert block.ref_cnt == 0
block.reset()
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
return block
def _deallocate_block(self, block_id: int):
assert self.blocks[block_id].ref_cnt == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
def allocate(self, sequence: Seq):
assert not sequence.block_table, "Sequence already has allocated blocks"
parent_hash = None
cache_miss = False
for i in range(sequence.num_blocks):
token_ids = sequence.get_block_tokens(i)
block_hash = (
self.compute_block_hash(token_ids, parent_hash)
if len(token_ids) == self.block_size
else None
)
block_id = self.block_hash_to_id.get(block_hash) if block_hash else None
if block_id is None or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
else:
sequence.num_cached_tokens += self.block_size
if block_id is not None and block_id in self.used_block_ids:
block = self.blocks[block_id]
block.ref_cnt += 1
else:
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
if block_hash is not None:
block.update(block_hash, token_ids)
self.block_hash_to_id[block_hash] = block_id
parent_hash = block_hash
sequence.block_table.append(block_id)
def deallocate(self, sequence: Seq):
for block_id in reversed(sequence.block_table):
block = self.blocks[block_id]
block.ref_cnt -= 1
if block.ref_cnt == 0:
self._deallocate_block(block_id)
sequence.num_cached_tokens = 0
sequence.block_table.clear()
def append_to_seq(self, sequence: Seq):
block_table = sequence.block_table
last_block = self.blocks[block_table[-1]]
if len(sequence) % self.block_size == 1:
assert last_block.block_hash is not None
block_id = self.free_block_ids[0]
self._allocate_block(block_id)
block_table.append(block_id)
elif len(sequence) % self.block_size == 0:
assert last_block.block_hash is None
token_ids = sequence.get_block_tokens(sequence.num_blocks - 1)
parent_hash = (
self.blocks[block_table[-2]].block_hash
if len(block_table) > 1
else None
)
block_hash = self.compute_block_hash(token_ids, parent_hash)
last_block.update(block_hash, token_ids)
self.block_hash_to_id[block_hash] = last_block.block_id
else:
assert last_block.block_hash is None
def remove_seq(self, sequence: Seq):
self.deallocate(sequence)
def wire_kv_cache_to_model(self, model):
layer_id = 0
for module in model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
layer_id += 1

View File

@ -307,7 +307,7 @@ class UnifiedVoice(nn.Module):
start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, types=1,
condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None):
condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None, use_accel=False):
"""
Args:
layers: Number of layers in transformer stack.
@ -409,6 +409,9 @@ class UnifiedVoice(nn.Module):
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
self.use_accel = use_accel
self.accel_engine = None # Will be initialized in post_init_gpt2_config
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(
@ -421,6 +424,38 @@ class UnifiedVoice(nn.Module):
gradient_checkpointing=False,
use_cache=True,
)
if self.use_accel and torch.cuda.is_available():
# Check if flash attention is available
try:
import flash_attn
except ImportError:
raise ImportError("flash_attn is required for acceleration but not installed. Please install from https://github.com/Dao-AILab/flash-attention/releases/")
from indextts.accel import GPT2AccelModel, AccelInferenceEngine
# Create accel model
accel_gpt = GPT2AccelModel(gpt_config)
accel_gpt.load_state_dict(self.gpt.state_dict(), strict=False)
if half:
accel_gpt = accel_gpt.half().cuda()
else:
accel_gpt = accel_gpt.cuda()
accel_gpt.eval()
lm_head_with_norm = nn.Sequential(self.final_norm, self.mel_head)
self.accel_engine = AccelInferenceEngine(
model=accel_gpt,
lm_head=lm_head_with_norm,
num_layers=self.layers,
num_heads=self.heads,
head_dim=self.model_dim // self.heads,
block_size=256,
num_blocks=16, # Reduce to save memory (16*256 = 4096 tokens capacity)
use_cuda_graph=True,
)
print("acceleration engine initialized")
self.inference_model = GPT2InferenceModel(
gpt_config,
self.gpt,
@ -721,12 +756,26 @@ class UnifiedVoice(nn.Module):
min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
output = self.inference_model.generate(inputs,
bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
max_length=max_length, logits_processor=logits_processor,
num_return_sequences=num_return_sequences,
**hf_generate_kwargs)
# Use accel engine if available (single sequence only)
if self.accel_engine is not None and num_return_sequences == 1:
output = self.accel_engine.generate(
inputs, # fake input_ids (all 1s + start_mel_token)
max_new_tokens=max_length - trunc_index,
attention_mask=attention_mask,
temperature=hf_generate_kwargs.get('temperature', 1),
stop_tokens=[self.stop_mel_token],
tts_embeddings=inputs_embeds, # [pad][cond][text] embeddings (87 tokens, NO start_mel_token)
tts_mel_embedding=self.inference_model.embeddings, # mel_embedding layer
tts_text_pos_embedding=self.inference_model.text_pos_embedding, # text_pos_embedding layer
)
else:
output = self.inference_model.generate(inputs,
bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
max_length=max_length, logits_processor=logits_processor,
num_return_sequences=num_return_sequences,
**hf_generate_kwargs)
if isinstance(output, torch.Tensor):
return output[:, trunc_index:], speech_conditioning_latent
# GenerateOutput

View File

@ -38,7 +38,7 @@ import torch.nn.functional as F
class IndexTTS2:
def __init__(
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None,
use_cuda_kernel=None,use_deepspeed=False, use_torch_compile=False
use_cuda_kernel=None,use_deepspeed=False, use_accel=False, use_torch_compile=False
):
"""
Args:
@ -48,6 +48,7 @@ class IndexTTS2:
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
use_deepspeed (bool): whether to use DeepSpeed or not.
use_accel (bool): whether to use acceleration engine for GPT2 or not.
use_torch_compile (bool): whether to use torch.compile for optimization or not.
"""
if device is not None:
@ -76,11 +77,12 @@ class IndexTTS2:
self.model_dir = model_dir
self.dtype = torch.float16 if self.use_fp16 else None
self.stop_mel_token = self.cfg.gpt.stop_mel_token
self.use_accel = use_accel
self.use_torch_compile = use_torch_compile
self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
self.gpt = UnifiedVoice(**self.cfg.gpt)
self.gpt = UnifiedVoice(**self.cfg.gpt, use_accel=self.use_accel)
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
load_checkpoint(self.gpt, self.gpt_path)
self.gpt = self.gpt.to(self.device)
@ -462,7 +464,7 @@ class IndexTTS2:
ref_mel = self.cache_mel
if emo_vector is not None:
weight_vector = torch.tensor(emo_vector).to(self.device)
weight_vector = torch.tensor(emo_vector, device=self.device)
if use_random:
random_index = [random.randint(0, x - 1) for x in self.emo_num]
else:
@ -589,15 +591,16 @@ class IndexTTS2:
# print(f"code len: {code_lens}")
code_lens = []
max_code_len = 0
for code in codes:
if self.stop_mel_token not in code:
code_lens.append(len(code))
code_len = len(code)
else:
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1
code_len = len_ - 1
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0]
code_len = len_[0].item() if len_.numel() > 0 else len(code)
code_lens.append(code_len)
codes = codes[:, :code_len]
max_code_len = max(max_code_len, code_len)
codes = codes[:, :max_code_len]
code_lens = torch.LongTensor(code_lens)
code_lens = code_lens.to(self.device)
if verbose: