Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Improve Page Attention #1

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

Aalanli
Copy link

@Aalanli Aalanli commented Feb 7, 2024

No description provided.

@Aalanli
Copy link
Author

Aalanli commented Feb 18, 2024

Hi @yaoyaoding

I believe that this concludes my part of page attention. I have found a bug in the "flash attention" implementation, as well as verfied the correctness of generation for the page attention op and cache rewrite op. Due to implementation details, the key-cache was stored differently, causing some issues with cache rewrite, which I fixed. The full tests can be found in hidet/tests/apps/llm/ops folder.

I have written an end-to-end example to test correctness:

# %%
import torch
import hidet

class LLM:
    def __init__(self, model, prefill_graph, decode_graph, num_layers, num_heads, head_size, block_size, dtype='float32'):
        self.model = model
        self.prefill_graph = prefill_graph
        self.decode_graph = decode_graph

        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_size = head_size
        self.block_size = block_size
        self.dtype = dtype

        self.cache = None
        self.cur_len = 0

    def init_cache(self, num_blocks: int):
        dtype = getattr(torch, self.dtype)
        cache = [
            torch.zeros((num_blocks, self.num_heads, self.head_size, self.block_size), dtype=dtype).cuda()
            for _ in range(self.num_layers * 2)
        ]
        self.cache = [hidet.from_torch(x) for x in cache]
    
    def enlarge_cache(self, num_blocks: int):
        dtype = getattr(torch, self.dtype)
        new_cache = [
            torch.zeros((num_blocks, self.num_heads, self.head_size, self.block_size), dtype=dtype).cuda()
            for _ in range(self.num_layers * 2)
        ]
        self.cache = [hidet.from_torch(torch.cat([x.torch(), y], dim=0)) for x, y in zip(self.cache, new_cache)]
    
    def prefill(self, tokens):
        self.init_cache(len(tokens) // self.block_size + 1)
        self.cur_len = len(tokens)
        input_ids = hidet.asarray([tokens], dtype='int32').cuda()
        position_ids = hidet.asarray([list(range(len(tokens)))], dtype='int32').cuda()
        cache_slots = hidet.asarray([list(range(len(tokens)))], dtype='int64').cuda()
        # cache_blocks = hidet.asarray([list(range(len(tokens) // self.block_size + 1))], dtype='int32').cuda()
        seq_lengths = hidet.asarray([len(tokens)], dtype='int32').cuda()
        inputs = [input_ids, position_ids, cache_slots, seq_lengths, *self.cache]
        outputs = self.prefill_graph.run_async(inputs)
        hidden_states = outputs[0]
        return hidden_states

    def decode(self, tok: int):
        blocks = self.cache[0].shape[0]
        if blocks < self.cur_len // self.block_size + 1:
            self.enlarge_cache(5)
        input_ids = hidet.asarray([[tok]], dtype='int32').cuda()
        position_ids = hidet.asarray([[self.cur_len]], dtype='int32').cuda()
        cache_slots = hidet.asarray([[self.cur_len]], dtype='int64').cuda()
        max_context_length = hidet.full([], self.cur_len, dtype='int32').cuda()
        cache_blocks = hidet.asarray([[list(range(self.cur_len // self.block_size + 1))]], dtype='int32').cuda()
        seq_lengths = hidet.asarray([self.cur_len], dtype='int32').cuda()
        inputs = [input_ids, position_ids, cache_slots, seq_lengths, max_context_length, cache_blocks, *self.cache]
        outputs = self.decode_graph.run_async(inputs)
        hidden_states = outputs[0]
        self.cur_len += 1
        return hidden_states


class LLamaLLM:
    def __init__(self, model):
        self.model = model
        self.cache = None
    
    def prefill(self, tokens):
        intput_ids = torch.tensor([tokens], dtype=torch.int32).cuda()
        position_ids = torch.tensor([list(range(len(tokens)))]).cuda()
        ht = self.model(intput_ids, position_ids=position_ids, use_cache=True, output_hidden_states=True)
        self.cache = ht.past_key_values
        return ht.hidden_states[-1]
    
    def decode(self, tok: int):
        input_ids = torch.tensor([[tok]], dtype=torch.int32).cuda()
        position_ids = torch.tensor([[self.cache[0][0].shape[2]]], dtype=torch.int32).cuda()
        ht = self.model(input_ids, position_ids=position_ids, past_key_values=self.cache, use_cache=True, output_hidden_states=True)
        self.cache = ht.past_key_values
        return ht.hidden_states[-1]

def make_hidet_llm(model: torch.nn.Module, block_size=16):
    from hidet.apps.llm.builder import _build_decode_graph, _build_prefill_graph
    from hidet.apps.llm.modeling.pretrained import copy_weights
    from hidet.apps.llm.modeling.llama.modeling import LlamaForCausalLM

    model.config.torch_dtype = torch.float32
    hidet_model = LlamaForCausalLM(model.config).cuda()
    copy_weights(model, hidet_model)

    prefill_graph = _build_prefill_graph(hidet_model, device='cuda', block_size=block_size, kernel_search_space=0)
    decode_graph = _build_decode_graph(hidet_model, device='cuda', block_size=block_size, kernel_search_space=0)

    return LLM(hidet_model, prefill_graph, decode_graph, model.config.num_hidden_layers, model.config.num_attention_heads, model.config.hidden_size // model.config.num_attention_heads, block_size, dtype='float32')


from transformers.models.llama import LlamaForCausalLM, LlamaConfig
hidet.option.cache_dir('./llmcache')
hidet.utils.clear_cache_dir()
config = LlamaConfig(vocab_size=512, hidden_size=1024, num_attention_heads=8, intermediate_size=2048, num_hidden_layers=2)
model = LlamaForCausalLM(config).cuda()

torch_llm = LLamaLLM(model)
hidet_llm = make_hidet_llm(model)


tokens = [1, 2, 3, 4]
hidden0 = torch_llm.prefill(tokens)
hidden1 = hidet_llm.prefill(tokens).torch()

print(hidden0.shape, hidden1.shape)

print((hidden0 - hidden1).abs().max())
print((hidden0 - hidden1).abs().mean())

print(hidet_llm.cache[0].torch().sum())

However, I am puzzled that the last line prints 0, meaning that the cache was not written to. Do you have any clues @yaoyaoding ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants