From 478cbe0d7f63600cc3731f444d173bde8e8b0533 Mon Sep 17 00:00:00 2001 From: Celve Date: Mon, 29 Jul 2024 21:42:19 -0700 Subject: [PATCH 01/17] feat: add naive worker implementation --- deserve_worker/__init__.py | 0 deserve_worker/forward_engine.py | 112 +++++++ deserve_worker/kvcache.py | 103 ++++++ deserve_worker/layer_storage.py | 193 +++++++++++ deserve_worker/model.py | 535 +++++++++++++++++++++++++++++++ deserve_worker/paged_kvcache.py | 96 ++++++ deserve_worker/py.typed | 0 deserve_worker/pyproject.toml | 13 + deserve_worker/worker.py | 221 +++++++++++++ deserve_worker/worker_api.py | 48 +++ 10 files changed, 1321 insertions(+) create mode 100644 deserve_worker/__init__.py create mode 100644 deserve_worker/forward_engine.py create mode 100644 deserve_worker/kvcache.py create mode 100644 deserve_worker/layer_storage.py create mode 100644 deserve_worker/model.py create mode 100644 deserve_worker/paged_kvcache.py create mode 100644 deserve_worker/py.typed create mode 100644 deserve_worker/pyproject.toml create mode 100644 deserve_worker/worker.py create mode 100644 deserve_worker/worker_api.py diff --git a/deserve_worker/__init__.py b/deserve_worker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deserve_worker/forward_engine.py b/deserve_worker/forward_engine.py new file mode 100644 index 0000000..fa19468 --- /dev/null +++ b/deserve_worker/forward_engine.py @@ -0,0 +1,112 @@ +import queue + +import torch + +from .kvcache import KVCacheBase +from .layer_storage import LayerStorage +from .model import ENABLE_FLASH_ATTN + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + if ENABLE_FLASH_ATTN: + freqs_cis = torch.stack([freqs.cos(), freqs.sin()]) # flash_attn + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0).to("cuda") + + +class LayerForward: + def __init__( + self, + layer_storage: LayerStorage, + h: torch.Tensor, + seqlen: int, + start_pos: int, + kvcaches: dict[int, KVCacheBase], + back: queue.Queue[torch.Tensor], + ): + self.layer_storage = layer_storage + self.h = h + self.seqlen = seqlen + self.start_pos = start_pos + self.kvcaches = kvcaches + self.back = back + + +class ForwardEngine: + def __init__(self, max_total_bsz: int): + self.max_total_bsz = max_total_bsz + + self.queue = queue.Queue[LayerForward]() + + def run(self) -> None: + q = self.queue + while True: + q_buffered: list[LayerForward] = [q.get()] + while True: + try: + tasks = q.get(block=False) + q_buffered.append(tasks) + except queue.Empty: + break + prefill_tasks = [task for task in q_buffered if task.seqlen > 1] + decode_tasks = [task for task in q_buffered if task.seqlen == 1] + print( + f"prefill_tasks: {len(prefill_tasks)}, decode_tasks: {len(decode_tasks)}" + ) + + for task in prefill_tasks: + h = self.forward([task]) + self.process(h, [task]) + + for i in range(0, len(decode_tasks), self.max_total_bsz): + to_decode = decode_tasks[ + i : min(i + self.max_total_bsz, len(decode_tasks)) + ] + h = self.forward(to_decode) + self.process(h, to_decode) + + def add_layer_forward(self, task: LayerForward) -> None: + self.queue.put(task) + + def forward(self, tasks: list[LayerForward]) -> torch.Tensor: + # we need to check that all tasks share the same layer storage + with torch.inference_mode(): + layer_storage = tasks[0].layer_storage + h = torch.cat([t.h for t in tasks]) + bsz_list, start_pos_list = [1 for _ in tasks], [t.start_pos for t in tasks] + return layer_storage.forward( + h, + bsz_list, + start_pos_list, + global_freqs_cis, + [task.kvcaches for task in tasks], + ) + + def process(self, merged_h: torch.Tensor, tasks: list[LayerForward]) -> None: + ptr = 0 + for task in tasks: + task.back.put(merged_h[ptr : ptr + 1]) + ptr += 1 diff --git a/deserve_worker/kvcache.py b/deserve_worker/kvcache.py new file mode 100644 index 0000000..7501bcc --- /dev/null +++ b/deserve_worker/kvcache.py @@ -0,0 +1,103 @@ +from abc import ABC, abstractmethod + +import torch + +main_dtype = torch.float16 +main_device = torch.device("cuda") +torch.set_default_dtype(main_dtype) # type: ignore + + +def del_tensor(t: torch.Tensor) -> None: + t.detach() + t.grad = None + t.untyped_storage().resize_(0) + + +class KVCacheBase(ABC): + @abstractmethod + def renew(self, x: torch.Tensor, start_pos: int) -> None: + pass + + @abstractmethod + def clear(self) -> None: + pass + + +KV_CACHE_BLOCK_SIZE = 256 + + +class KVCache(KVCacheBase): + def get_kv_cache_length(self, cur: int, seqlen: int) -> int: + while cur < seqlen: + cur += KV_CACHE_BLOCK_SIZE + return cur + + def __init__( + self, + x: torch.Tensor, + start_pos: int, + n_local_kv_heads: int, + head_dim: int, + ): + self.n_local_kv_heads = n_local_kv_heads + self.head_dim = head_dim + + bsz, seqlen = x.shape[0], x.shape[1] + length = self.get_kv_cache_length(0, start_pos + seqlen) + self.cache_k = torch.zeros( + ( + bsz, + length, + n_local_kv_heads, + head_dim, + ), + device=main_device, + dtype=main_dtype, + ) + self.cache_v = torch.zeros( + ( + bsz, + length, + n_local_kv_heads, + head_dim, + ), + device=main_device, + dtype=main_dtype, + ) + self.main_device = main_device + + def renew(self, x: torch.Tensor, start_pos: int) -> None: + bsz, seqlen = x.shape[0], x.shape[1] + if start_pos + seqlen > self.cache_k.shape[1]: + length = self.get_kv_cache_length(self.cache_k.shape[1], start_pos + seqlen) + cache_k = torch.zeros( + ( + bsz, + length, + self.n_local_kv_heads, + self.head_dim, + ), + device=self.main_device, + ) + cache_v = torch.zeros( + ( + bsz, + length, + self.n_local_kv_heads, + self.head_dim, + ), + device=self.main_device, + ) + cache_k[:, :start_pos, :, :], cache_v[:, :start_pos, :, :] = ( + self.cache_k[:, :start_pos, :, :], + self.cache_v[:, :start_pos, :, :], + ) + del_tensor(self.cache_k) + del_tensor(self.cache_v) + self.cache_k = cache_k + self.cache_v = cache_v + + def clear(self) -> None: + del_tensor(self.cache_k) + del_tensor(self.cache_v) + torch.cuda.empty_cache() diff --git a/deserve_worker/layer_storage.py b/deserve_worker/layer_storage.py new file mode 100644 index 0000000..d2de1c0 --- /dev/null +++ b/deserve_worker/layer_storage.py @@ -0,0 +1,193 @@ +import os +import threading +from concurrent.futures import ThreadPoolExecutor + +import requests +import torch + +from .kvcache import KVCacheBase +from .model import ENABLE_FLASH_ATTN, ModelArgs, RMSNorm, TransformerBlock + +llama_2_7b_args = { + "dim": 4096, + "multiple_of": 256, + "n_heads": 32, + "n_layers": 32, + "norm_eps": 1e-06, + "vocab_size": 32000, +} + +llama_2_13b_args = { + "dim": 5120, + "multiple_of": 256, + "n_heads": 40, + "n_layers": 40, + "norm_eps": 1e-05, + "vocab_size": 32000, +} + +llama_2_70b_args = { + "dim": 8192, + "multiple_of": 4096, + "ffn_dim_multiplier": 1.3, + "n_heads": 64, + "n_kv_heads": 8, + "n_layers": 80, + "norm_eps": 1e-05, + "vocab_size": 32000, +} + +llama_3_8b_args = { + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": 128256, + "multiple_of": 1024, + "ffn_dim_multiplier": 1.3, + "norm_eps": 1e-05, + "rope_theta": 500000.0, +} + +llama_3_70b_args = { + "dim": 8192, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "n_heads": 64, + "n_kv_heads": 8, + "n_layers": 80, + "norm_eps": 1e-05, + "vocab_size": 128256, + "rope_theta": 500000.0, +} + + +class LayerManager: + def __init__(self, main_device: torch.device): + self.main_device = main_device + self.network_executor = ThreadPoolExecutor(max_workers=80) + self.cache_dir = os.path.expanduser("~/.cache/fleece-worker/models") + self.layer_storage_map: dict[frozenset[str], LayerStorage] = {} + self.mutex = threading.Lock() + + def get_layer_storage(self, layer_names: list[str]) -> "LayerStorage": + frozen_layer_names = frozenset(layer_names) + with self.mutex: + if frozen_layer_names not in self.layer_storage_map: + self.layer_storage_map[frozen_layer_names] = LayerStorage( + layer_names, self.main_device + ) + return self.layer_storage_map[frozen_layer_names] + + +global_layer_manager = LayerManager(torch.device("cuda")) + + +class LayerStorage: + def __init__(self, layer_names: list[str], main_device: torch.device): + self.layer_names = layer_names + self.main_device = main_device + self.layers: dict[str, torch.nn.Module] = {} + + self.preload_layers(self.layer_names) + + def fetch_layer(self, full_layer_name: str) -> str: + model_name, layer_name = full_layer_name.split("/") + path = os.path.join( + global_layer_manager.cache_dir, model_name, f"{layer_name}.pt" + ) + if not os.path.exists(path): # TODO lock + os.makedirs( + os.path.join(global_layer_manager.cache_dir, model_name), exist_ok=True + ) + with requests.get( + f"https://huggingface.co/colearn/{model_name}/resolve/main/{layer_name}.pt", + stream=True, + ) as r: + r.raise_for_status() + with open(path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + return path + + def preload_layers(self, layer_names: list[str]) -> None: + threads = [] + for full_layer_name in layer_names: + if full_layer_name in self.layers: + continue + thread = global_layer_manager.network_executor.submit( + self.fetch_layer, full_layer_name + ) + threads.append((full_layer_name, thread)) + for full_layer_name, thread in threads: + path = thread.result() + model_name, layer_name = full_layer_name.split("/") + if model_name.startswith("llama-2-7b"): + model_args = ModelArgs(**llama_2_7b_args) + elif model_name.startswith("llama-2-13b"): + model_args = ModelArgs(**llama_2_13b_args) + elif model_name.startswith("llama-2-70b"): + model_args = ModelArgs(**llama_2_70b_args) + elif model_name.startswith("llama-3-8b"): + model_args = ModelArgs(**llama_3_8b_args) + elif model_name.startswith("llama-3-70b"): + model_args = ModelArgs(**llama_3_70b_args) + else: + raise NotImplementedError("Unknown model") + if layer_name == "tok_embeddings": + l = torch.nn.utils.skip_init( + torch.nn.Embedding, model_args.vocab_size, model_args.dim + ) + elif layer_name.startswith("layer"): + l = TransformerBlock(model_args) + elif layer_name == "norm": + l = RMSNorm(model_args.dim, eps=model_args.norm_eps) + elif layer_name == "output": + l = torch.nn.utils.skip_init( + torch.nn.Linear, + model_args.dim, + model_args.vocab_size, + bias=False, + ) + else: + raise NotImplementedError("Unknown layers") + l.load_state_dict(torch.load(path, map_location="cpu")) + l.to(self.main_device) + self.layers[full_layer_name] = l + print("Loaded", full_layer_name) + + def unload_layers(self) -> None: + for full_layer_name in self.layer_names: + if full_layer_name not in self.layers: + continue # TODO: continue or warning? + del self.layers[full_layer_name] + torch.cuda.empty_cache() + + def forward( + self, + h: torch.Tensor, + bsz_list: list[int], + start_pos_list: list[int], + global_freqs_cis: torch.Tensor, + kv_cache_list: list[dict[int, KVCacheBase]], + ) -> torch.Tensor: + for full_layer_name in self.layers: + _, layer_name = full_layer_name.split("/") + if layer_name == "tok_embeddings": + h = self.layers[full_layer_name](h) + elif layer_name.startswith("layers."): + layer_id = int(layer_name.split(".")[1]) + h = self.layers[full_layer_name]( + h, + bsz_list, + start_pos_list, + global_freqs_cis, + [kv_cache[layer_id] for kv_cache in kv_cache_list], + ) + elif layer_name == "norm": + h = self.layers[full_layer_name](h) + elif layer_name == "output": + h = self.layers[full_layer_name](h) + else: + raise NotImplementedError("Unknown layers") + return h diff --git a/deserve_worker/model.py b/deserve_worker/model.py new file mode 100644 index 0000000..09701f5 --- /dev/null +++ b/deserve_worker/model.py @@ -0,0 +1,535 @@ +# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement. + +import math +import pickle +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple, cast + +import safetensors.torch +import torch +import torch.nn.functional as F +from torch import nn + +from deserve_worker.paged_kvcache import PagedKVCache + +from .kvcache import KVCache, KVCacheBase + +ENABLE_FLASH_ATTN = False +try: + from flash_attn import flash_attn_with_kvcache # type: ignore + + from .paged_kvcache import global_paged_memory + + ENABLE_FLASH_ATTN = True +except ImportError as e: + print( + "Package flash-attn is not found. Please install it for better performance. https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features" + ) + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 500000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + @torch.inference_mode() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + if ENABLE_FLASH_ATTN: + freqs_cis = torch.stack([freqs.cos(), freqs.sin()]) # flash_attn + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + + """ + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """Multi-head attention module.""" + + def __init__(self, args: ModelArgs): + """ + Initialize the Attention module. + + Args: + args (ModelArgs): Model configuration parameters. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_local_heads (int): Number of local query heads. + n_local_kv_heads (int): Number of local key and value heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (ColumnParallelLinear): Linear transformation for queries. + wk (ColumnParallelLinear): Linear transformation for keys. + wv (ColumnParallelLinear): Linear transformation for values. + wo (RowParallelLinear): Linear transformation for output. + cache_k (torch.Tensor): Cached keys for attention. + cache_v (torch.Tensor): Cached values for attention. + + """ + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + # model_parallel_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads # // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads # // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = torch.nn.utils.skip_init( + nn.Linear, + args.dim, + args.n_heads * self.head_dim, + bias=False, + ) + self.wk = torch.nn.utils.skip_init( + nn.Linear, + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wv = torch.nn.utils.skip_init( + nn.Linear, + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wo = torch.nn.utils.skip_init( + nn.Linear, + args.n_heads * self.head_dim, + args.dim, + bias=False, + ) + + # self.cache_k = torch.zeros( + # ( + # args.max_batch_size, + # args.max_seq_len, + # self.n_local_kv_heads, + # self.head_dim, + # ) + # ) + # self.cache_v = torch.zeros( + # ( + # args.max_batch_size, + # args.max_seq_len, + # self.n_local_kv_heads, + # self.head_dim, + # ) + # ) + + @torch.inference_mode() + def forward( + self, + x: torch.Tensor, + bsz_list: List[int], + start_pos_list: List[int], + global_freqs_cis: torch.Tensor, + kv_cache_list: list[KVCacheBase], + ) -> torch.Tensor: + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position for caching. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + mask (torch.Tensor, optional): Attention mask tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + _, seqlen, _ = x.shape + xq_, xk_, xv_ = self.wq(x), self.wk(x), self.wv(x) + + if ENABLE_FLASH_ATTN: + cache_seqlens = [] + for i, bsz in enumerate(bsz_list): + cache_seqlens += [start_pos_list[i]] * bsz + cache_seqlens_tch = torch.tensor( + cache_seqlens, dtype=torch.int32, device=x.device + ) + bsz = cache_seqlens_tch.shape[0] + paged_kv_cache_list: list[PagedKVCache] = kv_cache_list # type: ignore + + max_len = max([kvcache.shape()[1] for kvcache in paged_kv_cache_list]) + block_table = torch.zeros( + (bsz, max_len), dtype=torch.int32, device=x.device + ) + start = 0 + for i, bsz in enumerate(bsz_list): + block_table[ + start : start + bsz, : paged_kv_cache_list[i].shape()[1] + ] = paged_kv_cache_list[i].block_table + start += bsz + + bsz = cache_seqlens_tch.shape[0] + xq = xq_.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk_.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv_.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + cos = global_freqs_cis[0].type_as(xq) + sin = global_freqs_cis[1].type_as(xq) + output = flash_attn_with_kvcache( + xq, + global_paged_memory.cache_k_paged, + global_paged_memory.cache_v_paged, + xk, + xv, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens_tch, + block_table=block_table, + causal=True, + rotary_interleaved=True, + ) + output = output.view(bsz, seqlen, -1) + else: + start = 0 + output_list = [] + for i, bsz in enumerate(bsz_list): + xq = xq_[start : start + bsz].view( + bsz, seqlen, self.n_local_heads, self.head_dim + ) + xk = xk_[start : start + bsz].view( + bsz, seqlen, self.n_local_kv_heads, self.head_dim + ) + xv = xv_[start : start + bsz].view( + bsz, seqlen, self.n_local_kv_heads, self.head_dim + ) + start += bsz + + start_pos = start_pos_list[i] + kv_cache: KVCache = cast(KVCache, kv_cache_list[i]) + cache_k, cache_v = kv_cache.cache_k, kv_cache.cache_v + + freqs_cis = global_freqs_cis[start_pos : start_pos + seqlen] + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=x.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(x) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + cache_k[:bsz, start_pos : start_pos + seqlen] = xk + cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = cache_k[:bsz, : start_pos + seqlen] + values = cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv( + keys, self.n_rep + ) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv( + values, self.n_rep + ) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + if mask is not None: + scores = ( + scores + mask + ) # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul( + scores, values + ) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + output_list.append(output) + output = torch.cat([x for x in output_list]) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = torch.nn.utils.skip_init( + nn.Linear, + dim, + hidden_dim, + bias=False, + ) + self.w2 = torch.nn.utils.skip_init( + nn.Linear, + hidden_dim, + dim, + bias=False, + ) + self.w3 = torch.nn.utils.skip_init( + nn.Linear, + dim, + hidden_dim, + bias=False, + ) + + @torch.inference_mode() + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + """ + Initialize a TransformerBlock. + + Args: + layer_id (int): Identifier for the layer. + args (ModelArgs): Model configuration parameters. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + # self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + @torch.inference_mode() + def forward( + self, + x: torch.Tensor, + bsz_list: List[int], + start_pos_list: List[int], + global_freqs_cis: torch.Tensor, + kv_cache_list: list[KVCacheBase], + ) -> torch.Tensor: + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position for attention caching. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention.forward( + self.attention_norm(x), + bsz_list, + start_pos_list, + global_freqs_cis, + kv_cache_list, + ) + out = h + self.feed_forward.forward(self.ffn_norm(h)) + return out + + +def dumps(tensors: dict[str, torch.Tensor], metadata: dict[str, Any]) -> bytes: + """ + Dump tensors and metadata into bytes + """ + + metadata_bytes = pickle.dumps(metadata) + tensors_bytes = safetensors.torch.save(tensors) + return ( + len(metadata_bytes).to_bytes(4, byteorder="big") + + metadata_bytes + + tensors_bytes + ) + + +def loads(b: bytes) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + """ + Load tensors and metadata from bytes + """ + + metadata_length = int.from_bytes(b[:4], byteorder="big") + metadata = pickle.loads(b[4 : 4 + metadata_length]) + tensors = safetensors.torch.load(b[4 + metadata_length :]) + return tensors, metadata diff --git a/deserve_worker/paged_kvcache.py b/deserve_worker/paged_kvcache.py new file mode 100644 index 0000000..85532ff --- /dev/null +++ b/deserve_worker/paged_kvcache.py @@ -0,0 +1,96 @@ +import queue +from typing import Optional + +import torch + +from .kvcache import KVCacheBase, main_device, main_dtype + + +class PagedMemory: + def __init__( + self, + num_blocks: int, + block_size: int, + main_device: torch.device, + main_dtype: torch.dtype, + ): + self.num_blocks = num_blocks + self.block_size = block_size + self.cache_k_paged = torch.randn( + num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype + ) + self.cache_v_paged = torch.randn( + num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype + ) + self.avai_blocks = queue.Queue[int]() + for i in range(1, num_blocks): + self.avai_blocks.put(i) + + +global_paged_memory = PagedMemory(11500, 256, main_device, main_dtype) + + +class PagedKVCache(KVCacheBase): + def get_kv_cache_length(self, cur: int, seqlen: int) -> int: + while cur < seqlen: + cur += global_paged_memory.block_size + return cur + + def __init__(self, x: torch.Tensor, start_pos: int, main_device: torch.device): + self.main_device = main_device + bsz, seqlen = x.shape[0], x.shape[1] + length = ( + self.get_kv_cache_length(0, start_pos + seqlen) + // global_paged_memory.block_size + ) + self.block_table = torch.zeros( + ( + bsz, + length, + ), + device=self.main_device, + dtype=torch.int32, + ) + for i in range(length): + for j in range(bsz): + self.block_table[j, i] = global_paged_memory.avai_blocks.get() + + def renew( + self, + x: torch.Tensor, + start_pos: int, + ) -> None: + bsz, seqlen = x.shape[0], x.shape[1] + if ( + start_pos + seqlen + > self.block_table.shape[1] * global_paged_memory.block_size + ): + # enlarge block table + length = ( + self.get_kv_cache_length( + self.block_table.shape[1] * global_paged_memory.block_size, + start_pos + seqlen, + ) + // global_paged_memory.block_size + ) + block_table = torch.zeros( + ( + bsz, + length, + ), + device=self.main_device, + dtype=torch.int32, + ) + block_table[:, : self.block_table.shape[1]] = self.block_table[:, :] + for i in range(self.block_table.shape[1], length): + for j in range(bsz): + block_table[j, i] = global_paged_memory.avai_blocks.get() + self.block_table = block_table + + def clear(self) -> None: + for row in self.block_table.tolist(): + for item in row: + global_paged_memory.avai_blocks.put(item) + + def shape(self) -> torch.Size: + return self.block_table.shape diff --git a/deserve_worker/py.typed b/deserve_worker/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/deserve_worker/pyproject.toml b/deserve_worker/pyproject.toml new file mode 100644 index 0000000..b9621c9 --- /dev/null +++ b/deserve_worker/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "deserve_worker" +version = "0.0.1" +authors = [ + { name="Example Author", email="author@example.com" }, +] +description = "Deserve Worker" +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py new file mode 100644 index 0000000..8098fc0 --- /dev/null +++ b/deserve_worker/worker.py @@ -0,0 +1,221 @@ +import queue +import threading +import traceback +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Optional, cast + +import requests +import torch +from pydantic import BaseModel +from transformers import AutoTokenizer # type: ignore + +from deserve_worker.paged_kvcache import PagedKVCache + +from .forward_engine import ForwardEngine, LayerForward +from .kvcache import KVCacheBase +from .layer_storage import global_layer_manager +from .model import dumps + +EOS_TOKEN_ID = 128001 # for llama 3 only +STOP_TOKEN_IDS = [128001, 128009] + +stop_tokens = torch.tensor(STOP_TOKEN_IDS) + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + + +class PlanStep(BaseModel): + worker_id: str + worker_url: str + layers: list[str] + + +@dataclass +class TaskInfo: + start_pos: int + + kvcaches: dict[int, KVCacheBase] + """ + When flash attention is enabled, we use paged attention, otherwise the standard attention is adopted. + """ + + +class SamplingParams(BaseModel): + temperature: float + top_p: float + max_total_len: int + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +class Worker: + def __init__(self, worker_id: str, max_total_bsz: int, controller_url: str): + self.worker_id = worker_id + self.controller_url = controller_url + self.task_infos: dict[str, TaskInfo] = {} + self.forward_engine = ForwardEngine(max_total_bsz) + threading.Thread(target=self.forward_engine.run, daemon=True).start() + self.network_executor = ThreadPoolExecutor(max_workers=max_total_bsz) + + def forward( + self, + x: torch.Tensor, + task_id: str, + round: int, + plan: list[PlanStep], + sampling_params: SamplingParams, + ) -> None: + try: + index = next( + ( + i + for i, worker in enumerate(plan) + if worker.worker_id == self.worker_id + ), + None, + ) + if index is None: + return None + + if round == 0: + kvcaches = {} + for full_layer_name in plan[index].layers: + _, layer_name = full_layer_name.split("/") + if layer_name.startswith("layers."): + layer_id = int(layer_name.split(".")[1]) + kvcaches[layer_id] = PagedKVCache(x, 0, torch.device("cuda")) + + # TODO: need double check whether request is repeated + self.task_infos[task_id] = TaskInfo( + start_pos=0, + kvcaches=cast(dict[int, KVCacheBase], kvcaches), + ) + + bsz, seqlen = x.shape[:2] # currently bsz is not used + task_info = self.task_infos[task_id] + layer_storage = global_layer_manager.get_layer_storage(plan[index].layers) + back = queue.Queue[ + torch.Tensor + ]() # used to transfer tensor between engine and worker + layer_forward = LayerForward( + layer_storage=layer_storage, + h=x.to("cuda"), + seqlen=seqlen, + start_pos=task_info.start_pos, + kvcaches=task_info.kvcaches, + back=back, + ) + self.forward_engine.add_layer_forward(layer_forward) + h = back.get() + task_info.start_pos += seqlen + + to_pass: torch.Tensor + cancel = False + if index == len(plan) - 1: + # it's the last node in the plan, firstly generate token + if task_info.start_pos > sampling_params.max_total_len: + next_token = torch.tensor([[EOS_TOKEN_ID]]) + elif sampling_params.temperature > 0: + probs = torch.softmax( + h[:, -1] / sampling_params.temperature, dim=-1 + ) + next_token = sample_top_p(probs, sampling_params.top_p) + next_token = next_token.reshape(1, -1) + else: + next_token = torch.argmax(h[:, -1], dim=-1) + next_token = next_token.reshape(1, -1) + to_pass = next_token.to("cpu") + + # check whether to stop + if to_pass[0] in STOP_TOKEN_IDS: + cancel = True + + round += 1 + self.network_executor.submit( + requests.post, + f"{self.controller_url}/update_tasks", + json=[ + { + "task_id": task_id, + "output_tokens": to_pass.tolist(), + } + ], + ) + else: + # pass tensor to next node + to_pass = h + + next_index = (index + 1) % len(plan) + next_worker_url = plan[next_index].worker_url + + if cancel: + self.network_executor.submit( + requests.post, + f"{next_worker_url}/cancel", + json={ + "task_id": task_id, + "plan": [step.model_dump() for step in plan], + }, + ) + else: + self.network_executor.submit( + requests.post, + f"{next_worker_url}/forward", + data=dumps( + {"x": to_pass}, + { + "task_id": task_id, + "round": round, + "plan": plan, + "sampling_params": sampling_params, + }, + ), + ) + except Exception as e: + traceback.print_exc() + + def cancel(self, task_id: str, plan: list[PlanStep]) -> None: + index = next( + (i for i, x in enumerate(plan) if x.worker_id == self.worker_id), None + ) + if index is None: + return + + task_info = self.task_infos.pop(task_id, None) + if task_info is not None: + for kvcache in task_info.kvcaches.values(): + kvcache.clear() + + if index != len(plan) - 1: + requests.post( + f"{plan[index + 1].worker_url}/cancel", + json={ + "task_id": task_id, + "plan": plan, + }, + ) diff --git a/deserve_worker/worker_api.py b/deserve_worker/worker_api.py new file mode 100644 index 0000000..5094dae --- /dev/null +++ b/deserve_worker/worker_api.py @@ -0,0 +1,48 @@ +import sys +import traceback +from concurrent.futures import ThreadPoolExecutor + +from fastapi import FastAPI, Request +from pydantic import BaseModel + +from .model import loads +from .worker import PlanStep, SamplingParams, Worker + +app = FastAPI() +worker = Worker(sys.argv[2], 64, "http://localhost:29980") +runtime_executor = ThreadPoolExecutor(max_workers=64) + + +@app.post("/forward") +async def forward(request: Request) -> str: + try: + body = await request.body() + tensors, metadata = loads(body) + runtime_executor.submit( + worker.forward, + tensors["x"], + metadata["task_id"], + metadata["round"], + [PlanStep.model_validate(step) for step in metadata["plan"]], + SamplingParams.model_validate(metadata["sampling_params"]), + ) + except Exception as e: + traceback.print_exc() + return "ok" + + +class CancelRequest(BaseModel): + task_id: str + plan: list[PlanStep] + + +@app.post("/cancel") +async def cancel(request: CancelRequest) -> str: + runtime_executor.submit(worker.cancel, request.task_id, request.plan) + return "ok" + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="127.0.0.1", port=int(sys.argv[1])) From 5f365a742894541086cea84541a4b5ffdc40c1d7 Mon Sep 17 00:00:00 2001 From: Celve Date: Tue, 30 Jul 2024 14:35:15 -0700 Subject: [PATCH 02/17] feat: support batch forward with correctness verified --- deserve_worker/forward_engine.py | 121 ++++++++---- deserve_worker/paged_kvcache.py | 12 +- deserve_worker/task.py | 61 ++++++ deserve_worker/worker.py | 321 ++++++++++++++++--------------- deserve_worker/worker_api.py | 18 +- 5 files changed, 341 insertions(+), 192 deletions(-) create mode 100644 deserve_worker/task.py diff --git a/deserve_worker/forward_engine.py b/deserve_worker/forward_engine.py index fa19468..61f3917 100644 --- a/deserve_worker/forward_engine.py +++ b/deserve_worker/forward_engine.py @@ -1,11 +1,18 @@ import queue +import time +from dataclasses import dataclass +from typing import Optional import torch +from deserve_worker.task import LayerForward, ResultBack + from .kvcache import KVCacheBase from .layer_storage import LayerStorage from .model import ENABLE_FLASH_ATTN +EOS_TOKEN_ID = 128001 # for llama 3 only + def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: """ @@ -37,45 +44,62 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0).to("cuda") -class LayerForward: - def __init__( - self, - layer_storage: LayerStorage, - h: torch.Tensor, - seqlen: int, - start_pos: int, - kvcaches: dict[int, KVCacheBase], - back: queue.Queue[torch.Tensor], - ): - self.layer_storage = layer_storage - self.h = h - self.seqlen = seqlen - self.start_pos = start_pos - self.kvcaches = kvcaches - self.back = back +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +def trace(begin: float, msg: str) -> float: + end = time.time() + print(f"{msg}: {(end - begin)*1000:.3f}ms") + return end class ForwardEngine: - def __init__(self, max_total_bsz: int): + def __init__( + self, max_total_bsz: int, sendback_queue: queue.Queue[list[ResultBack]] + ): self.max_total_bsz = max_total_bsz - - self.queue = queue.Queue[LayerForward]() + self.sendback_queue = sendback_queue + self.handling_queue = queue.Queue[list[LayerForward]]() def run(self) -> None: - q = self.queue + q = self.handling_queue while True: - q_buffered: list[LayerForward] = [q.get()] + forwards: list[LayerForward] = q.get() + begin = time.time() while True: try: - tasks = q.get(block=False) - q_buffered.append(tasks) + news = q.get(block=False) + forwards.extend(news) except queue.Empty: break - prefill_tasks = [task for task in q_buffered if task.seqlen > 1] - decode_tasks = [task for task in q_buffered if task.seqlen == 1] - print( - f"prefill_tasks: {len(prefill_tasks)}, decode_tasks: {len(decode_tasks)}" - ) + prefill_tasks = [task for task in forwards if task.seqlen > 1] + decode_tasks = [task for task in forwards if task.seqlen == 1] + # print( + # f"prefill_tasks: {len(prefill_tasks)}, decode_tasks: {len(decode_tasks)}" + # ) for task in prefill_tasks: h = self.forward([task]) @@ -88,25 +112,50 @@ def run(self) -> None: h = self.forward(to_decode) self.process(h, to_decode) - def add_layer_forward(self, task: LayerForward) -> None: - self.queue.put(task) + def add_layer_forward(self, forwards: list[LayerForward]) -> None: + self.handling_queue.put(forwards) def forward(self, tasks: list[LayerForward]) -> torch.Tensor: # we need to check that all tasks share the same layer storage with torch.inference_mode(): layer_storage = tasks[0].layer_storage h = torch.cat([t.h for t in tasks]) - bsz_list, start_pos_list = [1 for _ in tasks], [t.start_pos for t in tasks] + bsz_list, start_pos_list = [1 for _ in tasks], [ + t.task_info.start_pos for t in tasks + ] + for task in tasks: + for kvcache in task.task_info.kvcaches.values(): + kvcache.renew(task.h, task.task_info.start_pos) return layer_storage.forward( h, bsz_list, start_pos_list, global_freqs_cis, - [task.kvcaches for task in tasks], + [task.task_info.kvcaches for task in tasks], ) def process(self, merged_h: torch.Tensor, tasks: list[LayerForward]) -> None: - ptr = 0 - for task in tasks: - task.back.put(merged_h[ptr : ptr + 1]) - ptr += 1 + result: list[ResultBack] = [] + for ptr, task in enumerate(tasks): + h = merged_h[ptr : ptr + 1] + _, seqlen = h.shape[:2] + task_info = task.task_info + task_info.start_pos += seqlen + if task.need_sample: + task_info.round += 1 + sampling_params = task_info.sampling_params + if task_info.start_pos >= sampling_params.max_total_len: + next_token = torch.tensor([[EOS_TOKEN_ID]]) + elif sampling_params.temperature > 0: + probs = torch.softmax( + h[:, -1] / sampling_params.temperature, dim=-1 + ) + next_token = sample_top_p(probs, sampling_params.top_p) + next_token = next_token.reshape(1, -1) + else: + next_token = torch.argmax(h[:, -1], dim=-1) + next_token = next_token.reshape(1, -1) + result.append(ResultBack(next_token.to("cpu"), task_info.task_id)) + else: + result.append(ResultBack(h.to("cpu"), task_info.task_id)) + self.sendback_queue.put(result) diff --git a/deserve_worker/paged_kvcache.py b/deserve_worker/paged_kvcache.py index 85532ff..0465c25 100644 --- a/deserve_worker/paged_kvcache.py +++ b/deserve_worker/paged_kvcache.py @@ -27,7 +27,7 @@ def __init__( self.avai_blocks.put(i) -global_paged_memory = PagedMemory(11500, 256, main_device, main_dtype) +global_paged_memory = PagedMemory(11600, 256, main_device, main_dtype) class PagedKVCache(KVCacheBase): @@ -38,6 +38,7 @@ def get_kv_cache_length(self, cur: int, seqlen: int) -> int: def __init__(self, x: torch.Tensor, start_pos: int, main_device: torch.device): self.main_device = main_device + self.is_clear = False bsz, seqlen = x.shape[0], x.shape[1] length = ( self.get_kv_cache_length(0, start_pos + seqlen) @@ -53,7 +54,11 @@ def __init__(self, x: torch.Tensor, start_pos: int, main_device: torch.device): ) for i in range(length): for j in range(bsz): - self.block_table[j, i] = global_paged_memory.avai_blocks.get() + try: + blk = global_paged_memory.avai_blocks.get(block=False) + except queue.Empty: + assert False, "No available block" + self.block_table[j, i] = blk def renew( self, @@ -88,6 +93,9 @@ def renew( self.block_table = block_table def clear(self) -> None: + if self.is_clear: + assert False, "Already cleared" + self.is_clear = True for row in self.block_table.tolist(): for item in row: global_paged_memory.avai_blocks.put(item) diff --git a/deserve_worker/task.py b/deserve_worker/task.py new file mode 100644 index 0000000..61c738e --- /dev/null +++ b/deserve_worker/task.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass + +import torch +from pydantic import BaseModel + +from .kvcache import KVCacheBase +from .layer_storage import LayerStorage + + +class PlanStep(BaseModel): + worker_id: str + worker_url: str + layers: list[str] + + +class SamplingParams(BaseModel): + temperature: float + top_p: float + max_total_len: int + + +class TaskInfo(BaseModel): + task_id: str + plan: list[PlanStep] + round: int + sampling_params: SamplingParams + + +@dataclass +class TaskData: + task_id: str + start_pos: int + plan: list[PlanStep] + round: int + sampling_params: SamplingParams + kvcaches: dict[int, KVCacheBase] + """ + When flash attention is enabled, we use paged attention, otherwise the standard attention is adopted. + """ + + +class LayerForward: + def __init__( + self, + layer_storage: LayerStorage, + h: torch.Tensor, + seqlen: int, + task_data: TaskData, + need_sample: bool, + ): + self.layer_storage = layer_storage + self.h = h + self.seqlen = seqlen + self.task_info = task_data + self.need_sample = need_sample + + +@dataclass +class ResultBack: + x: torch.Tensor + task_id: str diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 8098fc0..a466e1b 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -1,5 +1,6 @@ import queue import threading +import time import traceback from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -10,12 +11,12 @@ from pydantic import BaseModel from transformers import AutoTokenizer # type: ignore -from deserve_worker.paged_kvcache import PagedKVCache - -from .forward_engine import ForwardEngine, LayerForward +from .forward_engine import ForwardEngine from .kvcache import KVCacheBase from .layer_storage import global_layer_manager from .model import dumps +from .paged_kvcache import PagedKVCache, global_paged_memory +from .task import LayerForward, PlanStep, ResultBack, SamplingParams, TaskData, TaskInfo EOS_TOKEN_ID = 128001 # for llama 3 only STOP_TOKEN_IDS = [128001, 128009] @@ -25,63 +26,83 @@ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") -class PlanStep(BaseModel): - worker_id: str - worker_url: str - layers: list[str] - - -@dataclass -class TaskInfo: - start_pos: int - - kvcaches: dict[int, KVCacheBase] - """ - When flash attention is enabled, we use paged attention, otherwise the standard attention is adopted. - """ - - -class SamplingParams(BaseModel): - temperature: float - top_p: float - max_total_len: int - - -def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: - """ - Perform top-p (nucleus) sampling on a probability distribution. - - Args: - probs (torch.Tensor): Probability distribution tensor. - p (float): Probability threshold for top-p sampling. - - Returns: - torch.Tensor: Sampled token indices. - - Note: - Top-p sampling selects the smallest set of tokens whose cumulative probability mass - exceeds the threshold p. The distribution is renormalized based on the selected tokens. - - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token - - class Worker: def __init__(self, worker_id: str, max_total_bsz: int, controller_url: str): self.worker_id = worker_id self.controller_url = controller_url - self.task_infos: dict[str, TaskInfo] = {} - self.forward_engine = ForwardEngine(max_total_bsz) + self.task_datas: dict[str, TaskData] = {} + self.relay_queue = queue.Queue[list[ResultBack]]() + self.forward_engine = ForwardEngine(max_total_bsz, self.relay_queue) threading.Thread(target=self.forward_engine.run, daemon=True).start() + threading.Thread(target=self.relay, daemon=True).start() self.network_executor = ThreadPoolExecutor(max_workers=max_total_bsz) + def locate_in_plan(self, plan: list[PlanStep]) -> Optional[int]: + return next( + (i for i, worker in enumerate(plan) if worker.worker_id == self.worker_id), + None, + ) + + def init_task_data( + self, x: torch.Tensor, index: int, task_info: TaskInfo + ) -> TaskData: + if task_info.round == 0: + kvcaches = {} + for full_layer_name in task_info.plan[index].layers: + _, layer_name = full_layer_name.split("/") + if layer_name.startswith("layers."): + layer_id = int(layer_name.split(".")[1]) + kvcaches[layer_id] = PagedKVCache(x, 0, torch.device("cuda")) + + # TODO: need double check whether request is repeated + task_data = TaskData( + task_id=task_info.task_id, + start_pos=0, + plan=task_info.plan, + round=0, + sampling_params=task_info.sampling_params, + kvcaches=cast(dict[int, KVCacheBase], kvcaches), + ) + self.task_datas[task_info.task_id] = task_data + else: + task_data = self.task_datas[task_info.task_id] + task_data.round = task_info.round + + return task_data + + def batch_forward( + self, + xs: torch.Tensor, + task_infos: list[TaskInfo], + ) -> None: + ptr = 0 + forwards = [] + begin = time.time() + for ptr, task_info in enumerate(task_infos): + x = xs[ptr : ptr + 1] + plan = task_infos[0].plan + index = self.locate_in_plan(plan) + assert index is not None + layer_storage = global_layer_manager.get_layer_storage( + task_infos[0].plan[index].layers + ) + bsz, seqlen = x.shape[:2] # currently bsz is not used + forwards.append( + LayerForward( + layer_storage=layer_storage, + h=x.to("cuda"), + seqlen=seqlen, + task_data=self.init_task_data( + x, + index, + task_info, + ), + need_sample=(index == len(plan) - 1), + ) + ) + print("process time:", (time.time() - begin) * 1000) + self.forward_engine.add_layer_forward(forwards) + def forward( self, x: torch.Tensor, @@ -90,114 +111,110 @@ def forward( plan: list[PlanStep], sampling_params: SamplingParams, ) -> None: - try: - index = next( - ( - i - for i, worker in enumerate(plan) - if worker.worker_id == self.worker_id + index = self.locate_in_plan(plan) + if index is None: + return None + + bsz, seqlen = x.shape[:2] # currently bsz is not used + layer_storage = global_layer_manager.get_layer_storage(plan[index].layers) + layer_forward = LayerForward( + layer_storage=layer_storage, + h=x.to("cuda"), + seqlen=seqlen, + task_data=self.init_task_data( + x, + index, + TaskInfo( + task_id=task_id, + plan=plan, + round=round, + sampling_params=sampling_params, ), - None, - ) - if index is None: - return None - - if round == 0: - kvcaches = {} - for full_layer_name in plan[index].layers: - _, layer_name = full_layer_name.split("/") - if layer_name.startswith("layers."): - layer_id = int(layer_name.split(".")[1]) - kvcaches[layer_id] = PagedKVCache(x, 0, torch.device("cuda")) - - # TODO: need double check whether request is repeated - self.task_infos[task_id] = TaskInfo( - start_pos=0, - kvcaches=cast(dict[int, KVCacheBase], kvcaches), - ) - - bsz, seqlen = x.shape[:2] # currently bsz is not used - task_info = self.task_infos[task_id] - layer_storage = global_layer_manager.get_layer_storage(plan[index].layers) - back = queue.Queue[ - torch.Tensor - ]() # used to transfer tensor between engine and worker - layer_forward = LayerForward( - layer_storage=layer_storage, - h=x.to("cuda"), - seqlen=seqlen, - start_pos=task_info.start_pos, - kvcaches=task_info.kvcaches, - back=back, - ) - self.forward_engine.add_layer_forward(layer_forward) - h = back.get() - task_info.start_pos += seqlen - - to_pass: torch.Tensor - cancel = False - if index == len(plan) - 1: - # it's the last node in the plan, firstly generate token - if task_info.start_pos > sampling_params.max_total_len: - next_token = torch.tensor([[EOS_TOKEN_ID]]) - elif sampling_params.temperature > 0: - probs = torch.softmax( - h[:, -1] / sampling_params.temperature, dim=-1 + ), + need_sample=(index == len(plan) - 1), + ) + self.forward_engine.add_layer_forward([layer_forward]) + + def relay(self) -> None: + q = self.relay_queue + while True: + results: list[ResultBack] = q.get() + while True: + try: + tasks = q.get(block=False) + results.extend(tasks) + except queue.Empty: + break + + updated_tasks = [] + forward_tasks = [] + forward_tensors = [] + for result in results: + task_id = result.task_id + task_info = self.task_datas[task_id] + plan = task_info.plan + index = self.locate_in_plan(plan) + assert index is not None + + cancel = False + if index == len(plan) - 1: + tokens = result.x.tolist() + + updated_tasks.append( + { + "task_id": task_id, + "output_tokens": tokens, + } ) - next_token = sample_top_p(probs, sampling_params.top_p) - next_token = next_token.reshape(1, -1) - else: - next_token = torch.argmax(h[:, -1], dim=-1) - next_token = next_token.reshape(1, -1) - to_pass = next_token.to("cpu") - - # check whether to stop - if to_pass[0] in STOP_TOKEN_IDS: - cancel = True - round += 1 - self.network_executor.submit( - requests.post, - f"{self.controller_url}/update_tasks", - json=[ + if tokens[0][0] in STOP_TOKEN_IDS: + cancel = True + + next_index = (index + 1) % len(plan) + next_worker_url = task_info.plan[next_index].worker_url + if cancel: + task_info = self.task_datas.pop(task_id) + for kvcache in task_info.kvcaches.values(): + kvcache.clear() + if next_index != len(plan) - 1: + self.network_executor.submit( + requests.post, + f"{next_worker_url}/cancel", + json={ + "task_id": task_id, + "plan": [step.model_dump() for step in plan], + }, + ) + else: + forward_tasks.append( { "task_id": task_id, - "output_tokens": to_pass.tolist(), + "round": task_info.round, + "plan": plan, + "sampling_params": task_info.sampling_params, } - ], - ) - else: - # pass tensor to next node - to_pass = h + ) + forward_tensors.append(result.x) - next_index = (index + 1) % len(plan) - next_worker_url = plan[next_index].worker_url + self.network_executor.submit( + requests.post, + f"{self.controller_url}/update_tasks", + json=updated_tasks, + ) - if cancel: - self.network_executor.submit( - requests.post, - f"{next_worker_url}/cancel", - json={ - "task_id": task_id, - "plan": [step.model_dump() for step in plan], + if len(forward_tasks) > 0: + x = torch.cat(forward_tensors) + data = dumps( + {"x": x}, + { + "task_infos": forward_tasks, }, ) - else: self.network_executor.submit( requests.post, - f"{next_worker_url}/forward", - data=dumps( - {"x": to_pass}, - { - "task_id": task_id, - "round": round, - "plan": plan, - "sampling_params": sampling_params, - }, - ), + f"{next_worker_url}/batch_forward", + data=data, ) - except Exception as e: - traceback.print_exc() def cancel(self, task_id: str, plan: list[PlanStep]) -> None: index = next( @@ -206,14 +223,14 @@ def cancel(self, task_id: str, plan: list[PlanStep]) -> None: if index is None: return - task_info = self.task_infos.pop(task_id, None) + task_info = self.task_datas.pop(task_id, None) if task_info is not None: for kvcache in task_info.kvcaches.values(): kvcache.clear() - - if index != len(plan) - 1: + next_index = (index + 1) % len(plan) + if next_index != len(plan) - 1: requests.post( - f"{plan[index + 1].worker_url}/cancel", + f"{plan[next_index].worker_url}/cancel", json={ "task_id": task_id, "plan": plan, diff --git a/deserve_worker/worker_api.py b/deserve_worker/worker_api.py index 5094dae..124f756 100644 --- a/deserve_worker/worker_api.py +++ b/deserve_worker/worker_api.py @@ -1,4 +1,5 @@ import sys +import time import traceback from concurrent.futures import ThreadPoolExecutor @@ -6,13 +7,26 @@ from pydantic import BaseModel from .model import loads -from .worker import PlanStep, SamplingParams, Worker +from .task import PlanStep, SamplingParams, TaskInfo +from .worker import Worker app = FastAPI() -worker = Worker(sys.argv[2], 64, "http://localhost:29980") +worker = Worker(sys.argv[2], 48, "http://localhost:29980") runtime_executor = ThreadPoolExecutor(max_workers=64) +@app.post("/batch_forward") +async def batch_forward(request: Request) -> str: + body = await request.body() + tensors, metadata = loads(body) + runtime_executor.submit( + worker.batch_forward, + tensors["x"], + [TaskInfo.model_validate(task_info) for task_info in metadata["task_infos"]], + ) + return "ok" + + @app.post("/forward") async def forward(request: Request) -> str: try: From a858053e039715f76910c138f0b5903793963236 Mon Sep 17 00:00:00 2001 From: Celve Date: Tue, 30 Jul 2024 15:30:24 -0700 Subject: [PATCH 03/17] chore: simplify some structure --- deserve_worker/forward_engine.py | 20 +++++++++++--------- deserve_worker/kvcache.py | 5 ++--- deserve_worker/paged_kvcache.py | 4 ++-- deserve_worker/task.py | 2 -- deserve_worker/worker.py | 25 +++++++++++-------------- 5 files changed, 26 insertions(+), 30 deletions(-) diff --git a/deserve_worker/forward_engine.py b/deserve_worker/forward_engine.py index 61f3917..182776b 100644 --- a/deserve_worker/forward_engine.py +++ b/deserve_worker/forward_engine.py @@ -88,15 +88,14 @@ def run(self) -> None: q = self.handling_queue while True: forwards: list[LayerForward] = q.get() - begin = time.time() while True: try: news = q.get(block=False) forwards.extend(news) except queue.Empty: break - prefill_tasks = [task for task in forwards if task.seqlen > 1] - decode_tasks = [task for task in forwards if task.seqlen == 1] + prefill_tasks = [task for task in forwards if task.h.shape[1] > 1] + decode_tasks = [task for task in forwards if task.h.shape[1] == 1] # print( # f"prefill_tasks: {len(prefill_tasks)}, decode_tasks: {len(decode_tasks)}" # ) @@ -120,12 +119,15 @@ def forward(self, tasks: list[LayerForward]) -> torch.Tensor: with torch.inference_mode(): layer_storage = tasks[0].layer_storage h = torch.cat([t.h for t in tasks]) - bsz_list, start_pos_list = [1 for _ in tasks], [ - t.task_info.start_pos for t in tasks - ] + bsz_list = [] + start_pos_list = [] for task in tasks: + bsz_list.append(1) + start_pos_list.append(task.task_info.start_pos) for kvcache in task.task_info.kvcaches.values(): - kvcache.renew(task.h, task.task_info.start_pos) + kvcache.renew( + task.h.shape[0], task.h.shape[1], task.task_info.start_pos + ) return layer_storage.forward( h, bsz_list, @@ -155,7 +157,7 @@ def process(self, merged_h: torch.Tensor, tasks: list[LayerForward]) -> None: else: next_token = torch.argmax(h[:, -1], dim=-1) next_token = next_token.reshape(1, -1) - result.append(ResultBack(next_token.to("cpu"), task_info.task_id)) + result.append(ResultBack(next_token, task_info.task_id)) else: - result.append(ResultBack(h.to("cpu"), task_info.task_id)) + result.append(ResultBack(h, task_info.task_id)) self.sendback_queue.put(result) diff --git a/deserve_worker/kvcache.py b/deserve_worker/kvcache.py index 7501bcc..8ed461f 100644 --- a/deserve_worker/kvcache.py +++ b/deserve_worker/kvcache.py @@ -15,7 +15,7 @@ def del_tensor(t: torch.Tensor) -> None: class KVCacheBase(ABC): @abstractmethod - def renew(self, x: torch.Tensor, start_pos: int) -> None: + def renew(self, bsz: int, seqlen: int, start_pos: int) -> None: pass @abstractmethod @@ -66,8 +66,7 @@ def __init__( ) self.main_device = main_device - def renew(self, x: torch.Tensor, start_pos: int) -> None: - bsz, seqlen = x.shape[0], x.shape[1] + def renew(self, bsz: int, seqlen: int, start_pos: int) -> None: if start_pos + seqlen > self.cache_k.shape[1]: length = self.get_kv_cache_length(self.cache_k.shape[1], start_pos + seqlen) cache_k = torch.zeros( diff --git a/deserve_worker/paged_kvcache.py b/deserve_worker/paged_kvcache.py index 0465c25..1781b97 100644 --- a/deserve_worker/paged_kvcache.py +++ b/deserve_worker/paged_kvcache.py @@ -62,10 +62,10 @@ def __init__(self, x: torch.Tensor, start_pos: int, main_device: torch.device): def renew( self, - x: torch.Tensor, + bsz: int, + seqlen: int, start_pos: int, ) -> None: - bsz, seqlen = x.shape[0], x.shape[1] if ( start_pos + seqlen > self.block_table.shape[1] * global_paged_memory.block_size diff --git a/deserve_worker/task.py b/deserve_worker/task.py index 61c738e..bda6736 100644 --- a/deserve_worker/task.py +++ b/deserve_worker/task.py @@ -44,13 +44,11 @@ def __init__( self, layer_storage: LayerStorage, h: torch.Tensor, - seqlen: int, task_data: TaskData, need_sample: bool, ): self.layer_storage = layer_storage self.h = h - self.seqlen = seqlen self.task_info = task_data self.need_sample = need_sample diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index a466e1b..826ca9b 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -78,22 +78,21 @@ def batch_forward( ptr = 0 forwards = [] begin = time.time() + plan = task_infos[0].plan + index = self.locate_in_plan(plan) + assert index is not None + layer_storage = global_layer_manager.get_layer_storage( + task_infos[0].plan[index].layers + ) + xs_cuda = xs.to("cuda") for ptr, task_info in enumerate(task_infos): - x = xs[ptr : ptr + 1] - plan = task_infos[0].plan - index = self.locate_in_plan(plan) - assert index is not None - layer_storage = global_layer_manager.get_layer_storage( - task_infos[0].plan[index].layers - ) - bsz, seqlen = x.shape[:2] # currently bsz is not used + x_cuda = xs_cuda[ptr : ptr + 1] forwards.append( LayerForward( layer_storage=layer_storage, - h=x.to("cuda"), - seqlen=seqlen, + h=x_cuda, task_data=self.init_task_data( - x, + x_cuda, index, task_info, ), @@ -115,12 +114,10 @@ def forward( if index is None: return None - bsz, seqlen = x.shape[:2] # currently bsz is not used layer_storage = global_layer_manager.get_layer_storage(plan[index].layers) layer_forward = LayerForward( layer_storage=layer_storage, h=x.to("cuda"), - seqlen=seqlen, task_data=self.init_task_data( x, index, @@ -203,7 +200,7 @@ def relay(self) -> None: ) if len(forward_tasks) > 0: - x = torch.cat(forward_tensors) + x = torch.cat(forward_tensors).to("cpu") data = dumps( {"x": x}, { From 7d6063ba7862f6374b8c21479122ce09f06612ab Mon Sep 17 00:00:00 2001 From: Celve Date: Tue, 30 Jul 2024 22:33:44 -0700 Subject: [PATCH 04/17] feat: support batch forward --- deserve_worker/forward_engine.py | 112 ++++++++++++++++--------------- deserve_worker/layer_storage.py | 12 +++- deserve_worker/model.py | 2 +- deserve_worker/paged_kvcache.py | 5 +- deserve_worker/task.py | 15 +++++ deserve_worker/worker.py | 72 ++++++++++---------- deserve_worker/worker_api.py | 3 +- 7 files changed, 121 insertions(+), 100 deletions(-) diff --git a/deserve_worker/forward_engine.py b/deserve_worker/forward_engine.py index 182776b..4a74c1f 100644 --- a/deserve_worker/forward_engine.py +++ b/deserve_worker/forward_engine.py @@ -1,15 +1,14 @@ +import itertools import queue -import time from dataclasses import dataclass from typing import Optional import torch -from deserve_worker.task import LayerForward, ResultBack - -from .kvcache import KVCacheBase +from .kvcache import KVCacheBase, main_device from .layer_storage import LayerStorage from .model import ENABLE_FLASH_ATTN +from .task import BatchForward, BatchResult, LayerForward, ResultBack EOS_TOKEN_ID = 128001 # for llama 3 only @@ -41,7 +40,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te return freqs_cis -global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0).to("cuda") +global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0).to(main_device) def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: @@ -70,83 +69,90 @@ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: return next_token -def trace(begin: float, msg: str) -> float: - end = time.time() - print(f"{msg}: {(end - begin)*1000:.3f}ms") - return end - - class ForwardEngine: def __init__( self, max_total_bsz: int, sendback_queue: queue.Queue[list[ResultBack]] ): self.max_total_bsz = max_total_bsz self.sendback_queue = sendback_queue - self.handling_queue = queue.Queue[list[LayerForward]]() + self.handling_queue = queue.Queue[BatchForward]() def run(self) -> None: q = self.handling_queue while True: - forwards: list[LayerForward] = q.get() + forwards: list[BatchForward] = [q.get()] while True: try: - news = q.get(block=False) - forwards.extend(news) + new = q.get(block=False) + forwards.append(new) except queue.Empty: break - prefill_tasks = [task for task in forwards if task.h.shape[1] > 1] - decode_tasks = [task for task in forwards if task.h.shape[1] == 1] - # print( - # f"prefill_tasks: {len(prefill_tasks)}, decode_tasks: {len(decode_tasks)}" - # ) + prefill_tasks = [task for task in forwards if task.xs.shape[1] > 1] + decode_tasks = [task for task in forwards if task.xs.shape[1] == 1] for task in prefill_tasks: - h = self.forward([task]) - self.process(h, [task]) + h = self.forward(task) + self.process(h, task) - for i in range(0, len(decode_tasks), self.max_total_bsz): - to_decode = decode_tasks[ - i : min(i + self.max_total_bsz, len(decode_tasks)) - ] - h = self.forward(to_decode) - self.process(h, to_decode) + print( + f"prefill_tasks: {len(prefill_tasks)}, decode_tasks: {sum(task.xs.shape[0] for task in decode_tasks)}" + ) - def add_layer_forward(self, forwards: list[LayerForward]) -> None: + decode_tasks.sort(key=lambda task: task.xs.shape[0], reverse=False) + while len(decode_tasks) > 0: + total_bsz = 0 + todo_tasks = [] + for i in reversed(range(len(decode_tasks))): + cur_bsz = decode_tasks[i].xs.shape[0] + if total_bsz + cur_bsz > self.max_total_bsz: + continue + total_bsz += cur_bsz + todo_tasks.append(decode_tasks.pop(i)) + new_task_datas = [] + for task in todo_tasks: + new_task_datas.extend(task.task_datas) + new_xs = torch.cat([task.xs for task in todo_tasks]) + new_task = BatchForward( + xs=new_xs, + layer_storage=todo_tasks[0].layer_storage, + task_datas=new_task_datas, + need_sample=todo_tasks[0].need_sample, + ) + h = self.forward(new_task) + self.process(h, new_task) + + def add_batch_forward(self, forwards: BatchForward) -> None: self.handling_queue.put(forwards) - def forward(self, tasks: list[LayerForward]) -> torch.Tensor: + def forward(self, tasks: BatchForward) -> torch.Tensor: # we need to check that all tasks share the same layer storage with torch.inference_mode(): - layer_storage = tasks[0].layer_storage - h = torch.cat([t.h for t in tasks]) - bsz_list = [] - start_pos_list = [] - for task in tasks: - bsz_list.append(1) - start_pos_list.append(task.task_info.start_pos) - for kvcache in task.task_info.kvcaches.values(): - kvcache.renew( - task.h.shape[0], task.h.shape[1], task.task_info.start_pos - ) - return layer_storage.forward( + torch.cuda.synchronize() + layer_storage = tasks.layer_storage + h = tasks.xs + bsz_list = [1 for _ in range(len(tasks.task_datas))] + start_pos_list = [task.start_pos for task in tasks.task_datas] + kvcache_list = [task.kvcaches for task in tasks.task_datas] + result = layer_storage.forward( h, bsz_list, start_pos_list, global_freqs_cis, - [task.task_info.kvcaches for task in tasks], + kvcache_list, ) + torch.cuda.synchronize() + return result - def process(self, merged_h: torch.Tensor, tasks: list[LayerForward]) -> None: + def process(self, merged_h: torch.Tensor, tasks: BatchForward) -> None: result: list[ResultBack] = [] - for ptr, task in enumerate(tasks): + for ptr, task_data in enumerate(tasks.task_datas): h = merged_h[ptr : ptr + 1] _, seqlen = h.shape[:2] - task_info = task.task_info - task_info.start_pos += seqlen - if task.need_sample: - task_info.round += 1 - sampling_params = task_info.sampling_params - if task_info.start_pos >= sampling_params.max_total_len: + task_data.start_pos += seqlen + if tasks.need_sample: + task_data.round += 1 + sampling_params = task_data.sampling_params + if task_data.start_pos >= sampling_params.max_total_len: next_token = torch.tensor([[EOS_TOKEN_ID]]) elif sampling_params.temperature > 0: probs = torch.softmax( @@ -157,7 +163,7 @@ def process(self, merged_h: torch.Tensor, tasks: list[LayerForward]) -> None: else: next_token = torch.argmax(h[:, -1], dim=-1) next_token = next_token.reshape(1, -1) - result.append(ResultBack(next_token, task_info.task_id)) + result.append(ResultBack(next_token, task_data.task_id)) else: - result.append(ResultBack(h, task_info.task_id)) + result.append(ResultBack(h, task_data.task_id)) self.sendback_queue.put(result) diff --git a/deserve_worker/layer_storage.py b/deserve_worker/layer_storage.py index d2de1c0..f350ffa 100644 --- a/deserve_worker/layer_storage.py +++ b/deserve_worker/layer_storage.py @@ -5,7 +5,7 @@ import requests import torch -from .kvcache import KVCacheBase +from .kvcache import KVCacheBase, main_device from .model import ENABLE_FLASH_ATTN, ModelArgs, RMSNorm, TransformerBlock llama_2_7b_args = { @@ -80,7 +80,7 @@ def get_layer_storage(self, layer_names: list[str]) -> "LayerStorage": return self.layer_storage_map[frozen_layer_names] -global_layer_manager = LayerManager(torch.device("cuda")) +global_layer_manager = LayerManager(main_device) class LayerStorage: @@ -163,6 +163,7 @@ def unload_layers(self) -> None: del self.layers[full_layer_name] torch.cuda.empty_cache() + @torch.inference_mode() def forward( self, h: torch.Tensor, @@ -171,18 +172,23 @@ def forward( global_freqs_cis: torch.Tensor, kv_cache_list: list[dict[int, KVCacheBase]], ) -> torch.Tensor: + _, seqlen = h.shape[:2] for full_layer_name in self.layers: _, layer_name = full_layer_name.split("/") if layer_name == "tok_embeddings": h = self.layers[full_layer_name](h) elif layer_name.startswith("layers."): layer_id = int(layer_name.split(".")[1]) + cur_kv_cache_list = [] + for i, kv_cache in enumerate(kv_cache_list): + kv_cache[layer_id].renew(1, seqlen, start_pos_list[i]) + cur_kv_cache_list.append(kv_cache[layer_id]) h = self.layers[full_layer_name]( h, bsz_list, start_pos_list, global_freqs_cis, - [kv_cache[layer_id] for kv_cache in kv_cache_list], + cur_kv_cache_list, ) elif layer_name == "norm": h = self.layers[full_layer_name](h) diff --git a/deserve_worker/model.py b/deserve_worker/model.py index 09701f5..ca0e5f5 100644 --- a/deserve_worker/model.py +++ b/deserve_worker/model.py @@ -292,7 +292,7 @@ def forward( cache_seqlens, dtype=torch.int32, device=x.device ) bsz = cache_seqlens_tch.shape[0] - paged_kv_cache_list: list[PagedKVCache] = kv_cache_list # type: ignore + paged_kv_cache_list = cast(list[PagedKVCache], kv_cache_list) max_len = max([kvcache.shape()[1] for kvcache in paged_kv_cache_list]) block_table = torch.zeros( diff --git a/deserve_worker/paged_kvcache.py b/deserve_worker/paged_kvcache.py index 1781b97..9b64019 100644 --- a/deserve_worker/paged_kvcache.py +++ b/deserve_worker/paged_kvcache.py @@ -54,10 +54,7 @@ def __init__(self, x: torch.Tensor, start_pos: int, main_device: torch.device): ) for i in range(length): for j in range(bsz): - try: - blk = global_paged_memory.avai_blocks.get(block=False) - except queue.Empty: - assert False, "No available block" + blk = global_paged_memory.avai_blocks.get(block=False) self.block_table[j, i] = blk def renew( diff --git a/deserve_worker/task.py b/deserve_worker/task.py index bda6736..9dd975d 100644 --- a/deserve_worker/task.py +++ b/deserve_worker/task.py @@ -53,6 +53,21 @@ def __init__( self.need_sample = need_sample +@dataclass +class BatchForward: + xs: torch.Tensor + layer_storage: LayerStorage + task_datas: list[TaskData] + need_sample: bool # to be eliminated in the future, because we can infer this from LayerStorage + + +@dataclass +class BatchResult: + xs: torch.Tensor + task_ids: list[str] + done_ids: list[str] + + @dataclass class ResultBack: x: torch.Tensor diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 826ca9b..19e77c9 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -1,6 +1,5 @@ import queue import threading -import time import traceback from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -12,11 +11,20 @@ from transformers import AutoTokenizer # type: ignore from .forward_engine import ForwardEngine -from .kvcache import KVCacheBase +from .kvcache import KVCacheBase, main_device from .layer_storage import global_layer_manager from .model import dumps from .paged_kvcache import PagedKVCache, global_paged_memory -from .task import LayerForward, PlanStep, ResultBack, SamplingParams, TaskData, TaskInfo +from .task import ( + BatchForward, + BatchResult, + LayerForward, + PlanStep, + ResultBack, + SamplingParams, + TaskData, + TaskInfo, +) EOS_TOKEN_ID = 128001 # for llama 3 only STOP_TOKEN_IDS = [128001, 128009] @@ -52,7 +60,7 @@ def init_task_data( _, layer_name = full_layer_name.split("/") if layer_name.startswith("layers."): layer_id = int(layer_name.split(".")[1]) - kvcaches[layer_id] = PagedKVCache(x, 0, torch.device("cuda")) + kvcaches[layer_id] = PagedKVCache(x, 0, main_device) # TODO: need double check whether request is repeated task_data = TaskData( @@ -75,32 +83,20 @@ def batch_forward( xs: torch.Tensor, task_infos: list[TaskInfo], ) -> None: - ptr = 0 - forwards = [] - begin = time.time() plan = task_infos[0].plan index = self.locate_in_plan(plan) assert index is not None layer_storage = global_layer_manager.get_layer_storage( task_infos[0].plan[index].layers ) - xs_cuda = xs.to("cuda") - for ptr, task_info in enumerate(task_infos): - x_cuda = xs_cuda[ptr : ptr + 1] - forwards.append( - LayerForward( - layer_storage=layer_storage, - h=x_cuda, - task_data=self.init_task_data( - x_cuda, - index, - task_info, - ), - need_sample=(index == len(plan) - 1), - ) + task_datas = [ + self.init_task_data(xs, index, task_info) for task_info in task_infos + ] + self.forward_engine.add_batch_forward( + BatchForward( + xs.to(main_device), layer_storage, task_datas, (index == len(plan) - 1) ) - print("process time:", (time.time() - begin) * 1000) - self.forward_engine.add_layer_forward(forwards) + ) def forward( self, @@ -115,22 +111,24 @@ def forward( return None layer_storage = global_layer_manager.get_layer_storage(plan[index].layers) - layer_forward = LayerForward( + layer_forward = BatchForward( + xs=x.to(main_device), layer_storage=layer_storage, - h=x.to("cuda"), - task_data=self.init_task_data( - x, - index, - TaskInfo( - task_id=task_id, - plan=plan, - round=round, - sampling_params=sampling_params, - ), - ), + task_datas=[ + self.init_task_data( + x, + index, + TaskInfo( + task_id=task_id, + plan=plan, + round=round, + sampling_params=sampling_params, + ), + ) + ], need_sample=(index == len(plan) - 1), ) - self.forward_engine.add_layer_forward([layer_forward]) + self.forward_engine.add_batch_forward(layer_forward) def relay(self) -> None: q = self.relay_queue @@ -200,7 +198,7 @@ def relay(self) -> None: ) if len(forward_tasks) > 0: - x = torch.cat(forward_tensors).to("cpu") + x = torch.cat(forward_tensors) data = dumps( {"x": x}, { diff --git a/deserve_worker/worker_api.py b/deserve_worker/worker_api.py index 124f756..754028b 100644 --- a/deserve_worker/worker_api.py +++ b/deserve_worker/worker_api.py @@ -1,5 +1,4 @@ import sys -import time import traceback from concurrent.futures import ThreadPoolExecutor @@ -12,7 +11,7 @@ app = FastAPI() worker = Worker(sys.argv[2], 48, "http://localhost:29980") -runtime_executor = ThreadPoolExecutor(max_workers=64) +runtime_executor = ThreadPoolExecutor(max_workers=96) @app.post("/batch_forward") From 5c2f700dab3f90038283814053576032a0e3ca30 Mon Sep 17 00:00:00 2001 From: Celve Date: Wed, 31 Jul 2024 01:30:54 -0700 Subject: [PATCH 05/17] feat: support batch result --- deserve_worker/forward_engine.py | 48 +++++++++----- deserve_worker/task.py | 9 ++- deserve_worker/worker.py | 103 +++++++++++-------------------- 3 files changed, 77 insertions(+), 83 deletions(-) diff --git a/deserve_worker/forward_engine.py b/deserve_worker/forward_engine.py index 4a74c1f..e30a635 100644 --- a/deserve_worker/forward_engine.py +++ b/deserve_worker/forward_engine.py @@ -1,5 +1,6 @@ import itertools import queue +import time from dataclasses import dataclass from typing import Optional @@ -8,9 +9,10 @@ from .kvcache import KVCacheBase, main_device from .layer_storage import LayerStorage from .model import ENABLE_FLASH_ATTN -from .task import BatchForward, BatchResult, LayerForward, ResultBack +from .task import BatchForward, BatchResult, BatchUpdate, LayerForward, ResultBack EOS_TOKEN_ID = 128001 # for llama 3 only +STOP_TOKEN_IDS = [128001, 128009] def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: @@ -71,7 +73,7 @@ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: class ForwardEngine: def __init__( - self, max_total_bsz: int, sendback_queue: queue.Queue[list[ResultBack]] + self, max_total_bsz: int, sendback_queue: queue.Queue[BatchResult | BatchUpdate] ): self.max_total_bsz = max_total_bsz self.sendback_queue = sendback_queue @@ -127,7 +129,6 @@ def add_batch_forward(self, forwards: BatchForward) -> None: def forward(self, tasks: BatchForward) -> torch.Tensor: # we need to check that all tasks share the same layer storage with torch.inference_mode(): - torch.cuda.synchronize() layer_storage = tasks.layer_storage h = tasks.xs bsz_list = [1 for _ in range(len(tasks.task_datas))] @@ -140,16 +141,19 @@ def forward(self, tasks: BatchForward) -> torch.Tensor: global_freqs_cis, kvcache_list, ) - torch.cuda.synchronize() return result def process(self, merged_h: torch.Tensor, tasks: BatchForward) -> None: - result: list[ResultBack] = [] - for ptr, task_data in enumerate(tasks.task_datas): - h = merged_h[ptr : ptr + 1] - _, seqlen = h.shape[:2] - task_data.start_pos += seqlen - if tasks.need_sample: + if tasks.need_sample: + ongoing_tokens = [] + ongoing_ids = [] + all_tokens = [] + all_ids = [] + done_ids = [] + for ptr, task_data in enumerate(tasks.task_datas): + h = merged_h[ptr : ptr + 1] + _, seqlen = h.shape[:2] + task_data.start_pos += seqlen task_data.round += 1 sampling_params = task_data.sampling_params if task_data.start_pos >= sampling_params.max_total_len: @@ -163,7 +167,23 @@ def process(self, merged_h: torch.Tensor, tasks: BatchForward) -> None: else: next_token = torch.argmax(h[:, -1], dim=-1) next_token = next_token.reshape(1, -1) - result.append(ResultBack(next_token, task_data.task_id)) - else: - result.append(ResultBack(h, task_data.task_id)) - self.sendback_queue.put(result) + next_token = next_token.to("cpu") + all_ids.append(task_data.task_id) + all_tokens.append(next_token) + if next_token[0][0] in STOP_TOKEN_IDS: + done_ids.append(task_data.task_id) + else: + ongoing_ids.append(task_data.task_id) + ongoing_tokens.append(next_token) + if len(ongoing_tokens) > 0: + self.sendback_queue.put( + BatchResult(torch.cat(ongoing_tokens), ongoing_ids) + ) + self.sendback_queue.put(BatchUpdate(all_tokens, all_ids, done_ids)) + else: + seqlen = tasks.xs.shape[1] + for task in tasks.task_datas: + task.start_pos += seqlen + self.sendback_queue.put( + BatchResult(merged_h, [task.task_id for task in tasks.task_datas]) + ) diff --git a/deserve_worker/task.py b/deserve_worker/task.py index 9dd975d..779a9c0 100644 --- a/deserve_worker/task.py +++ b/deserve_worker/task.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional import torch from pydantic import BaseModel @@ -65,7 +66,13 @@ class BatchForward: class BatchResult: xs: torch.Tensor task_ids: list[str] - done_ids: list[str] + + +@dataclass +class BatchUpdate: + tokens: list[torch.Tensor] + task_ids: list[str] + cancel_ids: list[str] @dataclass diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 19e77c9..3bffe5e 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -18,6 +18,7 @@ from .task import ( BatchForward, BatchResult, + BatchUpdate, LayerForward, PlanStep, ResultBack, @@ -27,9 +28,6 @@ ) EOS_TOKEN_ID = 128001 # for llama 3 only -STOP_TOKEN_IDS = [128001, 128009] - -stop_tokens = torch.tensor(STOP_TOKEN_IDS) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") @@ -39,7 +37,7 @@ def __init__(self, worker_id: str, max_total_bsz: int, controller_url: str): self.worker_id = worker_id self.controller_url = controller_url self.task_datas: dict[str, TaskData] = {} - self.relay_queue = queue.Queue[list[ResultBack]]() + self.relay_queue = queue.Queue[BatchResult | BatchUpdate]() self.forward_engine = ForwardEngine(max_total_bsz, self.relay_queue) threading.Thread(target=self.forward_engine.run, daemon=True).start() threading.Thread(target=self.relay, daemon=True).start() @@ -133,76 +131,29 @@ def forward( def relay(self) -> None: q = self.relay_queue while True: - results: list[ResultBack] = q.get() - while True: - try: - tasks = q.get(block=False) - results.extend(tasks) - except queue.Empty: - break - - updated_tasks = [] - forward_tasks = [] - forward_tensors = [] - for result in results: - task_id = result.task_id + result = q.get() + if isinstance(result, BatchResult): + task_id = result.task_ids[0] task_info = self.task_datas[task_id] plan = task_info.plan index = self.locate_in_plan(plan) assert index is not None - - cancel = False - if index == len(plan) - 1: - tokens = result.x.tolist() - - updated_tasks.append( - { - "task_id": task_id, - "output_tokens": tokens, - } - ) - - if tokens[0][0] in STOP_TOKEN_IDS: - cancel = True - next_index = (index + 1) % len(plan) - next_worker_url = task_info.plan[next_index].worker_url - if cancel: - task_info = self.task_datas.pop(task_id) - for kvcache in task_info.kvcaches.values(): - kvcache.clear() - if next_index != len(plan) - 1: - self.network_executor.submit( - requests.post, - f"{next_worker_url}/cancel", - json={ - "task_id": task_id, - "plan": [step.model_dump() for step in plan], - }, - ) - else: - forward_tasks.append( - { - "task_id": task_id, - "round": task_info.round, - "plan": plan, - "sampling_params": task_info.sampling_params, - } - ) - forward_tensors.append(result.x) - - self.network_executor.submit( - requests.post, - f"{self.controller_url}/update_tasks", - json=updated_tasks, - ) - - if len(forward_tasks) > 0: - x = torch.cat(forward_tensors) + next_worker_url = plan[next_index].worker_url data = dumps( - {"x": x}, + {"x": result.xs}, { - "task_infos": forward_tasks, + "task_infos": [ + { + "task_id": task_id, + "round": self.task_datas[task_id].round, + "plan": plan, + "sampling_params": self.task_datas[ + task_id + ].sampling_params, + } + for task_id in result.task_ids + ], }, ) self.network_executor.submit( @@ -210,6 +161,22 @@ def relay(self) -> None: f"{next_worker_url}/batch_forward", data=data, ) + elif isinstance(result, BatchUpdate): + updated_tasks = [] + for tokens, task_id in zip(result.tokens, result.task_ids): + updated_tasks.append( + { + "task_id": task_id, + "output_tokens": tokens.tolist(), + } + ) + self.network_executor.submit( + requests.post, + f"{self.controller_url}/update_tasks", + json=updated_tasks, + ) + for task_id in result.cancel_ids: + self.cancel(task_id, self.task_datas[task_id].plan) def cancel(self, task_id: str, plan: list[PlanStep]) -> None: index = next( @@ -228,6 +195,6 @@ def cancel(self, task_id: str, plan: list[PlanStep]) -> None: f"{plan[next_index].worker_url}/cancel", json={ "task_id": task_id, - "plan": plan, + "plan": [step.model_dump() for step in plan], }, ) From 710bd5a5743bfce677149fb6532e753a5c56323c Mon Sep 17 00:00:00 2001 From: Celve Date: Wed, 31 Jul 2024 14:26:27 -0700 Subject: [PATCH 06/17] chore: move llama to model dir --- deserve_worker/forward_engine.py | 2 +- deserve_worker/layer_storage.py | 16 +++++++-------- deserve_worker/model/__init__.py | 0 deserve_worker/{model.py => model/llama.py} | 22 ++++++++++----------- deserve_worker/worker.py | 2 +- deserve_worker/worker_api.py | 2 +- 6 files changed, 22 insertions(+), 22 deletions(-) create mode 100644 deserve_worker/model/__init__.py rename deserve_worker/{model.py => model/llama.py} (96%) diff --git a/deserve_worker/forward_engine.py b/deserve_worker/forward_engine.py index e30a635..5a627d9 100644 --- a/deserve_worker/forward_engine.py +++ b/deserve_worker/forward_engine.py @@ -8,7 +8,7 @@ from .kvcache import KVCacheBase, main_device from .layer_storage import LayerStorage -from .model import ENABLE_FLASH_ATTN +from .model.llama import ENABLE_FLASH_ATTN from .task import BatchForward, BatchResult, BatchUpdate, LayerForward, ResultBack EOS_TOKEN_ID = 128001 # for llama 3 only diff --git a/deserve_worker/layer_storage.py b/deserve_worker/layer_storage.py index f350ffa..87e719b 100644 --- a/deserve_worker/layer_storage.py +++ b/deserve_worker/layer_storage.py @@ -6,7 +6,7 @@ import torch from .kvcache import KVCacheBase, main_device -from .model import ENABLE_FLASH_ATTN, ModelArgs, RMSNorm, TransformerBlock +from .model.llama import ENABLE_FLASH_ATTN, ModelArgs, RMSNorm, TransformerBlock llama_2_7b_args = { "dim": 4096, @@ -123,19 +123,19 @@ def preload_layers(self, layer_names: list[str]) -> None: path = thread.result() model_name, layer_name = full_layer_name.split("/") if model_name.startswith("llama-2-7b"): - model_args = ModelArgs(**llama_2_7b_args) + model_args = ModelArgs(**llama_2_7b_args) # type: ignore elif model_name.startswith("llama-2-13b"): - model_args = ModelArgs(**llama_2_13b_args) + model_args = ModelArgs(**llama_2_13b_args) # type: ignore elif model_name.startswith("llama-2-70b"): - model_args = ModelArgs(**llama_2_70b_args) + model_args = ModelArgs(**llama_2_70b_args) # type: ignore elif model_name.startswith("llama-3-8b"): - model_args = ModelArgs(**llama_3_8b_args) + model_args = ModelArgs(**llama_3_8b_args) # type: ignore elif model_name.startswith("llama-3-70b"): - model_args = ModelArgs(**llama_3_70b_args) + model_args = ModelArgs(**llama_3_70b_args) # type: ignore else: raise NotImplementedError("Unknown model") if layer_name == "tok_embeddings": - l = torch.nn.utils.skip_init( + l = torch.nn.utils.skip_init( # type: ignore torch.nn.Embedding, model_args.vocab_size, model_args.dim ) elif layer_name.startswith("layer"): @@ -143,7 +143,7 @@ def preload_layers(self, layer_names: list[str]) -> None: elif layer_name == "norm": l = RMSNorm(model_args.dim, eps=model_args.norm_eps) elif layer_name == "output": - l = torch.nn.utils.skip_init( + l = torch.nn.utils.skip_init( # type: ignore torch.nn.Linear, model_args.dim, model_args.vocab_size, diff --git a/deserve_worker/model/__init__.py b/deserve_worker/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deserve_worker/model.py b/deserve_worker/model/llama.py similarity index 96% rename from deserve_worker/model.py rename to deserve_worker/model/llama.py index ca0e5f5..e02f84a 100644 --- a/deserve_worker/model.py +++ b/deserve_worker/model/llama.py @@ -12,13 +12,13 @@ from deserve_worker.paged_kvcache import PagedKVCache -from .kvcache import KVCache, KVCacheBase +from ..kvcache import KVCache, KVCacheBase ENABLE_FLASH_ATTN = False try: from flash_attn import flash_attn_with_kvcache # type: ignore - from .paged_kvcache import global_paged_memory + from ..paged_kvcache import global_paged_memory ENABLE_FLASH_ATTN = True except ImportError as e: @@ -217,25 +217,25 @@ def __init__(self, args: ModelArgs): self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads - self.wq = torch.nn.utils.skip_init( + self.wq = torch.nn.utils.skip_init( # type: ignore nn.Linear, args.dim, args.n_heads * self.head_dim, bias=False, ) - self.wk = torch.nn.utils.skip_init( + self.wk = torch.nn.utils.skip_init( # type: ignore nn.Linear, args.dim, self.n_kv_heads * self.head_dim, bias=False, ) - self.wv = torch.nn.utils.skip_init( + self.wv = torch.nn.utils.skip_init( # type: ignore nn.Linear, args.dim, self.n_kv_heads * self.head_dim, bias=False, ) - self.wo = torch.nn.utils.skip_init( + self.wo = torch.nn.utils.skip_init( # type: ignore nn.Linear, args.n_heads * self.head_dim, args.dim, @@ -385,7 +385,7 @@ def forward( output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output_list.append(output) output = torch.cat([x for x in output_list]) - return self.wo(output) + return self.wo(output) # type: ignore class FeedForward(nn.Module): @@ -418,19 +418,19 @@ def __init__( hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.w1 = torch.nn.utils.skip_init( + self.w1 = torch.nn.utils.skip_init( # type: ignore nn.Linear, dim, hidden_dim, bias=False, ) - self.w2 = torch.nn.utils.skip_init( + self.w2 = torch.nn.utils.skip_init( # type: ignore nn.Linear, hidden_dim, dim, bias=False, ) - self.w3 = torch.nn.utils.skip_init( + self.w3 = torch.nn.utils.skip_init( # type: ignore nn.Linear, dim, hidden_dim, @@ -439,7 +439,7 @@ def __init__( @torch.inference_mode() def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + return self.w2(F.silu(self.w1(x)) * self.w3(x)) # type: ignore class TransformerBlock(nn.Module): diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 3bffe5e..9d04d55 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -13,7 +13,7 @@ from .forward_engine import ForwardEngine from .kvcache import KVCacheBase, main_device from .layer_storage import global_layer_manager -from .model import dumps +from .model.llama import dumps from .paged_kvcache import PagedKVCache, global_paged_memory from .task import ( BatchForward, diff --git a/deserve_worker/worker_api.py b/deserve_worker/worker_api.py index 754028b..f309615 100644 --- a/deserve_worker/worker_api.py +++ b/deserve_worker/worker_api.py @@ -5,7 +5,7 @@ from fastapi import FastAPI, Request from pydantic import BaseModel -from .model import loads +from .model.llama import loads from .task import PlanStep, SamplingParams, TaskInfo from .worker import Worker From 5e19f67a841e4740e3f04b03d4ec1b66519760c7 Mon Sep 17 00:00:00 2001 From: Celve Date: Wed, 31 Jul 2024 17:58:46 -0700 Subject: [PATCH 07/17] feat: support cancel at anytime --- deserve_worker/worker.py | 11 ++++++++--- deserve_worker/worker_api.py | 5 ++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 9d04d55..38aa330 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -176,25 +176,30 @@ def relay(self) -> None: json=updated_tasks, ) for task_id in result.cancel_ids: - self.cancel(task_id, self.task_datas[task_id].plan) + self.cancel(task_id, None, self.task_datas[task_id].plan) - def cancel(self, task_id: str, plan: list[PlanStep]) -> None: + def cancel( + self, task_id: str, start_index: Optional[int], plan: list[PlanStep] + ) -> None: index = next( (i for i, x in enumerate(plan) if x.worker_id == self.worker_id), None ) if index is None: return + if start_index is None: + start_index = index task_info = self.task_datas.pop(task_id, None) if task_info is not None: for kvcache in task_info.kvcaches.values(): kvcache.clear() next_index = (index + 1) % len(plan) - if next_index != len(plan) - 1: + if next_index != start_index: requests.post( f"{plan[next_index].worker_url}/cancel", json={ "task_id": task_id, + "start_index": index, "plan": [step.model_dump() for step in plan], }, ) diff --git a/deserve_worker/worker_api.py b/deserve_worker/worker_api.py index f309615..1e9a05b 100644 --- a/deserve_worker/worker_api.py +++ b/deserve_worker/worker_api.py @@ -46,12 +46,15 @@ async def forward(request: Request) -> str: class CancelRequest(BaseModel): task_id: str + start_index: int plan: list[PlanStep] @app.post("/cancel") async def cancel(request: CancelRequest) -> str: - runtime_executor.submit(worker.cancel, request.task_id, request.plan) + runtime_executor.submit( + worker.cancel, request.task_id, request.start_index, request.plan + ) return "ok" From 2ab64f46488bf537d71310ff8a7afc015c53aa3e Mon Sep 17 00:00:00 2001 From: Celve Date: Thu, 1 Aug 2024 14:18:23 -0700 Subject: [PATCH 08/17] feat: support shared layer storage --- deserve_worker/layer_storage.py | 71 ++++++++++++++------------------- deserve_worker/worker.py | 12 +++--- 2 files changed, 34 insertions(+), 49 deletions(-) diff --git a/deserve_worker/layer_storage.py b/deserve_worker/layer_storage.py index 87e719b..1aa98f1 100644 --- a/deserve_worker/layer_storage.py +++ b/deserve_worker/layer_storage.py @@ -1,12 +1,12 @@ import os import threading -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor import requests import torch -from .kvcache import KVCacheBase, main_device -from .model.llama import ENABLE_FLASH_ATTN, ModelArgs, RMSNorm, TransformerBlock +from .kvcache import KVCacheBase +from .model.llama import ModelArgs, RMSNorm, TransformerBlock llama_2_7b_args = { "dim": 4096, @@ -67,39 +67,24 @@ def __init__(self, main_device: torch.device): self.main_device = main_device self.network_executor = ThreadPoolExecutor(max_workers=80) self.cache_dir = os.path.expanduser("~/.cache/fleece-worker/models") - self.layer_storage_map: dict[frozenset[str], LayerStorage] = {} + self.layer_storages: dict[frozenset[str], LayerStorage] = {} + self.layers: dict[str, torch.nn.Module] = {} self.mutex = threading.Lock() def get_layer_storage(self, layer_names: list[str]) -> "LayerStorage": frozen_layer_names = frozenset(layer_names) - with self.mutex: - if frozen_layer_names not in self.layer_storage_map: - self.layer_storage_map[frozen_layer_names] = LayerStorage( - layer_names, self.main_device + if frozen_layer_names not in self.layer_storages: + with self.mutex: + self.layer_storages[frozen_layer_names] = LayerStorage( + self.preload_layers(layer_names), self.main_device ) - return self.layer_storage_map[frozen_layer_names] - - -global_layer_manager = LayerManager(main_device) - - -class LayerStorage: - def __init__(self, layer_names: list[str], main_device: torch.device): - self.layer_names = layer_names - self.main_device = main_device - self.layers: dict[str, torch.nn.Module] = {} - - self.preload_layers(self.layer_names) + return self.layer_storages[frozen_layer_names] def fetch_layer(self, full_layer_name: str) -> str: model_name, layer_name = full_layer_name.split("/") - path = os.path.join( - global_layer_manager.cache_dir, model_name, f"{layer_name}.pt" - ) + path = os.path.join(self.cache_dir, model_name, f"{layer_name}.pt") if not os.path.exists(path): # TODO lock - os.makedirs( - os.path.join(global_layer_manager.cache_dir, model_name), exist_ok=True - ) + os.makedirs(os.path.join(self.cache_dir, model_name), exist_ok=True) with requests.get( f"https://huggingface.co/colearn/{model_name}/resolve/main/{layer_name}.pt", stream=True, @@ -110,14 +95,11 @@ def fetch_layer(self, full_layer_name: str) -> str: f.write(chunk) return path - def preload_layers(self, layer_names: list[str]) -> None: - threads = [] - for full_layer_name in layer_names: - if full_layer_name in self.layers: - continue - thread = global_layer_manager.network_executor.submit( - self.fetch_layer, full_layer_name - ) + def preload_layers(self, full_layer_names: list[str]) -> dict[str, torch.nn.Module]: + threads: list[tuple[str, Future[str]]] = [] + result = {} + for full_layer_name in full_layer_names: + thread = self.network_executor.submit(self.fetch_layer, full_layer_name) threads.append((full_layer_name, thread)) for full_layer_name, thread in threads: path = thread.result() @@ -153,15 +135,20 @@ def preload_layers(self, layer_names: list[str]) -> None: raise NotImplementedError("Unknown layers") l.load_state_dict(torch.load(path, map_location="cpu")) l.to(self.main_device) - self.layers[full_layer_name] = l print("Loaded", full_layer_name) + self.layers[full_layer_name] = l + for full_layer_name in full_layer_names: + result[full_layer_name] = self.layers[full_layer_name] + return result + + +class LayerStorage: + def __init__(self, layers: dict[str, torch.nn.Module], main_device: torch.device): + self.main_device = main_device + self.layers = layers - def unload_layers(self) -> None: - for full_layer_name in self.layer_names: - if full_layer_name not in self.layers: - continue # TODO: continue or warning? - del self.layers[full_layer_name] - torch.cuda.empty_cache() + def clear(self) -> None: + self.layers.clear() @torch.inference_mode() def forward( diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 38aa330..7b8dc29 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -7,21 +7,18 @@ import requests import torch -from pydantic import BaseModel from transformers import AutoTokenizer # type: ignore from .forward_engine import ForwardEngine from .kvcache import KVCacheBase, main_device -from .layer_storage import global_layer_manager +from .layer_storage import LayerManager from .model.llama import dumps -from .paged_kvcache import PagedKVCache, global_paged_memory +from .paged_kvcache import PagedKVCache from .task import ( BatchForward, BatchResult, BatchUpdate, - LayerForward, PlanStep, - ResultBack, SamplingParams, TaskData, TaskInfo, @@ -39,6 +36,7 @@ def __init__(self, worker_id: str, max_total_bsz: int, controller_url: str): self.task_datas: dict[str, TaskData] = {} self.relay_queue = queue.Queue[BatchResult | BatchUpdate]() self.forward_engine = ForwardEngine(max_total_bsz, self.relay_queue) + self.layer_manager = LayerManager(main_device) threading.Thread(target=self.forward_engine.run, daemon=True).start() threading.Thread(target=self.relay, daemon=True).start() self.network_executor = ThreadPoolExecutor(max_workers=max_total_bsz) @@ -84,7 +82,7 @@ def batch_forward( plan = task_infos[0].plan index = self.locate_in_plan(plan) assert index is not None - layer_storage = global_layer_manager.get_layer_storage( + layer_storage = self.layer_manager.get_layer_storage( task_infos[0].plan[index].layers ) task_datas = [ @@ -108,7 +106,7 @@ def forward( if index is None: return None - layer_storage = global_layer_manager.get_layer_storage(plan[index].layers) + layer_storage = self.layer_manager.get_layer_storage(plan[index].layers) layer_forward = BatchForward( xs=x.to(main_device), layer_storage=layer_storage, From 9fedc6152490a0a2fe5c19f9a6221ea6ebfa79e4 Mon Sep 17 00:00:00 2001 From: Celve Date: Thu, 1 Aug 2024 16:00:11 -0700 Subject: [PATCH 09/17] feat: store block table in torch tensor --- deserve_worker/forward_engine.py | 5 +- deserve_worker/kvcache/__init__.py | 0 deserve_worker/kvcache/kvcache.py | 40 ++++++ .../{kvcache.py => kvcache/packed_kvcache.py} | 36 ++---- deserve_worker/kvcache/paged_kvcache.py | 116 ++++++++++++++++++ deserve_worker/layer_storage.py | 14 ++- deserve_worker/model/llama.py | 25 ++-- deserve_worker/paged_kvcache.py | 101 --------------- deserve_worker/task.py | 5 +- deserve_worker/worker.py | 18 ++- 10 files changed, 208 insertions(+), 152 deletions(-) create mode 100644 deserve_worker/kvcache/__init__.py create mode 100644 deserve_worker/kvcache/kvcache.py rename deserve_worker/{kvcache.py => kvcache/packed_kvcache.py} (78%) create mode 100644 deserve_worker/kvcache/paged_kvcache.py delete mode 100644 deserve_worker/paged_kvcache.py diff --git a/deserve_worker/forward_engine.py b/deserve_worker/forward_engine.py index 5a627d9..9dd1ecd 100644 --- a/deserve_worker/forward_engine.py +++ b/deserve_worker/forward_engine.py @@ -6,7 +6,7 @@ import torch -from .kvcache import KVCacheBase, main_device +from .kvcache.kvcache import KVCache, main_device from .layer_storage import LayerStorage from .model.llama import ENABLE_FLASH_ATTN from .task import BatchForward, BatchResult, BatchUpdate, LayerForward, ResultBack @@ -114,11 +114,13 @@ def run(self) -> None: for task in todo_tasks: new_task_datas.extend(task.task_datas) new_xs = torch.cat([task.xs for task in todo_tasks]) + # TODO: check if all tasks share same information new_task = BatchForward( xs=new_xs, layer_storage=todo_tasks[0].layer_storage, task_datas=new_task_datas, need_sample=todo_tasks[0].need_sample, + kvcache_manager=todo_tasks[0].kvcache_manager, ) h = self.forward(new_task) self.process(h, new_task) @@ -140,6 +142,7 @@ def forward(self, tasks: BatchForward) -> torch.Tensor: start_pos_list, global_freqs_cis, kvcache_list, + tasks.kvcache_manager, ) return result diff --git a/deserve_worker/kvcache/__init__.py b/deserve_worker/kvcache/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deserve_worker/kvcache/kvcache.py b/deserve_worker/kvcache/kvcache.py new file mode 100644 index 0000000..a0102d7 --- /dev/null +++ b/deserve_worker/kvcache/kvcache.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +KV_CACHE_BLOCK_SIZE = 256 + +main_dtype = torch.float16 +main_device = torch.device("cuda") +torch.set_default_dtype(main_dtype) # type: ignore + + +def del_tensor(t: torch.Tensor) -> None: + t.detach() + t.grad = None + t.untyped_storage().resize_(0) + + +class KVCache(ABC): + @abstractmethod + def renew(self, bsz: int, seqlen: int, start_pos: int) -> bool: + pass + + @abstractmethod + def clear(self) -> None: + pass + + +class KVCacheManager(ABC): + @abstractmethod + def alloc(self, bsz: int, seqlen: int) -> Optional[KVCache]: + pass + + @abstractmethod + def recycle(self, kvcache: KVCache) -> None: + pass + + @abstractmethod + def renew(self, kvcache: KVCache, bsz: int, seqlen: int, start_pos: int) -> bool: + pass diff --git a/deserve_worker/kvcache.py b/deserve_worker/kvcache/packed_kvcache.py similarity index 78% rename from deserve_worker/kvcache.py rename to deserve_worker/kvcache/packed_kvcache.py index 8ed461f..3ff8651 100644 --- a/deserve_worker/kvcache.py +++ b/deserve_worker/kvcache/packed_kvcache.py @@ -1,32 +1,15 @@ -from abc import ABC, abstractmethod - import torch -main_dtype = torch.float16 -main_device = torch.device("cuda") -torch.set_default_dtype(main_dtype) # type: ignore - - -def del_tensor(t: torch.Tensor) -> None: - t.detach() - t.grad = None - t.untyped_storage().resize_(0) - - -class KVCacheBase(ABC): - @abstractmethod - def renew(self, bsz: int, seqlen: int, start_pos: int) -> None: - pass - - @abstractmethod - def clear(self) -> None: - pass - - -KV_CACHE_BLOCK_SIZE = 256 +from deserve_worker.kvcache.kvcache import ( + KV_CACHE_BLOCK_SIZE, + KVCache, + del_tensor, + main_device, + main_dtype, +) -class KVCache(KVCacheBase): +class PackedKVCache(KVCache): def get_kv_cache_length(self, cur: int, seqlen: int) -> int: while cur < seqlen: cur += KV_CACHE_BLOCK_SIZE @@ -66,7 +49,7 @@ def __init__( ) self.main_device = main_device - def renew(self, bsz: int, seqlen: int, start_pos: int) -> None: + def renew(self, bsz: int, seqlen: int, start_pos: int) -> bool: if start_pos + seqlen > self.cache_k.shape[1]: length = self.get_kv_cache_length(self.cache_k.shape[1], start_pos + seqlen) cache_k = torch.zeros( @@ -95,6 +78,7 @@ def renew(self, bsz: int, seqlen: int, start_pos: int) -> None: del_tensor(self.cache_v) self.cache_k = cache_k self.cache_v = cache_v + return True def clear(self) -> None: del_tensor(self.cache_k) diff --git a/deserve_worker/kvcache/paged_kvcache.py b/deserve_worker/kvcache/paged_kvcache.py new file mode 100644 index 0000000..6d34527 --- /dev/null +++ b/deserve_worker/kvcache/paged_kvcache.py @@ -0,0 +1,116 @@ +import queue +from typing import Optional, cast + +import torch + +from .kvcache import KVCache, KVCacheManager, main_device, main_dtype + + +class PagedKVCacheManager(KVCacheManager): + def __init__( + self, + num_blocks: int, + block_size: int, + main_device: torch.device, + main_dtype: torch.dtype, + ): + self.num_blocks = num_blocks + self.block_size = block_size + self.cache_k_paged = torch.randn( + num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype + ) + self.cache_v_paged = torch.randn( + num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype + ) + self.block_bitmap = torch.zeros( + (num_blocks,), device=main_device, dtype=torch.bool + ) + self.block_buffer = torch.arange( + 0, num_blocks, device=main_device, dtype=torch.int32 + ) + + def get_kv_cache_length(self, cur: int, seqlen: int) -> int: + while cur < seqlen: + cur += self.block_size + return cur + + def alloc_blocks(self, size: int) -> Optional[torch.Tensor]: + if size > self.block_buffer.shape[0]: + block_avails = torch.nonzero(self.block_bitmap) + self.block_bitmap[block_avails] = False + self.block_buffer = torch.cat([self.block_buffer, block_avails]) + if size > self.block_buffer.shape[0]: + return None + result = self.block_buffer[:size] + self.block_buffer = self.block_buffer[size:] + return result + + def alloc(self, bsz: int, seqlen: int) -> Optional["PagedKVCache"]: + len_token = self.get_kv_cache_length(0, seqlen) + len_block = len_token // self.block_size + total_block = len_block * bsz + blocks = self.alloc_blocks(total_block) + if blocks is None: + return None + else: + return PagedKVCache(blocks.reshape(bsz, -1), self) + + def recycle(self, kvcache: KVCache) -> None: + kvcache = cast(PagedKVCache, kvcache) + self.block_bitmap[kvcache.block_table.flatten()] = True + kvcache.block_table = torch.empty(0, device=main_device, dtype=torch.int32) + + def renew(self, kvcache: KVCache, bsz: int, seqlen: int, start_pos: int) -> bool: + kvcache = cast(PagedKVCache, kvcache) + if start_pos + seqlen > kvcache.block_table.shape[1] * self.block_size: + len_block = ( + self.get_kv_cache_length( + kvcache.block_table.shape[1] * self.block_size, start_pos + seqlen + ) + // self.block_size + ) + total_block = (len_block - kvcache.block_table.shape[1]) * bsz + blocks = self.alloc_blocks(total_block) + if blocks is None: + return False + else: + new_block_table = torch.zeros( + ( + bsz, + len_block, + ), + device=main_device, + dtype=torch.int32, + ) + new_block_table[:, : kvcache.block_table.shape[1]] = ( + kvcache.block_table[:, :] + ) + new_block_table[:, kvcache.block_table.shape[1] :] = blocks.reshape( + bsz, -1 + ) + kvcache.block_table = new_block_table + return True + + +class PagedKVCache(KVCache): + def __init__( + self, + block_table: torch.Tensor, + manager: PagedKVCacheManager, + ): + self.block_table = block_table + self.manager = manager + + def renew( + self, + bsz: int, + seqlen: int, + start_pos: int, + ) -> bool: + return self.manager.renew(self, bsz, seqlen, start_pos) + + def clear(self) -> None: + self.manager.recycle(self) + + def shape(self) -> torch.Size: + return self.block_table.shape diff --git a/deserve_worker/layer_storage.py b/deserve_worker/layer_storage.py index 1aa98f1..1a70303 100644 --- a/deserve_worker/layer_storage.py +++ b/deserve_worker/layer_storage.py @@ -5,7 +5,7 @@ import requests import torch -from .kvcache import KVCacheBase +from .kvcache.kvcache import KVCache, KVCacheManager from .model.llama import ModelArgs, RMSNorm, TransformerBlock llama_2_7b_args = { @@ -157,7 +157,8 @@ def forward( bsz_list: list[int], start_pos_list: list[int], global_freqs_cis: torch.Tensor, - kv_cache_list: list[dict[int, KVCacheBase]], + kvcache_list: list[dict[int, KVCache]], + kvcache_manager: KVCacheManager, ) -> torch.Tensor: _, seqlen = h.shape[:2] for full_layer_name in self.layers: @@ -166,16 +167,17 @@ def forward( h = self.layers[full_layer_name](h) elif layer_name.startswith("layers."): layer_id = int(layer_name.split(".")[1]) - cur_kv_cache_list = [] - for i, kv_cache in enumerate(kv_cache_list): + cur_kvcache_list = [] + for i, kv_cache in enumerate(kvcache_list): kv_cache[layer_id].renew(1, seqlen, start_pos_list[i]) - cur_kv_cache_list.append(kv_cache[layer_id]) + cur_kvcache_list.append(kv_cache[layer_id]) h = self.layers[full_layer_name]( h, bsz_list, start_pos_list, global_freqs_cis, - cur_kv_cache_list, + cur_kvcache_list, + kvcache_manager, ) elif layer_name == "norm": h = self.layers[full_layer_name](h) diff --git a/deserve_worker/model/llama.py b/deserve_worker/model/llama.py index e02f84a..ed45970 100644 --- a/deserve_worker/model/llama.py +++ b/deserve_worker/model/llama.py @@ -10,16 +10,15 @@ import torch.nn.functional as F from torch import nn -from deserve_worker.paged_kvcache import PagedKVCache +from deserve_worker.kvcache.paged_kvcache import PagedKVCache, PagedKVCacheManager -from ..kvcache import KVCache, KVCacheBase +from ..kvcache.kvcache import KVCache, KVCacheManager +from ..kvcache.packed_kvcache import PackedKVCache ENABLE_FLASH_ATTN = False try: from flash_attn import flash_attn_with_kvcache # type: ignore - from ..paged_kvcache import global_paged_memory - ENABLE_FLASH_ATTN = True except ImportError as e: print( @@ -266,7 +265,8 @@ def forward( bsz_list: List[int], start_pos_list: List[int], global_freqs_cis: torch.Tensor, - kv_cache_list: list[KVCacheBase], + kvcache_list: list[KVCache], + kvcache_manager: KVCacheManager, ) -> torch.Tensor: """ Forward pass of the attention module. @@ -292,7 +292,7 @@ def forward( cache_seqlens, dtype=torch.int32, device=x.device ) bsz = cache_seqlens_tch.shape[0] - paged_kv_cache_list = cast(list[PagedKVCache], kv_cache_list) + paged_kv_cache_list = cast(list[PagedKVCache], kvcache_list) max_len = max([kvcache.shape()[1] for kvcache in paged_kv_cache_list]) block_table = torch.zeros( @@ -311,10 +311,11 @@ def forward( xv = xv_.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) cos = global_freqs_cis[0].type_as(xq) sin = global_freqs_cis[1].type_as(xq) + kvcache_manager = cast(PagedKVCacheManager, kvcache_manager) output = flash_attn_with_kvcache( xq, - global_paged_memory.cache_k_paged, - global_paged_memory.cache_v_paged, + kvcache_manager.cache_k_paged, + kvcache_manager.cache_v_paged, xk, xv, rotary_cos=cos, @@ -341,7 +342,7 @@ def forward( start += bsz start_pos = start_pos_list[i] - kv_cache: KVCache = cast(KVCache, kv_cache_list[i]) + kv_cache: PackedKVCache = cast(PackedKVCache, kvcache_list[i]) cache_k, cache_v = kv_cache.cache_k, kv_cache.cache_v freqs_cis = global_freqs_cis[start_pos : start_pos + seqlen] @@ -484,7 +485,8 @@ def forward( bsz_list: List[int], start_pos_list: List[int], global_freqs_cis: torch.Tensor, - kv_cache_list: list[KVCacheBase], + kvcache_list: list[KVCache], + kvcache_manager: KVCacheManager, ) -> torch.Tensor: """ Perform a forward pass through the TransformerBlock. @@ -504,7 +506,8 @@ def forward( bsz_list, start_pos_list, global_freqs_cis, - kv_cache_list, + kvcache_list, + kvcache_manager, ) out = h + self.feed_forward.forward(self.ffn_norm(h)) return out diff --git a/deserve_worker/paged_kvcache.py b/deserve_worker/paged_kvcache.py deleted file mode 100644 index 9b64019..0000000 --- a/deserve_worker/paged_kvcache.py +++ /dev/null @@ -1,101 +0,0 @@ -import queue -from typing import Optional - -import torch - -from .kvcache import KVCacheBase, main_device, main_dtype - - -class PagedMemory: - def __init__( - self, - num_blocks: int, - block_size: int, - main_device: torch.device, - main_dtype: torch.dtype, - ): - self.num_blocks = num_blocks - self.block_size = block_size - self.cache_k_paged = torch.randn( - num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype - ) - self.cache_v_paged = torch.randn( - num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype - ) - self.avai_blocks = queue.Queue[int]() - for i in range(1, num_blocks): - self.avai_blocks.put(i) - - -global_paged_memory = PagedMemory(11600, 256, main_device, main_dtype) - - -class PagedKVCache(KVCacheBase): - def get_kv_cache_length(self, cur: int, seqlen: int) -> int: - while cur < seqlen: - cur += global_paged_memory.block_size - return cur - - def __init__(self, x: torch.Tensor, start_pos: int, main_device: torch.device): - self.main_device = main_device - self.is_clear = False - bsz, seqlen = x.shape[0], x.shape[1] - length = ( - self.get_kv_cache_length(0, start_pos + seqlen) - // global_paged_memory.block_size - ) - self.block_table = torch.zeros( - ( - bsz, - length, - ), - device=self.main_device, - dtype=torch.int32, - ) - for i in range(length): - for j in range(bsz): - blk = global_paged_memory.avai_blocks.get(block=False) - self.block_table[j, i] = blk - - def renew( - self, - bsz: int, - seqlen: int, - start_pos: int, - ) -> None: - if ( - start_pos + seqlen - > self.block_table.shape[1] * global_paged_memory.block_size - ): - # enlarge block table - length = ( - self.get_kv_cache_length( - self.block_table.shape[1] * global_paged_memory.block_size, - start_pos + seqlen, - ) - // global_paged_memory.block_size - ) - block_table = torch.zeros( - ( - bsz, - length, - ), - device=self.main_device, - dtype=torch.int32, - ) - block_table[:, : self.block_table.shape[1]] = self.block_table[:, :] - for i in range(self.block_table.shape[1], length): - for j in range(bsz): - block_table[j, i] = global_paged_memory.avai_blocks.get() - self.block_table = block_table - - def clear(self) -> None: - if self.is_clear: - assert False, "Already cleared" - self.is_clear = True - for row in self.block_table.tolist(): - for item in row: - global_paged_memory.avai_blocks.put(item) - - def shape(self) -> torch.Size: - return self.block_table.shape diff --git a/deserve_worker/task.py b/deserve_worker/task.py index 779a9c0..402afa9 100644 --- a/deserve_worker/task.py +++ b/deserve_worker/task.py @@ -4,7 +4,7 @@ import torch from pydantic import BaseModel -from .kvcache import KVCacheBase +from .kvcache.kvcache import KVCache, KVCacheManager from .layer_storage import LayerStorage @@ -34,7 +34,7 @@ class TaskData: plan: list[PlanStep] round: int sampling_params: SamplingParams - kvcaches: dict[int, KVCacheBase] + kvcaches: dict[int, KVCache] """ When flash attention is enabled, we use paged attention, otherwise the standard attention is adopted. """ @@ -59,6 +59,7 @@ class BatchForward: xs: torch.Tensor layer_storage: LayerStorage task_datas: list[TaskData] + kvcache_manager: KVCacheManager need_sample: bool # to be eliminated in the future, because we can infer this from LayerStorage diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 7b8dc29..8c40e6d 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -10,10 +10,10 @@ from transformers import AutoTokenizer # type: ignore from .forward_engine import ForwardEngine -from .kvcache import KVCacheBase, main_device +from .kvcache.kvcache import KVCache, main_device, main_dtype +from .kvcache.paged_kvcache import PagedKVCache, PagedKVCacheManager from .layer_storage import LayerManager from .model.llama import dumps -from .paged_kvcache import PagedKVCache from .task import ( BatchForward, BatchResult, @@ -37,6 +37,7 @@ def __init__(self, worker_id: str, max_total_bsz: int, controller_url: str): self.relay_queue = queue.Queue[BatchResult | BatchUpdate]() self.forward_engine = ForwardEngine(max_total_bsz, self.relay_queue) self.layer_manager = LayerManager(main_device) + self.kvcache_manager = PagedKVCacheManager(11600, 256, main_device, main_dtype) threading.Thread(target=self.forward_engine.run, daemon=True).start() threading.Thread(target=self.relay, daemon=True).start() self.network_executor = ThreadPoolExecutor(max_workers=max_total_bsz) @@ -56,7 +57,9 @@ def init_task_data( _, layer_name = full_layer_name.split("/") if layer_name.startswith("layers."): layer_id = int(layer_name.split(".")[1]) - kvcaches[layer_id] = PagedKVCache(x, 0, main_device) + kvcaches[layer_id] = self.kvcache_manager.alloc( + x.shape[0], x.shape[1] + ) # TODO: need double check whether request is repeated task_data = TaskData( @@ -65,7 +68,7 @@ def init_task_data( plan=task_info.plan, round=0, sampling_params=task_info.sampling_params, - kvcaches=cast(dict[int, KVCacheBase], kvcaches), + kvcaches=cast(dict[int, KVCache], kvcaches), ) self.task_datas[task_info.task_id] = task_data else: @@ -90,7 +93,11 @@ def batch_forward( ] self.forward_engine.add_batch_forward( BatchForward( - xs.to(main_device), layer_storage, task_datas, (index == len(plan) - 1) + xs=xs.to(main_device), + layer_storage=layer_storage, + task_datas=task_datas, + need_sample=(index == len(plan) - 1), + kvcache_manager=self.kvcache_manager, ) ) @@ -123,6 +130,7 @@ def forward( ) ], need_sample=(index == len(plan) - 1), + kvcache_manager=self.kvcache_manager, ) self.forward_engine.add_batch_forward(layer_forward) From b0cb5431b4a7c80d1c97991c6c978bb592f91107 Mon Sep 17 00:00:00 2001 From: Celve Date: Thu, 1 Aug 2024 16:27:19 -0700 Subject: [PATCH 10/17] chore: move some code --- deserve_worker/command.py | 29 +++++++ deserve_worker/layer_storage.py | 65 ++++++++++++++++ .../{forward_engine.py => llm_engine.py} | 77 +++---------------- deserve_worker/model/llama.py | 27 ------- deserve_worker/task.py | 47 +---------- deserve_worker/worker.py | 19 ++--- 6 files changed, 110 insertions(+), 154 deletions(-) create mode 100644 deserve_worker/command.py rename deserve_worker/{forward_engine.py => llm_engine.py} (63%) diff --git a/deserve_worker/command.py b/deserve_worker/command.py new file mode 100644 index 0000000..db8c0fd --- /dev/null +++ b/deserve_worker/command.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass + +import torch + +from deserve_worker.kvcache.kvcache import KVCacheManager +from deserve_worker.layer_storage import LayerStorage +from deserve_worker.task import TaskData + + +@dataclass +class BatchForward: + xs: torch.Tensor + layer_storage: LayerStorage + task_datas: list[TaskData] + kvcache_manager: KVCacheManager + need_sample: bool # to be eliminated in the future, because we can infer this from LayerStorage + + +@dataclass +class BatchResult: + xs: torch.Tensor + task_ids: list[str] + + +@dataclass +class BatchUpdate: + tokens: list[torch.Tensor] + task_ids: list[str] + cancel_ids: list[str] diff --git a/deserve_worker/layer_storage.py b/deserve_worker/layer_storage.py index 1a70303..021bf27 100644 --- a/deserve_worker/layer_storage.py +++ b/deserve_worker/layer_storage.py @@ -5,9 +5,14 @@ import requests import torch +from deserve_worker.task import TaskData + from .kvcache.kvcache import KVCache, KVCacheManager from .model.llama import ModelArgs, RMSNorm, TransformerBlock +EOS_TOKEN_ID = 128001 # for llama 3 only +STOP_TOKEN_IDS = [128001, 128009] + llama_2_7b_args = { "dim": 4096, "multiple_of": 256, @@ -62,6 +67,32 @@ } +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + class LayerManager: def __init__(self, main_device: torch.device): self.main_device = main_device @@ -186,3 +217,37 @@ def forward( else: raise NotImplementedError("Unknown layers") return h + + @torch.inference_mode() + def sample( + self, merged_h: torch.Tensor, task_datas: list[TaskData] + ) -> tuple[list[torch.Tensor], list[str], list[torch.Tensor], list[str], list[str]]: + ongoing_tokens = [] + ongoing_ids = [] + all_tokens = [] + all_ids = [] + done_ids = [] + for ptr, task_data in enumerate(task_datas): + h = merged_h[ptr : ptr + 1] + _, seqlen = h.shape[:2] + task_data.start_pos += seqlen + task_data.round += 1 + sampling_params = task_data.sampling_params + if task_data.start_pos >= sampling_params.max_total_len: + next_token = torch.tensor([[EOS_TOKEN_ID]]) + elif sampling_params.temperature > 0: + probs = torch.softmax(h[:, -1] / sampling_params.temperature, dim=-1) + next_token = sample_top_p(probs, sampling_params.top_p) + next_token = next_token.reshape(1, -1) + else: + next_token = torch.argmax(h[:, -1], dim=-1) + next_token = next_token.reshape(1, -1) + next_token = next_token.to("cpu") + all_ids.append(task_data.task_id) + all_tokens.append(next_token) + if next_token[0][0] in STOP_TOKEN_IDS: + done_ids.append(task_data.task_id) + else: + ongoing_ids.append(task_data.task_id) + ongoing_tokens.append(next_token) + return ongoing_tokens, ongoing_ids, all_tokens, all_ids, done_ids diff --git a/deserve_worker/forward_engine.py b/deserve_worker/llm_engine.py similarity index 63% rename from deserve_worker/forward_engine.py rename to deserve_worker/llm_engine.py index 9dd1ecd..9e5df2f 100644 --- a/deserve_worker/forward_engine.py +++ b/deserve_worker/llm_engine.py @@ -1,15 +1,10 @@ -import itertools import queue -import time -from dataclasses import dataclass -from typing import Optional import torch -from .kvcache.kvcache import KVCache, main_device -from .layer_storage import LayerStorage +from .command import BatchForward, BatchResult, BatchUpdate +from .kvcache.kvcache import main_device from .model.llama import ENABLE_FLASH_ATTN -from .task import BatchForward, BatchResult, BatchUpdate, LayerForward, ResultBack EOS_TOKEN_ID = 128001 # for llama 3 only STOP_TOKEN_IDS = [128001, 128009] @@ -45,33 +40,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0).to(main_device) -def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: - """ - Perform top-p (nucleus) sampling on a probability distribution. - - Args: - probs (torch.Tensor): Probability distribution tensor. - p (float): Probability threshold for top-p sampling. - - Returns: - torch.Tensor: Sampled token indices. - - Note: - Top-p sampling selects the smallest set of tokens whose cumulative probability mass - exceeds the threshold p. The distribution is renormalized based on the selected tokens. - - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token - - -class ForwardEngine: +class LLMEngine: def __init__( self, max_total_bsz: int, sendback_queue: queue.Queue[BatchResult | BatchUpdate] ): @@ -94,7 +63,7 @@ def run(self) -> None: for task in prefill_tasks: h = self.forward(task) - self.process(h, task) + self.post_process(h, task) print( f"prefill_tasks: {len(prefill_tasks)}, decode_tasks: {sum(task.xs.shape[0] for task in decode_tasks)}" @@ -123,7 +92,7 @@ def run(self) -> None: kvcache_manager=todo_tasks[0].kvcache_manager, ) h = self.forward(new_task) - self.process(h, new_task) + self.post_process(h, new_task) def add_batch_forward(self, forwards: BatchForward) -> None: self.handling_queue.put(forwards) @@ -146,38 +115,12 @@ def forward(self, tasks: BatchForward) -> torch.Tensor: ) return result - def process(self, merged_h: torch.Tensor, tasks: BatchForward) -> None: + def post_process(self, merged_h: torch.Tensor, tasks: BatchForward) -> None: if tasks.need_sample: - ongoing_tokens = [] - ongoing_ids = [] - all_tokens = [] - all_ids = [] - done_ids = [] - for ptr, task_data in enumerate(tasks.task_datas): - h = merged_h[ptr : ptr + 1] - _, seqlen = h.shape[:2] - task_data.start_pos += seqlen - task_data.round += 1 - sampling_params = task_data.sampling_params - if task_data.start_pos >= sampling_params.max_total_len: - next_token = torch.tensor([[EOS_TOKEN_ID]]) - elif sampling_params.temperature > 0: - probs = torch.softmax( - h[:, -1] / sampling_params.temperature, dim=-1 - ) - next_token = sample_top_p(probs, sampling_params.top_p) - next_token = next_token.reshape(1, -1) - else: - next_token = torch.argmax(h[:, -1], dim=-1) - next_token = next_token.reshape(1, -1) - next_token = next_token.to("cpu") - all_ids.append(task_data.task_id) - all_tokens.append(next_token) - if next_token[0][0] in STOP_TOKEN_IDS: - done_ids.append(task_data.task_id) - else: - ongoing_ids.append(task_data.task_id) - ongoing_tokens.append(next_token) + layer_storage = tasks.layer_storage + ongoing_tokens, ongoing_ids, all_tokens, all_ids, done_ids = ( + layer_storage.sample(merged_h, tasks.task_datas) + ) if len(ongoing_tokens) > 0: self.sendback_queue.put( BatchResult(torch.cat(ongoing_tokens), ongoing_ids) diff --git a/deserve_worker/model/llama.py b/deserve_worker/model/llama.py index ed45970..58cf1e4 100644 --- a/deserve_worker/model/llama.py +++ b/deserve_worker/model/llama.py @@ -89,33 +89,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output * self.weight -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex exponentials. - - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - freqs = torch.outer(t, freqs) - if ENABLE_FLASH_ATTN: - freqs_cis = torch.stack([freqs.cos(), freqs.sin()]) # flash_attn - else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ Reshape frequency tensor for broadcasting it with another tensor. diff --git a/deserve_worker/task.py b/deserve_worker/task.py index 402afa9..718da07 100644 --- a/deserve_worker/task.py +++ b/deserve_worker/task.py @@ -1,11 +1,8 @@ from dataclasses import dataclass -from typing import Optional -import torch from pydantic import BaseModel -from .kvcache.kvcache import KVCache, KVCacheManager -from .layer_storage import LayerStorage +from .kvcache.kvcache import KVCache class PlanStep(BaseModel): @@ -38,45 +35,3 @@ class TaskData: """ When flash attention is enabled, we use paged attention, otherwise the standard attention is adopted. """ - - -class LayerForward: - def __init__( - self, - layer_storage: LayerStorage, - h: torch.Tensor, - task_data: TaskData, - need_sample: bool, - ): - self.layer_storage = layer_storage - self.h = h - self.task_info = task_data - self.need_sample = need_sample - - -@dataclass -class BatchForward: - xs: torch.Tensor - layer_storage: LayerStorage - task_datas: list[TaskData] - kvcache_manager: KVCacheManager - need_sample: bool # to be eliminated in the future, because we can infer this from LayerStorage - - -@dataclass -class BatchResult: - xs: torch.Tensor - task_ids: list[str] - - -@dataclass -class BatchUpdate: - tokens: list[torch.Tensor] - task_ids: list[str] - cancel_ids: list[str] - - -@dataclass -class ResultBack: - x: torch.Tensor - task_id: str diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 8c40e6d..7ec24df 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -1,28 +1,19 @@ import queue import threading -import traceback from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass from typing import Optional, cast import requests import torch from transformers import AutoTokenizer # type: ignore -from .forward_engine import ForwardEngine +from .command import BatchForward, BatchResult, BatchUpdate +from .llm_engine import LLMEngine from .kvcache.kvcache import KVCache, main_device, main_dtype -from .kvcache.paged_kvcache import PagedKVCache, PagedKVCacheManager +from .kvcache.paged_kvcache import PagedKVCacheManager from .layer_storage import LayerManager from .model.llama import dumps -from .task import ( - BatchForward, - BatchResult, - BatchUpdate, - PlanStep, - SamplingParams, - TaskData, - TaskInfo, -) +from .task import PlanStep, SamplingParams, TaskData, TaskInfo EOS_TOKEN_ID = 128001 # for llama 3 only @@ -35,7 +26,7 @@ def __init__(self, worker_id: str, max_total_bsz: int, controller_url: str): self.controller_url = controller_url self.task_datas: dict[str, TaskData] = {} self.relay_queue = queue.Queue[BatchResult | BatchUpdate]() - self.forward_engine = ForwardEngine(max_total_bsz, self.relay_queue) + self.forward_engine = LLMEngine(max_total_bsz, self.relay_queue) self.layer_manager = LayerManager(main_device) self.kvcache_manager = PagedKVCacheManager(11600, 256, main_device, main_dtype) threading.Thread(target=self.forward_engine.run, daemon=True).start() From 29e5c24adc8bf55ec7745e2d652e834f3d2689b7 Mon Sep 17 00:00:00 2001 From: Celve Date: Fri, 2 Aug 2024 01:51:22 -0700 Subject: [PATCH 11/17] feat: support dumping traces --- deserve_worker/command.py | 17 ++ deserve_worker/kvcache/kvcache.py | 6 - deserve_worker/kvcache/packed_kvcache.py | 162 ++++++++++-------- deserve_worker/kvcache/paged_kvcache.py | 2 +- deserve_worker/layer_storage.py | 37 ++++- deserve_worker/llm_engine.py | 200 ++++++++++++++++------- deserve_worker/model/llama.py | 175 +++++++++++++++----- deserve_worker/trace.py | 28 ++++ deserve_worker/worker.py | 121 ++++++++++++-- deserve_worker/worker_api.py | 18 ++ 10 files changed, 561 insertions(+), 205 deletions(-) create mode 100644 deserve_worker/trace.py diff --git a/deserve_worker/command.py b/deserve_worker/command.py index db8c0fd..b64892d 100644 --- a/deserve_worker/command.py +++ b/deserve_worker/command.py @@ -5,6 +5,7 @@ from deserve_worker.kvcache.kvcache import KVCacheManager from deserve_worker.layer_storage import LayerStorage from deserve_worker.task import TaskData +from deserve_worker.trace import OpId @dataclass @@ -16,6 +17,15 @@ class BatchForward: need_sample: bool # to be eliminated in the future, because we can infer this from LayerStorage +@dataclass +class SingleTrace: + x: torch.Tensor + layer_storage: LayerStorage + task_data: TaskData + kvcache_manager: KVCacheManager + need_sample: bool + + @dataclass class BatchResult: xs: torch.Tensor @@ -27,3 +37,10 @@ class BatchUpdate: tokens: list[torch.Tensor] task_ids: list[str] cancel_ids: list[str] + + +@dataclass +class TraceResult: + x: torch.Tensor + task_id: str + trace: dict[OpId, torch.Tensor] diff --git a/deserve_worker/kvcache/kvcache.py b/deserve_worker/kvcache/kvcache.py index a0102d7..8e0e4c6 100644 --- a/deserve_worker/kvcache/kvcache.py +++ b/deserve_worker/kvcache/kvcache.py @@ -10,12 +10,6 @@ torch.set_default_dtype(main_dtype) # type: ignore -def del_tensor(t: torch.Tensor) -> None: - t.detach() - t.grad = None - t.untyped_storage().resize_(0) - - class KVCache(ABC): @abstractmethod def renew(self, bsz: int, seqlen: int, start_pos: int) -> bool: diff --git a/deserve_worker/kvcache/packed_kvcache.py b/deserve_worker/kvcache/packed_kvcache.py index 3ff8651..dbdd6f1 100644 --- a/deserve_worker/kvcache/packed_kvcache.py +++ b/deserve_worker/kvcache/packed_kvcache.py @@ -1,86 +1,114 @@ +from typing import Optional, cast + import torch -from deserve_worker.kvcache.kvcache import ( - KV_CACHE_BLOCK_SIZE, - KVCache, - del_tensor, - main_device, - main_dtype, -) +from deserve_worker.kvcache.kvcache import KVCache, KVCacheManager -class PackedKVCache(KVCache): - def get_kv_cache_length(self, cur: int, seqlen: int) -> int: - while cur < seqlen: - cur += KV_CACHE_BLOCK_SIZE - return cur +def del_tensor(t: torch.Tensor) -> None: + t.detach() + t.grad = None + t.untyped_storage().resize_(0) + +class PackedKVCacheManager(KVCacheManager): def __init__( self, - x: torch.Tensor, - start_pos: int, - n_local_kv_heads: int, - head_dim: int, + num_blocks: int, + block_size: int, + main_device: torch.device, + main_dtype: torch.dtype, ): - self.n_local_kv_heads = n_local_kv_heads - self.head_dim = head_dim - - bsz, seqlen = x.shape[0], x.shape[1] - length = self.get_kv_cache_length(0, start_pos + seqlen) - self.cache_k = torch.zeros( - ( - bsz, - length, - n_local_kv_heads, - head_dim, - ), - device=main_device, - dtype=main_dtype, - ) - self.cache_v = torch.zeros( - ( - bsz, - length, - n_local_kv_heads, - head_dim, - ), - device=main_device, - dtype=main_dtype, - ) + self.num_blocks = num_blocks + self.block_size = block_size self.main_device = main_device + self.main_dtype = main_dtype - def renew(self, bsz: int, seqlen: int, start_pos: int) -> bool: - if start_pos + seqlen > self.cache_k.shape[1]: - length = self.get_kv_cache_length(self.cache_k.shape[1], start_pos + seqlen) + def get_kv_cache_length(self, cur: int, seqlen: int) -> int: + while cur < seqlen: + cur += self.block_size + return cur + + def alloc(self, bsz: int, seqlen: int) -> Optional[KVCache]: + len_token = self.get_kv_cache_length(0, seqlen) + len_block = len_token // self.block_size + if bsz * len_block <= self.num_blocks: + self.num_blocks -= bsz * len_block cache_k = torch.zeros( - ( - bsz, - length, - self.n_local_kv_heads, - self.head_dim, - ), + (bsz, len_token, 8, 128), device=self.main_device, + dtype=self.main_dtype, ) cache_v = torch.zeros( - ( - bsz, - length, - self.n_local_kv_heads, - self.head_dim, - ), + (bsz, len_token, 8, 128), device=self.main_device, + dtype=self.main_dtype, ) - cache_k[:, :start_pos, :, :], cache_v[:, :start_pos, :, :] = ( - self.cache_k[:, :start_pos, :, :], - self.cache_v[:, :start_pos, :, :], + return PackedKVCache(cache_k, cache_v, self) + else: + return None + + def recycle(self, kvcache: KVCache) -> None: + kvcache = cast(PackedKVCache, kvcache) + bsz, seqlen = kvcache.cache_k.shape[:2] + self.num_blocks += bsz * seqlen + + del_tensor(kvcache.cache_k) + del_tensor(kvcache.cache_v) + kvcache.cache_k = torch.empty( + (0, 0), device=self.main_device, dtype=self.main_dtype + ) + kvcache.cache_v = torch.empty( + (0, 0), device=self.main_device, dtype=self.main_dtype + ) + torch.cuda.empty_cache() + + def renew(self, kvcache: KVCache, bsz: int, seqlen: int, start_pos: int) -> bool: + kvcache = cast(PackedKVCache, kvcache) + if start_pos + seqlen > kvcache.cache_k.shape[1]: + len_token = self.get_kv_cache_length( + kvcache.cache_k.shape[1], start_pos + seqlen ) - del_tensor(self.cache_k) - del_tensor(self.cache_v) - self.cache_k = cache_k - self.cache_v = cache_v - return True + len_block = len_token // self.block_size + if bsz * len_block <= self.num_blocks: + self.num_blocks -= bsz * len_token + cache_k = torch.zeros( + (bsz, len_token, 8, 128), + device=self.main_device, + dtype=self.main_dtype, + ) + cache_v = torch.zeros( + (bsz, len_token, 8, 128), + device=self.main_device, + dtype=self.main_dtype, + ) + cache_k[:, :start_pos, :, :], cache_v[:, :start_pos, :, :] = ( + kvcache.cache_k[:, :start_pos, :, :], + kvcache.cache_v[:, :start_pos, :, :], + ) + original_shape = bsz * kvcache.cache_k.shape[1] + del_tensor(kvcache.cache_k) + del_tensor(kvcache.cache_v) + self.num_blocks += original_shape + kvcache.cache_k = cache_k + kvcache.cache_v = cache_v + return True + return False + + +class PackedKVCache(KVCache): + def __init__( + self, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + manager: PackedKVCacheManager, + ): + self.cache_k = cache_k + self.cache_v = cache_v + self.manager = manager + + def renew(self, bsz: int, seqlen: int, start_pos: int) -> bool: + return self.manager.renew(self, bsz, seqlen, start_pos) def clear(self) -> None: - del_tensor(self.cache_k) - del_tensor(self.cache_v) - torch.cuda.empty_cache() + return self.manager.recycle(self) diff --git a/deserve_worker/kvcache/paged_kvcache.py b/deserve_worker/kvcache/paged_kvcache.py index 6d34527..6f094f8 100644 --- a/deserve_worker/kvcache/paged_kvcache.py +++ b/deserve_worker/kvcache/paged_kvcache.py @@ -58,7 +58,7 @@ def alloc(self, bsz: int, seqlen: int) -> Optional["PagedKVCache"]: def recycle(self, kvcache: KVCache) -> None: kvcache = cast(PagedKVCache, kvcache) self.block_bitmap[kvcache.block_table.flatten()] = True - kvcache.block_table = torch.empty(0, device=main_device, dtype=torch.int32) + kvcache.block_table = torch.empty((0, 0), device=main_device, dtype=torch.int32) def renew(self, kvcache: KVCache, bsz: int, seqlen: int, start_pos: int) -> bool: kvcache = cast(PagedKVCache, kvcache) diff --git a/deserve_worker/layer_storage.py b/deserve_worker/layer_storage.py index 021bf27..bb5055a 100644 --- a/deserve_worker/layer_storage.py +++ b/deserve_worker/layer_storage.py @@ -1,14 +1,22 @@ import os import threading from concurrent.futures import Future, ThreadPoolExecutor +from typing import Optional import requests import torch from deserve_worker.task import TaskData +from deserve_worker.trace import ComponentId, LayerId, OpId from .kvcache.kvcache import KVCache, KVCacheManager -from .model.llama import ModelArgs, RMSNorm, TransformerBlock +from .model.llama import ( + ModelArgs, + RMSNorm, + TraceEmbedding, + TraceLinear, + TransformerBlock, +) EOS_TOKEN_ID = 128001 # for llama 3 only STOP_TOKEN_IDS = [128001, 128009] @@ -149,18 +157,25 @@ def preload_layers(self, full_layer_names: list[str]) -> dict[str, torch.nn.Modu raise NotImplementedError("Unknown model") if layer_name == "tok_embeddings": l = torch.nn.utils.skip_init( # type: ignore - torch.nn.Embedding, model_args.vocab_size, model_args.dim + # torch.nn.Embedding, + TraceEmbedding, + ComponentId("tok_embeddings", "main"), + model_args.vocab_size, + model_args.dim, ) elif layer_name.startswith("layer"): - l = TransformerBlock(model_args) + l = TransformerBlock(LayerId(f"layer_{layer_name[6:]}"), model_args) elif layer_name == "norm": - l = RMSNorm(model_args.dim, eps=model_args.norm_eps) + l = RMSNorm( + ComponentId("norm", "main"), model_args.dim, eps=model_args.norm_eps + ) elif layer_name == "output": l = torch.nn.utils.skip_init( # type: ignore - torch.nn.Linear, + # torch.nn.Linear, + TraceLinear, + ComponentId("output", "main"), model_args.dim, model_args.vocab_size, - bias=False, ) else: raise NotImplementedError("Unknown layers") @@ -190,12 +205,14 @@ def forward( global_freqs_cis: torch.Tensor, kvcache_list: list[dict[int, KVCache]], kvcache_manager: KVCacheManager, + traces: Optional[dict[OpId, torch.Tensor]], ) -> torch.Tensor: _, seqlen = h.shape[:2] for full_layer_name in self.layers: _, layer_name = full_layer_name.split("/") if layer_name == "tok_embeddings": - h = self.layers[full_layer_name](h) + h = self.layers[full_layer_name](h, traces) + # h = self.layers[full_layer_name](h) elif layer_name.startswith("layers."): layer_id = int(layer_name.split(".")[1]) cur_kvcache_list = [] @@ -209,11 +226,13 @@ def forward( global_freqs_cis, cur_kvcache_list, kvcache_manager, + traces, ) elif layer_name == "norm": - h = self.layers[full_layer_name](h) + h = self.layers[full_layer_name](h, traces) elif layer_name == "output": - h = self.layers[full_layer_name](h) + h = self.layers[full_layer_name](h, traces) + # h = self.layers[full_layer_name](h) else: raise NotImplementedError("Unknown layers") return h diff --git a/deserve_worker/llm_engine.py b/deserve_worker/llm_engine.py index 9e5df2f..84d83fb 100644 --- a/deserve_worker/llm_engine.py +++ b/deserve_worker/llm_engine.py @@ -1,16 +1,22 @@ import queue +from typing import Optional import torch -from .command import BatchForward, BatchResult, BatchUpdate -from .kvcache.kvcache import main_device -from .model.llama import ENABLE_FLASH_ATTN +from deserve_worker.layer_storage import LayerStorage +from deserve_worker.task import TaskData +from deserve_worker.trace import OpId + +from .command import BatchForward, BatchResult, BatchUpdate, SingleTrace, TraceResult +from .kvcache.kvcache import KVCacheManager, main_device EOS_TOKEN_ID = 128001 # for llama 3 only STOP_TOKEN_IDS = [128001, 128009] -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, enable_flash_attn: bool = False +) -> torch.Tensor: """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -30,106 +36,174 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device, dtype=torch.float32) freqs = torch.outer(t, freqs) - if ENABLE_FLASH_ATTN: + if enable_flash_attn: freqs_cis = torch.stack([freqs.cos(), freqs.sin()]) # flash_attn else: freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis -global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0).to(main_device) +global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0, False).to(main_device) +flash_global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0, True).to(main_device) class LLMEngine: def __init__( - self, max_total_bsz: int, sendback_queue: queue.Queue[BatchResult | BatchUpdate] + self, + max_total_bsz: int, + sender: queue.Queue[BatchResult | BatchUpdate | TraceResult], ): self.max_total_bsz = max_total_bsz - self.sendback_queue = sendback_queue - self.handling_queue = queue.Queue[BatchForward]() + self.sender = sender + self.receiver = queue.Queue[BatchForward | SingleTrace]() def run(self) -> None: - q = self.handling_queue + q = self.receiver while True: - forwards: list[BatchForward] = [q.get()] + commands: list[BatchForward | SingleTrace] = [q.get()] while True: try: new = q.get(block=False) - forwards.append(new) + commands.append(new) except queue.Empty: break - prefill_tasks = [task for task in forwards if task.xs.shape[1] > 1] - decode_tasks = [task for task in forwards if task.xs.shape[1] == 1] - - for task in prefill_tasks: - h = self.forward(task) - self.post_process(h, task) - - print( - f"prefill_tasks: {len(prefill_tasks)}, decode_tasks: {sum(task.xs.shape[0] for task in decode_tasks)}" + traces = [ + command for command in commands if isinstance(command, SingleTrace) + ] + forwards = [ + command for command in commands if isinstance(command, BatchForward) + ] + self.handle_trace(traces) + self.handle_forward(forwards) + + def handle_forward(self, forwards: list[BatchForward]) -> None: + prefill_tasks = [task for task in forwards if task.xs.shape[1] > 1] + decode_tasks = [task for task in forwards if task.xs.shape[1] == 1] + + for task in prefill_tasks: + h = self.step_forward( + task.xs, + task.layer_storage, + task.task_datas, + task.kvcache_manager, + flash_global_freqs_cis, + None, ) - - decode_tasks.sort(key=lambda task: task.xs.shape[0], reverse=False) - while len(decode_tasks) > 0: - total_bsz = 0 - todo_tasks = [] - for i in reversed(range(len(decode_tasks))): - cur_bsz = decode_tasks[i].xs.shape[0] - if total_bsz + cur_bsz > self.max_total_bsz: - continue - total_bsz += cur_bsz - todo_tasks.append(decode_tasks.pop(i)) - new_task_datas = [] - for task in todo_tasks: - new_task_datas.extend(task.task_datas) - new_xs = torch.cat([task.xs for task in todo_tasks]) - # TODO: check if all tasks share same information - new_task = BatchForward( - xs=new_xs, - layer_storage=todo_tasks[0].layer_storage, - task_datas=new_task_datas, - need_sample=todo_tasks[0].need_sample, - kvcache_manager=todo_tasks[0].kvcache_manager, - ) - h = self.forward(new_task) - self.post_process(h, new_task) - - def add_batch_forward(self, forwards: BatchForward) -> None: - self.handling_queue.put(forwards) - - def forward(self, tasks: BatchForward) -> torch.Tensor: + self.post_forward(h, task) + + print( + f"prefill_tasks: {len(prefill_tasks)}, decode_tasks: {sum(task.xs.shape[0] for task in decode_tasks)}" + ) + + decode_tasks.sort(key=lambda task: task.xs.shape[0], reverse=False) + while len(decode_tasks) > 0: + total_bsz = 0 + todo_tasks = [] + for i in reversed(range(len(decode_tasks))): + cur_bsz = decode_tasks[i].xs.shape[0] + if total_bsz + cur_bsz > self.max_total_bsz: + continue + total_bsz += cur_bsz + todo_tasks.append(decode_tasks.pop(i)) + new_task_datas = [] + for task in todo_tasks: + new_task_datas.extend(task.task_datas) + new_xs = torch.cat([task.xs for task in todo_tasks]) + # TODO: check if all tasks share same information + new_task = BatchForward( + xs=new_xs, + layer_storage=todo_tasks[0].layer_storage, + task_datas=new_task_datas, + need_sample=todo_tasks[0].need_sample, + kvcache_manager=todo_tasks[0].kvcache_manager, + ) + h = self.step_forward( + new_task.xs, + new_task.layer_storage, + new_task.task_datas, + new_task.kvcache_manager, + flash_global_freqs_cis, + None, + ) + self.post_forward(h, new_task) + + def handle_trace(self, tasks: list[SingleTrace]) -> None: + for task in tasks: + traces: dict[OpId, torch.Tensor] = {} + h = self.step_forward( + task.x, + task.layer_storage, + [task.task_data], + task.kvcache_manager, + global_freqs_cis, + traces, + ) + self.post_trace(h, traces, task) + + def step_forward( + self, + h: torch.Tensor, + layer_storage: LayerStorage, + task_datas: list[TaskData], + kvcache_manager: KVCacheManager, + global_freqs_cis: torch.Tensor, + traces: Optional[dict[OpId, torch.Tensor]], + ) -> torch.Tensor: # we need to check that all tasks share the same layer storage with torch.inference_mode(): - layer_storage = tasks.layer_storage - h = tasks.xs - bsz_list = [1 for _ in range(len(tasks.task_datas))] - start_pos_list = [task.start_pos for task in tasks.task_datas] - kvcache_list = [task.kvcaches for task in tasks.task_datas] + bsz_list = [1 for _ in range(len(task_datas))] + start_pos_list = [task.start_pos for task in task_datas] + kvcache_list = [task.kvcaches for task in task_datas] result = layer_storage.forward( h, bsz_list, start_pos_list, global_freqs_cis, kvcache_list, - tasks.kvcache_manager, + kvcache_manager, + traces, ) return result - def post_process(self, merged_h: torch.Tensor, tasks: BatchForward) -> None: + def post_forward(self, merged_h: torch.Tensor, tasks: BatchForward) -> None: if tasks.need_sample: layer_storage = tasks.layer_storage ongoing_tokens, ongoing_ids, all_tokens, all_ids, done_ids = ( layer_storage.sample(merged_h, tasks.task_datas) ) if len(ongoing_tokens) > 0: - self.sendback_queue.put( - BatchResult(torch.cat(ongoing_tokens), ongoing_ids) - ) - self.sendback_queue.put(BatchUpdate(all_tokens, all_ids, done_ids)) + self.sender.put(BatchResult(torch.cat(ongoing_tokens), ongoing_ids)) + self.sender.put(BatchUpdate(all_tokens, all_ids, done_ids)) else: seqlen = tasks.xs.shape[1] for task in tasks.task_datas: task.start_pos += seqlen - self.sendback_queue.put( + self.sender.put( BatchResult(merged_h, [task.task_id for task in tasks.task_datas]) ) + + def post_trace( + self, h: torch.Tensor, traces: dict[OpId, torch.Tensor], task: SingleTrace + ) -> None: + task_data = task.task_data + if task.need_sample: + layer_storage = task.layer_storage + ongoing_tokens, ongoing_ids, all_tokens, all_ids, done_ids = ( + layer_storage.sample(h, [task_data]) + ) + if len(ongoing_tokens) > 0: + # at most have one + self.sender.put( + TraceResult(torch.cat(ongoing_tokens), ongoing_ids[0], traces) + ) + self.sender.put(BatchUpdate(all_tokens, all_ids, done_ids)) + else: + seqlen = task.x.shape[1] + task_data.start_pos += seqlen + self.sender.put(TraceResult(h, task_data.task_id, traces)) + + def add_batch_forward(self, forwards: BatchForward) -> None: + self.receiver.put(forwards) + + def add_trace(self, trace: SingleTrace) -> None: + self.receiver.put(trace) diff --git a/deserve_worker/model/llama.py b/deserve_worker/model/llama.py index 58cf1e4..41751d0 100644 --- a/deserve_worker/model/llama.py +++ b/deserve_worker/model/llama.py @@ -3,28 +3,20 @@ import math import pickle from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, cast +from typing import Any, List, Mapping, Optional, Tuple, cast import safetensors.torch import torch import torch.nn.functional as F +from flash_attn import flash_attn_with_kvcache # type: ignore from torch import nn from deserve_worker.kvcache.paged_kvcache import PagedKVCache, PagedKVCacheManager +from deserve_worker.trace import ComponentId, LayerId, OpId from ..kvcache.kvcache import KVCache, KVCacheManager from ..kvcache.packed_kvcache import PackedKVCache -ENABLE_FLASH_ATTN = False -try: - from flash_attn import flash_attn_with_kvcache # type: ignore - - ENABLE_FLASH_ATTN = True -except ImportError as e: - print( - "Package flash-attn is not found. Please install it for better performance. https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features" - ) - @dataclass class ModelArgs: @@ -42,8 +34,15 @@ class ModelArgs: max_seq_len: int = 2048 +def trace_op( + traces: Optional[dict[OpId, torch.Tensor]], op_id: OpId, op_value: torch.Tensor +) -> None: + if traces is not None: + traces[op_id] = op_value + + class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__(self, component_id: ComponentId, dim: int, eps: float = 1e-6): """ Initialize the RMSNorm normalization layer. @@ -57,6 +56,7 @@ def __init__(self, dim: int, eps: float = 1e-6): """ super().__init__() + self.component_id = component_id self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) @@ -74,7 +74,11 @@ def _norm(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) @torch.inference_mode() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + traces: Optional[dict[OpId, torch.Tensor]], + ) -> torch.Tensor: """ Forward pass through the RMSNorm layer. @@ -86,7 +90,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ output = self._norm(x.float()).type_as(x) - return output * self.weight + # trace_op(traces, self.component_id.with_op("output"), output) + result = output * self.weight + # trace_op(traces, self.component_id.with_op("weighted_output"), result) + return result def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: @@ -160,7 +167,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class Attention(nn.Module): """Multi-head attention module.""" - def __init__(self, args: ModelArgs): + def __init__(self, component_id: ComponentId, args: ModelArgs): """ Initialize the Attention module. @@ -182,6 +189,7 @@ def __init__(self, args: ModelArgs): """ super().__init__() + self.component_id = component_id self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads # model_parallel_size = fs_init.get_model_parallel_world_size() self.n_local_heads = args.n_heads # // model_parallel_size @@ -214,23 +222,6 @@ def __init__(self, args: ModelArgs): bias=False, ) - # self.cache_k = torch.zeros( - # ( - # args.max_batch_size, - # args.max_seq_len, - # self.n_local_kv_heads, - # self.head_dim, - # ) - # ) - # self.cache_v = torch.zeros( - # ( - # args.max_batch_size, - # args.max_seq_len, - # self.n_local_kv_heads, - # self.head_dim, - # ) - # ) - @torch.inference_mode() def forward( self, @@ -240,6 +231,7 @@ def forward( global_freqs_cis: torch.Tensor, kvcache_list: list[KVCache], kvcache_manager: KVCacheManager, + traces: Optional[dict[OpId, torch.Tensor]], ) -> torch.Tensor: """ Forward pass of the attention module. @@ -257,7 +249,7 @@ def forward( _, seqlen, _ = x.shape xq_, xk_, xv_ = self.wq(x), self.wk(x), self.wv(x) - if ENABLE_FLASH_ATTN: + if isinstance(kvcache_manager, PagedKVCacheManager): cache_seqlens = [] for i, bsz in enumerate(bsz_list): cache_seqlens += [start_pos_list[i]] * bsz @@ -284,7 +276,6 @@ def forward( xv = xv_.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) cos = global_freqs_cis[0].type_as(xq) sin = global_freqs_cis[1].type_as(xq) - kvcache_manager = cast(PagedKVCacheManager, kvcache_manager) output = flash_attn_with_kvcache( xq, kvcache_manager.cache_k_paged, @@ -299,6 +290,7 @@ def forward( rotary_interleaved=True, ) output = output.view(bsz, seqlen, -1) + return self.wo(output) # type: ignore else: start = 0 output_list = [] @@ -312,6 +304,9 @@ def forward( xv = xv_[start : start + bsz].view( bsz, seqlen, self.n_local_kv_heads, self.head_dim ) + # trace_op(traces, self.component_id.with_op("xq"), xq) + # trace_op(traces, self.component_id.with_op("xk"), xk) + # trace_op(traces, self.component_id.with_op("xv"), xv) start += bsz start_pos = start_pos_list[i] @@ -327,6 +322,8 @@ def forward( mask = torch.triu(mask, diagonal=start_pos + 1).type_as(x) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + # trace_op(traces, self.component_id.with_op("xq_rotary"), xq) + # trace_op(traces, self.component_id.with_op("xk_rotary"), xk) cache_k[:bsz, start_pos : start_pos + seqlen] = xk cache_v[:bsz, start_pos : start_pos + seqlen] = xv @@ -348,6 +345,7 @@ def forward( scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt( self.head_dim ) + # trace_op(traces, self.component_id.with_op("scores"), scores) if mask is not None: scores = ( scores + mask @@ -356,15 +354,19 @@ def forward( output = torch.matmul( scores, values ) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output_list.append(output) output = torch.cat([x for x in output_list]) - return self.wo(output) # type: ignore + result = self.wo(output) + # trace_op(traces, self.component_id.with_op("weighted_output"), result) + return result # type: ignore class FeedForward(nn.Module): def __init__( self, + component_id: ComponentId, dim: int, hidden_dim: int, multiple_of: int, @@ -386,6 +388,7 @@ def __init__( """ super().__init__() + self.component_id = component_id hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier if ffn_dim_multiplier is not None: @@ -412,12 +415,85 @@ def __init__( ) @torch.inference_mode() - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) # type: ignore + def forward( + self, x: torch.Tensor, traces: Optional[dict[OpId, torch.Tensor]] + ) -> torch.Tensor: + w1 = F.silu(self.w1(x)) + w3 = self.w3(x) + w2 = self.w2(w1 * w3) + # trace_op(traces, self.component_id.with_op("w1"), w1) + # trace_op(traces, self.component_id.with_op("w3"), w3) + # trace_op(traces, self.component_id.with_op("w2"), w2) + + return w2 # type: ignore + + +class TraceLinear(nn.Module): + def __init__( + self, + component_id: ComponentId, + in_features: int, + out_features: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.component_id = component_id + self.linear = nn.Linear( + in_features, out_features, bias=False, device=device, dtype=dtype + ) + + @torch.inference_mode() + def forward( + self, x: torch.Tensor, traces: Optional[dict[OpId, torch.Tensor]] + ) -> torch.Tensor: + out = self.linear(x) + # trace_op(traces, self.component_id.with_op("output"), out) + return out # type: ignore + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ) -> torch.nn.modules.module._IncompatibleKeys: + return self.linear.load_state_dict(state_dict, strict, assign) # type: ignore + + +class TraceEmbedding(nn.Module): + def __init__( + self, + component_id: ComponentId, + num_embeddings: int, + embedding_dim: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.component_id = component_id + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, device=device, dtype=dtype + ) + + @torch.inference_mode() + def forward( + self, x: torch.Tensor, traces: Optional[dict[OpId, torch.Tensor]] + ) -> torch.Tensor: + out = self.embedding(x) + # trace_op(traces, self.component_id.with_op("output"), out) + return out # type: ignore + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ) -> torch.nn.modules.module._IncompatibleKeys: + return self.embedding.load_state_dict(state_dict, strict, assign) # type: ignore class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, layer_id: LayerId, args: ModelArgs): """ Initialize a TransformerBlock. @@ -437,19 +513,24 @@ def __init__(self, args: ModelArgs): """ super().__init__() + self.layer_id = layer_id self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads - self.attention = Attention(args) + self.attention = Attention(layer_id.with_component("attention"), args) self.feed_forward = FeedForward( + layer_id.with_component("feed_forward"), dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier, ) - # self.layer_id = layer_id - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = RMSNorm( + layer_id.with_component("attention_norm"), args.dim, eps=args.norm_eps + ) + self.ffn_norm = RMSNorm( + layer_id.with_component("ffn_norm"), args.dim, eps=args.norm_eps + ) @torch.inference_mode() def forward( @@ -460,6 +541,7 @@ def forward( global_freqs_cis: torch.Tensor, kvcache_list: list[KVCache], kvcache_manager: KVCacheManager, + traces: Optional[dict[OpId, torch.Tensor]], ) -> torch.Tensor: """ Perform a forward pass through the TransformerBlock. @@ -475,14 +557,19 @@ def forward( """ h = x + self.attention.forward( - self.attention_norm(x), + self.attention_norm(x, traces), bsz_list, start_pos_list, global_freqs_cis, kvcache_list, kvcache_manager, + traces, ) - out = h + self.feed_forward.forward(self.ffn_norm(h)) + # trace_op(traces, self.layer_id.with_component("attention").with_op("res"), h) + out = h + self.feed_forward.forward(self.ffn_norm(h, traces), traces) + # trace_op( + # traces, self.layer_id.with_component("feed_forward").with_op("res"), out + # ) return out diff --git a/deserve_worker/trace.py b/deserve_worker/trace.py new file mode 100644 index 0000000..a3e0e37 --- /dev/null +++ b/deserve_worker/trace.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass + + +@dataclass +class LayerId: + layer: str + + def with_component(self, component: str) -> "ComponentId": + return ComponentId(self.layer, component) + + +@dataclass +class ComponentId: + layer: str + component: str + + def with_op(self, op: str) -> "OpId": + return OpId(self.layer, self.component, op) + + +@dataclass +class OpId: + layer: str + component: str + op: str + + def __hash__(self) -> int: + return hash((self.layer, self.component, self.op)) diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 7ec24df..5949679 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -1,5 +1,6 @@ import queue import threading +import traceback from concurrent.futures import ThreadPoolExecutor from typing import Optional, cast @@ -7,11 +8,13 @@ import torch from transformers import AutoTokenizer # type: ignore -from .command import BatchForward, BatchResult, BatchUpdate -from .llm_engine import LLMEngine +from deserve_worker.kvcache.packed_kvcache import PackedKVCacheManager # type: ignore + +from .command import BatchForward, BatchResult, BatchUpdate, SingleTrace, TraceResult from .kvcache.kvcache import KVCache, main_device, main_dtype from .kvcache.paged_kvcache import PagedKVCacheManager from .layer_storage import LayerManager +from .llm_engine import LLMEngine from .model.llama import dumps from .task import PlanStep, SamplingParams, TaskData, TaskInfo @@ -25,11 +28,17 @@ def __init__(self, worker_id: str, max_total_bsz: int, controller_url: str): self.worker_id = worker_id self.controller_url = controller_url self.task_datas: dict[str, TaskData] = {} - self.relay_queue = queue.Queue[BatchResult | BatchUpdate]() - self.forward_engine = LLMEngine(max_total_bsz, self.relay_queue) + self.relay_queue = queue.Queue[BatchResult | BatchUpdate | TraceResult]() + self.llm_engine = LLMEngine(max_total_bsz, self.relay_queue) self.layer_manager = LayerManager(main_device) - self.kvcache_manager = PagedKVCacheManager(11600, 256, main_device, main_dtype) - threading.Thread(target=self.forward_engine.run, daemon=True).start() + # TODO: in future, different cache manager could allocate on same memory + self.paged_kvcache_manager = PagedKVCacheManager( + 11600, 256, main_device, main_dtype + ) + self.packed_kvcache_manager = PackedKVCacheManager( + 0, 256, main_device, main_dtype + ) + threading.Thread(target=self.llm_engine.run, daemon=True).start() threading.Thread(target=self.relay, daemon=True).start() self.network_executor = ThreadPoolExecutor(max_workers=max_total_bsz) @@ -39,7 +48,7 @@ def locate_in_plan(self, plan: list[PlanStep]) -> Optional[int]: None, ) - def init_task_data( + def init_forward_task_data( self, x: torch.Tensor, index: int, task_info: TaskInfo ) -> TaskData: if task_info.round == 0: @@ -48,7 +57,7 @@ def init_task_data( _, layer_name = full_layer_name.split("/") if layer_name.startswith("layers."): layer_id = int(layer_name.split(".")[1]) - kvcaches[layer_id] = self.kvcache_manager.alloc( + kvcaches[layer_id] = self.paged_kvcache_manager.alloc( x.shape[0], x.shape[1] ) @@ -68,6 +77,34 @@ def init_task_data( return task_data + def init_trace_task_data( + self, x: torch.Tensor, index: int, task_info: TaskInfo + ) -> TaskData: + if task_info.round == 0: + kvcaches = {} + for full_layer_name in task_info.plan[index].layers: + _, layer_name = full_layer_name.split("/") + if layer_name.startswith("layers."): + layer_id = int(layer_name.split(".")[1]) + kvcaches[layer_id] = self.packed_kvcache_manager.alloc( + x.shape[0], x.shape[1] + ) + + task_data = TaskData( + task_id=task_info.task_id, + start_pos=0, + plan=task_info.plan, + round=0, + sampling_params=task_info.sampling_params, + kvcaches=cast(dict[int, KVCache], kvcaches), + ) + self.task_datas[task_info.task_id] = task_data + else: + task_data = self.task_datas[task_info.task_id] + task_data.round = task_info.round + + return task_data + def batch_forward( self, xs: torch.Tensor, @@ -80,15 +117,16 @@ def batch_forward( task_infos[0].plan[index].layers ) task_datas = [ - self.init_task_data(xs, index, task_info) for task_info in task_infos + self.init_forward_task_data(xs, index, task_info) + for task_info in task_infos ] - self.forward_engine.add_batch_forward( + self.llm_engine.add_batch_forward( BatchForward( xs=xs.to(main_device), layer_storage=layer_storage, task_datas=task_datas, need_sample=(index == len(plan) - 1), - kvcache_manager=self.kvcache_manager, + kvcache_manager=self.paged_kvcache_manager, ) ) @@ -105,11 +143,11 @@ def forward( return None layer_storage = self.layer_manager.get_layer_storage(plan[index].layers) - layer_forward = BatchForward( + forward = BatchForward( xs=x.to(main_device), layer_storage=layer_storage, task_datas=[ - self.init_task_data( + self.init_forward_task_data( x, index, TaskInfo( @@ -121,9 +159,40 @@ def forward( ) ], need_sample=(index == len(plan) - 1), - kvcache_manager=self.kvcache_manager, + kvcache_manager=self.paged_kvcache_manager, ) - self.forward_engine.add_batch_forward(layer_forward) + self.llm_engine.add_batch_forward(forward) + + def trace( + self, + x: torch.Tensor, + task_id: str, + round: int, + plan: list[PlanStep], + sampling_params: SamplingParams, + ) -> None: + index = self.locate_in_plan(plan) + if index is None: + return None + + layer_storage = self.layer_manager.get_layer_storage(plan[index].layers) + trace = SingleTrace( + x=x.to(main_device), + layer_storage=layer_storage, + task_data=self.init_trace_task_data( + x, + index, + TaskInfo( + task_id=task_id, + plan=plan, + round=round, + sampling_params=sampling_params, + ), + ), + kvcache_manager=self.packed_kvcache_manager, + need_sample=(index == len(plan) - 1), + ) + self.llm_engine.add_trace(trace) def relay(self) -> None: q = self.relay_queue @@ -174,6 +243,28 @@ def relay(self) -> None: ) for task_id in result.cancel_ids: self.cancel(task_id, None, self.task_datas[task_id].plan) + elif isinstance(result, TraceResult): + task_id = result.task_id + task_info = self.task_datas[task_id] + plan = task_info.plan + index = self.locate_in_plan(plan) + assert index is not None + next_index = (index + 1) % len(plan) + next_worker_url = plan[next_index].worker_url + data = dumps( + {"x": result.x}, + { + "task_id": task_id, + "round": self.task_datas[task_id].round, + "plan": plan, + "sampling_params": self.task_datas[task_id].sampling_params, + }, + ) + self.network_executor.submit( + requests.post, + f"{next_worker_url}/trace", + data=data, + ) def cancel( self, task_id: str, start_index: Optional[int], plan: list[PlanStep] diff --git a/deserve_worker/worker_api.py b/deserve_worker/worker_api.py index 1e9a05b..c56ae43 100644 --- a/deserve_worker/worker_api.py +++ b/deserve_worker/worker_api.py @@ -44,6 +44,24 @@ async def forward(request: Request) -> str: return "ok" +@app.post("/trace") +async def trace(request: Request) -> str: + try: + body = await request.body() + tensors, metadata = loads(body) + runtime_executor.submit( + worker.trace, + tensors["x"], + metadata["task_id"], + metadata["round"], + [PlanStep.model_validate(step) for step in metadata["plan"]], + SamplingParams.model_validate(metadata["sampling_params"]), + ) + except Exception as e: + traceback.print_exc() + return "ok" + + class CancelRequest(BaseModel): task_id: str start_index: int From bee4e616dc2569c653ae677ea7b8bedb9ffa3ec3 Mon Sep 17 00:00:00 2001 From: Celve Date: Sat, 3 Aug 2024 01:36:34 -0700 Subject: [PATCH 12/17] feat: share memory for different KV cache --- deserve_worker/kvcache/block_pool.py | 55 +++++++++++ deserve_worker/kvcache/packed_kvcache.py | 118 ++++++++++------------- deserve_worker/kvcache/paged_kvcache.py | 53 +++------- deserve_worker/model/llama.py | 19 +++- deserve_worker/worker.py | 10 +- 5 files changed, 140 insertions(+), 115 deletions(-) create mode 100644 deserve_worker/kvcache/block_pool.py diff --git a/deserve_worker/kvcache/block_pool.py b/deserve_worker/kvcache/block_pool.py new file mode 100644 index 0000000..3ddc432 --- /dev/null +++ b/deserve_worker/kvcache/block_pool.py @@ -0,0 +1,55 @@ +from typing import Optional +import torch + + +class BlockPool: + def __init__( + self, + num_blocks: int, + block_size: int, + main_device: torch.device, + main_dtype: torch.dtype, + ): + self.num_blocks = num_blocks + self.block_size = block_size + self.main_device = main_device + self.main_dtype = main_dtype + self.fetch_size = 1024 + + self.block_ks = torch.randn( + num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype + ) + self.block_vs = torch.randn( + num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype + ) + self.block_bitmap = torch.ones( + (num_blocks,), device=main_device, dtype=torch.bool + ) + self.block_buffer = torch.empty(0, device=main_device, dtype=torch.int32) + + def alloc(self, size: int) -> Optional[torch.Tensor]: + if size > self.block_buffer.shape[0]: + fetch_size = max(self.fetch_size, size - self.block_buffer.shape[0]) + block_avails = torch.nonzero(self.block_bitmap)[:fetch_size] + self.block_bitmap[block_avails] = False + self.block_buffer = torch.cat([self.block_buffer, block_avails]) + if size > self.block_buffer.shape[0]: + return None + result = self.block_buffer[:size] + self.block_buffer = self.block_buffer[size:] + return result + + def alloc_consecutive(self, size: int) -> Optional[torch.Tensor]: + output, invert_indices, counts = torch.unique_consecutive( + self.block_bitmap, return_counts=True, return_inverse=True + ) + avail_bitmap: torch.Tensor = (counts >= size) & output + avail_indices = avail_bitmap.nonzero().flatten() + if avail_indices.shape[0] == 0: + return None + else: + index = avail_indices[0] + return (invert_indices == index).nonzero().flatten() + + def recycle(self, blocks: torch.Tensor) -> None: + self.block_bitmap[blocks] = True diff --git a/deserve_worker/kvcache/packed_kvcache.py b/deserve_worker/kvcache/packed_kvcache.py index dbdd6f1..a717aa0 100644 --- a/deserve_worker/kvcache/packed_kvcache.py +++ b/deserve_worker/kvcache/packed_kvcache.py @@ -2,7 +2,8 @@ import torch -from deserve_worker.kvcache.kvcache import KVCache, KVCacheManager +from deserve_worker.kvcache.block_pool import BlockPool +from deserve_worker.kvcache.kvcache import KVCache, KVCacheManager, main_device def del_tensor(t: torch.Tensor) -> None: @@ -12,17 +13,9 @@ def del_tensor(t: torch.Tensor) -> None: class PackedKVCacheManager(KVCacheManager): - def __init__( - self, - num_blocks: int, - block_size: int, - main_device: torch.device, - main_dtype: torch.dtype, - ): - self.num_blocks = num_blocks - self.block_size = block_size - self.main_device = main_device - self.main_dtype = main_dtype + def __init__(self, block_pool: BlockPool): + self.block_pool = block_pool + self.block_size = block_pool.block_size def get_kv_cache_length(self, cur: int, seqlen: int) -> int: while cur < seqlen: @@ -32,79 +25,70 @@ def get_kv_cache_length(self, cur: int, seqlen: int) -> int: def alloc(self, bsz: int, seqlen: int) -> Optional[KVCache]: len_token = self.get_kv_cache_length(0, seqlen) len_block = len_token // self.block_size - if bsz * len_block <= self.num_blocks: - self.num_blocks -= bsz * len_block - cache_k = torch.zeros( - (bsz, len_token, 8, 128), - device=self.main_device, - dtype=self.main_dtype, - ) - cache_v = torch.zeros( - (bsz, len_token, 8, 128), - device=self.main_device, - dtype=self.main_dtype, - ) - return PackedKVCache(cache_k, cache_v, self) - else: + total_block = len_block * bsz + blocks = self.block_pool.alloc(total_block) + # the consecutive block table is in shape of [bsz, len_block], which corresponds to [bsz, len_block * block_size, 8, 128] in memory + if blocks is None: return None + else: + return PackedKVCache(blocks.reshape(bsz, -1), self) def recycle(self, kvcache: KVCache) -> None: kvcache = cast(PackedKVCache, kvcache) - bsz, seqlen = kvcache.cache_k.shape[:2] - self.num_blocks += bsz * seqlen - - del_tensor(kvcache.cache_k) - del_tensor(kvcache.cache_v) - kvcache.cache_k = torch.empty( - (0, 0), device=self.main_device, dtype=self.main_dtype - ) - kvcache.cache_v = torch.empty( - (0, 0), device=self.main_device, dtype=self.main_dtype + self.block_pool.recycle(kvcache.csct_block_table.flatten()) + kvcache.csct_block_table = torch.empty( + (0, 0), device=main_device, dtype=torch.int32 ) - torch.cuda.empty_cache() def renew(self, kvcache: KVCache, bsz: int, seqlen: int, start_pos: int) -> bool: kvcache = cast(PackedKVCache, kvcache) - if start_pos + seqlen > kvcache.cache_k.shape[1]: + if ( + start_pos + seqlen + > kvcache.csct_block_table.shape[1] * self.block_pool.block_size + ): len_token = self.get_kv_cache_length( - kvcache.cache_k.shape[1], start_pos + seqlen + kvcache.csct_block_table.shape[1] * self.block_size, start_pos + seqlen ) len_block = len_token // self.block_size - if bsz * len_block <= self.num_blocks: - self.num_blocks -= bsz * len_token - cache_k = torch.zeros( - (bsz, len_token, 8, 128), - device=self.main_device, - dtype=self.main_dtype, - ) - cache_v = torch.zeros( - (bsz, len_token, 8, 128), - device=self.main_device, - dtype=self.main_dtype, - ) - cache_k[:, :start_pos, :, :], cache_v[:, :start_pos, :, :] = ( - kvcache.cache_k[:, :start_pos, :, :], - kvcache.cache_v[:, :start_pos, :, :], - ) - original_shape = bsz * kvcache.cache_k.shape[1] - del_tensor(kvcache.cache_k) - del_tensor(kvcache.cache_v) - self.num_blocks += original_shape - kvcache.cache_k = cache_k - kvcache.cache_v = cache_v - return True - return False + total_block = len_block * bsz + blocks = self.block_pool.alloc(total_block) + if blocks is None: + return False + else: + # the original blocks are viewed as [bsz, old_len_block * block_size, 8, 128] + # the new blocks are viewed as [bsz, len_block * block_size, 8, 128] + # we need to copy the old blocks to the new blocks + old_len_block = kvcache.csct_block_table.shape[1] + old_blocks = kvcache.csct_block_table.flatten() + old_block_ks = self.block_pool.block_ks[ + old_blocks[0] : old_blocks[-1] + 1 + ].view(bsz, old_len_block * self.block_size, 8, 128) + new_block_ks = self.block_pool.block_ks[ + blocks[0] : blocks[-1] + 1 + ].view(bsz, len_block * self.block_size, 8, 128) + new_block_ks[:, :start_pos, :, :] = old_block_ks[:, :start_pos, :, :] + + old_block_vs = self.block_pool.block_vs[ + old_blocks[0] : old_blocks[-1] + 1 + ].view(bsz, old_len_block * self.block_size, 8, 128) + new_block_vs = self.block_pool.block_vs[ + blocks[0] : blocks[-1] + 1 + ].view(bsz, len_block * self.block_size, 8, 128) + new_block_vs[:, :start_pos, :, :] = old_block_vs[:, :start_pos, :, :] + + self.block_pool.recycle(old_blocks) + kvcache.csct_block_table = blocks.reshape(bsz, -1) + + return True class PackedKVCache(KVCache): def __init__( self, - cache_k: torch.Tensor, - cache_v: torch.Tensor, + csct_block_table: torch.Tensor, manager: PackedKVCacheManager, ): - self.cache_k = cache_k - self.cache_v = cache_v + self.csct_block_table = csct_block_table # consecutive block table self.manager = manager def renew(self, bsz: int, seqlen: int, start_pos: int) -> bool: diff --git a/deserve_worker/kvcache/paged_kvcache.py b/deserve_worker/kvcache/paged_kvcache.py index 6f094f8..d140ca0 100644 --- a/deserve_worker/kvcache/paged_kvcache.py +++ b/deserve_worker/kvcache/paged_kvcache.py @@ -3,53 +3,28 @@ import torch +from deserve_worker.kvcache.block_pool import BlockPool + from .kvcache import KVCache, KVCacheManager, main_device, main_dtype class PagedKVCacheManager(KVCacheManager): def __init__( self, - num_blocks: int, - block_size: int, - main_device: torch.device, - main_dtype: torch.dtype, + block_pool: BlockPool, ): - self.num_blocks = num_blocks - self.block_size = block_size - self.cache_k_paged = torch.randn( - num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype - ) - self.cache_v_paged = torch.randn( - num_blocks, block_size, 8, 128, device=main_device, dtype=main_dtype - ) - self.block_bitmap = torch.zeros( - (num_blocks,), device=main_device, dtype=torch.bool - ) - self.block_buffer = torch.arange( - 0, num_blocks, device=main_device, dtype=torch.int32 - ) + self.block_pool = block_pool def get_kv_cache_length(self, cur: int, seqlen: int) -> int: while cur < seqlen: - cur += self.block_size + cur += self.block_pool.block_size return cur - def alloc_blocks(self, size: int) -> Optional[torch.Tensor]: - if size > self.block_buffer.shape[0]: - block_avails = torch.nonzero(self.block_bitmap) - self.block_bitmap[block_avails] = False - self.block_buffer = torch.cat([self.block_buffer, block_avails]) - if size > self.block_buffer.shape[0]: - return None - result = self.block_buffer[:size] - self.block_buffer = self.block_buffer[size:] - return result - def alloc(self, bsz: int, seqlen: int) -> Optional["PagedKVCache"]: len_token = self.get_kv_cache_length(0, seqlen) - len_block = len_token // self.block_size + len_block = len_token // self.block_pool.block_size total_block = len_block * bsz - blocks = self.alloc_blocks(total_block) + blocks = self.block_pool.alloc(total_block) if blocks is None: return None else: @@ -57,20 +32,24 @@ def alloc(self, bsz: int, seqlen: int) -> Optional["PagedKVCache"]: def recycle(self, kvcache: KVCache) -> None: kvcache = cast(PagedKVCache, kvcache) - self.block_bitmap[kvcache.block_table.flatten()] = True + self.block_pool.recycle(kvcache.block_table.flatten()) kvcache.block_table = torch.empty((0, 0), device=main_device, dtype=torch.int32) def renew(self, kvcache: KVCache, bsz: int, seqlen: int, start_pos: int) -> bool: kvcache = cast(PagedKVCache, kvcache) - if start_pos + seqlen > kvcache.block_table.shape[1] * self.block_size: + if ( + start_pos + seqlen + > kvcache.block_table.shape[1] * self.block_pool.block_size + ): len_block = ( self.get_kv_cache_length( - kvcache.block_table.shape[1] * self.block_size, start_pos + seqlen + kvcache.block_table.shape[1] * self.block_pool.block_size, + start_pos + seqlen, ) - // self.block_size + // self.block_pool.block_size ) total_block = (len_block - kvcache.block_table.shape[1]) * bsz - blocks = self.alloc_blocks(total_block) + blocks = self.block_pool.alloc(total_block) if blocks is None: return False else: diff --git a/deserve_worker/model/llama.py b/deserve_worker/model/llama.py index 41751d0..489b06d 100644 --- a/deserve_worker/model/llama.py +++ b/deserve_worker/model/llama.py @@ -15,7 +15,7 @@ from deserve_worker.trace import ComponentId, LayerId, OpId from ..kvcache.kvcache import KVCache, KVCacheManager -from ..kvcache.packed_kvcache import PackedKVCache +from ..kvcache.packed_kvcache import PackedKVCache, PackedKVCacheManager @dataclass @@ -278,8 +278,8 @@ def forward( sin = global_freqs_cis[1].type_as(xq) output = flash_attn_with_kvcache( xq, - kvcache_manager.cache_k_paged, - kvcache_manager.cache_v_paged, + kvcache_manager.block_pool.block_ks, + kvcache_manager.block_pool.block_vs, xk, xv, rotary_cos=cos, @@ -292,6 +292,7 @@ def forward( output = output.view(bsz, seqlen, -1) return self.wo(output) # type: ignore else: + kvcache_manager = cast(PackedKVCacheManager, kvcache_manager) start = 0 output_list = [] for i, bsz in enumerate(bsz_list): @@ -310,8 +311,16 @@ def forward( start += bsz start_pos = start_pos_list[i] + # remember consecutive block table [bsz, len] corresponds to memory [bsz, len * block_size, 8, 128] kv_cache: PackedKVCache = cast(PackedKVCache, kvcache_list[i]) - cache_k, cache_v = kv_cache.cache_k, kv_cache.cache_v + csct_block_table = kv_cache.csct_block_table.flatten() + block_bsz, block_len = kv_cache.csct_block_table.shape[:2] + cache_k = kvcache_manager.block_pool.block_ks[ + csct_block_table[0] : csct_block_table[-1] + 1 + ].view(block_bsz, block_len * kvcache_manager.block_size, 8, 128) + cache_v = kvcache_manager.block_pool.block_vs[ + csct_block_table[0] : csct_block_table[-1] + 1 + ].view(block_bsz, block_len * kvcache_manager.block_size, 8, 128) freqs_cis = global_freqs_cis[start_pos : start_pos + seqlen] mask = None @@ -425,7 +434,7 @@ def forward( # trace_op(traces, self.component_id.with_op("w3"), w3) # trace_op(traces, self.component_id.with_op("w2"), w2) - return w2 # type: ignore + return w2 # type: ignore class TraceLinear(nn.Module): diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 5949679..01d4836 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -8,6 +8,7 @@ import torch from transformers import AutoTokenizer # type: ignore +from deserve_worker.kvcache.block_pool import BlockPool from deserve_worker.kvcache.packed_kvcache import PackedKVCacheManager # type: ignore from .command import BatchForward, BatchResult, BatchUpdate, SingleTrace, TraceResult @@ -31,13 +32,10 @@ def __init__(self, worker_id: str, max_total_bsz: int, controller_url: str): self.relay_queue = queue.Queue[BatchResult | BatchUpdate | TraceResult]() self.llm_engine = LLMEngine(max_total_bsz, self.relay_queue) self.layer_manager = LayerManager(main_device) + self.block_pool = BlockPool(11600, 256, main_device, main_dtype) # TODO: in future, different cache manager could allocate on same memory - self.paged_kvcache_manager = PagedKVCacheManager( - 11600, 256, main_device, main_dtype - ) - self.packed_kvcache_manager = PackedKVCacheManager( - 0, 256, main_device, main_dtype - ) + self.paged_kvcache_manager = PagedKVCacheManager(self.block_pool) + self.packed_kvcache_manager = PackedKVCacheManager(self.block_pool) threading.Thread(target=self.llm_engine.run, daemon=True).start() threading.Thread(target=self.relay, daemon=True).start() self.network_executor = ThreadPoolExecutor(max_workers=max_total_bsz) From a706df12eb05f0b65235e39a86ea0d7739e4083e Mon Sep 17 00:00:00 2001 From: Celve Date: Sat, 3 Aug 2024 01:48:46 -0700 Subject: [PATCH 13/17] docs(worker): add readme --- deserve_worker/README.md | 67 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 deserve_worker/README.md diff --git a/deserve_worker/README.md b/deserve_worker/README.md new file mode 100644 index 0000000..09ae28b --- /dev/null +++ b/deserve_worker/README.md @@ -0,0 +1,67 @@ +# deserve worker + +## How to run + +```bash +python3 -m deserve_worker.worker_api +``` + +For example, + +```bash +python3 -m deserve_worker.worker_api 8080 worker0 +``` + +## API + +### Inference + +To inference, you need to pass a plan and other metadata in the request body. You have to send it to the first worker. The plan is a list of workers with their layers. The first worker will send the request to the next worker in the plan. The last worker will return the token to the controller. Here is an example: + +```python +plan = [ + { + "worker_id": worker_id0, + "worker_url": "http://localhost:8080", + "layers": [ + "llama-3-8b-instruct-slice/tok_embeddings", + *[f"llama-3-8b-instruct-slice/layers.{i}" for i in range(0, 16)], + ], + }, + { + "worker_id": worker_id1, + "worker_url": "http://localhost:8081", + "layers": [ + *[f"llama-3-8b-instruct-slice/layers.{i}" for i in range(16, 32)], + "llama-3-8b-instruct-slice/norm", + "llama-3-8b-instruct-slice/output", + ], + }, +] + +metadata = { + "task_id": task_id, + "round": 0, + "plan": plan, + "sampling_params": { + "temperature": 0.0, + "top_p": 1.0, + "max_total_len": 2048, + }, +} + +tensors = {"x": tokens} + +requests.post( + "http://localhost:8080/forward", data=dumps(tensors, metadata) +) +``` + +### Trace + +To trace, the plan is also required. It is worth noting that trace use different kernel for computation and dumping. + + +### Cancel + +You should not cancel a task. It's used for freeing resources like KV caches. \ No newline at end of file From 757301b4335084a5730d8ff8e60688e40eb7d19f Mon Sep 17 00:00:00 2001 From: Celve Date: Sat, 3 Aug 2024 11:10:53 -0700 Subject: [PATCH 14/17] fix(worker): enable tracing --- deserve_worker/model/llama.py | 36 +++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/deserve_worker/model/llama.py b/deserve_worker/model/llama.py index 489b06d..5e2f701 100644 --- a/deserve_worker/model/llama.py +++ b/deserve_worker/model/llama.py @@ -90,9 +90,9 @@ def forward( """ output = self._norm(x.float()).type_as(x) - # trace_op(traces, self.component_id.with_op("output"), output) + trace_op(traces, self.component_id.with_op("output"), output) result = output * self.weight - # trace_op(traces, self.component_id.with_op("weighted_output"), result) + trace_op(traces, self.component_id.with_op("weighted_output"), result) return result @@ -305,9 +305,9 @@ def forward( xv = xv_[start : start + bsz].view( bsz, seqlen, self.n_local_kv_heads, self.head_dim ) - # trace_op(traces, self.component_id.with_op("xq"), xq) - # trace_op(traces, self.component_id.with_op("xk"), xk) - # trace_op(traces, self.component_id.with_op("xv"), xv) + trace_op(traces, self.component_id.with_op("xq"), xq) + trace_op(traces, self.component_id.with_op("xk"), xk) + trace_op(traces, self.component_id.with_op("xv"), xv) start += bsz start_pos = start_pos_list[i] @@ -331,8 +331,8 @@ def forward( mask = torch.triu(mask, diagonal=start_pos + 1).type_as(x) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - # trace_op(traces, self.component_id.with_op("xq_rotary"), xq) - # trace_op(traces, self.component_id.with_op("xk_rotary"), xk) + trace_op(traces, self.component_id.with_op("xq_rotary"), xq) + trace_op(traces, self.component_id.with_op("xk_rotary"), xk) cache_k[:bsz, start_pos : start_pos + seqlen] = xk cache_v[:bsz, start_pos : start_pos + seqlen] = xv @@ -354,7 +354,7 @@ def forward( scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt( self.head_dim ) - # trace_op(traces, self.component_id.with_op("scores"), scores) + trace_op(traces, self.component_id.with_op("scores"), scores) if mask is not None: scores = ( scores + mask @@ -368,7 +368,7 @@ def forward( output_list.append(output) output = torch.cat([x for x in output_list]) result = self.wo(output) - # trace_op(traces, self.component_id.with_op("weighted_output"), result) + trace_op(traces, self.component_id.with_op("weighted_output"), result) return result # type: ignore @@ -430,9 +430,9 @@ def forward( w1 = F.silu(self.w1(x)) w3 = self.w3(x) w2 = self.w2(w1 * w3) - # trace_op(traces, self.component_id.with_op("w1"), w1) - # trace_op(traces, self.component_id.with_op("w3"), w3) - # trace_op(traces, self.component_id.with_op("w2"), w2) + trace_op(traces, self.component_id.with_op("w1"), w1) + trace_op(traces, self.component_id.with_op("w3"), w3) + trace_op(traces, self.component_id.with_op("w2"), w2) return w2 # type: ignore @@ -457,7 +457,7 @@ def forward( self, x: torch.Tensor, traces: Optional[dict[OpId, torch.Tensor]] ) -> torch.Tensor: out = self.linear(x) - # trace_op(traces, self.component_id.with_op("output"), out) + trace_op(traces, self.component_id.with_op("output"), out) return out # type: ignore def load_state_dict( @@ -489,7 +489,7 @@ def forward( self, x: torch.Tensor, traces: Optional[dict[OpId, torch.Tensor]] ) -> torch.Tensor: out = self.embedding(x) - # trace_op(traces, self.component_id.with_op("output"), out) + trace_op(traces, self.component_id.with_op("output"), out) return out # type: ignore def load_state_dict( @@ -574,11 +574,11 @@ def forward( kvcache_manager, traces, ) - # trace_op(traces, self.layer_id.with_component("attention").with_op("res"), h) + trace_op(traces, self.layer_id.with_component("attention").with_op("res"), h) out = h + self.feed_forward.forward(self.ffn_norm(h, traces), traces) - # trace_op( - # traces, self.layer_id.with_component("feed_forward").with_op("res"), out - # ) + trace_op( + traces, self.layer_id.with_component("feed_forward").with_op("res"), out + ) return out From b08ec03fde25f393dcb468f94e87879ac2abbe27 Mon Sep 17 00:00:00 2001 From: Celve Date: Sun, 4 Aug 2024 17:03:51 -0700 Subject: [PATCH 15/17] feat: add controller and client --- deserve_client/README.md | 18 + deserve_client/__init__.py | 0 deserve_client/client.py | 99 +++++ deserve_client/model.py | 625 +++++++++++++++++++++++++++ deserve_client/py.typed | 0 deserve_client/pyproject.toml | 13 + deserve_controller/README.md | 7 + deserve_controller/__init__.py | 0 deserve_controller/controller_api.py | 266 ++++++++++++ deserve_controller/py.typed | 0 deserve_controller/pyproject.toml | 13 + deserve_worker/README.md | 2 +- deserve_worker/kvcache/block_pool.py | 1 + deserve_worker/layer_storage.py | 4 +- deserve_worker/llm_engine.py | 1 + deserve_worker/model/llama.py | 3 +- deserve_worker/trace.py | 23 + deserve_worker/worker.py | 47 +- deserve_worker/worker_api.py | 19 +- 19 files changed, 1124 insertions(+), 17 deletions(-) create mode 100644 deserve_client/README.md create mode 100644 deserve_client/__init__.py create mode 100644 deserve_client/client.py create mode 100644 deserve_client/model.py create mode 100644 deserve_client/py.typed create mode 100644 deserve_client/pyproject.toml create mode 100644 deserve_controller/README.md create mode 100644 deserve_controller/__init__.py create mode 100644 deserve_controller/controller_api.py create mode 100644 deserve_controller/py.typed create mode 100644 deserve_controller/pyproject.toml diff --git a/deserve_client/README.md b/deserve_client/README.md new file mode 100644 index 0000000..ff7a99d --- /dev/null +++ b/deserve_client/README.md @@ -0,0 +1,18 @@ +# DeServe Client + +## How To Run + +For completion: +```bash +python3 -m deserve_client.client complete meta-llama/Meta-Llama-3-8B-Instruct "Here is a text prompt." +``` + +For dumping traces of prefill: +```bash +python3 -m deserve_client.client trace meta-llama/Meta-Llama-3-8B-Instruct "Here is a text prompt." +``` + +For verifying the correctness of the trace: +```bash +python3 -m deserve_client.client verify meta-llama/Meta-Llama-3-8B-Instruct "Here is a text prompt." +``` diff --git a/deserve_client/__init__.py b/deserve_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deserve_client/client.py b/deserve_client/client.py new file mode 100644 index 0000000..ae99e32 --- /dev/null +++ b/deserve_client/client.py @@ -0,0 +1,99 @@ +import pickle +from typing import Any + +import requests +import safetensors.torch +import torch +import typer +from transformers import AutoTokenizer # type: ignore + +from deserve_client.model import ( + CheckCtx, + Transformer, + VerifyCtx, + llama_3_8b_args, + main_device, +) +from deserve_controller.controller_api import app +from deserve_worker.trace import OpId + +cli = typer.Typer() +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + + +def loads(b: bytes) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + """ + Load tensors and metadata from bytes + """ + + metadata_length = int.from_bytes(b[:4], byteorder="big") + metadata = pickle.loads(b[4 : 4 + metadata_length]) + tensors = safetensors.torch.load(b[4 + metadata_length :]) + return tensors, metadata + + +@cli.command() +def complete(model: str, prompt: str, entry_point: str = "http://localhost:19000"): + response = requests.post( + f"{entry_point}/complete", + json={"model": model, "prompt": prompt}, + stream=True, + ) + if response.status_code != 200: + typer.echo("Error") + return + + for chunk in response.iter_content(): + if chunk: + print(chunk.decode("utf-8"), end="", flush=True) + + +@cli.command() +def trace(model: str, prompt: str, entry_point: str = "http://localhost:19000"): + response = requests.post( + f"{entry_point}/trace", + json={"model": model, "prompt": prompt}, + stream=True, + ) + if response.status_code != 200: + typer.echo("Error") + return + + tensors = {} + for chunk in response.iter_content(chunk_size=None): + if chunk: + temp_tensors, _ = loads(chunk) + tensors.update(temp_tensors) + print(list(tensors.keys())) + +@cli.command() +def verify(model: str, prompt: str, entry_point: str = "http://localhost:19000"): + response = requests.post( + f"{entry_point}/trace", + json={"model": model, "prompt": prompt}, + stream=True, + ) + if response.status_code != 200: + typer.echo("Error") + return + tensors: dict[str, torch.Tensor] = {} + for chunk in response.iter_content(chunk_size=None): + if chunk: + temp_tensors, _ = loads(chunk) + tensors.update(temp_tensors) + + traces = {OpId.from_str(k): v for k, v in tensors.items()} + transformer = Transformer(llama_3_8b_args) + tokens = tokenizer(prompt, return_tensors="pt")["input_ids"].to(main_device) + result = transformer.forward(tokens, CheckCtx(0.03, traces)) + if isinstance(result, torch.Tensor): + print("No difference found") + else: + if not transformer.verify(tokens, VerifyCtx(result.op_id, 0.03, traces)): + print("Difference found for", result.op_id) + else: + print("Difference found but verification failed") + + +if __name__ == "__main__": + cli() diff --git a/deserve_client/model.py b/deserve_client/model.py new file mode 100644 index 0000000..4c7d3b9 --- /dev/null +++ b/deserve_client/model.py @@ -0,0 +1,625 @@ +import math +import os +from dataclasses import dataclass +from typing import Any, Mapping, Optional + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import AutoTokenizer # type: ignore + +from deserve_worker.trace import ComponentId, LayerId, OpId + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") +torch.set_default_dtype(torch.float16) +main_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 500000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +llama_3_8b_args = ModelArgs( + n_kv_heads=8, + vocab_size=128256, + multiple_of=1024, + ffn_dim_multiplier=1.3, + norm_eps=1e-5, + rope_theta=500000.0, + ) + + +@dataclass +class Diff: + op_id: OpId + diff: float + +@dataclass +class CheckCtx: + threshold: float + traces: dict[OpId, torch.Tensor] + + def check(self, op_id: OpId, x: torch.Tensor) -> torch.Tensor | Diff: + y = self.traces[op_id].to(main_device) + if torch.allclose(x, y, atol=self.threshold): + return y + else: + return Diff(op_id, torch.max(torch.abs(x - y)).item()) + +@dataclass +class VerifyCtx: + op_id: OpId + threshold: float + traces: dict[OpId, torch.Tensor] + + def get_trace(self, op_id: OpId) -> torch.Tensor: + return self.traces[op_id].to(main_device) + + def verify(self, x: torch.Tensor) -> bool: + y = self.traces[self.op_id].to(main_device) + return torch.allclose(x, y, atol=self.threshold) + + +class RMSNorm(nn.Module): + def __init__(self, component_id: ComponentId, dim: int, eps: float = 1e-6): + super().__init__() + self.component_id = component_id + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def verify(self, x: torch.Tensor, ctx: VerifyCtx) -> bool: + op = ctx.op_id.op + if op == "output": + return ctx.verify(self._norm(x.float()).type_as(x)) + else: + output = ctx.get_trace(self.component_id.with_op("weighted_output")) + return ctx.verify(output * self.weight) + + + def forward( + self, + x: torch.Tensor, + ctx: CheckCtx + ) -> torch.Tensor | Diff: + output = ctx.check(self.component_id.with_op("output"), self._norm(x.float()).type_as(x)) + if isinstance(output, Diff): + return output + return ctx.check(self.component_id.with_op("weighted_output"), output * self.weight) + + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + def __init__(self, component_id: ComponentId, args: ModelArgs): + super().__init__() + self.component_id = component_id + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + self.n_local_heads = args.n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = torch.nn.utils.skip_init( + nn.Linear, + args.dim, + args.n_heads * self.head_dim, + bias=False, + ) + self.wk = torch.nn.utils.skip_init( + nn.Linear, + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wv = torch.nn.utils.skip_init( + nn.Linear, + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wo = torch.nn.utils.skip_init( + nn.Linear, + args.n_heads * self.head_dim, + args.dim, + bias=False, + ) + + def verify( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ctx: VerifyCtx, + ) -> bool: + bsz, seqlen, _ = x.shape + op = ctx.op_id.op + if op == "xq": + xq = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim) + return ctx.verify(xq) + elif op == "xk": + xk = self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + return ctx.verify(xk) + elif op == "xv": + xv = self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + return ctx.verify(xv) + elif op == "xq_rotary" or op == "xk_rotary": + xq, xk = self.wq(x), self.wk(x) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + if op == "xq_rotary": + return ctx.verify(xq) + else: + return ctx.verify(xk) + elif op == "scores": + xq = ctx.get_trace(self.component_id.with_op("xq_rotary")) + keys = ctx.get_trace(self.component_id.with_op("xk_rotary")) + keys = repeat_kv( + keys, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose( + 1, 2 + ) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + return ctx.verify(scores) + elif op == "output": + scores = ctx.get_trace(self.component_id.with_op("scores")) + values = ctx.get_trace(self.component_id.with_op("xv")) + values = repeat_kv( + values, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = values.transpose( + 1, 2 + ) # (bs, n_local_heads, cache_len + seqlen, head_dim) + output = torch.matmul( + scores, values + ) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return ctx.verify(output) + elif op == "weighted_output": + output = ctx.get_trace(self.component_id.with_op("output")) + return ctx.verify(self.wo(output)) + assert False + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ctx: CheckCtx, + ) -> torch.Tensor | Diff: + bsz, seqlen, _ = x.shape + + xq = ctx.check(self.component_id.with_op("xq"), self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)) + if isinstance(xq, Diff): + return xq + + xk = ctx.check(self.component_id.with_op("xk"), self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)) + if isinstance(xk, Diff): + return xk + + xv = ctx.check(self.component_id.with_op("xv"), self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)) + if isinstance(xv, Diff): + return xv + + xq_new, xk_new = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq = ctx.check(self.component_id.with_op("xq_rotary"), xq_new) + if isinstance(xq, Diff): + return xq + + xk = ctx.check(self.component_id.with_op("xk_rotary"), xk_new) + if isinstance(xk, Diff): + return xk + + keys = xk.clone() + values = xv.clone() + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv( + keys, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv( + values, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose( + 1, 2 + ) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores_new = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores_new = scores_new + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores_new = F.softmax(scores_new.float(), dim=-1).type_as(xq) + + # check scores + scores = ctx.check(self.component_id.with_op("scores"), scores_new) + if isinstance(scores, Diff): + return scores + + output_new = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output_new = output_new.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + output = ctx.check(self.component_id.with_op("output"), output_new) + if isinstance(output, Diff): + return output + + return ctx.check(self.component_id.with_op("weighted_output"), self.wo(output)) + + +class FeedForward(nn.Module): + def __init__( + self, + component_id: ComponentId, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + self.component_id = component_id + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = torch.nn.utils.skip_init( + nn.Linear, + dim, + hidden_dim, + bias=False, + ) + self.w2 = torch.nn.utils.skip_init( + nn.Linear, + hidden_dim, + dim, + bias=False, + ) + self.w3 = torch.nn.utils.skip_init( + nn.Linear, + dim, + hidden_dim, + bias=False, + ) + + def verify(self, x: torch.Tensor, ctx: VerifyCtx) -> bool: + op = ctx.op_id.op + if op == "w1": + return ctx.verify(F.silu(self.w1(x))) + elif op == "w3": + return ctx.verify(self.w3(x)) + elif op == "w2": + w1 = ctx.get_trace(self.component_id.with_op("w1")) + w3 = ctx.get_trace(self.component_id.with_op("w3")) + return ctx.verify(self.w2(w1 * w3)) + assert False + + def forward( + self, x: torch.Tensor, ctx: CheckCtx, + ) -> torch.Tensor | Diff: + # check w1, w3, w2 + w1 = ctx.check(self.component_id.with_op("w1"), F.silu(self.w1(x))) + if isinstance(w1, Diff): + return w1 + + w3 = ctx.check(self.component_id.with_op("w3"), self.w3(x)) + if isinstance(w3, Diff): + return w3 + + return ctx.check(self.component_id.with_op("w2"), self.w2(w1 * w3)) + +class TraceLinear(nn.Module): + def __init__( + self, + component_id: ComponentId, + in_features: int, + out_features: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.component_id = component_id + self.linear = nn.Linear( + in_features, out_features, bias=False, device=device, dtype=dtype + ) + + @torch.inference_mode() + def forward( + self, x: torch.Tensor, ctx: CheckCtx + ) -> torch.Tensor | Diff: + return ctx.check(self.component_id.with_op("output"), self.linear(x)) + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ) -> torch.nn.modules.module._IncompatibleKeys: + return self.linear.load_state_dict(state_dict, strict, assign) # type: ignore + +class TraceEmbedding(nn.Module): + def __init__( + self, + component_id: ComponentId, + num_embeddings: int, + embedding_dim: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.component_id = component_id + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, device=device, dtype=dtype + ) + + @torch.inference_mode() + def forward( + self, x: torch.Tensor, ctx: CheckCtx + ) -> torch.Tensor | Diff: + return ctx.check(self.component_id.with_op("output"), self.embedding(x)) + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ) -> torch.nn.modules.module._IncompatibleKeys: + return self.embedding.load_state_dict(state_dict, strict, assign) # type: ignore + + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: LayerId, args: ModelArgs): + super().__init__() + self.layer_id = layer_id + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(layer_id.with_component("attention"), args) + self.feed_forward = FeedForward( + layer_id.with_component("feed_forward"), + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm( + layer_id.with_component("attention_norm"), args.dim, eps=args.norm_eps + ) + self.ffn_norm = RMSNorm(layer_id.with_component("ffn_norm"), args.dim, eps=args.norm_eps) + + def verify( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ctx: VerifyCtx + ) -> bool: + layer = ctx.op_id.layer + component = ctx.op_id.component + op = ctx.op_id.op + if component == "feed_forward": + if op == "res": + return ctx.verify(ctx.get_trace(OpId(layer, "attention", "res")) + ctx.get_trace(OpId(layer, "feed_forward", "w2"))) + else: + return self.feed_forward.verify( + ctx.get_trace(OpId(layer, "ffn_norm", "weighted_output")), ctx + ) + elif component == "ffn_norm": + return self.ffn_norm.verify(ctx.get_trace(OpId(layer, "attention", "res")), ctx) + elif component == "attention_norm": + return self.attention_norm.verify(x, ctx) + elif component == "attention": + if op == "res": + return ctx.verify(x + ctx.get_trace(OpId(layer, "attention", "weighted_output"))) + else: + return self.attention.verify( + ctx.get_trace(OpId(layer, "attention_norm", "weighted_output")), + freqs_cis, + mask, + ctx + ) + assert False + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ctx: CheckCtx, + ) -> torch.Tensor | Diff: + attn_norm = self.attention_norm.forward(x, ctx) + if isinstance(attn_norm, Diff): + return attn_norm + + attn = self.attention.forward(attn_norm, freqs_cis, mask, ctx) + if isinstance(attn, Diff): + return attn + + h = ctx.check(self.layer_id.with_component("attention").with_op("res"), x + attn) + if isinstance(h, Diff): + return h + + ffn_norm = self.ffn_norm.forward(h, ctx) + if isinstance(ffn_norm, Diff): + return ffn_norm + + ffn = self.feed_forward.forward(ffn_norm, ctx) + if isinstance(ffn, Diff): + return ffn + + return ctx.check(self.layer_id.with_component("feed_forward").with_op("res"), h + ffn) + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + cache_dir = "~/.cache/fleece-worker/models/llama-3-8b-instruct-slice/" + cache_dir = os.path.expanduser(cache_dir) + + self.tok_embeddings = torch.nn.utils.skip_init( + TraceEmbedding, ComponentId("tok_embeddings", "main"), params.vocab_size, params.dim + ) + self.tok_embeddings.load_state_dict( + torch.load(cache_dir + "tok_embeddings.pt", map_location="cpu") + ) + self.tok_embeddings.to(main_device) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + layer = TransformerBlock(LayerId.from_str(f"{layer_id:02}"), params) + layer.load_state_dict( + torch.load(cache_dir + f"layers.{layer_id}.pt", map_location="cpu") + ) + layer.to(main_device) + self.layers.append(layer) + + self.norm = RMSNorm(ComponentId("norm", "main"), params.dim, eps=params.norm_eps) + self.norm.load_state_dict(torch.load(cache_dir + "norm.pt", map_location="cpu")) + self.norm.to(main_device) + self.output = torch.nn.utils.skip_init( + TraceLinear, ComponentId("output", "main"), params.dim, params.vocab_size + ) + self.output.load_state_dict( + torch.load(cache_dir + "output.pt", map_location="cpu") + ) + self.output.to(main_device) + + self.freqs_cis = precompute_freqs_cis( + params.dim // params.n_heads, + params.max_seq_len * 2, + params.rope_theta, + ) + + @torch.inference_mode() + def verify( + self, tokens: torch.Tensor, ctx: VerifyCtx + ) -> bool: + _bsz, seqlen = tokens.shape + layer = ctx.op_id.layer + if layer.isdigit(): + mask = None + if seqlen > 1: + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) + mask = torch.triu(mask, diagonal=1) + mask = torch.hstack( + [torch.zeros((seqlen, 0), device=tokens.device), mask] + ).type_as(tokens) + + layer_int = int(layer) + if layer_int < 0 or layer_int >= self.n_layers: + assert False + + if layer_int == 0: + input = ctx.get_trace(OpId("tok_embeddings", "main", "output")) + else: + input = ctx.get_trace(OpId(f"{layer_int - 1:02}", "feed_forward", "res")) + + return self.layers[layer_int].verify(input, self.freqs_cis, mask, ctx) + elif layer == "tok_embeddings": + return self.tok_embeddings.verify(tokens, ctx) + elif layer == "norm": + num_layers = self.n_layers + return self.norm.verify(ctx.get_trace(OpId(f"{num_layers - 1:02}", "feed_forward", "res")), ctx) + elif layer == "output": + return self.output.verify(ctx.get_trace(OpId("norm", "main", "output")), ctx) + assert False + + @torch.inference_mode() + def forward( + self, tokens: torch.Tensor, ctx: CheckCtx + ) -> torch.Tensor | Diff: + _bsz, seqlen = tokens.shape + + h = self.tok_embeddings.forward(tokens, ctx) + if isinstance(h, Diff): + return h + + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[0:seqlen] + + mask = None + if seqlen > 1: + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) + + mask = torch.triu(mask, diagonal=1) + + # When performing key-value caching, we compute the attention scores + # only for the new sequence. Thus, the matrix of scores is of size + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. + mask = torch.hstack( + [torch.zeros((seqlen, 0), device=tokens.device), mask] + ).type_as(h) + + for layer in self.layers: + h = layer.forward(h, freqs_cis, mask, ctx) + if isinstance(h, Diff): + return h + h = self.norm.forward(h, ctx) + if isinstance(h, Diff): + return h + + output = self.output.forward(h, ctx) + if isinstance(output, Diff): + return output + return output.float() diff --git a/deserve_client/py.typed b/deserve_client/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/deserve_client/pyproject.toml b/deserve_client/pyproject.toml new file mode 100644 index 0000000..b73ffc6 --- /dev/null +++ b/deserve_client/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "deserve_client" +version = "0.0.1" +authors = [ + { name="Example Author", email="author@example.com" }, +] +description = "Deserve Client" +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] diff --git a/deserve_controller/README.md b/deserve_controller/README.md new file mode 100644 index 0000000..8098686 --- /dev/null +++ b/deserve_controller/README.md @@ -0,0 +1,7 @@ +# DeServe Controller + +## How to run + +```bash +python3 -m deserve_controller.controller_api --port= +``` \ No newline at end of file diff --git a/deserve_controller/__init__.py b/deserve_controller/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deserve_controller/controller_api.py b/deserve_controller/controller_api.py new file mode 100644 index 0000000..5e2c1da --- /dev/null +++ b/deserve_controller/controller_api.py @@ -0,0 +1,266 @@ +import argparse +import logging +import pickle +import queue +import traceback +import uuid +from typing import Any, Generator, Optional + +import requests +import safetensors.torch +import torch +from cachetools import TTLCache +from fastapi import FastAPI, HTTPException, Request, Response +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from transformers import AutoTokenizer # type: ignore + +controller_url: str +app = FastAPI() +logger = logging.getLogger("uvicorn") +workers: TTLCache[str, str] = TTLCache(maxsize=128, ttl=2) +model2layers = { + "meta-llama/Meta-Llama-3-70B-Instruct": 80, + "meta-llama/Meta-Llama-3-8B-Instruct": 32, +} +model2alias = { + "meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b-instruct-slice", + "meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b-instruct-slice", +} +token_channels: dict[str, queue.Queue[Optional[str]]] = {} +trace_channels: dict[str, queue.Queue[dict[str, torch.Tensor]]] = {} +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + + +def dumps(tensors: dict[str, torch.Tensor], metadata: dict[str, Any]) -> bytes: + """ + Dump tensors and metadata into bytes + """ + + metadata_bytes = pickle.dumps(metadata) + tensors_bytes = safetensors.torch.save(tensors) + return ( + len(metadata_bytes).to_bytes(4, byteorder="big") + + metadata_bytes + + tensors_bytes + ) + + +def loads(b: bytes) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + """ + Load tensors and metadata from bytes + """ + + metadata_length = int.from_bytes(b[:4], byteorder="big") + metadata = pickle.loads(b[4 : 4 + metadata_length]) + tensors = safetensors.torch.load(b[4 + metadata_length :]) + return tensors, metadata + + +class RegisterRequest(BaseModel): + worker_id: str + worker_url: str + + +@app.post("/register") +def register(request: RegisterRequest) -> str: + workers[request.worker_id] = request.worker_url + return "ok" + + +class HeartbeatRequest(BaseModel): + worker_id: str + worker_url: str + + +@app.post("/heartbeat") +def heartbeat(request: HeartbeatRequest) -> str: + workers[request.worker_id] = request.worker_url + return "ok" + + +class CompleteRequest: + pass # discuss about implementation details (how to send, how to retrieve) + + +class PlanStep(BaseModel): + worker_id: str + worker_url: str + layers: list[str] + + +def generate_plan(model: str, worker_ids: list[str]) -> list[PlanStep]: + alias = model2alias[model] + num_layer_total = model2layers[model] + num_layer_worker = num_layer_total // len(worker_ids) + layers = [ + (i * num_layer_worker, (i + 1) * num_layer_worker) + for i in range(len(worker_ids) - 1) + ] + if len(layers) == 0: + layers.append((0, num_layer_total)) + else: + layers.append((layers[-1][1], num_layer_total)) + plans: list[PlanStep] = [] + for worker_id, layer in zip(worker_ids, layers): + plans.append( + PlanStep( + worker_id=worker_id, + worker_url=workers[worker_id], + layers=[f"{alias}/layers.{i}" for i in range(layer[0], layer[1])], + ) + ) + plans[0].layers.insert(0, f"{alias}/tok_embeddings") + plans[-1].layers.append(f"{alias}/norm") + plans[-1].layers.append(f"{alias}/output") + return plans + + +def relay_tokens( + channel: queue.Queue[Optional[str]], +) -> Generator[bytes, None, None]: + while True: + value = channel.get() + if value is None: + break + yield value.encode("utf-8") + + +class OnlineCompleteRequest(BaseModel): + model: str + prompt: str + + +@app.post("/complete") +def complete(request: OnlineCompleteRequest) -> StreamingResponse: + model = request.model + prompt = request.prompt + + if model not in model2layers: + raise HTTPException(status_code=404, detail="Model not found") + + task_id = str(uuid.uuid4()) + + # init channel for relay + token_channel = queue.Queue[Optional[str]]() + token_channels[task_id] = token_channel + + # generate request + tokens = tokenizer(prompt, return_tensors="pt")["input_ids"] + plan = generate_plan(model, list(workers.keys())) + tensors = {"x": tokens} + metadata = { + "task_id": task_id, + "round": 0, + "plan": plan, + "sampling_params": {"temperature": 0.0, "top_p": 1.0, "max_total_len": 2048}, + } + first_worker_url = plan[0].worker_url + response = requests.post( + f"{first_worker_url}/forward", data=dumps(tensors, metadata) + ) + if response.status_code != 200: + raise HTTPException(status_code=500, detail="Worker error") + + return StreamingResponse(relay_tokens(token_channel)) + + +class OfflineCompleteRequest(BaseModel): + model: str + prompts: list[str] + + +@app.post("/offline-complete") +def offline_complete(request: OfflineCompleteRequest) -> None: + pass + +def relay_traces(channel: queue.Queue[dict[str, torch.Tensor]], total: int) -> Generator[bytes, None, None]: + cnt = 0 + while cnt < total: + value = channel.get() + cnt += 1 + if value is None: + break + bytes = dumps(value, {}) + yield bytes + + +class TraceRequest(BaseModel): + model: str + prompt: str + + +@app.post("/trace") +def trace(request: TraceRequest) -> Response: + model = request.model + prompt = request.prompt + + if model not in model2layers: + raise HTTPException(status_code=404, detail="Model not found") + + task_id = str(uuid.uuid4()) + + # init channel for relay, but we don't handle it inside tracing + token_channel = queue.Queue[Optional[str]]() + token_channels[task_id] = token_channel + + # init traces + trace_channel = queue.Queue[dict[str, torch.Tensor]]() + trace_channels[task_id] = trace_channel + + # generate request + tokens = tokenizer(prompt, return_tensors="pt")["input_ids"] + online_workers = list(workers.keys()) + plan = generate_plan(model, online_workers) + tensors = {"x": tokens} + metadata = { + "task_id": task_id, + "round": 0, + "plan": plan, + "sampling_params": {"temperature": 0.0, "top_p": 1.0, "max_total_len": 2048}, + } + first_worker_url = plan[0].worker_url + response = requests.post(f"{first_worker_url}/trace", data=dumps(tensors, metadata)) + if response.status_code != 200: + raise HTTPException(status_code=500, detail="Worker error") + return StreamingResponse(relay_traces(trace_channel, len(online_workers))) + + +class UpdateTaskRequest(BaseModel): + task_id: str + output_tokens: list[list[int]] # [bsz, seqlen], in normal case, bsz=1 and seqlen=1 + + +@app.post("/update_tasks") +def update_tasks(requests: list[UpdateTaskRequest]) -> None: + for request in requests: + task_id = request.task_id + for token_id in request.output_tokens: + token = tokenizer.decode(token_id) + if task_id in token_channels: + token_channels[task_id].put(token) + else: + logger.warning(f"Task {task_id} not found") + + +@app.post("/update_traces") +async def update_traces(requests: Request) -> None: + body = await requests.body() + tensors, metadata = loads(body) + task_id = metadata["task_id"] + if task_id in trace_channels: + trace_channels[task_id].put(tensors) + else: + logger.warning(f"Task {task_id} not found") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=19000) + args = parser.parse_args() + + controller_url = f"http://localhost:{args.port}" + + import uvicorn + + uvicorn.run(app, host="127.0.0.1", port=args.port) diff --git a/deserve_controller/py.typed b/deserve_controller/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/deserve_controller/pyproject.toml b/deserve_controller/pyproject.toml new file mode 100644 index 0000000..c80b66b --- /dev/null +++ b/deserve_controller/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "deserve_controller" +version = "0.0.1" +authors = [ + { name="Example Author", email="author@example.com" }, +] +description = "Deserve Controller" +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] diff --git a/deserve_worker/README.md b/deserve_worker/README.md index 09ae28b..4778647 100644 --- a/deserve_worker/README.md +++ b/deserve_worker/README.md @@ -1,4 +1,4 @@ -# deserve worker +# DeServe Worker ## How to run diff --git a/deserve_worker/kvcache/block_pool.py b/deserve_worker/kvcache/block_pool.py index 3ddc432..7700d11 100644 --- a/deserve_worker/kvcache/block_pool.py +++ b/deserve_worker/kvcache/block_pool.py @@ -1,4 +1,5 @@ from typing import Optional + import torch diff --git a/deserve_worker/layer_storage.py b/deserve_worker/layer_storage.py index bb5055a..4494dc0 100644 --- a/deserve_worker/layer_storage.py +++ b/deserve_worker/layer_storage.py @@ -164,7 +164,9 @@ def preload_layers(self, full_layer_names: list[str]) -> dict[str, torch.nn.Modu model_args.dim, ) elif layer_name.startswith("layer"): - l = TransformerBlock(LayerId(f"layer_{layer_name[6:]}"), model_args) + l = TransformerBlock( + LayerId(f"{int(layer_name[7:]):02}"), model_args + ) elif layer_name == "norm": l = RMSNorm( ComponentId("norm", "main"), model_args.dim, eps=model_args.norm_eps diff --git a/deserve_worker/llm_engine.py b/deserve_worker/llm_engine.py index 84d83fb..333e0fc 100644 --- a/deserve_worker/llm_engine.py +++ b/deserve_worker/llm_engine.py @@ -128,6 +128,7 @@ def handle_forward(self, forwards: list[BatchForward]) -> None: self.post_forward(h, new_task) def handle_trace(self, tasks: list[SingleTrace]) -> None: + print(f"trace_tasks: {len(tasks)}") for task in tasks: traces: dict[OpId, torch.Tensor] = {} h = self.step_forward( diff --git a/deserve_worker/model/llama.py b/deserve_worker/model/llama.py index 5e2f701..61f6132 100644 --- a/deserve_worker/model/llama.py +++ b/deserve_worker/model/llama.py @@ -354,17 +354,18 @@ def forward( scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt( self.head_dim ) - trace_op(traces, self.component_id.with_op("scores"), scores) if mask is not None: scores = ( scores + mask ) # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) + trace_op(traces, self.component_id.with_op("scores"), scores) output = torch.matmul( scores, values ) # (bs, n_local_heads, seqlen, head_dim) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + trace_op(traces, self.component_id.with_op("output"), output) output_list.append(output) output = torch.cat([x for x in output_list]) result = self.wo(output) diff --git a/deserve_worker/trace.py b/deserve_worker/trace.py index a3e0e37..2f0c039 100644 --- a/deserve_worker/trace.py +++ b/deserve_worker/trace.py @@ -8,6 +8,13 @@ class LayerId: def with_component(self, component: str) -> "ComponentId": return ComponentId(self.layer, component) + def __str__(self) -> str: + return self.layer + + @staticmethod + def from_str(s: str) -> "LayerId": + return LayerId(s) + @dataclass class ComponentId: @@ -17,6 +24,14 @@ class ComponentId: def with_op(self, op: str) -> "OpId": return OpId(self.layer, self.component, op) + def __str__(self) -> str: + return f"{self.layer}.{self.component}" + + @staticmethod + def from_str(s: str) -> "ComponentId": + layer, component = s.split(".") + return ComponentId(layer, component) + @dataclass class OpId: @@ -26,3 +41,11 @@ class OpId: def __hash__(self) -> int: return hash((self.layer, self.component, self.op)) + + def __str__(self) -> str: + return f"{self.layer}.{self.component}.{self.op}" + + @staticmethod + def from_str(s: str) -> "OpId": + layer, component, op = s.split(".") + return OpId(layer, component, op) diff --git a/deserve_worker/worker.py b/deserve_worker/worker.py index 01d4836..6cb1914 100644 --- a/deserve_worker/worker.py +++ b/deserve_worker/worker.py @@ -1,6 +1,6 @@ import queue import threading -import traceback +import time from concurrent.futures import ThreadPoolExecutor from typing import Optional, cast @@ -25,8 +25,11 @@ class Worker: - def __init__(self, worker_id: str, max_total_bsz: int, controller_url: str): + def __init__( + self, worker_id: str, worker_url: str, max_total_bsz: int, controller_url: str + ): self.worker_id = worker_id + self.worker_url = worker_url self.controller_url = controller_url self.task_datas: dict[str, TaskData] = {} self.relay_queue = queue.Queue[BatchResult | BatchUpdate | TraceResult]() @@ -36,9 +39,11 @@ def __init__(self, worker_id: str, max_total_bsz: int, controller_url: str): # TODO: in future, different cache manager could allocate on same memory self.paged_kvcache_manager = PagedKVCacheManager(self.block_pool) self.packed_kvcache_manager = PackedKVCacheManager(self.block_pool) + self.network_executor = ThreadPoolExecutor(max_workers=max_total_bsz) + threading.Thread(target=self.llm_engine.run, daemon=True).start() threading.Thread(target=self.relay, daemon=True).start() - self.network_executor = ThreadPoolExecutor(max_workers=max_total_bsz) + threading.Thread(target=self.heartbeat, daemon=True).start() def locate_in_plan(self, plan: list[PlanStep]) -> Optional[int]: return next( @@ -248,19 +253,31 @@ def relay(self) -> None: index = self.locate_in_plan(plan) assert index is not None next_index = (index + 1) % len(plan) - next_worker_url = plan[next_index].worker_url + if next_index != 0: + next_worker_url = plan[next_index].worker_url + data = dumps( + {"x": result.x}, + { + "task_id": task_id, + "round": self.task_datas[task_id].round, + "plan": plan, + "sampling_params": self.task_datas[task_id].sampling_params, + }, + ) + self.network_executor.submit( + requests.post, + f"{next_worker_url}/trace", + data=data, + ) data = dumps( - {"x": result.x}, + {str(key): value for key, value in result.trace.items()}, { "task_id": task_id, - "round": self.task_datas[task_id].round, - "plan": plan, - "sampling_params": self.task_datas[task_id].sampling_params, }, ) self.network_executor.submit( requests.post, - f"{next_worker_url}/trace", + f"{self.controller_url}/update_traces", data=data, ) @@ -289,3 +306,15 @@ def cancel( "plan": [step.model_dump() for step in plan], }, ) + + def heartbeat(self): + while True: + self.network_executor.submit( + requests.post, + f"{self.controller_url}/heartbeat", + json={ + "worker_id": self.worker_id, + "worker_url": self.worker_url, + }, + ) + time.sleep(1) diff --git a/deserve_worker/worker_api.py b/deserve_worker/worker_api.py index c56ae43..d6ff97f 100644 --- a/deserve_worker/worker_api.py +++ b/deserve_worker/worker_api.py @@ -1,4 +1,5 @@ -import sys +import argparse +import uvicorn import traceback from concurrent.futures import ThreadPoolExecutor @@ -10,7 +11,7 @@ from .worker import Worker app = FastAPI() -worker = Worker(sys.argv[2], 48, "http://localhost:29980") +worker: Worker runtime_executor = ThreadPoolExecutor(max_workers=96) @@ -77,6 +78,14 @@ async def cancel(request: CancelRequest) -> str: if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="127.0.0.1", port=int(sys.argv[1])) + parser = argparse.ArgumentParser() + parser.add_argument("id", type=str) + parser.add_argument("--batch-size", type=int) + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str) + args = parser.parse_args() + + worker = Worker( + args.id, f"http://localhost:{args.port}", args.batch_size, args.controller_url + ) + uvicorn.run(app, host="127.0.0.1", port=args.port) From a8dd3e91dcfdfb12c6754bb831b0608cbbc8a5f1 Mon Sep 17 00:00:00 2001 From: Celve Date: Sun, 4 Aug 2024 20:07:06 -0700 Subject: [PATCH 16/17] fix: end output when receives end of text --- deserve_controller/controller_api.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/deserve_controller/controller_api.py b/deserve_controller/controller_api.py index 5e2c1da..643821e 100644 --- a/deserve_controller/controller_api.py +++ b/deserve_controller/controller_api.py @@ -31,6 +31,8 @@ trace_channels: dict[str, queue.Queue[dict[str, torch.Tensor]]] = {} tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") +STOP_TOKEN_IDS = [128001, 128009] + def dumps(tensors: dict[str, torch.Tensor], metadata: dict[str, Any]) -> bytes: """ @@ -174,7 +176,10 @@ class OfflineCompleteRequest(BaseModel): def offline_complete(request: OfflineCompleteRequest) -> None: pass -def relay_traces(channel: queue.Queue[dict[str, torch.Tensor]], total: int) -> Generator[bytes, None, None]: + +def relay_traces( + channel: queue.Queue[dict[str, torch.Tensor]], total: int +) -> Generator[bytes, None, None]: cnt = 0 while cnt < total: value = channel.get() @@ -235,12 +240,16 @@ class UpdateTaskRequest(BaseModel): def update_tasks(requests: list[UpdateTaskRequest]) -> None: for request in requests: task_id = request.task_id - for token_id in request.output_tokens: - token = tokenizer.decode(token_id) - if task_id in token_channels: - token_channels[task_id].put(token) + for token_ids in request.output_tokens: + token_id = token_ids[0] + if token_id in STOP_TOKEN_IDS: + token_channels[task_id].put(None) else: - logger.warning(f"Task {task_id} not found") + token = tokenizer.decode(token_id) + if task_id in token_channels: + token_channels[task_id].put(token) + else: + logger.warning(f"Task {task_id} not found") @app.post("/update_traces") From 6ec459ae128b8fabb19adf5c2a01437bbc74fc78 Mon Sep 17 00:00:00 2001 From: Celve Date: Sun, 4 Aug 2024 20:27:15 -0700 Subject: [PATCH 17/17] chore: format some files --- deserve_client/client.py | 15 +-- deserve_client/model.py | 185 +++++++++++++++++++------------- deserve_worker/layer_storage.py | 4 +- deserve_worker/worker_api.py | 2 +- 4 files changed, 119 insertions(+), 87 deletions(-) diff --git a/deserve_client/client.py b/deserve_client/client.py index ae99e32..9f9d090 100644 --- a/deserve_client/client.py +++ b/deserve_client/client.py @@ -58,13 +58,14 @@ def trace(model: str, prompt: str, entry_point: str = "http://localhost:19000"): if response.status_code != 200: typer.echo("Error") return - - tensors = {} + + tensors = {} for chunk in response.iter_content(chunk_size=None): if chunk: temp_tensors, _ = loads(chunk) tensors.update(temp_tensors) - print(list(tensors.keys())) + print(list(tensors.keys())) + @cli.command() def verify(model: str, prompt: str, entry_point: str = "http://localhost:19000"): @@ -81,17 +82,17 @@ def verify(model: str, prompt: str, entry_point: str = "http://localhost:19000") if chunk: temp_tensors, _ = loads(chunk) tensors.update(temp_tensors) - + traces = {OpId.from_str(k): v for k, v in tensors.items()} transformer = Transformer(llama_3_8b_args) tokens = tokenizer(prompt, return_tensors="pt")["input_ids"].to(main_device) result = transformer.forward(tokens, CheckCtx(0.03, traces)) if isinstance(result, torch.Tensor): print("No difference found") - else: - if not transformer.verify(tokens, VerifyCtx(result.op_id, 0.03, traces)): + else: + if not transformer.verify(tokens, VerifyCtx(result.op_id, 0.03, traces)): print("Difference found for", result.op_id) - else: + else: print("Difference found but verification failed") diff --git a/deserve_client/model.py b/deserve_client/model.py index 4c7d3b9..ca92562 100644 --- a/deserve_client/model.py +++ b/deserve_client/model.py @@ -14,6 +14,7 @@ torch.set_default_dtype(torch.float16) main_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + @dataclass class ModelArgs: dim: int = 4096 @@ -31,45 +32,47 @@ class ModelArgs: llama_3_8b_args = ModelArgs( - n_kv_heads=8, - vocab_size=128256, - multiple_of=1024, - ffn_dim_multiplier=1.3, - norm_eps=1e-5, - rope_theta=500000.0, - ) + n_kv_heads=8, + vocab_size=128256, + multiple_of=1024, + ffn_dim_multiplier=1.3, + norm_eps=1e-5, + rope_theta=500000.0, +) @dataclass class Diff: op_id: OpId diff: float - + + @dataclass -class CheckCtx: - threshold: float +class CheckCtx: + threshold: float traces: dict[OpId, torch.Tensor] - + def check(self, op_id: OpId, x: torch.Tensor) -> torch.Tensor | Diff: y = self.traces[op_id].to(main_device) if torch.allclose(x, y, atol=self.threshold): return y - else: + else: return Diff(op_id, torch.max(torch.abs(x - y)).item()) - -@dataclass -class VerifyCtx: - op_id: OpId + + +@dataclass +class VerifyCtx: + op_id: OpId threshold: float - traces: dict[OpId, torch.Tensor] - - def get_trace(self, op_id: OpId) -> torch.Tensor: + traces: dict[OpId, torch.Tensor] + + def get_trace(self, op_id: OpId) -> torch.Tensor: return self.traces[op_id].to(main_device) def verify(self, x: torch.Tensor) -> bool: y = self.traces[self.op_id].to(main_device) return torch.allclose(x, y, atol=self.threshold) - + class RMSNorm(nn.Module): def __init__(self, component_id: ComponentId, dim: int, eps: float = 1e-6): @@ -85,23 +88,21 @@ def verify(self, x: torch.Tensor, ctx: VerifyCtx) -> bool: op = ctx.op_id.op if op == "output": return ctx.verify(self._norm(x.float()).type_as(x)) - else: + else: output = ctx.get_trace(self.component_id.with_op("weighted_output")) - return ctx.verify(output * self.weight) - + return ctx.verify(output * self.weight) - def forward( - self, - x: torch.Tensor, - ctx: CheckCtx - ) -> torch.Tensor | Diff: - output = ctx.check(self.component_id.with_op("output"), self._norm(x.float()).type_as(x)) + def forward(self, x: torch.Tensor, ctx: CheckCtx) -> torch.Tensor | Diff: + output = ctx.check( + self.component_id.with_op("output"), self._norm(x.float()).type_as(x) + ) if isinstance(output, Diff): return output - return ctx.check(self.component_id.with_op("weighted_output"), output * self.weight) + return ctx.check( + self.component_id.with_op("weighted_output"), output * self.weight + ) - def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device, dtype=torch.float32) @@ -183,7 +184,7 @@ def verify( x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], - ctx: VerifyCtx, + ctx: VerifyCtx, ) -> bool: bsz, seqlen, _ = x.shape op = ctx.op_id.op @@ -246,23 +247,32 @@ def forward( ) -> torch.Tensor | Diff: bsz, seqlen, _ = x.shape - xq = ctx.check(self.component_id.with_op("xq"), self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)) + xq = ctx.check( + self.component_id.with_op("xq"), + self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim), + ) if isinstance(xq, Diff): return xq - xk = ctx.check(self.component_id.with_op("xk"), self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)) + xk = ctx.check( + self.component_id.with_op("xk"), + self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim), + ) if isinstance(xk, Diff): return xk - xv = ctx.check(self.component_id.with_op("xv"), self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)) + xv = ctx.check( + self.component_id.with_op("xv"), + self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim), + ) if isinstance(xv, Diff): return xv xq_new, xk_new = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) xq = ctx.check(self.component_id.with_op("xq_rotary"), xq_new) if isinstance(xq, Diff): - return xq - + return xq + xk = ctx.check(self.component_id.with_op("xk_rotary"), xk_new) if isinstance(xk, Diff): return xk @@ -285,7 +295,9 @@ def forward( ) # (bs, n_local_heads, cache_len + seqlen, head_dim) scores_new = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) if mask is not None: - scores_new = scores_new + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores_new = ( + scores_new + mask + ) # (bs, n_local_heads, seqlen, cache_len + seqlen) scores_new = F.softmax(scores_new.float(), dim=-1).type_as(xq) # check scores @@ -293,7 +305,9 @@ def forward( if isinstance(scores, Diff): return scores - output_new = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output_new = torch.matmul( + scores, values + ) # (bs, n_local_heads, seqlen, head_dim) output_new = output_new.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output = ctx.check(self.component_id.with_op("output"), output_new) if isinstance(output, Diff): @@ -351,19 +365,22 @@ def verify(self, x: torch.Tensor, ctx: VerifyCtx) -> bool: assert False def forward( - self, x: torch.Tensor, ctx: CheckCtx, + self, + x: torch.Tensor, + ctx: CheckCtx, ) -> torch.Tensor | Diff: # check w1, w3, w2 w1 = ctx.check(self.component_id.with_op("w1"), F.silu(self.w1(x))) if isinstance(w1, Diff): return w1 - + w3 = ctx.check(self.component_id.with_op("w3"), self.w3(x)) if isinstance(w3, Diff): return w3 - + return ctx.check(self.component_id.with_op("w2"), self.w2(w1 * w3)) + class TraceLinear(nn.Module): def __init__( self, @@ -380,9 +397,7 @@ def __init__( ) @torch.inference_mode() - def forward( - self, x: torch.Tensor, ctx: CheckCtx - ) -> torch.Tensor | Diff: + def forward(self, x: torch.Tensor, ctx: CheckCtx) -> torch.Tensor | Diff: return ctx.check(self.component_id.with_op("output"), self.linear(x)) def load_state_dict( @@ -393,6 +408,7 @@ def load_state_dict( ) -> torch.nn.modules.module._IncompatibleKeys: return self.linear.load_state_dict(state_dict, strict, assign) # type: ignore + class TraceEmbedding(nn.Module): def __init__( self, @@ -409,9 +425,7 @@ def __init__( ) @torch.inference_mode() - def forward( - self, x: torch.Tensor, ctx: CheckCtx - ) -> torch.Tensor | Diff: + def forward(self, x: torch.Tensor, ctx: CheckCtx) -> torch.Tensor | Diff: return ctx.check(self.component_id.with_op("output"), self.embedding(x)) def load_state_dict( @@ -423,7 +437,6 @@ def load_state_dict( return self.embedding.load_state_dict(state_dict, strict, assign) # type: ignore - class TransformerBlock(nn.Module): def __init__(self, layer_id: LayerId, args: ModelArgs): super().__init__() @@ -442,38 +455,47 @@ def __init__(self, layer_id: LayerId, args: ModelArgs): self.attention_norm = RMSNorm( layer_id.with_component("attention_norm"), args.dim, eps=args.norm_eps ) - self.ffn_norm = RMSNorm(layer_id.with_component("ffn_norm"), args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm( + layer_id.with_component("ffn_norm"), args.dim, eps=args.norm_eps + ) def verify( self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], - ctx: VerifyCtx + ctx: VerifyCtx, ) -> bool: layer = ctx.op_id.layer - component = ctx.op_id.component + component = ctx.op_id.component op = ctx.op_id.op if component == "feed_forward": if op == "res": - return ctx.verify(ctx.get_trace(OpId(layer, "attention", "res")) + ctx.get_trace(OpId(layer, "feed_forward", "w2"))) - else: + return ctx.verify( + ctx.get_trace(OpId(layer, "attention", "res")) + + ctx.get_trace(OpId(layer, "feed_forward", "w2")) + ) + else: return self.feed_forward.verify( ctx.get_trace(OpId(layer, "ffn_norm", "weighted_output")), ctx ) elif component == "ffn_norm": - return self.ffn_norm.verify(ctx.get_trace(OpId(layer, "attention", "res")), ctx) + return self.ffn_norm.verify( + ctx.get_trace(OpId(layer, "attention", "res")), ctx + ) elif component == "attention_norm": return self.attention_norm.verify(x, ctx) elif component == "attention": - if op == "res": - return ctx.verify(x + ctx.get_trace(OpId(layer, "attention", "weighted_output"))) - else: + if op == "res": + return ctx.verify( + x + ctx.get_trace(OpId(layer, "attention", "weighted_output")) + ) + else: return self.attention.verify( ctx.get_trace(OpId(layer, "attention_norm", "weighted_output")), freqs_cis, mask, - ctx + ctx, ) assert False @@ -492,7 +514,9 @@ def forward( if isinstance(attn, Diff): return attn - h = ctx.check(self.layer_id.with_component("attention").with_op("res"), x + attn) + h = ctx.check( + self.layer_id.with_component("attention").with_op("res"), x + attn + ) if isinstance(h, Diff): return h @@ -504,7 +528,9 @@ def forward( if isinstance(ffn, Diff): return ffn - return ctx.check(self.layer_id.with_component("feed_forward").with_op("res"), h + ffn) + return ctx.check( + self.layer_id.with_component("feed_forward").with_op("res"), h + ffn + ) class Transformer(nn.Module): @@ -517,7 +543,10 @@ def __init__(self, params: ModelArgs): cache_dir = os.path.expanduser(cache_dir) self.tok_embeddings = torch.nn.utils.skip_init( - TraceEmbedding, ComponentId("tok_embeddings", "main"), params.vocab_size, params.dim + TraceEmbedding, + ComponentId("tok_embeddings", "main"), + params.vocab_size, + params.dim, ) self.tok_embeddings.load_state_dict( torch.load(cache_dir + "tok_embeddings.pt", map_location="cpu") @@ -533,7 +562,9 @@ def __init__(self, params: ModelArgs): layer.to(main_device) self.layers.append(layer) - self.norm = RMSNorm(ComponentId("norm", "main"), params.dim, eps=params.norm_eps) + self.norm = RMSNorm( + ComponentId("norm", "main"), params.dim, eps=params.norm_eps + ) self.norm.load_state_dict(torch.load(cache_dir + "norm.pt", map_location="cpu")) self.norm.to(main_device) self.output = torch.nn.utils.skip_init( @@ -551,9 +582,7 @@ def __init__(self, params: ModelArgs): ) @torch.inference_mode() - def verify( - self, tokens: torch.Tensor, ctx: VerifyCtx - ) -> bool: + def verify(self, tokens: torch.Tensor, ctx: VerifyCtx) -> bool: _bsz, seqlen = tokens.shape layer = ctx.op_id.layer if layer.isdigit(): @@ -568,26 +597,30 @@ def verify( layer_int = int(layer) if layer_int < 0 or layer_int >= self.n_layers: assert False - - if layer_int == 0: + + if layer_int == 0: input = ctx.get_trace(OpId("tok_embeddings", "main", "output")) - else: - input = ctx.get_trace(OpId(f"{layer_int - 1:02}", "feed_forward", "res")) - + else: + input = ctx.get_trace( + OpId(f"{layer_int - 1:02}", "feed_forward", "res") + ) + return self.layers[layer_int].verify(input, self.freqs_cis, mask, ctx) elif layer == "tok_embeddings": return self.tok_embeddings.verify(tokens, ctx) elif layer == "norm": num_layers = self.n_layers - return self.norm.verify(ctx.get_trace(OpId(f"{num_layers - 1:02}", "feed_forward", "res")), ctx) + return self.norm.verify( + ctx.get_trace(OpId(f"{num_layers - 1:02}", "feed_forward", "res")), ctx + ) elif layer == "output": - return self.output.verify(ctx.get_trace(OpId("norm", "main", "output")), ctx) + return self.output.verify( + ctx.get_trace(OpId("norm", "main", "output")), ctx + ) assert False @torch.inference_mode() - def forward( - self, tokens: torch.Tensor, ctx: CheckCtx - ) -> torch.Tensor | Diff: + def forward(self, tokens: torch.Tensor, ctx: CheckCtx) -> torch.Tensor | Diff: _bsz, seqlen = tokens.shape h = self.tok_embeddings.forward(tokens, ctx) diff --git a/deserve_worker/layer_storage.py b/deserve_worker/layer_storage.py index 4494dc0..df0eb5c 100644 --- a/deserve_worker/layer_storage.py +++ b/deserve_worker/layer_storage.py @@ -164,9 +164,7 @@ def preload_layers(self, full_layer_names: list[str]) -> dict[str, torch.nn.Modu model_args.dim, ) elif layer_name.startswith("layer"): - l = TransformerBlock( - LayerId(f"{int(layer_name[7:]):02}"), model_args - ) + l = TransformerBlock(LayerId(f"{int(layer_name[7:]):02}"), model_args) elif layer_name == "norm": l = RMSNorm( ComponentId("norm", "main"), model_args.dim, eps=model_args.norm_eps diff --git a/deserve_worker/worker_api.py b/deserve_worker/worker_api.py index d6ff97f..19aa6eb 100644 --- a/deserve_worker/worker_api.py +++ b/deserve_worker/worker_api.py @@ -1,8 +1,8 @@ import argparse -import uvicorn import traceback from concurrent.futures import ThreadPoolExecutor +import uvicorn from fastapi import FastAPI, Request from pydantic import BaseModel