diff --git a/worker/.gitignore b/worker/.gitignore new file mode 100644 index 0000000..5b93702 --- /dev/null +++ b/worker/.gitignore @@ -0,0 +1,5 @@ +*.pyc +/build/ +/dist/ +/fleece_worker.egg-info/ +/.vscode/ diff --git a/worker/README.md b/worker/README.md new file mode 100644 index 0000000..535b708 --- /dev/null +++ b/worker/README.md @@ -0,0 +1,46 @@ +## Installation + +### Install From PyPI +``` +pip install fleece-worker +``` + +### Install From Source +``` +pip install -e . +``` + +### (Optional) Install FlashAttention +https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features + +## Connect to a controller + +``` +python -m fleece-worker -c -t +``` +Optional: `--worker-nickname abc`, `--heartbeat-interval 10`, `-w ` + +For example: + +``` +python -m fleece-worker -c https://serving-api.colearn.cloud:8443 -t +``` + +## Try it out (deprecated) + +``` +export CUDA_VISIBLE_DEVICES=0 +python -m fleece-worker -w http://127.0.0.1:8080 +``` + +``` +python send_forward.py +``` + +``` +curl localhost:8080/forward -H 'Content-Type: application/json' -d '{"task_id":"123","step":0,"round":0,"plan":[["local",["llama-3-8b-instruct-slice/tok_embeddings", "llama-3-8b-instruct-slice/layers.0", "llama-3-8b-instruct-slice/layers.1", "llama-3-8b-instruct-slice/layers.2", "llama-3-8b-instruct-slice/layers.3", "llama-3-8b-instruct-slice/layers.4", "llama-3-8b-instruct-slice/layers.5", "llama-3-8b-instruct-slice/layers.6", "llama-3-8b-instruct-slice/layers.7", "llama-3-8b-instruct-slice/layers.8", "llama-3-8b-instruct-slice/layers.9", "llama-3-8b-instruct-slice/layers.10", "llama-3-8b-instruct-slice/layers.11", "llama-3-8b-instruct-slice/layers.12", "llama-3-8b-instruct-slice/layers.13", "llama-3-8b-instruct-slice/layers.14", "llama-3-8b-instruct-slice/layers.15", "llama-3-8b-instruct-slice/layers.16", "llama-3-8b-instruct-slice/layers.17", "llama-3-8b-instruct-slice/layers.18", "llama-3-8b-instruct-slice/layers.19", "llama-3-8b-instruct-slice/layers.20", "llama-3-8b-instruct-slice/layers.21", "llama-3-8b-instruct-slice/layers.22", "llama-3-8b-instruct-slice/layers.23", "llama-3-8b-instruct-slice/layers.24", "llama-3-8b-instruct-slice/layers.25", "llama-3-8b-instruct-slice/layers.26", "llama-3-8b-instruct-slice/layers.27", "llama-3-8b-instruct-slice/layers.28", "llama-3-8b-instruct-slice/layers.29", "llama-3-8b-instruct-slice/layers.30", "llama-3-8b-instruct-slice/layers.31", "llama-3-8b-instruct-slice/norm", "llama-3-8b-instruct-slice/output"]]],"payload":[[128000, 128006, 882, 128007, 271, 12840, 374, 279, 11363, 315, 1253, 13767, 1082, 30, 128009, 128006, 78191, 128007, 271]]}' +``` +``` +curl localhost:8080/forward -H 'Content-Type: application/json' -d '{"task_id":"123","step":0,"round":0,"plan":[["local",["llama-3-8b-instruct-slice/tok_embeddings", "llama-3-8b-instruct-slice/layers.0", "llama-3-8b-instruct-slice/layers.1", "llama-3-8b-instruct-slice/layers.2", "llama-3-8b-instruct-slice/layers.3", "llama-3-8b-instruct-slice/layers.4", "llama-3-8b-instruct-slice/layers.5", "llama-3-8b-instruct-slice/layers.6", "llama-3-8b-instruct-slice/layers.7", "llama-3-8b-instruct-slice/layers.8", "llama-3-8b-instruct-slice/layers.9", "llama-3-8b-instruct-slice/layers.10", "llama-3-8b-instruct-slice/layers.11", "llama-3-8b-instruct-slice/layers.12", "llama-3-8b-instruct-slice/layers.13", "llama-3-8b-instruct-slice/layers.14", "llama-3-8b-instruct-slice/layers.15", "llama-3-8b-instruct-slice/layers.16", "llama-3-8b-instruct-slice/layers.17", "llama-3-8b-instruct-slice/layers.18", "llama-3-8b-instruct-slice/layers.19", "llama-3-8b-instruct-slice/layers.20", "llama-3-8b-instruct-slice/layers.21", "llama-3-8b-instruct-slice/layers.22", "llama-3-8b-instruct-slice/layers.23", "llama-3-8b-instruct-slice/layers.24", "llama-3-8b-instruct-slice/layers.25", "llama-3-8b-instruct-slice/layers.26", "llama-3-8b-instruct-slice/layers.27", "llama-3-8b-instruct-slice/layers.28", "llama-3-8b-instruct-slice/layers.29", "llama-3-8b-instruct-slice/layers.30", "llama-3-8b-instruct-slice/layers.31", "llama-3-8b-instruct-slice/norm", "llama-3-8b-instruct-slice/output"]]],"payload":[[128000, 128006, 882, 128007, 271, 12840, 374, 279, 11363, 315, 1253, 13767, 1082, 30, 128009, 128006, 78191, 128007, 271], [128000, 128006, 9125, 128007, 271, 38195, 4320, 449, 14433, 39342, 128009, 128006, 882, 128007, 271, 40, 1097, 2133, 311, 12366, 11, 1148, 1288, 358, 1518, 30, 128009, 128006, 78191, 128007, 271], [128000, 128006, 9125, 128007, 271, 38195, 4320, 449, 100166, 128009, 128006, 882, 128007, 271, 4438, 311, 733, 505, 27647, 311, 12551, 30, 128009, 128006, 78191, 128007, 271]]}' +``` +> note that the model will be automatically downloaded to `~/.cache` diff --git a/worker/fleece-worker/__init__.py b/worker/fleece-worker/__init__.py new file mode 100644 index 0000000..485f44a --- /dev/null +++ b/worker/fleece-worker/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.1" diff --git a/worker/fleece-worker/__main__.py b/worker/fleece-worker/__main__.py new file mode 100644 index 0000000..bbeaad6 --- /dev/null +++ b/worker/fleece-worker/__main__.py @@ -0,0 +1,239 @@ +from typing import List, Tuple, Optional +from fastapi import FastAPI, HTTPException, Request +from fleece_network import Peer, loads +from pydantic import BaseModel +import anyio +import uvicorn +from .worker import Worker +from .__init__ import __version__ +import argparse +import requests +import json +import torch +import concurrent.futures +from anyio.from_thread import BlockingPortal +import uuid + +app = FastAPI() +worker = Worker() + + +class LayersRequest(BaseModel): + layer_names: List[str] + + +def preload_layers(req: LayersRequest): + try: + worker.preload_layers(req.layer_names) + return None + except Exception as e: + print(e) + raise HTTPException(status_code=500, detail="Internal Server Error") + + +def unload_layers(req: LayersRequest): + try: + worker.unload_layers(req.layer_names) + return None + except Exception as e: + print(e) + raise HTTPException(status_code=500, detail="Internal Server Error") + + +class ForwardRequest(BaseModel): + task_id: str + plan: List[Tuple[str, List[str]]] + step: int + round: int = -1 + payload: Optional[List] = None + max_total_len: int = 2048 + temperature: float = 0.0 + top_p: float = 0.9 + task_manager_url: Optional[str] = None + signature: Optional[str] = None + timestamp: Optional[int] = None + + +executor = concurrent.futures.ThreadPoolExecutor(max_workers=64) + + +def forward(req: bytes): + try: + tensors, metadata = loads(req) + if isinstance(metadata, dict): + executor.submit( + worker.forward, + **tensors, + **metadata, + ) + elif isinstance(metadata, list): + executor.submit( + worker.forward_merged, + tensors, + metadata, + ) + else: + raise + return None + except Exception as e: + print(e) + raise HTTPException(status_code=500, detail="Internal Server Error") + + +async def app_forward(request: Request): + buffer = await request.body() + try: + tensors, metadata = loads(buffer) + if isinstance(metadata, dict): + executor.submit( + worker.forward, + **tensors, + **metadata, + ) + elif isinstance(metadata, list): + executor.submit( + worker.forward_merged, + tensors, + metadata, + ) + else: + raise + return None + except Exception as e: + print(e) + raise HTTPException(status_code=500, detail="Internal Server Error") + + +class GetInfoRequest(BaseModel): + node_list: List[str] = [] + timeout: int = 30 + + +class GetInfoResponse(BaseModel): + worker_nickname: Optional[str] + gpu_mem_info: Tuple[int, int] = [0, 0] + latency_list: List[Optional[float]] = [] + + +def get_info(req: GetInfoRequest) -> GetInfoResponse: + try: + worker_nickname, gpu_mem_info, latency_list = worker.get_info( + req.node_list, req.timeout + ) + return GetInfoResponse( + worker_nickname=worker_nickname, + gpu_mem_info=gpu_mem_info, + latency_list=latency_list, + ) + except Exception as e: + print(e) + raise HTTPException(status_code=500, detail="Internal Server Error") + + +async def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--controller-url") + parser.add_argument("-w", "--worker-url") + parser.add_argument("-t", "--api-token") + parser.add_argument("--port") + parser.add_argument("--worker-nickname") + parser.add_argument("--heartbeat-interval") + args = parser.parse_args() + if args.worker_url is not None: + worker_url = args.worker_url + parsed = worker_url.split(':') + if len(parsed) >= 3: + port = int(parsed[2]) + else: + port = 8080 + else: + worker_url = "none" + port = 8080 + if args.port is not None: + port = int(args.port) + worker.port = port + if args.api_token is not None: + worker.api_token = args.api_token + if args.worker_nickname is not None: + worker.worker_nickname = args.worker_nickname + if args.heartbeat_interval is not None: + worker.heartbeat_interval = int(args.heartbeat_interval) + if args.controller_url is not None: + worker.controller_url = args.controller_url + data = {"url": worker_url, "version": __version__} + if worker.worker_nickname is not None: + data["nickname"] = worker.worker_nickname + if torch.cuda.is_available(): + model = torch.cuda.get_device_name() + memory = torch.cuda.mem_get_info() + data["gpu_model"] = model + data["gpu_total_memory"] = memory[1] + data["gpu_remaining_memory"] = memory[0] + else: + data["gpu_model"] = "CPU" + data["gpu_total_memory"] = 0 + data["gpu_remaining_memory"] = 0 + r = requests.post(f"{args.controller_url}/register_worker", + json=data, + headers={"api-token": worker.api_token}) + res = json.loads(r.content) + worker.worker_id = res["id"] + worker.pull_worker_url() + worker.start_heartbeat_daemon() + worker.start_layer_forward_engine() + + print("Worker ID: ", worker.worker_id) + + r = requests.get( + f"{args.controller_url}/get_network_servers", + headers={"api-token": worker.api_token} + ) + + servers = json.loads(r.content) + signaling = servers["signaling"]["url"] + turns = servers["turn"] + async with BlockingPortal() as portal: + worker.async_portal = portal + async with anyio.create_task_group() as tg: + worker.peer = Peer( + worker.worker_id, + signaling, + [(turn["url"], turn["username"], turn["password"]) for turn in turns], + { + "preload_layers": preload_layers, + "unload_layers": unload_layers, + "forward": forward, + "get_info": get_info, + }, + tg, + ) + + # start the FastAPI server when public IP is available + if worker_url != "none": + app.add_api_route("/preload_layers", preload_layers, methods=["POST"]) + app.add_api_route("/unload_layers", unload_layers, methods=["POST"]) + app.add_api_route("/forward", app_forward, methods=["POST"]) + app.add_api_route("/get_info", get_info, methods=["POST"]) + + uviconfig = uvicorn.Config(app, host="0.0.0.0", port=port, access_log=False) + uviserver = uvicorn.Server(uviconfig) + tg.start_soon(uviserver.serve) + await portal.sleep_until_stopped() + else: + worker.worker_id = "local"+uuid.uuid4().hex[:8] + print("Worker ID: ", worker.worker_id) + worker.start_layer_forward_engine() + async with anyio.create_task_group() as tg: + if worker_url != "none": + app.add_api_route("/preload_layers", preload_layers, methods=["POST"]) + app.add_api_route("/unload_layers", unload_layers, methods=["POST"]) + app.add_api_route("/forward", app_forward, methods=["POST"]) + app.add_api_route("/get_info", get_info, methods=["POST"]) + + uviconfig = uvicorn.Config(app, host="0.0.0.0", port=port, access_log=True) + uviserver = uvicorn.Server(uviconfig) + tg.start_soon(uviserver.serve) + + +if __name__ == '__main__': + anyio.run(main) diff --git a/worker/fleece-worker/model.py b/worker/fleece-worker/model.py new file mode 100644 index 0000000..38697bd --- /dev/null +++ b/worker/fleece-worker/model.py @@ -0,0 +1,535 @@ +# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, List + +import torch +import torch.nn.functional as F +from torch import nn + +ENABLE_FLASH_ATTN = False +try: + from flash_attn import flash_attn_with_kvcache + ENABLE_FLASH_ATTN = True +except ImportError as e: + pass + # print("Package flash-attn is not found. Please install it for better performance. https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features") + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 500000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + @torch.inference_mode() + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + + + + + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + if ENABLE_FLASH_ATTN: + freqs_cis = torch.stack([freqs.cos(), freqs.sin()]) # flash_attn + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + + + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """Multi-head attention module.""" + + def __init__(self, args: ModelArgs): + """ + Initialize the Attention module. + + Args: + args (ModelArgs): Model configuration parameters. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_local_heads (int): Number of local query heads. + n_local_kv_heads (int): Number of local key and value heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (ColumnParallelLinear): Linear transformation for queries. + wk (ColumnParallelLinear): Linear transformation for keys. + wv (ColumnParallelLinear): Linear transformation for values. + wo (RowParallelLinear): Linear transformation for output. + cache_k (torch.Tensor): Cached keys for attention. + cache_v (torch.Tensor): Cached values for attention. + + """ + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + # model_parallel_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads # // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads # // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = torch.nn.utils.skip_init(nn.Linear, + args.dim, + args.n_heads * self.head_dim, + bias=False, + ) + self.wk = torch.nn.utils.skip_init(nn.Linear, + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wv = torch.nn.utils.skip_init(nn.Linear, + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wo = torch.nn.utils.skip_init(nn.Linear, + args.n_heads * self.head_dim, + args.dim, + bias=False, + ) + + # self.cache_k = torch.zeros( + # ( + # args.max_batch_size, + # args.max_seq_len, + # self.n_local_kv_heads, + # self.head_dim, + # ) + # ) + # self.cache_v = torch.zeros( + # ( + # args.max_batch_size, + # args.max_seq_len, + # self.n_local_kv_heads, + # self.head_dim, + # ) + # ) + + @torch.inference_mode() + def forward( + self, + x: torch.Tensor, + bsz_list: List[int], + start_pos_list: List[int], + global_freqs_cis: torch.Tensor, + kv_cache_paged: Tuple[torch.Tensor, torch.Tensor], + kv_cache_list: List, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position for caching. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + mask (torch.Tensor, optional): Attention mask tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + _, seqlen, _ = x.shape + xq_, xk_, xv_ = self.wq(x), self.wk(x), self.wv(x) + + if ENABLE_FLASH_ATTN: + cache_k, cache_v = kv_cache_paged + cache_seqlens = [] + for i, bsz in enumerate(bsz_list): + cache_seqlens += [start_pos_list[i]]*bsz + cache_seqlens = torch.tensor(cache_seqlens, dtype=torch.int32, device=x.device) + bsz = cache_seqlens.shape[0] + + max_len = max([x.shape[1] for x in kv_cache_list]) + block_table = torch.zeros((bsz, max_len), dtype=torch.int32, device=x.device) + start = 0 + for i, bsz in enumerate(bsz_list): + block_table[start:start+bsz, :kv_cache_list[i].shape[1]] = kv_cache_list[i] + start += bsz + + bsz = cache_seqlens.shape[0] + xq = xq_.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk_.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv_.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + cos = global_freqs_cis[0].type_as(xq) + sin = global_freqs_cis[1].type_as(xq) + output = flash_attn_with_kvcache(xq, cache_k, cache_v, xk, xv, + rotary_cos=cos, rotary_sin=sin, + cache_seqlens=cache_seqlens, block_table=block_table, causal=True, rotary_interleaved=True) + output = output.view(bsz, seqlen, -1) + else: + start = 0 + output_list = [] + for i, bsz in enumerate(bsz_list): + xq = xq_[start:start+bsz].view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk_[start:start+bsz].view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv_[start:start+bsz].view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + start += bsz + + start_pos = start_pos_list[i] + kv_cache = kv_cache_list[i] + cache_k, cache_v = kv_cache + + freqs_cis = global_freqs_cis[start_pos: start_pos + seqlen] + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=x.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(x) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # self.cache_k = self.cache_k.to(xq) + # self.cache_v = self.cache_v.to(xq) + + cache_k[:bsz, start_pos: start_pos + seqlen] = xk + cache_v[:bsz, start_pos: start_pos + seqlen] = xv + + keys = cache_k[:bsz, : start_pos + seqlen] + values = cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + output_list.append(output) + output = torch.cat([x for x in output_list]) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = torch.nn.utils.skip_init(nn.Linear, + dim, hidden_dim, bias=False, + ) + self.w2 = torch.nn.utils.skip_init(nn.Linear, + hidden_dim, dim, bias=False, + ) + self.w3 = torch.nn.utils.skip_init(nn.Linear, + dim, hidden_dim, bias=False, + ) + + @torch.inference_mode() + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + """ + Initialize a TransformerBlock. + + Args: + layer_id (int): Identifier for the layer. + args (ModelArgs): Model configuration parameters. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + # self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + @torch.inference_mode() + def forward( + self, + x: torch.Tensor, + bsz_list: List[int], + start_pos_list: List[int], + global_freqs_cis: torch.Tensor, + kv_cache_paged: Tuple[torch.Tensor, torch.Tensor], + kv_cache_list: List, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position for attention caching. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention.forward( + self.attention_norm(x), bsz_list, start_pos_list, global_freqs_cis, kv_cache_paged, kv_cache_list + ) + out = h + self.feed_forward.forward(self.ffn_norm(h)) + return out + + +# class Transformer(nn.Module): +# def __init__(self, params: ModelArgs): +# """ +# Initialize a Transformer model. + +# Args: +# params (ModelArgs): Model configuration parameters. + +# Attributes: +# params (ModelArgs): Model configuration parameters. +# vocab_size (int): Vocabulary size. +# n_layers (int): Number of layers in the model. +# tok_embeddings (ParallelEmbedding): Token embeddings. +# layers (torch.nn.ModuleList): List of Transformer blocks. +# norm (RMSNorm): Layer normalization for the model output. +# output (ColumnParallelLinear): Linear layer for final output. +# freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + +# """ +# super().__init__() +# self.params = params +# self.vocab_size = params.vocab_size +# self.n_layers = params.n_layers + +# self.tok_embeddings = torch.nn.utils.skip_init(nn.Embedding, +# params.vocab_size, params.dim +# ) + +# self.layers = torch.nn.ModuleList() +# for _ in range(params.n_layers): +# self.layers.append(TransformerBlock(params)) + +# self.norm = RMSNorm(params.dim, eps=params.norm_eps) +# self.output = torch.nn.utils.skip_init(nn.Linear, +# params.dim, params.vocab_size, bias=False, +# ) + +# self.freqs_cis = precompute_freqs_cis( +# # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. +# # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. +# self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 +# ) + +# @torch.inference_mode() +# def forward(self, tokens: torch.Tensor, start_pos: int): +# """ +# Perform a forward pass through the Transformer model. + +# Args: +# tokens (torch.Tensor): Input token indices. +# start_pos (int): Starting position for attention caching. + +# Returns: +# torch.Tensor: Output logits after applying the Transformer model. + +# """ +# _bsz, seqlen = tokens.shape +# h = self.tok_embeddings(tokens) +# self.freqs_cis = self.freqs_cis.to(h.device) +# freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + +# mask = None +# if seqlen > 1: +# mask = torch.full( +# (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device +# ) +# mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + +# for layer in self.layers: +# h = layer(h, start_pos, freqs_cis, mask) +# h = self.norm(h) +# output = self.output(h).float() +# return output diff --git a/worker/fleece-worker/worker.py b/worker/fleece-worker/worker.py new file mode 100644 index 0000000..bc9b23f --- /dev/null +++ b/worker/fleece-worker/worker.py @@ -0,0 +1,1197 @@ +from typing import List, Optional, Tuple, Dict, Any, Set +import os +import torch +from torch import Tensor, nn +from .model import ModelArgs, TransformerBlock, RMSNorm, precompute_freqs_cis +from fleece_network import Peer, dumps +import requests +import threading +import concurrent.futures +import time +import socket +from urllib.parse import urlparse +import json +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives import hashes +import queue +import traceback + +torch.set_default_device("cpu") + +llama_2_7b_args = {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000} +llama_2_13b_args = {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000} +llama_2_70b_args = {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000} +llama_3_8b_args = {"dim": 4096, "n_layers": 32, "n_heads": 32, "n_kv_heads": 8, "vocab_size": 128256, "multiple_of": 1024, "ffn_dim_multiplier": 1.3, "norm_eps": 1e-05, "rope_theta": 500000.0} +llama_3_70b_args = {"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 128256, "rope_theta": 500000.0} + +if torch.cuda.is_available(): + main_device = "cuda" + # if torch.cuda.is_bf16_supported(): + # main_dtype = torch.bfloat16 + # torch.set_default_dtype(torch.bfloat16) + # else: + main_dtype = torch.float16 + torch.set_default_dtype(torch.float16) +else: + main_device = "cpu" + main_dtype = torch.float32 + torch.set_default_dtype(torch.float32) + +# llama 2 +# global_freqs_cis = precompute_freqs_cis(128, 4096).to(main_device) +# EOS_ID = 2 +# STOP_TOKEN_IDS = [2] +# VOCAL_SIZE = 32000 +# tokenizer = Tokenizer(model_path="/home/ubuntu/llama/tokenizer.model") +# print(tokenizer.bos_id) 1 +# print(tokenizer.eos_id) 2 +# print(tokenizer.pad_id) -1 +# print(tokenizer.n_words) 32000 + +# llama 3 +global_freqs_cis = precompute_freqs_cis(128, 8192, 500000.0).to(main_device) +EOS_ID = 128001 +STOP_TOKEN_IDS = [128001, 128009] +VOCAL_SIZE = 128256 +# tokenizer = Tokenizer(model_path="./Meta-Llama-3-8B-Instruct/tokenizer.model") +# print(tokenizer.bos_id) 128000 +# print(tokenizer.eos_id) 128001 +# print(tokenizer.pad_id) -1 +# print(tokenizer.n_words) 128256 + +stop_tokens = torch.tensor(STOP_TOKEN_IDS, device=main_device) +stop_tokens_cpu = torch.tensor(STOP_TOKEN_IDS) + + +def parse_layer_name(layer_name: str): + s = layer_name.split('/') + return s[0], s[1] + + +KV_CACHE_BLOCK = 256 + + +def get_kv_cache_length(cur, seqlen): + while cur < seqlen: + cur += KV_CACHE_BLOCK + return cur + + +ENABLE_FLASH_ATTN = False +try: + from flash_attn import flash_attn_with_kvcache + ENABLE_FLASH_ATTN = True +except ImportError as e: + print("Package flash-attn is not found. Please install it for better performance. https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features") + +NUM_BLOCKS = 12096 +PAGE_BLOCK_SIZE = 256 +MAX_TOTAL_BSZ = 64 +if ENABLE_FLASH_ATTN: + k_cache_paged = torch.randn( + NUM_BLOCKS, PAGE_BLOCK_SIZE, 8, 128, device=main_device + ) + v_cache_paged = torch.randn( + NUM_BLOCKS, PAGE_BLOCK_SIZE, 8, 128, device=main_device + ) + page_queue = queue.Queue() + _ = [page_queue.put(x) for x in range(1, NUM_BLOCKS)] + + +def get_kv_cache_paged(x, start_pos, block_table, model): + bsz, seqlen = x.shape[0], x.shape[1] + if block_table is None: + length = get_kv_cache_length(0, start_pos + seqlen)//PAGE_BLOCK_SIZE + block_table = torch.zeros( + ( + bsz, + length, + ), + device=main_device, + dtype=torch.int32 + ) + for i in range(length): + for j in range(bsz): + block_table[j, i] = page_queue.get() + return block_table + old_block_table = block_table + if start_pos + seqlen > block_table.shape[1]*PAGE_BLOCK_SIZE: + length = get_kv_cache_length(block_table.shape[1]*PAGE_BLOCK_SIZE, start_pos + seqlen)//PAGE_BLOCK_SIZE + block_table = torch.zeros( + ( + bsz, + length, + ), + device=main_device, + dtype=torch.int32 + ) + block_table[:, :old_block_table.shape[1]] = old_block_table[:, :] + for i in range(old_block_table.shape[1], length): + for j in range(bsz): + block_table[j, i] = page_queue.get() + return block_table + else: + return block_table + + +def get_kv_cache(x, start_pos, kv_cache, model): + bsz, seqlen = x.shape[0], x.shape[1] + if kv_cache is None: + length = get_kv_cache_length(0, start_pos + seqlen) + cache_k = torch.zeros( + ( + bsz, + length, + model.attention.n_local_kv_heads, + model.attention.head_dim, + ), + device=main_device + ) + cache_v = torch.zeros( + ( + bsz, + length, + model.attention.n_local_kv_heads, + model.attention.head_dim, + ), + device=main_device + ) + return (cache_k, cache_v) + old_cache_k, old_cache_v = kv_cache + if start_pos + seqlen > old_cache_k.shape[1]: + length = get_kv_cache_length(old_cache_k.shape[1], start_pos + seqlen) + cache_k = torch.zeros( + ( + bsz, + length, + model.attention.n_local_kv_heads, + model.attention.head_dim, + ), + device=main_device + ) + cache_v = torch.zeros( + ( + bsz, + length, + model.attention.n_local_kv_heads, + model.attention.head_dim, + ), + device=main_device + ) + cache_k[:, :start_pos, :, :], cache_v[:, :start_pos, :, :] = old_cache_k[:, :start_pos, :, :], old_cache_v[:, :start_pos, :, :] + del_tensor(old_cache_k) + del_tensor(old_cache_v) + del kv_cache + return (cache_k, cache_v) + else: + return kv_cache + + +def del_tensor(t): + t.detach() + t.grad = None + t.untyped_storage().resize_(0) + + +executor = concurrent.futures.ThreadPoolExecutor(max_workers=400) +executor_forward = concurrent.futures.ThreadPoolExecutor(max_workers=40) + + +def requests_post(url, headers=None, data=None, json=None, worker=None, to_worker_id=None): + try: + if to_worker_id is not None: + st = time.monotonic() + # time.sleep(0.01) + r = requests.post(url, headers=headers, data=data, json=json) + assert r.status_code == 200 + if to_worker_id is not None: + en = time.monotonic() + latency = (en-st)*1000 + worker.perf_network.append((to_worker_id, latency)) + except Exception: + if worker is not None: + worker.cancel_task(json["task_id"]) + + +def send_request(url, headers=None, data=None, exec=None, worker=None, to_worker_id=None): + if exec is None: + executor.submit(requests_post, url=url, headers=headers, data=data, worker=worker, to_worker_id=to_worker_id) + else: + exec.submit(requests_post, url=url, headers=headers, data=data, worker=worker, to_worker_id=to_worker_id) + + +executor_latency_test = concurrent.futures.ThreadPoolExecutor(max_workers=40) + + +def latency_test(host: str, port: int, timeout=60): + st = time.monotonic() + try: + s = socket.create_connection((host, port), timeout=timeout) + s.shutdown(socket.SHUT_RD) + except socket.timeout: + return None + except OSError: + return None + en = time.monotonic() + return (en-st)*1000 + + +def measure_latency(node_list: List[str], timeout): + # executor_latency_test + jobs = [] + for node in node_list: + parsed_url = urlparse(node) + host = parsed_url.hostname + if parsed_url.port is not None: + port = parsed_url.port + elif parsed_url.scheme == "http": + port = 80 + elif parsed_url.scheme == "https": + port = 443 + else: + port = 22 + jobs.append(executor_latency_test.submit(latency_test, host, port, timeout)) + ans = [] + for job in jobs: + ans.append(job.result()) + return ans + + +class LayerForward: + def __init__( + self, + h: torch.Tensor, + layer_names: List, + bsz: int, + is_new_task: bool, + round: int, + start_pos: int, + seqlen: int, + kv_cache_dict: Dict, + metadata: Dict, + ): + self.h = h + self.layer_names = layer_names + self.bsz = bsz + self.is_new_task = is_new_task + self.round = round + self.start_pos = start_pos + self.seqlen = seqlen + self.kv_cache_dict = kv_cache_dict + self.metadata = metadata + + +class Worker: + def __init__( + self, + worker_id: str = None, + # mirror_url: str = "TODO", + cache_dir: str = "~/.cache/fleece-worker/models", + ): + self.worker_id = worker_id + # self.mirror_url = mirror_url + self.controller_url = None + self.api_token = None + self.worker_nickname = worker_id + self.heartbeat_interval = 300 + self.tm_pubkeys = {} + self.worker_urls = {} + self.perf_computation = [] + self.perf_network = [] + self.peer: Optional[Peer] = None + self.async_portal = None + + self.cache_dir = os.path.expanduser(cache_dir) + self.layers = dict() + self.task_info: Dict[(str, int), Tuple[int, Dict[str, Any]]] = dict() + self.mutex = threading.Lock() + self.task_prompt_tokens: Dict[str, torch.Tensor] = dict() + self.task_eos_reached: Dict[str, torch.Tensor] = dict() + self.task_local_steps: Dict[str, List[int]] = dict() + self.task_update_queue: Dict[str, queue.Queue[Tuple[int, List[int]]]] = dict() + self.layer_forward_engine_queue: queue.Queue[LayerForward] = queue.Queue() + self.canceled_task: Set[str] = set() + + def fetch_layer(self, full_layer_name): + model_name, layer_name = parse_layer_name(full_layer_name) + if model_name.startswith("dummy"): + return None + path = os.path.join(self.cache_dir, model_name, f"{layer_name}.pt") + if not os.path.exists(path): # TODO lock + os.makedirs(os.path.join(self.cache_dir, model_name), exist_ok=True) + with requests.get(f"https://huggingface.co/colearn/{model_name}/resolve/main/{layer_name}.pt", stream=True) as r: + r.raise_for_status() + with open(path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + return path + + def preload_layers(self, layer_names: List[str]): + with self.mutex: # TODO + ths = [] + for full_layer_name in layer_names: + if full_layer_name in self.layers: + continue + th = executor.submit(self.fetch_layer, full_layer_name) + ths.append((full_layer_name, th)) + for full_layer_name, th in ths: + path = th.result() + model_name, layer_name = parse_layer_name(full_layer_name) + if model_name.startswith("dummy"): + continue + if model_name.startswith("llama-2-7b"): + model_args = ModelArgs(**llama_2_7b_args) + elif model_name.startswith("llama-2-13b"): + model_args = ModelArgs(**llama_2_13b_args) + elif model_name.startswith("llama-2-70b"): + model_args = ModelArgs(**llama_2_70b_args) + elif model_name.startswith("llama-3-8b"): + model_args = ModelArgs(**llama_3_8b_args) + elif model_name.startswith("llama-3-70b"): + model_args = ModelArgs(**llama_3_70b_args) + else: + raise NotImplementedError("Unknown model") + if layer_name == "tok_embeddings": + l = torch.nn.utils.skip_init(nn.Embedding, model_args.vocab_size, model_args.dim) + elif layer_name.startswith("layer"): + l = TransformerBlock(model_args) + elif layer_name == "norm": + l = RMSNorm(model_args.dim, eps=model_args.norm_eps) + elif layer_name == "output": + l = torch.nn.utils.skip_init(nn.Linear, model_args.dim, model_args.vocab_size, bias=False) + else: + raise NotImplementedError("Unknown layers") + l.load_state_dict(torch.load(path, map_location="cpu")) + l.to(main_device) + self.layers[full_layer_name] = l + + def unload_layers(self, layer_names: List[str]): + for full_layer_name in layer_names: + if full_layer_name not in self.layers: + continue # TODO continue or warning? + del self.layers[full_layer_name] + torch.cuda.empty_cache() + + def cancel_task(self, task_id: str, del_task=False): + if del_task: + self.del_task(task_id) + self.canceled_task.add(task_id) + + def del_task(self, task_id: str): + steps = self.task_local_steps.pop(task_id, None) + if steps is None: + return + if task_id in self.task_prompt_tokens: + del self.task_prompt_tokens[task_id] + if task_id in self.task_eos_reached: + del self.task_eos_reached[task_id] + if task_id in self.task_update_queue: + self.task_update_queue[task_id].put((None, None)) + for step in steps: + _, kv_cache_dict = self.task_info[(task_id, step)] + for _, kv_cache in kv_cache_dict.items(): + if ENABLE_FLASH_ATTN: + block_table = kv_cache.tolist() + for x in block_table: + for y in x: + page_queue.put(y) + else: + k_cache, v_cache = kv_cache + del_tensor(k_cache) + del_tensor(v_cache) + del self.task_info[(task_id, step)] + if not ENABLE_FLASH_ATTN: + torch.cuda.empty_cache() + + def pull_worker_url(self): + r = requests.get(f"{self.controller_url}/get_worker_list", + headers={"api-token": self.api_token}) + res = json.loads(r.content) + for worker in res["workers"]: + self.worker_urls[worker["worker_id"]] = worker["url"] + + def get_worker_url(self, worker_id): + if worker_id not in self.worker_urls: + self.pull_worker_url() + return self.worker_urls.get(worker_id) + + def verify(self, tm_url, task_id, plan, timestamp, signature_hex): + public_key_bytes = bytes.fromhex(self.tm_pubkeys[tm_url]) + public_key = ec.EllipticCurvePublicKey.from_encoded_point( + ec.SECP256K1(), public_key_bytes + ) + signed_bytes = task_id.encode()+str(timestamp).encode() + for x in plan: + signed_bytes += x[0].encode() + for y in x[1]: + signed_bytes += y.encode() + try: + public_key.verify( + bytes.fromhex(signature_hex), + signed_bytes, + ec.ECDSA(hashes.SHA256()) + ) + return True + except: + return False + + def send_forward(self, to_worker_id, tensors: dict[str, Tensor], metadata: dict[str, Any]): + if to_worker_id == self.worker_id: + if isinstance(metadata, dict): + executor.submit( + self.forward, + **tensors, + **metadata, + ) + elif isinstance(metadata, list): + executor.submit( + self.forward_merged, + tensors, + metadata, + ) + else: + raise + return + url = self.get_worker_url(to_worker_id) + buffer = dumps(tensors, metadata) + if (url is not None and url != "none") and to_worker_id != self.worker_id: + if to_worker_id == self.worker_id: + # self.forward(**data) + send_request( + f"http://127.0.0.1:{self.port}/forward", + data=buffer, + exec=executor_forward, + worker=self, + to_worker_id=to_worker_id) + else: + send_request( + f"{self.get_worker_url(to_worker_id)}/forward", + data=buffer, + exec=executor_forward, + worker=self, + to_worker_id=to_worker_id) + else: + async def send(): + connection = await self.peer.connect(to_worker_id) + reply = await connection.send("forward", buffer) + if reply.status_code != 200: + self.cancel_task(metadata["task_id"]) + self.async_portal.call(self.peer.tg.start_soon, send) + + def layer_forward_engine_step(self, task_list: List[LayerForward]): + task = task_list[0] + with torch.inference_mode(): + input_shapes = [list(t.h.shape) for t in task_list] + st = time.monotonic() + h = torch.cat([t.h for t in task_list]) + bsz_list, start_pos_list = [t.bsz for t in task_list], [t.start_pos for t in task_list] + for full_layer_name in task.layer_names: + model_name, layer_name = parse_layer_name(full_layer_name) + if model_name.startswith("dummy"): + if layer_name == "output": + for t in task_list: + t.h = torch.zeros((t.bsz, 1, VOCAL_SIZE), dtype=main_dtype, device=main_device) + t.h[:, :, t.round+10] = 1.0 + if t.round >= 320: + t.h = torch.zeros((t.bsz, 1, VOCAL_SIZE), dtype=main_dtype, device=main_device) + t.h[:, :, EOS_ID] = 1.0 + # time.sleep(0.01) + h = torch.cat([t.h for t in task_list]) + continue + if layer_name == "tok_embeddings": + h = self.layers[full_layer_name](h) + elif layer_name.startswith("layers."): + kv_cache_list = [] + for t in task_list: + if t.is_new_task: + # if torch.cuda.is_available(): + # gpu_mem_info = torch.cuda.mem_get_info() + # if gpu_mem_info[0]/gpu_mem_info[1] < 0.05 and gpu_mem_info[0] < 2e9: + # return None, None # TODO need fix + if ENABLE_FLASH_ATTN: + kv_cache_list.append(get_kv_cache_paged(t.h, t.start_pos, None, self.layers[full_layer_name])) + else: + kv_cache_list.append(get_kv_cache(t.h, t.start_pos, None, self.layers[full_layer_name])) + else: + if ENABLE_FLASH_ATTN: + kv_cache_list.append(get_kv_cache_paged(t.h, t.start_pos, t.kv_cache_dict[full_layer_name], self.layers[full_layer_name])) + else: + kv_cache_list.append(get_kv_cache(t.h, t.start_pos, t.kv_cache_dict[full_layer_name], self.layers[full_layer_name])) + if ENABLE_FLASH_ATTN: + h = self.layers[full_layer_name](h, bsz_list, start_pos_list, global_freqs_cis, (k_cache_paged, v_cache_paged), kv_cache_list) + else: + h = self.layers[full_layer_name](h, bsz_list, start_pos_list, global_freqs_cis, None, kv_cache_list) + for i, t in enumerate(task_list): + t.kv_cache_dict[full_layer_name] = kv_cache_list[i] + elif layer_name == "norm": + h = self.layers[full_layer_name](h) + elif layer_name == "output": + h = self.layers[full_layer_name](h) + else: + raise NotImplementedError("Unknown layers") + # start = 0 + # for t in task_list: + # bsz = t.bsz + # t.h = h[start:start+bsz] + # start += bsz + en = time.monotonic() + latency = (en-st)*1000 + self.perf_computation.append(((str(task.layer_names), str(input_shapes)), latency)) + return h + # for task in task_list: + # task.call_back_queue.put((task.h, task.kv_cache_dict)) + + def post_layer_forward_engine_step(self, task_list: List[LayerForward], merged_h): + start = 0 + next_token_list = [] + task_update_list = [] + metadata_list = [{"plan": task_list[0].metadata["plan"]}] + for task in task_list: + self.task_info[(task.metadata["task_id"], task.metadata["step"])] = (task.start_pos+task.seqlen, task.kv_cache_dict) + + bsz = task.bsz + h = merged_h[start:start+bsz] + start += bsz + + # last node + if task.metadata["step"] == len(task.metadata["plan"])-1: + if task.metadata["temperature"] > 0: + probs = torch.softmax(h[:, -1] / task.metadata["temperature"], dim=-1) + next_token = sample_top_p(probs, task.metadata["top_p"]) + else: + next_token = torch.argmax(h[:, -1], dim=-1) + next_token = next_token.reshape(-1) + if task.start_pos > task.metadata["max_total_len"]: + next_token = torch.tensor([EOS_ID] * task.bsz) # FIXME fake max length limit + # print(next_token) + next_token = next_token.to("cpu") + + # eos_reached + self.task_eos_reached[task.metadata["task_id"]] |= torch.isin(next_token, stop_tokens_cpu) # eos_id + if not all(self.task_eos_reached[task.metadata["task_id"]]): + # next node + # tensors[str(len(metadata_list))] = next_token + next_token_list.append(next_token) + metadata_list.append({ + "task_id": task.metadata["task_id"], + # "plan": task.metadata["plan"], + "step": 0, + "round": task.metadata["round"]+1, + "max_total_len": task.metadata["max_total_len"], + "temperature": task.metadata["temperature"], + "top_p": task.metadata["top_p"], + "task_manager_url": task.metadata["task_manager_url"], + "signature": task.metadata["signature"], + "timestamp": task.metadata["timestamp"], + "bsz": task.bsz, + }) + else: + self.send_forward( + task.metadata["plan"][0][0], + tensors={}, + metadata={ + "task_id": task.metadata["task_id"], + "plan": task.metadata["plan"], + "step": 0, + "task_manager_url": task.metadata["task_manager_url"], + "signature": task.metadata["signature"], + "timestamp": task.metadata["timestamp"], + }) + # update + if task.metadata["task_manager_url"] is not None: + # self.new_task_update(task.metadata["task_manager_url"], task.metadata["task_id"], task.metadata["step"], task.metadata["round"], next_token.tolist()) + task_update_list.append((task.metadata["task_manager_url"], task.metadata["task_id"], task.metadata["step"], task.metadata["round"], next_token.tolist())) + if all(self.task_eos_reached[task.metadata["task_id"]]): + self.cancel_task(task.metadata["task_id"], True) + else: + # next node + metadata_list.append({ + "task_id": task.metadata["task_id"], + # "plan": task.metadata["plan"], + "step": task.metadata["step"]+1, + "round": task.metadata["round"], + "max_total_len": task.metadata["max_total_len"], + "temperature": task.metadata["temperature"], + "top_p": task.metadata["top_p"], + "task_manager_url": task.metadata["task_manager_url"], + "signature": task.metadata["signature"], + "timestamp": task.metadata["timestamp"], + "bsz": task.bsz, + }) + # self.send_forward( + # task.metadata["plan"][task.metadata["step"]+1][0], + # tensors={ + # "payload": h, + # }, + # metadata=) + # torch.cuda.synchronize() + # print("1.1", time.monotonic()) + if task.metadata["step"] == len(task.metadata["plan"])-1: + if len(next_token_list) > 0: + merged_next_token = torch.cat(next_token_list) + self.send_forward(task.metadata["plan"][0][0], tensors={"payload": merged_next_token}, metadata=metadata_list) + else: + if "worker_urls" in task_list[0].metadata and task_list[0].metadata["worker_urls"] is not None: + metadata_list[0]["worker_urls"] = task_list[0].metadata["worker_urls"] + self.send_forward(task.metadata["plan"][task.metadata["step"]+1][0], tensors={"payload": merged_h}, metadata=metadata_list) + return task_update_list + + def layer_forward_engine(self): + q = self.layer_forward_engine_queue + while True: + q_buffered: list[list[LayerForward]] = [q.get()] + while True: + try: + tasks = q.get(block=False) + q_buffered.append(tasks) + except queue.Empty: + break + prefill_tasks_list = [tasks for tasks in q_buffered if tasks[0].seqlen > 1] + decode_tasks_list = [tasks for tasks in q_buffered if tasks[0].seqlen == 1] + + for tasks in prefill_tasks_list: + h = self.layer_forward_engine_step(tasks) + task_update_list = self.post_layer_forward_engine_step(tasks, h) + batch_update_len = sum([len(task[4]) for task in task_update_list]) + print(time.monotonic(), len(tasks), sum([task.bsz for task in tasks]), batch_update_len) + executor_forward.submit( + self.batch_update, + task_update_list + ) + + decode_tasks_list.sort(key=lambda x: x[0].bsz, reverse=False) + while len(decode_tasks_list) > 0: + total_bsz = 0 + task_list = [] + for i in reversed(range(len(decode_tasks_list))): + print(i) + cur_bsz = sum([task.bsz for task in decode_tasks_list[i]]) + if total_bsz + cur_bsz > MAX_TOTAL_BSZ: + continue + total_bsz += cur_bsz + task_list.extend(decode_tasks_list.pop(i)) + h = self.layer_forward_engine_step(task_list) + task_update_list = self.post_layer_forward_engine_step(task_list, h) + batch_update_len = sum([len(task[4]) for task in task_update_list]) + print(time.monotonic(), len(task_list), sum([task.bsz for task in task_list]), batch_update_len) + executor_forward.submit( + self.batch_update, + task_update_list + ) + + def layer_forward_engine_old(self): + q = self.layer_forward_engine_queue + while True: + task_list = [] + total_bsz = 0 + tasks = q.get() + for task in tasks: + total_bsz += task.bsz + task_list.append(task) + while True: + if total_bsz > MAX_TOTAL_BSZ: + break + try: + tasks = q.get(block=False) + if tasks[0].seqlen == task_list[0].seqlen and tasks[0].layer_names == task_list[0].layer_names: # FIXME + bsz = sum([task.bsz for task in tasks]) + if total_bsz+bsz > MAX_TOTAL_BSZ: + if bsz > total_bsz: + q.put(task_list) + task_list = tasks + total_bsz = bsz + else: + q.put(tasks) + break + for task in tasks: + task_list.append(task) + total_bsz += task.bsz + if total_bsz > MAX_TOTAL_BSZ: + break + else: + q.put(tasks) + break + except queue.Empty: + break + # print("layer_forward_engine_step: ", len(task_list), total_bsz) + # torch.cuda.synchronize() + # print("0", time.monotonic()) + h = self.layer_forward_engine_step(task_list) + # torch.cuda.synchronize() + # print("1", time.monotonic()) + task_update_list = self.post_layer_forward_engine_step(task_list, h) + # torch.cuda.synchronize() + # print("2", time.monotonic()) + batch_update_len = sum([len(task[4]) for task in task_update_list]) + print(time.monotonic(), len(task_list), total_bsz, batch_update_len) + executor_forward.submit( + self.batch_update, + task_update_list + ) + + def batch_update(self, task_update_list): + req_list = [] + for task_update in task_update_list: + # self.new_task_update(task_update[0], task_update[1], task_update[2], task_update[3], task_update[4]) + task_manager_url = task_update[0] + req_list.append({ + "task_id": task_update[1], + "plan_current_step": task_update[2], + "plan_current_round": task_update[3], + "output_tokens": [task_update[4]], + }) + requests_post( + f"{task_manager_url}/update_tasks", + headers={"worker-id": self.worker_id, "api-token": self.api_token}, + json=req_list) + + def start_layer_forward_engine(self): + heartbeat_thread = threading.Thread(target=self.layer_forward_engine) + heartbeat_thread.daemon = True + heartbeat_thread.start() + + def layers_forward(self, h, layer_names, bsz, is_new_task, round, start_pos, seqlen, kv_cache_dict, metadata): + # q = queue.Queue() + self.layer_forward_engine_queue.put([LayerForward(h, layer_names, bsz, is_new_task, round, start_pos, seqlen, kv_cache_dict, metadata)]) + # h, kv_cache_dict = q.get() + # del q + # return h, kv_cache_dict + + def send_update_task(self, task_manager_url, task_id, step): + q = self.task_update_queue[task_id] + while True: + output_tokens_list = [] + round, output_tokens = q.get() + if output_tokens is None: + break + output_tokens_list.append(output_tokens) + ret_flag = False + while True: + try: + _, output_tokens = q.get(block=False) + if output_tokens is None: + ret_flag = True + break + else: + output_tokens_list.append(output_tokens) + except queue.Empty: + break + # print("requests_post", round, output_tokens_list) + requests_post( + f"{task_manager_url}/update_task", + headers={"worker-id": self.worker_id, "api-token": self.api_token}, + json={ + "task_id": task_id, + "plan_current_step": step, + "plan_current_round": round, + "output_tokens": output_tokens_list, + }, + worker=self) + if ret_flag: + break + # TODO del self.task_update_queue[task_id]? + + def new_task_update(self, task_manager_url, task_id, _step, round, output_tokens): + if task_manager_url is not None: + self.task_update_queue[task_id].put((round, output_tokens)) + + # def forward_same_node(self, delta_round, h, layer_names, bsz, is_new_task, round, start_pos, seqlen, kv_cache_dict, temperature, top_p, max_total_len, eos_reached, prompt_tokens, task_manager_url, task_id, step): + # ans_tokens = [] + # try: + # for i in range(delta_round): + # h, kv_cache_dict = self.layers_forward(h, layer_names, bsz, is_new_task, round+i, start_pos, seqlen, kv_cache_dict) + # # last node + # if temperature > 0: + # probs = torch.softmax(h[:, -1] / temperature, dim=-1) + # next_token = sample_top_p(probs, top_p) + # else: + # next_token = torch.argmax(h[:, -1], dim=-1) + # next_token = next_token.reshape(-1) + # if start_pos > max_total_len: + # next_token = torch.tensor([EOS_ID] * bsz, device=main_device) # FIXME fake max length limit + # # print(next_token) + # # eos_reached + # if all(eos_reached | torch.isin(next_token, stop_tokens)) or i == delta_round-1: + # return h, kv_cache_dict, ans_tokens, eos_reached + + # # loop + # eos_reached |= torch.isin(next_token, stop_tokens) # eos_id + # ans_tokens.append(next_token) + # start_pos = start_pos+seqlen + # seqlen = 1 + # is_new_task = False + + # # first node + # tokens = torch.zeros((bsz, 1), dtype=torch.long, device=main_device) + # for k, t in enumerate(prompt_tokens): + # if len(t) > start_pos: + # tokens[k, :] = torch.tensor([t[start_pos]], dtype=torch.long, device=main_device) + # else: + # tokens[k, :] = next_token[k] + # h = tokens + # finally: + # # update_task + # for i, output_tokens in enumerate(ans_tokens): + # self.new_task_update(task_manager_url, task_id, step, round+i, output_tokens.tolist()) + + def forward_merged(self, + tensors: Dict[str, torch.Tensor], + metadata_list: List[Dict], + ): + try: + start = 0 + layers_forward_list = [] + plan = metadata_list[0]["plan"] + if "worker_urls" in metadata_list[0] and metadata_list[0]["worker_urls"] is not None: + self.worker_urls = metadata_list[0]["worker_urls"] + metadata_list.pop(0) + for i, task in enumerate(metadata_list): + index = task["step"] + is_new_task = task["round"] == 0 + task_id = task["task_id"] + round = task["round"] + step = task["step"] + task["plan"] = plan + if is_new_task: + if task_id in self.task_local_steps: + self.task_local_steps[task_id].append(step) + else: + self.task_local_steps[task_id] = [step] + self.task_info[(task_id, step)] = (0, dict()) + else: + if not task_id in self.task_local_steps: + return + start_pos, kv_cache_dict = self.task_info[(task_id, step)] + + # first node + if index == 0: + # payload = tensors[str(i)] + bsz = task["bsz"] + payload = tensors["payload"][start:start+bsz] + start += bsz + if is_new_task: + min_prompt_len = min(len(t) for t in payload) + self.task_prompt_tokens[task_id] = payload + tokens = torch.zeros((bsz, min_prompt_len), dtype=torch.long) + for k, t in enumerate(payload): + tokens[k, :] = torch.tensor(t[:min_prompt_len], dtype=torch.long) + h = tokens.to(main_device) + else: + prompt_tokens = self.task_prompt_tokens[task_id] + tokens = torch.zeros((bsz, 1), dtype=torch.long) + for k, t in enumerate(prompt_tokens): + if len(t) > start_pos: + tokens[k, :] = torch.tensor([t[start_pos]], dtype=torch.long) + else: + tokens[k, :] = torch.tensor([payload[k]], dtype=torch.long) + h = tokens.to(main_device) + # print(h) + bsz, seqlen = h.shape + else: + bsz = task["bsz"] + payload = tensors["payload"][start:start+bsz] + start += bsz + h = payload.to(main_dtype).to(main_device) # h = torch.tensor(payload, dtype=main_dtype, device=main_device) + if len(h.shape) > 2: + bsz, seqlen, _ = h.shape + else: + bsz, seqlen = h.shape + + # last node init + if index == len(plan)-1 and is_new_task: + self.task_eos_reached[task_id] = torch.tensor([False] * bsz) + self.task_update_queue[task_id] = queue.Queue() + executor.submit(self.send_update_task, task["task_manager_url"], task_id, step) + + # forward + _, layer_names = plan[index] + self.preload_layers(layer_names) + layers_forward_list.append(LayerForward(h, layer_names, bsz, is_new_task, round, start_pos, seqlen, kv_cache_dict, task)) + self.layer_forward_engine_queue.put(layers_forward_list) + except Exception: + print(traceback.format_exc()) + # print(tensors, metadata_list) + + def forward(self, + task_id: str, + plan: List[Tuple[str, List[str]]], + step: int, + round: int = -1, + payload: Optional[List] = None, + max_total_len: int = 2048, + temperature: float = 0.0, + top_p: float = 0.9, + task_manager_url: Optional[str] = None, + signature: Optional[str] = None, + timestamp: Optional[int] = None, + worker_urls: Dict = None, + ): + try: + # self.verify(task_manager_url, task_id, plan, timestamp, signature) + if worker_urls is not None: + self.worker_urls = worker_urls + + index = step + is_new_task = round == 0 + if payload is None or task_id in self.canceled_task: + self.del_task(task_id) + if index < len(plan)-1: + # next node + self.send_forward( + plan[index+1][0], + tensors={}, + metadata={ + "task_id": task_id, + "plan": plan, + "step": step+1, + "task_manager_url": task_manager_url, + "signature": signature, + "timestamp": timestamp, + }) + return + + if is_new_task: + if task_id in self.task_local_steps: + self.task_local_steps[task_id].append(step) + else: + self.task_local_steps[task_id] = [step] + self.task_info[(task_id, step)] = (0, dict()) + else: + if not task_id in self.task_local_steps: + return + start_pos, kv_cache_dict = self.task_info[(task_id, step)] + + # first node + if index == 0: + bsz = len(payload) + if is_new_task: + min_prompt_len = min(len(t) for t in payload) + self.task_prompt_tokens[task_id] = payload + tokens = torch.zeros((bsz, min_prompt_len), dtype=torch.long) + for k, t in enumerate(payload): + tokens[k, :] = torch.tensor(t[:min_prompt_len], dtype=torch.long) + h = tokens.to(main_device) + else: + prompt_tokens = self.task_prompt_tokens[task_id] + tokens = torch.zeros((bsz, 1), dtype=torch.long) + for k, t in enumerate(prompt_tokens): + if len(t) > start_pos: + tokens[k, :] = torch.tensor([t[start_pos]], dtype=torch.long) + else: + tokens[k, :] = torch.tensor([payload[k]], dtype=torch.long) + h = tokens.to(main_device) + # print(h) + bsz, seqlen = h.shape + else: + h = payload.to(main_dtype).to(main_device) # h = torch.tensor(payload, dtype=main_dtype, device=main_device) + if len(h.shape) > 2: + bsz, seqlen, _ = h.shape + else: + bsz, seqlen = h.shape + + # last node init + if index == len(plan)-1 and is_new_task: + self.task_eos_reached[task_id] = torch.tensor([False] * bsz) + self.task_update_queue[task_id] = queue.Queue() + executor.submit(self.send_update_task, task_manager_url, task_id, step) + + # forward + _, layer_names = plan[index] + self.preload_layers(layer_names) # preload + # if len(plan) == 1: + # delta_round = 16 + # eos_reached = self.task_eos_reached[task_id].to(main_device) + # prompt_tokens = self.task_prompt_tokens[task_id] + # h, kv_cache_dict, tokens, eos_reached = self.forward_same_node(delta_round, h, layer_names, bsz, is_new_task, round, start_pos, seqlen, + # kv_cache_dict, temperature, top_p, max_total_len, eos_reached, prompt_tokens, task_manager_url, task_id, step) + # self.task_eos_reached[task_id] = eos_reached.to("cpu") + # delta_round = len(tokens)+1 + # round = round+delta_round-1 + # start_pos = start_pos+delta_round-1 + # else: + metadata = { + "task_id": task_id, + "plan": plan, + "step": step, + "round": round, + "max_total_len": max_total_len, + "temperature": temperature, + "top_p": top_p, + "task_manager_url": task_manager_url, + "signature": signature, + "timestamp": timestamp, + "worker_urls": worker_urls, + } + self.layers_forward(h, layer_names, bsz, is_new_task, round, start_pos, seqlen, kv_cache_dict, metadata) + return + h, kv_cache_dict = self.layers_forward(h, layer_names, bsz, is_new_task, round, start_pos, seqlen, kv_cache_dict, metadata) + # if h is None: + # return + # else: + self.task_info[(task_id, step)] = (start_pos+seqlen, kv_cache_dict) + + # last node + if index == len(plan)-1: + if temperature > 0: + probs = torch.softmax(h[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(h[:, -1], dim=-1) + next_token = next_token.reshape(-1) + if start_pos > max_total_len: + next_token = torch.tensor([EOS_ID] * bsz) # FIXME fake max length limit + print(next_token) + next_token = next_token.to("cpu") + + # eos_reached + self.task_eos_reached[task_id] |= torch.isin(next_token, stop_tokens_cpu) # eos_id + if not all(self.task_eos_reached[task_id]): + # next node + self.send_forward( + plan[0][0], + tensors={ + "payload": next_token, + }, + metadata={ + "task_id": task_id, + "plan": plan, + "step": 0, + "round": round+1, + "max_total_len": max_total_len, + "temperature": temperature, + "top_p": top_p, + "task_manager_url": task_manager_url, + "signature": signature, + "timestamp": timestamp, + }) + else: + self.send_forward( + plan[0][0], + tensors={}, + metadata={ + "task_id": task_id, + "plan": plan, + "step": 0, + "task_manager_url": task_manager_url, + "signature": signature, + "timestamp": timestamp, + }) + # update + if task_manager_url is not None: + self.new_task_update(task_manager_url, task_id, step, round, next_token.tolist()) + if all(self.task_eos_reached[task_id]): + self.cancel_task(task_id, True) + else: + # next node + self.send_forward( + plan[index+1][0], + tensors={ + "payload": h, + }, + metadata={ + "task_id": task_id, + "plan": plan, + "step": step+1, + "round": round, + "max_total_len": max_total_len, + "temperature": temperature, + "top_p": top_p, + "task_manager_url": task_manager_url, + "signature": signature, + "timestamp": timestamp, + }) + # update + # if task_manager_url is not None: + # send_request( + # f"{task_manager_url}/update_task", + # headers={"worker-id": self.worker_id, "api-token": self.api_token}, + # json={ + # "task_id": task_id, + # "plan_current_step": step, + # "plan_current_round": round, + # }, + # worker=self) + except Exception: + print(traceback.format_exc()) + + def get_info(self, node_list, timeout): + gpu_mem_info = torch.cuda.mem_get_info() + latency_list = measure_latency(node_list, timeout) + return self.worker_nickname, gpu_mem_info, latency_list + + def send_heartbeat(self): + info_data = { + "loaded_layers": json.dumps(list(self.layers.keys())), + "perf_computation": [], + "perf_network": [] + } + + s = {} + for k, v in self.perf_computation: + if k not in s: + s[k] = [v, 1] + else: + s[k][0] += v + s[k][1] += 1 + for k, v in s.items(): + layers, input_shape = k + avg_latency = v[0]/v[1] + info_data["perf_computation"].append({"layers": layers, "input_shape": input_shape, "latency": avg_latency}) + s = {} + for k, v in self.perf_network: + if k not in s: + s[k] = [v, 1] + else: + s[k][0] += v + s[k][1] += 1 + for k, v in s.items(): + avg_latency = v[0]/v[1] + info_data["perf_network"].append({"to_worker_id": k, "latency": avg_latency}) + + if torch.cuda.is_available(): + memory = torch.cuda.mem_get_info() + info_data["gpu_remaining_memory"] = memory[0] + data = {"info_update": json.dumps(info_data)} + try: + r = requests.post(f"{self.controller_url}/worker_heartbeat", + json=data, + headers={"worker-id": self.worker_id, "api-token": self.api_token}) + res = json.loads(r.content) + self.tm_pubkeys = res["pubkeys"] + except: + pass + + def start_heartbeat_daemon(self): + def heartbeat_thread(): + while True: + self.send_heartbeat() + time.sleep(self.heartbeat_interval) + heartbeat_thread = threading.Thread(target=heartbeat_thread) + heartbeat_thread.daemon = True + heartbeat_thread.start() + + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/worker/send_forward.py b/worker/send_forward.py new file mode 100644 index 0000000..af3c573 --- /dev/null +++ b/worker/send_forward.py @@ -0,0 +1,47 @@ +import requests +from fleece_network import dumps +import uuid + +plan_8b = [ + ["locale8da5c97", [ + "llama-3-8b-instruct-slice/tok_embeddings", + *[f"llama-3-8b-instruct-slice/layers.{i}" for i in range(0, 32)], + "llama-3-8b-instruct-slice/norm", + "llama-3-8b-instruct-slice/output", + ]] +] + +worker_urls = { + "locale16a9d78": "http://127.0.0.1:8081", + "local7ead4b11": "http://127.0.0.1:8082" +} + +plan_70b = [ + ["locale16a9d78", [ + "llama-3-70b-instruct-slice/tok_embeddings", + *[f"llama-3-70b-instruct-slice/layers.{i}" for i in range(0, 40)], + ]], + ["locale16a9d78", [ + *[f"llama-3-70b-instruct-slice/layers.{i}" for i in range(40, 80)], + "llama-3-70b-instruct-slice/norm", + "llama-3-70b-instruct-slice/output", + ]], +] + +input = [[128000, 128006, 882, 128007, 271, 12840, 374, 279, 11363, 315, 1253, 13767, 1082, 30, 128009, 128006, 78191, 128007, 271], [128000, 128006, 9125, 128007, 271, 38195, 4320, 449, 14433, 39342, 128009, 128006, 882, 128007, 271, 40, 1097, 2133, 311, 12366, 11, 1148, 1288, 358, 1518, 30, 128009, 128006, 78191, 128007, 271], + [128000, 128006, 9125, 128007, 271, 38195, 4320, 449, 100166, 128009, 128006, 882, 128007, 271, 4438, 311, 733, 505, 27647, 311, 12551, 30, 128009, 128006, 78191, 128007, 271], [128000, 128006, 882, 128007, 271, 12840, 374, 279, 11363, 315, 1253, 13767, 1082, 30, 128009, 128006, 78191, 128007, 271]] + +tensors = {} +metadata = { + "task_id": str(uuid.uuid4()), + "plan": plan_8b, + "step": 0, + "round": 0, + "max_total_len": 1024, + "temperature": 0, + "payload": input, + "worker_urls": None, +} + +data = dumps(tensors, metadata) +r = requests.post("http://127.0.0.1:8080/forward", data=data) diff --git a/worker/setup.py b/worker/setup.py new file mode 100644 index 0000000..7592124 --- /dev/null +++ b/worker/setup.py @@ -0,0 +1,41 @@ +from setuptools import setup +import os + + +def read(rel_path: str) -> str: + here = os.path.abspath(os.path.dirname(__file__)) + with open(os.path.join(here, rel_path)) as fp: + return fp.read() + + +def get_version(rel_path: str) -> str: + for line in read(rel_path).splitlines(): + if line.startswith("__version__"): + delim = '"' if '"' in line else "'" + return line.split(delim)[1] + raise RuntimeError("Unable to find version string.") + + +setup( + name="fleece-worker", + version=get_version('fleece-worker/__init__.py'), + description="fleece-worker", + long_description=open('README.md').read(), + long_description_content_type='text/markdown', + author="stneng", + author_email="git@stneng.com", + url="https://github.com/CoLearn-Dev/fleece-worker", + packages=["fleece-worker"], + install_requires=[ + "numpy", + "torch", + "fire", + "sentencepiece", + "fastapi", + "uvicorn", + "requests", + "cryptography", + "fleece-network==0.2.2" + ], + python_requires=">=3.10", +)