feat: achieve inference acceleration for the gpt2 stage
Signed-off-by: storyicon <storyicon@foxmail.com>
This commit is contained in:
parent
bde7d0bdf0
commit
c1ef4148af
9
indextts/accel/__init__.py
Normal file
9
indextts/accel/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
from .accel_engine import AccelInferenceEngine # noqa: F401
|
||||
from .attention import ( # noqa: F401
|
||||
Attention,
|
||||
get_forward_context,
|
||||
reset_forward_context,
|
||||
set_forward_context,
|
||||
)
|
||||
from .gpt2_accel import GPT2AccelAttention, GPT2AccelModel # noqa: F401
|
||||
from .kv_manager import KVCacheManager, Seq # noqa: F401
|
||||
609
indextts/accel/accel_engine.py
Normal file
609
indextts/accel/accel_engine.py
Normal file
@ -0,0 +1,609 @@
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import (
|
||||
ForwardContext,
|
||||
get_forward_context,
|
||||
reset_forward_context,
|
||||
set_forward_context,
|
||||
)
|
||||
from .kv_manager import KVCacheManager, Seq
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.compile
|
||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
sample_tokens = probs.div_(
|
||||
torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)
|
||||
).argmax(dim=-1)
|
||||
return sample_tokens
|
||||
|
||||
|
||||
class AccelInferenceEngine:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
lm_head,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
block_size: int = 256,
|
||||
num_blocks: int = 128,
|
||||
use_cuda_graph: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: The GPT transformer model (should have accel attention)
|
||||
lm_head: Language model head for generating logits
|
||||
num_layers: Number of transformer layers
|
||||
num_heads: Number of attention heads
|
||||
head_dim: Dimension per head
|
||||
block_size: KV cache block size
|
||||
num_blocks: Total number of KV cache blocks
|
||||
use_cuda_graph: Whether to use CUDA Graph for decode optimization
|
||||
"""
|
||||
self.model = model
|
||||
self.lm_head = lm_head
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
|
||||
model_dtype = next(model.parameters()).dtype
|
||||
self.hidden_size = (
|
||||
model.config.hidden_size
|
||||
if hasattr(model, "config")
|
||||
else head_dim * num_heads
|
||||
)
|
||||
self.kv_manager = KVCacheManager(
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
block_size=block_size,
|
||||
num_blocks=num_blocks,
|
||||
dtype=torch.float16, # Force fp16 for FlashAttention
|
||||
)
|
||||
self.kv_manager.wire_kv_cache_to_model(model)
|
||||
self.sampler = Sampler()
|
||||
self.current_sequences = []
|
||||
self.graphs = {}
|
||||
self.graph_vars = None
|
||||
self.graph_pool = None
|
||||
self.graph_captured = False
|
||||
|
||||
def _prepare_prefill(self, requests: List[Seq]):
|
||||
input_ids = []
|
||||
positions = []
|
||||
cu_seqlens_q = [0]
|
||||
cu_seqlens_k = [0]
|
||||
max_seqlen_q = 0
|
||||
max_seqlen_k = 0
|
||||
slot_mapping = []
|
||||
|
||||
for req in requests:
|
||||
seqlen = len(req)
|
||||
input_ids.extend(req[req.num_cached_tokens :])
|
||||
positions.extend(list(range(req.num_cached_tokens, seqlen)))
|
||||
seqlen_q = seqlen - req.num_cached_tokens
|
||||
seqlen_k = seqlen
|
||||
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
|
||||
if req.block_table:
|
||||
for i in range(req.num_cached_blocks, req.num_blocks):
|
||||
block_id = req.block_table[i]
|
||||
start = block_id * self.block_size
|
||||
if i != req.num_blocks - 1:
|
||||
end = start + self.block_size
|
||||
else:
|
||||
end = start + req.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
|
||||
non_blocking=True
|
||||
)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(
|
||||
non_blocking=True
|
||||
)
|
||||
cu_seqlens_q = torch.tensor(
|
||||
cu_seqlens_q, dtype=torch.int32, pin_memory=True
|
||||
).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor(
|
||||
cu_seqlens_k, dtype=torch.int32, pin_memory=True
|
||||
).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(
|
||||
slot_mapping, dtype=torch.int32, pin_memory=True
|
||||
).cuda(non_blocking=True)
|
||||
|
||||
block_tables = None
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]:
|
||||
max_len = max(len(req.block_table) for req in requests)
|
||||
block_tables_list = []
|
||||
for req in requests:
|
||||
table = req.block_table + [-1] * (max_len - len(req.block_table))
|
||||
block_tables_list.append(table)
|
||||
block_tables = torch.tensor(
|
||||
block_tables_list, dtype=torch.int32, pin_memory=True
|
||||
).cuda(non_blocking=True)
|
||||
|
||||
set_forward_context(
|
||||
True,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
slot_mapping,
|
||||
None,
|
||||
block_tables,
|
||||
)
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
def _prepare_decode(self, requests: List[Seq]):
|
||||
if not requests:
|
||||
raise RuntimeError("FATAL: No requests provided to _prepare_decode!")
|
||||
|
||||
input_ids = []
|
||||
positions = []
|
||||
slot_mapping = []
|
||||
context_lens = []
|
||||
|
||||
for req in requests:
|
||||
input_ids.append(req.last_token)
|
||||
|
||||
pos = len(req) - 1
|
||||
if hasattr(self, "_tts_mode") and self._tts_mode:
|
||||
pos = pos - (self._tts_prompt_len - 1)
|
||||
positions.append(pos)
|
||||
|
||||
context_lens.append(len(req))
|
||||
slot_mapping.append(
|
||||
req.block_table[-1] * self.block_size + req.last_block_num_tokens - 1
|
||||
)
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(
|
||||
non_blocking=True
|
||||
)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(
|
||||
non_blocking=True
|
||||
)
|
||||
slot_mapping = torch.tensor(
|
||||
slot_mapping, dtype=torch.int32, pin_memory=True
|
||||
).cuda(non_blocking=True)
|
||||
context_lens = torch.tensor(
|
||||
context_lens, dtype=torch.int32, pin_memory=True
|
||||
).cuda(non_blocking=True)
|
||||
|
||||
max_len = max(len(req.block_table) for req in requests)
|
||||
block_tables_list = []
|
||||
for req in requests:
|
||||
table = req.block_table + [-1] * (max_len - len(req.block_table))
|
||||
block_tables_list.append(table)
|
||||
block_tables = torch.tensor(
|
||||
block_tables_list, dtype=torch.int32, pin_memory=True
|
||||
).cuda(non_blocking=True)
|
||||
|
||||
assert block_tables.dim() == 2, (
|
||||
f"block_tables must be 2D, got shape {block_tables.shape}"
|
||||
)
|
||||
assert block_tables.size(0) == len(requests), (
|
||||
f"block_tables batch size mismatch: {block_tables.size(0)} vs {len(requests)}"
|
||||
)
|
||||
|
||||
set_forward_context(
|
||||
False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
def _prepare_sample(self, requests: List[Seq], temperature: float):
|
||||
temperatures = [temperature] * len(requests)
|
||||
temperatures = torch.tensor(
|
||||
temperatures, dtype=torch.float32, pin_memory=True
|
||||
).cuda(non_blocking=True)
|
||||
return temperatures
|
||||
|
||||
@torch.inference_mode()
|
||||
def _capture_cuda_graphs(self, tts_mel_embedding=None, tts_text_pos_embedding=None):
|
||||
print("Capturing CUDA graphs for decode optimization...")
|
||||
max_bs = 8 # Support up to batch size 8
|
||||
max_num_blocks = (2048 + self.block_size - 1) // self.block_size
|
||||
model_dtype = next(self.model.parameters()).dtype
|
||||
input_ids = torch.ones(max_bs, dtype=torch.int64, device="cuda") * 8192
|
||||
positions = torch.ones(max_bs, dtype=torch.int64, device="cuda")
|
||||
slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
|
||||
context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cuda")
|
||||
block_tables = torch.zeros(
|
||||
max_bs, max_num_blocks, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
outputs = torch.zeros(
|
||||
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
|
||||
)
|
||||
inputs_embeds_buffer = torch.zeros(
|
||||
max_bs, self.hidden_size, dtype=model_dtype, device="cuda"
|
||||
)
|
||||
|
||||
self.graph_bs = [1]
|
||||
|
||||
use_tts = tts_mel_embedding is not None and tts_text_pos_embedding is not None
|
||||
|
||||
for bs in reversed(self.graph_bs):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
slot_mapping[:bs] = torch.arange(bs, dtype=torch.int32, device="cuda")
|
||||
context_lens[:bs] = bs + 1
|
||||
block_tables[:bs, 0] = 0
|
||||
|
||||
set_forward_context(
|
||||
False,
|
||||
slot_mapping=slot_mapping[:bs],
|
||||
context_lens=context_lens[:bs],
|
||||
block_tables=block_tables[:bs],
|
||||
)
|
||||
|
||||
# warmup
|
||||
if use_tts:
|
||||
assert tts_mel_embedding is not None
|
||||
assert tts_text_pos_embedding is not None
|
||||
emb = tts_mel_embedding(input_ids[:bs])
|
||||
pos_clamped = torch.clamp(positions[:bs], min=0)
|
||||
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
|
||||
inputs_embeds_buffer[:bs] = emb + pos_emb
|
||||
out = self.model(
|
||||
inputs_embeds=inputs_embeds_buffer[:bs].unsqueeze(1),
|
||||
return_dict=True,
|
||||
).last_hidden_state
|
||||
else:
|
||||
out = self.model(
|
||||
input_ids=input_ids[:bs].unsqueeze(1), return_dict=True
|
||||
).last_hidden_state
|
||||
outputs[:bs] = out.squeeze(1) if out.dim() == 3 else out
|
||||
|
||||
with torch.cuda.graph(graph, self.graph_pool):
|
||||
if use_tts:
|
||||
assert tts_mel_embedding is not None
|
||||
assert tts_text_pos_embedding is not None
|
||||
emb = tts_mel_embedding(input_ids[:bs])
|
||||
pos_clamped = torch.clamp(positions[:bs], min=0)
|
||||
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
|
||||
inputs_embeds_buffer[:bs] = emb + pos_emb
|
||||
out = self.model(
|
||||
inputs_embeds=inputs_embeds_buffer[:bs].unsqueeze(1),
|
||||
return_dict=True,
|
||||
).last_hidden_state
|
||||
else:
|
||||
out = self.model(
|
||||
input_ids=input_ids[:bs].unsqueeze(1), return_dict=True
|
||||
).last_hidden_state
|
||||
outputs[:bs] = out.squeeze(1) if out.dim() == 3 else out
|
||||
|
||||
if self.graph_pool is None:
|
||||
self.graph_pool = graph.pool()
|
||||
|
||||
self.graphs[bs] = graph
|
||||
torch.cuda.synchronize()
|
||||
reset_forward_context()
|
||||
|
||||
self.graph_vars = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"slot_mapping": slot_mapping,
|
||||
"context_lens": context_lens,
|
||||
"block_tables": block_tables,
|
||||
"outputs": outputs,
|
||||
"inputs_embeds": inputs_embeds_buffer,
|
||||
}
|
||||
print(f"CUDA graphs captured for batch sizes: {self.graph_bs}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def _run_decode_with_graph(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
context: ForwardContext,
|
||||
tts_mel_embedding: Optional[torch.nn.Module] = None,
|
||||
tts_text_pos_embedding: Optional[torch.nn.Module] = None,
|
||||
) -> torch.Tensor:
|
||||
bs = input_ids.size(0)
|
||||
use_tts_embedding = hasattr(self, "_tts_mode") and self._tts_mode
|
||||
|
||||
if not self.use_cuda_graph or not self.graphs:
|
||||
if use_tts_embedding:
|
||||
assert tts_mel_embedding is not None
|
||||
assert tts_text_pos_embedding is not None
|
||||
inputs_embeds = tts_mel_embedding(input_ids)
|
||||
pos_clamped = torch.clamp(positions, min=0)
|
||||
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
|
||||
inputs_embeds = inputs_embeds + pos_emb
|
||||
out = self.model(
|
||||
inputs_embeds=inputs_embeds.unsqueeze(1), return_dict=True
|
||||
).last_hidden_state
|
||||
else:
|
||||
out = self.model(
|
||||
input_ids=input_ids.unsqueeze(1), return_dict=True
|
||||
).last_hidden_state
|
||||
return out.squeeze(1) if out.dim() == 3 else out
|
||||
|
||||
graph_bs = next((x for x in self.graph_bs if x >= bs), None)
|
||||
if graph_bs is None:
|
||||
if use_tts_embedding:
|
||||
assert tts_mel_embedding is not None
|
||||
assert tts_text_pos_embedding is not None
|
||||
inputs_embeds = tts_mel_embedding(input_ids)
|
||||
pos_clamped = torch.clamp(positions, min=0)
|
||||
pos_emb = tts_text_pos_embedding.emb(pos_clamped)
|
||||
inputs_embeds = inputs_embeds + pos_emb
|
||||
out = self.model(
|
||||
inputs_embeds=inputs_embeds.unsqueeze(1), return_dict=True
|
||||
).last_hidden_state
|
||||
else:
|
||||
out = self.model(
|
||||
input_ids=input_ids.unsqueeze(1), return_dict=True
|
||||
).last_hidden_state
|
||||
return out.squeeze(1) if out.dim() == 3 else out
|
||||
|
||||
graph = self.graphs[graph_bs]
|
||||
graph_vars = self.graph_vars
|
||||
|
||||
if graph_vars is None:
|
||||
raise RuntimeError("Graph variables not initialized")
|
||||
|
||||
set_forward_context(
|
||||
False,
|
||||
slot_mapping=graph_vars["slot_mapping"][:graph_bs],
|
||||
context_lens=graph_vars["context_lens"][:graph_bs],
|
||||
block_tables=graph_vars["block_tables"][:graph_bs],
|
||||
)
|
||||
|
||||
graph_vars["input_ids"][:bs] = input_ids
|
||||
graph_vars["positions"][:bs] = positions
|
||||
graph_vars["slot_mapping"].fill_(-1)
|
||||
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
||||
graph_vars["context_lens"].zero_()
|
||||
graph_vars["context_lens"][:bs] = context.context_lens
|
||||
graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = (
|
||||
context.block_tables
|
||||
)
|
||||
graph.replay()
|
||||
|
||||
return graph_vars["outputs"][:bs]
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
max_new_tokens: int = 100,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 50,
|
||||
top_p: float = 1.0,
|
||||
stop_tokens: Optional[List[int]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
tts_embeddings: Optional[
|
||||
torch.Tensor
|
||||
] = None, # TTS: [pad][cond][text] embeddings (87 tokens, NO start_mel)
|
||||
tts_mel_embedding: Optional[torch.nn.Module] = None, # TTS: mel_embedding layer
|
||||
tts_text_pos_embedding: Optional[
|
||||
torch.nn.Module
|
||||
] = None, # TTS: text_pos_embedding layer
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate tokens.
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs [batch_size, seq_len]
|
||||
max_new_tokens: Maximum number of tokens to generate
|
||||
temperature: Sampling temperature
|
||||
top_k: Top-k sampling
|
||||
top_p: Nucleus sampling threshold
|
||||
stop_tokens: List of token IDs that stop generation
|
||||
|
||||
Returns:
|
||||
Generated token IDs [batch_size, total_len]
|
||||
"""
|
||||
batch_size = input_ids.size(0)
|
||||
device = input_ids.device
|
||||
|
||||
self._tts_mode = tts_embeddings is not None
|
||||
self._tts_prompt_len = input_ids.size(1) if self._tts_mode else 0
|
||||
|
||||
if self.use_cuda_graph and not self.graph_captured:
|
||||
print(
|
||||
f"[CAPTURE] use_cuda_graph={self.use_cuda_graph}, graph_captured={self.graph_captured}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
self._capture_cuda_graphs(
|
||||
tts_mel_embedding=tts_mel_embedding,
|
||||
tts_text_pos_embedding=tts_text_pos_embedding,
|
||||
)
|
||||
self.graph_captured = True
|
||||
print(
|
||||
f"[CAPTURE] Completed! graphs={list(self.graphs.keys())}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if tts_embeddings is not None:
|
||||
actual_seq_len = tts_embeddings.size(1) + 1 # embeddings + start_mel_token
|
||||
pass
|
||||
else:
|
||||
actual_seq_len = input_ids.size(1)
|
||||
|
||||
sequences = []
|
||||
for i in range(batch_size):
|
||||
token_ids = [1] * actual_seq_len
|
||||
if tts_embeddings is not None and actual_seq_len > 0:
|
||||
token_ids[-1] = input_ids[i, -1].item() if input_ids.size(1) > 0 else 1
|
||||
else:
|
||||
token_ids = input_ids[i].tolist()
|
||||
req = Seq(token_ids)
|
||||
self.kv_manager.allocate(req)
|
||||
sequences.append(req)
|
||||
|
||||
self.current_sequences = sequences
|
||||
|
||||
# Prefill phase
|
||||
prefill_ids, prefill_pos = self._prepare_prefill(sequences)
|
||||
|
||||
if prefill_ids.dim() == 1:
|
||||
prefill_ids = prefill_ids.unsqueeze(
|
||||
0
|
||||
) # [total_tokens] -> [1, total_tokens]
|
||||
if prefill_pos.dim() == 1:
|
||||
prefill_pos = prefill_pos.unsqueeze(
|
||||
0
|
||||
) # [total_tokens] -> [1, total_tokens]
|
||||
|
||||
if (
|
||||
tts_embeddings is not None
|
||||
and tts_mel_embedding is not None
|
||||
and tts_text_pos_embedding is not None
|
||||
):
|
||||
start_token_id = input_ids[0, -1] if input_ids.size(1) > 0 else 8192
|
||||
|
||||
start_emb = tts_mel_embedding(
|
||||
torch.tensor([[start_token_id]], device="cuda")
|
||||
) # [1, 1, hidden_dim]
|
||||
|
||||
start_emb = start_emb + tts_text_pos_embedding(start_emb)
|
||||
|
||||
full_embeddings = torch.cat(
|
||||
[tts_embeddings, start_emb], dim=1
|
||||
) # [1, 88, hidden_dim]
|
||||
|
||||
model_dtype = next(self.model.parameters()).dtype
|
||||
if full_embeddings.dtype != model_dtype:
|
||||
full_embeddings = full_embeddings.to(model_dtype)
|
||||
|
||||
hidden_states = self.model(
|
||||
inputs_embeds=full_embeddings, return_dict=True
|
||||
).last_hidden_state
|
||||
|
||||
else:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
|
||||
).last_hidden_state
|
||||
|
||||
reset_forward_context()
|
||||
|
||||
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
|
||||
|
||||
if self.lm_head is not None:
|
||||
if last_hidden.dtype != next(self.lm_head.parameters()).dtype:
|
||||
last_hidden = last_hidden.to(next(self.lm_head.parameters()).dtype)
|
||||
logits = self.lm_head(last_hidden) # [batch_size, vocab_size]
|
||||
else:
|
||||
logits = self.model.compute_logits(last_hidden) # [batch_size, vocab_size]
|
||||
|
||||
temperatures = self._prepare_sample(sequences, temperature)
|
||||
if temperature > 0:
|
||||
first_token = self.sampler(logits, temperatures)
|
||||
else:
|
||||
first_token = torch.argmax(logits, dim=-1)
|
||||
|
||||
first_token_list = first_token.tolist()
|
||||
|
||||
generated_tokens = [[] for _ in range(batch_size)]
|
||||
hit_stop_on_first = False
|
||||
|
||||
for i, token_id in enumerate(first_token_list):
|
||||
if stop_tokens and token_id in stop_tokens:
|
||||
hit_stop_on_first = True
|
||||
else:
|
||||
generated_tokens[i].append(token_id)
|
||||
|
||||
if hit_stop_on_first:
|
||||
for req in sequences:
|
||||
self.kv_manager.remove_seq(req)
|
||||
self.current_sequences = []
|
||||
|
||||
output_ids = []
|
||||
for i in range(batch_size):
|
||||
full_sequence = input_ids[i].tolist() + generated_tokens[i]
|
||||
output_ids.append(full_sequence)
|
||||
|
||||
output = torch.tensor(output_ids, dtype=torch.long, device=device)
|
||||
return output
|
||||
|
||||
if not hit_stop_on_first:
|
||||
for i, req in enumerate(sequences):
|
||||
req.append_token(first_token_list[i])
|
||||
self.kv_manager.append_to_seq(req)
|
||||
|
||||
remaining_tokens = max_new_tokens - 1
|
||||
|
||||
for step in range(remaining_tokens):
|
||||
decode_ids, decode_pos = self._prepare_decode(sequences)
|
||||
|
||||
# Forward pass
|
||||
if batch_size > 8:
|
||||
raise RuntimeError(
|
||||
f"FATAL: batch_size={batch_size} exceeds CUDA Graph limit (8)!"
|
||||
)
|
||||
|
||||
context = get_forward_context()
|
||||
hidden_states = self._run_decode_with_graph(
|
||||
decode_ids,
|
||||
decode_pos,
|
||||
context,
|
||||
tts_mel_embedding=tts_mel_embedding,
|
||||
tts_text_pos_embedding=tts_text_pos_embedding,
|
||||
)
|
||||
|
||||
# Get logits
|
||||
if self.lm_head is not None:
|
||||
logits = self.lm_head(hidden_states) # [batch_size, vocab_size]
|
||||
else:
|
||||
logits = self.model.compute_logits(
|
||||
hidden_states
|
||||
) # [batch_size, vocab_size]
|
||||
|
||||
reset_forward_context()
|
||||
|
||||
temperatures = self._prepare_sample(sequences, temperature)
|
||||
if temperature > 0:
|
||||
next_token = self.sampler(logits, temperatures)
|
||||
else:
|
||||
next_token = torch.argmax(logits, dim=-1)
|
||||
next_token_list = next_token.tolist()
|
||||
|
||||
should_stop = False
|
||||
for i, token_id in enumerate(next_token_list):
|
||||
if stop_tokens and token_id in stop_tokens:
|
||||
should_stop = True
|
||||
else:
|
||||
sequences[i].append_token(token_id)
|
||||
self.kv_manager.append_to_seq(sequences[i])
|
||||
generated_tokens[i].append(token_id)
|
||||
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
for req in sequences:
|
||||
self.kv_manager.remove_seq(req)
|
||||
self.current_sequences = []
|
||||
|
||||
output_ids = []
|
||||
for i in range(batch_size):
|
||||
initial_tokens = sequences[i].token_ids[: sequences[i].num_prompt_tokens]
|
||||
full_sequence = initial_tokens + generated_tokens[i]
|
||||
output_ids.append(full_sequence)
|
||||
|
||||
output = torch.tensor(output_ids, dtype=torch.long, device=device)
|
||||
|
||||
assert output.size(0) == batch_size, (
|
||||
f"Output batch size mismatch: {output.size(0)} != {batch_size}"
|
||||
)
|
||||
|
||||
return output
|
||||
154
indextts/accel/attention.py
Normal file
154
indextts/accel/attention.py
Normal file
@ -0,0 +1,154 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from torch import nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
is_prefill: bool = False
|
||||
cu_seqlens_q: torch.Tensor | None = None
|
||||
cu_seqlens_k: torch.Tensor | None = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
slot_mapping: torch.Tensor | None = None
|
||||
context_lens: torch.Tensor | None = None
|
||||
block_tables: torch.Tensor | None = None
|
||||
|
||||
|
||||
_FORWARD_CONTEXT = ForwardContext()
|
||||
|
||||
|
||||
def get_forward_context():
|
||||
return _FORWARD_CONTEXT
|
||||
|
||||
|
||||
def set_forward_context(
|
||||
is_prefill,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
max_seqlen_q=0,
|
||||
max_seqlen_k=0,
|
||||
slot_mapping=None,
|
||||
context_lens=None,
|
||||
block_tables=None,
|
||||
):
|
||||
global _FORWARD_CONTEXT
|
||||
_FORWARD_CONTEXT = ForwardContext(
|
||||
is_prefill,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
slot_mapping,
|
||||
context_lens,
|
||||
block_tables,
|
||||
)
|
||||
|
||||
|
||||
def reset_forward_context():
|
||||
global _FORWARD_CONTEXT
|
||||
_FORWARD_CONTEXT = ForwardContext()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def store_kvcache_kernel(
|
||||
key_ptr,
|
||||
key_stride,
|
||||
value_ptr,
|
||||
value_stride,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
slot_mapping_ptr,
|
||||
D: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 2048
|
||||
idx = tl.program_id(0)
|
||||
slot = tl.load(slot_mapping_ptr + idx)
|
||||
if slot == -1:
|
||||
return
|
||||
d_offset = 0
|
||||
while d_offset < D:
|
||||
cur_block_size = min(BLOCK_SIZE, D - d_offset)
|
||||
key_offsets = idx * key_stride + d_offset + tl.arange(0, BLOCK_SIZE)
|
||||
value_offsets = idx * value_stride + d_offset + tl.arange(0, BLOCK_SIZE)
|
||||
cache_offsets = slot * D + d_offset + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
mask = tl.arange(0, BLOCK_SIZE) < cur_block_size
|
||||
key = tl.load(key_ptr + key_offsets, mask=mask, other=0.0)
|
||||
value = tl.load(value_ptr + value_offsets, mask=mask, other=0.0)
|
||||
tl.store(k_cache_ptr + cache_offsets, key, mask=mask)
|
||||
tl.store(v_cache_ptr + cache_offsets, value, mask=mask)
|
||||
|
||||
d_offset += BLOCK_SIZE
|
||||
|
||||
|
||||
def store_kvcache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
N, num_heads, head_dim = key.shape
|
||||
D = num_heads * head_dim
|
||||
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
||||
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
||||
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
||||
assert slot_mapping.numel() == N
|
||||
store_kvcache_kernel[(N,)](
|
||||
key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.k_cache = self.v_cache = torch.tensor([])
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
context = get_forward_context()
|
||||
k_cache, v_cache = self.k_cache, self.v_cache
|
||||
|
||||
if k_cache.numel() and v_cache.numel() and context.slot_mapping is not None:
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
|
||||
if context.is_prefill:
|
||||
if context.block_tables is not None:
|
||||
k, v = k_cache, v_cache
|
||||
o = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
max_seqlen_q=context.max_seqlen_q,
|
||||
cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k,
|
||||
cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
block_table=context.block_tables,
|
||||
)
|
||||
else:
|
||||
o = flash_attn_with_kvcache(
|
||||
q.unsqueeze(1),
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_seqlens=context.context_lens,
|
||||
block_table=context.block_tables,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
return o
|
||||
181
indextts/accel/gpt2_accel.py
Normal file
181
indextts/accel/gpt2_accel.py
Normal file
@ -0,0 +1,181 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2Block, GPT2Model
|
||||
|
||||
from .attention import Attention
|
||||
|
||||
|
||||
class GPT2AccelAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(
|
||||
torch.ones((max_positions, max_positions), dtype=torch.bool)
|
||||
).view(1, 1, max_positions, max_positions),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.split_size = self.embed_dim
|
||||
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
self.scale_attn_weights = config.scale_attn_weights
|
||||
|
||||
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
||||
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
scale = (self.head_dim**-0.5) if self.scale_attn_weights else 1.0
|
||||
self.accel_attn = Attention(
|
||||
self.num_heads, self.head_dim, scale, self.num_heads
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
past_key_value=None,
|
||||
**kwargs,
|
||||
):
|
||||
if encoder_hidden_states is not None:
|
||||
raise NotImplementedError("Cross attention not supported in accel mode")
|
||||
|
||||
qkv = self.c_attn(hidden_states)
|
||||
query, key, value = qkv.split(self.split_size, dim=2)
|
||||
|
||||
# [B, T, H*D] -> [B, H, T, D]
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
# flatten to [B*T, H, D]
|
||||
bsz, num_heads, seq_len, head_dim = query.shape
|
||||
q_flat = query.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
|
||||
k_flat = key.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
|
||||
v_flat = value.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
|
||||
|
||||
# ensure fp16
|
||||
if q_flat.device.type == "cuda" and q_flat.dtype != torch.float16:
|
||||
orig_dtype = q_flat.dtype
|
||||
q_flat = q_flat.to(torch.float16)
|
||||
k_flat = k_flat.to(torch.float16)
|
||||
v_flat = v_flat.to(torch.float16)
|
||||
else:
|
||||
orig_dtype = q_flat.dtype
|
||||
|
||||
o_flat = self.accel_attn(q_flat, k_flat, v_flat) # [B*T, H, D]
|
||||
|
||||
if o_flat.dtype != orig_dtype:
|
||||
o_flat = o_flat.to(orig_dtype)
|
||||
|
||||
# Reshape back: [B*T, H, D] -> [B, H, T, D]
|
||||
attn_output = o_flat.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, None)
|
||||
if output_attentions:
|
||||
outputs += (None,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _split_heads(self, tensor, num_heads, head_dim):
|
||||
new_shape = tensor.size()[:-1] + (num_heads, head_dim)
|
||||
tensor = tensor.view(new_shape)
|
||||
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def _merge_heads(self, tensor, num_heads, head_dim):
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
new_shape = tensor.size()[:-2] + (num_heads * head_dim,)
|
||||
return tensor.view(new_shape)
|
||||
|
||||
|
||||
class GPT2AccelBlock(GPT2Block):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__(config, layer_idx)
|
||||
self.attn = GPT2AccelAttention(config, layer_idx)
|
||||
|
||||
|
||||
class GPT2AccelModel(GPT2Model):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
GPT2AccelBlock(config, layer_idx=i)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for block in self.h:
|
||||
hidden_states = block(hidden_states)[0]
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
if return_dict:
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
return (hidden_states,)
|
||||
else:
|
||||
return super().forward(
|
||||
input_ids=input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=False,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
209
indextts/accel/kv_manager.py
Normal file
209
indextts/accel/kv_manager.py
Normal file
@ -0,0 +1,209 @@
|
||||
import hashlib
|
||||
import pickle
|
||||
from collections import deque
|
||||
from copy import copy
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class KVCacheBlock:
|
||||
def __init__(self, block_id: int):
|
||||
self.block_id = block_id
|
||||
self.ref_cnt = 0
|
||||
self._block_hash = None
|
||||
self.token_ids = []
|
||||
|
||||
@property
|
||||
def block_hash(self) -> Optional[bytes]:
|
||||
return self._block_hash
|
||||
|
||||
def update(self, block_hash: bytes, token_ids: List[int]):
|
||||
self._block_hash = block_hash
|
||||
self.token_ids = token_ids
|
||||
|
||||
def reset(self):
|
||||
self.ref_cnt = 1
|
||||
self._block_hash = None
|
||||
self.token_ids = []
|
||||
|
||||
|
||||
class Seq:
|
||||
def __init__(self, token_ids: List[int], block_size: int = 256):
|
||||
self.token_ids = copy(token_ids)
|
||||
self.last_token = token_ids[-1] if token_ids else 0
|
||||
self.num_tokens = len(self.token_ids)
|
||||
self.num_prompt_tokens = len(token_ids)
|
||||
self.num_cached_tokens = 0
|
||||
self.block_table: List[int] = []
|
||||
self.block_size = block_size
|
||||
|
||||
def __len__(self):
|
||||
return self.num_tokens
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.token_ids[key]
|
||||
|
||||
@property
|
||||
def num_blocks(self):
|
||||
return (self.num_tokens + self.block_size - 1) // self.block_size
|
||||
|
||||
@property
|
||||
def num_cached_blocks(self):
|
||||
return self.num_cached_tokens // self.block_size
|
||||
|
||||
@property
|
||||
def last_block_num_tokens(self):
|
||||
return self.num_tokens - (self.num_blocks - 1) * self.block_size
|
||||
|
||||
def get_block_tokens(self, block_idx: int) -> List[int]:
|
||||
assert 0 <= block_idx < self.num_blocks
|
||||
start = block_idx * self.block_size
|
||||
end = start + self.block_size
|
||||
return self.token_ids[start:end]
|
||||
|
||||
def append_token(self, token_id: int):
|
||||
self.token_ids.append(token_id)
|
||||
self.last_token = token_id
|
||||
self.num_tokens += 1
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
self.dtype = dtype
|
||||
|
||||
self.blocks: List[KVCacheBlock] = [KVCacheBlock(i) for i in range(num_blocks)]
|
||||
self.block_hash_to_id: Dict[bytes, int] = {}
|
||||
self.free_block_ids: deque = deque(range(num_blocks))
|
||||
self.used_block_ids: Set[int] = set()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
cache_dtype = torch.float16 if device == "cuda" else dtype
|
||||
self.kv_cache = torch.empty(
|
||||
2,
|
||||
num_layers,
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
dtype=cache_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def compute_block_hash(
|
||||
cls, token_ids: List[int], parent_hash: Optional[bytes] = None
|
||||
) -> bytes:
|
||||
hash_input = []
|
||||
if parent_hash is not None:
|
||||
hash_input.append(parent_hash)
|
||||
hash_input.extend(token_ids)
|
||||
input_bytes = pickle.dumps(tuple(hash_input), protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return hashlib.sha256(input_bytes).digest()
|
||||
|
||||
def _allocate_block(self, block_id: int) -> KVCacheBlock:
|
||||
block = self.blocks[block_id]
|
||||
assert block.ref_cnt == 0
|
||||
block.reset()
|
||||
self.free_block_ids.remove(block_id)
|
||||
self.used_block_ids.add(block_id)
|
||||
return block
|
||||
|
||||
def _deallocate_block(self, block_id: int):
|
||||
assert self.blocks[block_id].ref_cnt == 0
|
||||
self.used_block_ids.remove(block_id)
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def allocate(self, sequence: Seq):
|
||||
assert not sequence.block_table, "Sequence already has allocated blocks"
|
||||
|
||||
parent_hash = None
|
||||
cache_miss = False
|
||||
|
||||
for i in range(sequence.num_blocks):
|
||||
token_ids = sequence.get_block_tokens(i)
|
||||
block_hash = (
|
||||
self.compute_block_hash(token_ids, parent_hash)
|
||||
if len(token_ids) == self.block_size
|
||||
else None
|
||||
)
|
||||
block_id = self.block_hash_to_id.get(block_hash) if block_hash else None
|
||||
|
||||
if block_id is None or self.blocks[block_id].token_ids != token_ids:
|
||||
cache_miss = True
|
||||
|
||||
if cache_miss:
|
||||
block_id = self.free_block_ids[0]
|
||||
block = self._allocate_block(block_id)
|
||||
else:
|
||||
sequence.num_cached_tokens += self.block_size
|
||||
if block_id is not None and block_id in self.used_block_ids:
|
||||
block = self.blocks[block_id]
|
||||
block.ref_cnt += 1
|
||||
else:
|
||||
block_id = self.free_block_ids[0]
|
||||
block = self._allocate_block(block_id)
|
||||
|
||||
if block_hash is not None:
|
||||
block.update(block_hash, token_ids)
|
||||
self.block_hash_to_id[block_hash] = block_id
|
||||
parent_hash = block_hash
|
||||
|
||||
sequence.block_table.append(block_id)
|
||||
|
||||
def deallocate(self, sequence: Seq):
|
||||
for block_id in reversed(sequence.block_table):
|
||||
block = self.blocks[block_id]
|
||||
block.ref_cnt -= 1
|
||||
if block.ref_cnt == 0:
|
||||
self._deallocate_block(block_id)
|
||||
|
||||
sequence.num_cached_tokens = 0
|
||||
sequence.block_table.clear()
|
||||
|
||||
def append_to_seq(self, sequence: Seq):
|
||||
block_table = sequence.block_table
|
||||
last_block = self.blocks[block_table[-1]]
|
||||
|
||||
if len(sequence) % self.block_size == 1:
|
||||
assert last_block.block_hash is not None
|
||||
block_id = self.free_block_ids[0]
|
||||
self._allocate_block(block_id)
|
||||
block_table.append(block_id)
|
||||
elif len(sequence) % self.block_size == 0:
|
||||
assert last_block.block_hash is None
|
||||
token_ids = sequence.get_block_tokens(sequence.num_blocks - 1)
|
||||
parent_hash = (
|
||||
self.blocks[block_table[-2]].block_hash
|
||||
if len(block_table) > 1
|
||||
else None
|
||||
)
|
||||
block_hash = self.compute_block_hash(token_ids, parent_hash)
|
||||
last_block.update(block_hash, token_ids)
|
||||
self.block_hash_to_id[block_hash] = last_block.block_id
|
||||
else:
|
||||
assert last_block.block_hash is None
|
||||
|
||||
def remove_seq(self, sequence: Seq):
|
||||
self.deallocate(sequence)
|
||||
|
||||
def wire_kv_cache_to_model(self, model):
|
||||
layer_id = 0
|
||||
for module in model.modules():
|
||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||
module.k_cache = self.kv_cache[0, layer_id]
|
||||
module.v_cache = self.kv_cache[1, layer_id]
|
||||
layer_id += 1
|
||||
@ -307,7 +307,7 @@ class UnifiedVoice(nn.Module):
|
||||
start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
|
||||
train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||
checkpointing=True, types=1,
|
||||
condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None):
|
||||
condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None, use_accel=False):
|
||||
"""
|
||||
Args:
|
||||
layers: Number of layers in transformer stack.
|
||||
@ -409,6 +409,9 @@ class UnifiedVoice(nn.Module):
|
||||
for module in embeddings:
|
||||
module.weight.data.normal_(mean=0.0, std=.02)
|
||||
|
||||
self.use_accel = use_accel
|
||||
self.accel_engine = None # Will be initialized in post_init_gpt2_config
|
||||
|
||||
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
|
||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||
gpt_config = GPT2Config(
|
||||
@ -421,6 +424,38 @@ class UnifiedVoice(nn.Module):
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if self.use_accel and torch.cuda.is_available():
|
||||
# Check if flash attention is available
|
||||
try:
|
||||
import flash_attn
|
||||
except ImportError:
|
||||
raise ImportError("flash_attn is required for acceleration but not installed. Please install from https://github.com/Dao-AILab/flash-attention/releases/")
|
||||
|
||||
from indextts.accel import GPT2AccelModel, AccelInferenceEngine
|
||||
|
||||
# Create accel model
|
||||
accel_gpt = GPT2AccelModel(gpt_config)
|
||||
accel_gpt.load_state_dict(self.gpt.state_dict(), strict=False)
|
||||
|
||||
if half:
|
||||
accel_gpt = accel_gpt.half().cuda()
|
||||
else:
|
||||
accel_gpt = accel_gpt.cuda()
|
||||
accel_gpt.eval()
|
||||
|
||||
lm_head_with_norm = nn.Sequential(self.final_norm, self.mel_head)
|
||||
self.accel_engine = AccelInferenceEngine(
|
||||
model=accel_gpt,
|
||||
lm_head=lm_head_with_norm,
|
||||
num_layers=self.layers,
|
||||
num_heads=self.heads,
|
||||
head_dim=self.model_dim // self.heads,
|
||||
block_size=256,
|
||||
num_blocks=16, # Reduce to save memory (16*256 = 4096 tokens capacity)
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
print("acceleration engine initialized")
|
||||
self.inference_model = GPT2InferenceModel(
|
||||
gpt_config,
|
||||
self.gpt,
|
||||
@ -721,6 +756,20 @@ class UnifiedVoice(nn.Module):
|
||||
min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
|
||||
logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
|
||||
max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
|
||||
|
||||
# Use accel engine if available (single sequence only)
|
||||
if self.accel_engine is not None and num_return_sequences == 1:
|
||||
output = self.accel_engine.generate(
|
||||
inputs, # fake input_ids (all 1s + start_mel_token)
|
||||
max_new_tokens=max_length - trunc_index,
|
||||
attention_mask=attention_mask,
|
||||
temperature=hf_generate_kwargs.get('temperature', 1),
|
||||
stop_tokens=[self.stop_mel_token],
|
||||
tts_embeddings=inputs_embeds, # [pad][cond][text] embeddings (87 tokens, NO start_mel_token)
|
||||
tts_mel_embedding=self.inference_model.embeddings, # mel_embedding layer
|
||||
tts_text_pos_embedding=self.inference_model.text_pos_embedding, # text_pos_embedding layer
|
||||
)
|
||||
else:
|
||||
output = self.inference_model.generate(inputs,
|
||||
bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
|
||||
eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
|
||||
|
||||
@ -38,7 +38,7 @@ import torch.nn.functional as F
|
||||
class IndexTTS2:
|
||||
def __init__(
|
||||
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None,
|
||||
use_cuda_kernel=None,use_deepspeed=False
|
||||
use_cuda_kernel=None,use_deepspeed=False, use_accel=False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -48,6 +48,7 @@ class IndexTTS2:
|
||||
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
|
||||
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
|
||||
use_deepspeed (bool): whether to use DeepSpeed or not.
|
||||
use_accel (bool): whether to use acceleration engine for GPT2 or not.
|
||||
"""
|
||||
if device is not None:
|
||||
self.device = device
|
||||
@ -75,10 +76,11 @@ class IndexTTS2:
|
||||
self.model_dir = model_dir
|
||||
self.dtype = torch.float16 if self.use_fp16 else None
|
||||
self.stop_mel_token = self.cfg.gpt.stop_mel_token
|
||||
self.use_accel = use_accel
|
||||
|
||||
self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
|
||||
|
||||
self.gpt = UnifiedVoice(**self.cfg.gpt)
|
||||
self.gpt = UnifiedVoice(**self.cfg.gpt, use_accel=self.use_accel)
|
||||
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
|
||||
load_checkpoint(self.gpt, self.gpt_path)
|
||||
self.gpt = self.gpt.to(self.device)
|
||||
@ -453,7 +455,7 @@ class IndexTTS2:
|
||||
ref_mel = self.cache_mel
|
||||
|
||||
if emo_vector is not None:
|
||||
weight_vector = torch.tensor(emo_vector).to(self.device)
|
||||
weight_vector = torch.tensor(emo_vector, device=self.device)
|
||||
if use_random:
|
||||
random_index = [random.randint(0, x - 1) for x in self.emo_num]
|
||||
else:
|
||||
@ -580,15 +582,16 @@ class IndexTTS2:
|
||||
# print(f"code len: {code_lens}")
|
||||
|
||||
code_lens = []
|
||||
max_code_len = 0
|
||||
for code in codes:
|
||||
if self.stop_mel_token not in code:
|
||||
code_lens.append(len(code))
|
||||
code_len = len(code)
|
||||
else:
|
||||
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1
|
||||
code_len = len_ - 1
|
||||
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0]
|
||||
code_len = len_[0].item() if len_.numel() > 0 else len(code)
|
||||
code_lens.append(code_len)
|
||||
codes = codes[:, :code_len]
|
||||
max_code_len = max(max_code_len, code_len)
|
||||
codes = codes[:, :max_code_len]
|
||||
code_lens = torch.LongTensor(code_lens)
|
||||
code_lens = code_lens.to(self.device)
|
||||
if verbose:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user