-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Improve Page Attention #1
base: main
Are you sure you want to change the base?
Conversation
Hi @yaoyaoding I believe that this concludes my part of page attention. I have found a bug in the "flash attention" implementation, as well as verfied the correctness of generation for the page attention op and cache rewrite op. Due to implementation details, the key-cache was stored differently, causing some issues with cache rewrite, which I fixed. The full tests can be found in I have written an end-to-end example to test correctness: # %%
import torch
import hidet
class LLM:
def __init__(self, model, prefill_graph, decode_graph, num_layers, num_heads, head_size, block_size, dtype='float32'):
self.model = model
self.prefill_graph = prefill_graph
self.decode_graph = decode_graph
self.num_layers = num_layers
self.num_heads = num_heads
self.head_size = head_size
self.block_size = block_size
self.dtype = dtype
self.cache = None
self.cur_len = 0
def init_cache(self, num_blocks: int):
dtype = getattr(torch, self.dtype)
cache = [
torch.zeros((num_blocks, self.num_heads, self.head_size, self.block_size), dtype=dtype).cuda()
for _ in range(self.num_layers * 2)
]
self.cache = [hidet.from_torch(x) for x in cache]
def enlarge_cache(self, num_blocks: int):
dtype = getattr(torch, self.dtype)
new_cache = [
torch.zeros((num_blocks, self.num_heads, self.head_size, self.block_size), dtype=dtype).cuda()
for _ in range(self.num_layers * 2)
]
self.cache = [hidet.from_torch(torch.cat([x.torch(), y], dim=0)) for x, y in zip(self.cache, new_cache)]
def prefill(self, tokens):
self.init_cache(len(tokens) // self.block_size + 1)
self.cur_len = len(tokens)
input_ids = hidet.asarray([tokens], dtype='int32').cuda()
position_ids = hidet.asarray([list(range(len(tokens)))], dtype='int32').cuda()
cache_slots = hidet.asarray([list(range(len(tokens)))], dtype='int64').cuda()
# cache_blocks = hidet.asarray([list(range(len(tokens) // self.block_size + 1))], dtype='int32').cuda()
seq_lengths = hidet.asarray([len(tokens)], dtype='int32').cuda()
inputs = [input_ids, position_ids, cache_slots, seq_lengths, *self.cache]
outputs = self.prefill_graph.run_async(inputs)
hidden_states = outputs[0]
return hidden_states
def decode(self, tok: int):
blocks = self.cache[0].shape[0]
if blocks < self.cur_len // self.block_size + 1:
self.enlarge_cache(5)
input_ids = hidet.asarray([[tok]], dtype='int32').cuda()
position_ids = hidet.asarray([[self.cur_len]], dtype='int32').cuda()
cache_slots = hidet.asarray([[self.cur_len]], dtype='int64').cuda()
max_context_length = hidet.full([], self.cur_len, dtype='int32').cuda()
cache_blocks = hidet.asarray([[list(range(self.cur_len // self.block_size + 1))]], dtype='int32').cuda()
seq_lengths = hidet.asarray([self.cur_len], dtype='int32').cuda()
inputs = [input_ids, position_ids, cache_slots, seq_lengths, max_context_length, cache_blocks, *self.cache]
outputs = self.decode_graph.run_async(inputs)
hidden_states = outputs[0]
self.cur_len += 1
return hidden_states
class LLamaLLM:
def __init__(self, model):
self.model = model
self.cache = None
def prefill(self, tokens):
intput_ids = torch.tensor([tokens], dtype=torch.int32).cuda()
position_ids = torch.tensor([list(range(len(tokens)))]).cuda()
ht = self.model(intput_ids, position_ids=position_ids, use_cache=True, output_hidden_states=True)
self.cache = ht.past_key_values
return ht.hidden_states[-1]
def decode(self, tok: int):
input_ids = torch.tensor([[tok]], dtype=torch.int32).cuda()
position_ids = torch.tensor([[self.cache[0][0].shape[2]]], dtype=torch.int32).cuda()
ht = self.model(input_ids, position_ids=position_ids, past_key_values=self.cache, use_cache=True, output_hidden_states=True)
self.cache = ht.past_key_values
return ht.hidden_states[-1]
def make_hidet_llm(model: torch.nn.Module, block_size=16):
from hidet.apps.llm.builder import _build_decode_graph, _build_prefill_graph
from hidet.apps.llm.modeling.pretrained import copy_weights
from hidet.apps.llm.modeling.llama.modeling import LlamaForCausalLM
model.config.torch_dtype = torch.float32
hidet_model = LlamaForCausalLM(model.config).cuda()
copy_weights(model, hidet_model)
prefill_graph = _build_prefill_graph(hidet_model, device='cuda', block_size=block_size, kernel_search_space=0)
decode_graph = _build_decode_graph(hidet_model, device='cuda', block_size=block_size, kernel_search_space=0)
return LLM(hidet_model, prefill_graph, decode_graph, model.config.num_hidden_layers, model.config.num_attention_heads, model.config.hidden_size // model.config.num_attention_heads, block_size, dtype='float32')
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
hidet.option.cache_dir('./llmcache')
hidet.utils.clear_cache_dir()
config = LlamaConfig(vocab_size=512, hidden_size=1024, num_attention_heads=8, intermediate_size=2048, num_hidden_layers=2)
model = LlamaForCausalLM(config).cuda()
torch_llm = LLamaLLM(model)
hidet_llm = make_hidet_llm(model)
tokens = [1, 2, 3, 4]
hidden0 = torch_llm.prefill(tokens)
hidden1 = hidet_llm.prefill(tokens).torch()
print(hidden0.shape, hidden1.shape)
print((hidden0 - hidden1).abs().max())
print((hidden0 - hidden1).abs().mean())
print(hidet_llm.cache[0].torch().sum()) However, I am puzzled that the last line prints 0, meaning that the cache was not written to. Do you have any clues @yaoyaoding ? |
No description provided.