Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf top on change #559

Open
wants to merge 4 commits into
base: runners
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,15 +367,11 @@ async def handle_post_chat_completions(self, request):
)
await response.prepare(request)

async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
new_tokens = tokens[prev_last_tokens_len:]
async def stream_result(_request_id: str, token: int, is_finished: bool):
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
new_tokens = new_tokens[:-1]
if token == eos_token_id:
if is_finished:
finish_reason = "stop"
if is_finished and not finish_reason:
Expand All @@ -386,7 +382,7 @@ async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
tokenizer,
prompt,
request_id,
new_tokens,
[token],
stream,
finish_reason,
"chat.completion",
Expand All @@ -398,12 +394,12 @@ async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
if DEBUG >= 2: print(f"Error streaming completion: {e}")
if DEBUG >= 2: traceback.print_exc()

def on_result(_request_id: str, tokens: List[int], is_finished: bool):
if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
def on_result(_request_id: str, token: int, is_finished: bool):
if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, token, is_finished))

return _request_id == request_id and is_finished

_, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
_, token, _ = await callback.wait(on_result, timeout=self.response_timeout)
if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
try:
Expand All @@ -413,19 +409,18 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
await response.write_eof()
return response
else:
_, tokens, _ = await callback.wait(
lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
_, token, _ = await callback.wait(
lambda _request_id, token, is_finished: _request_id == request_id and is_finished,
timeout=self.response_timeout,
)

finish_reason = "length"
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
tokens = tokens[:-1]
if DEBUG >= 2: print(f"Checking if end of tokens result {token=} is {eos_token_id=}")
if token == eos_token_id:
finish_reason = "stop"

return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, [token], stream, finish_reason, "chat.completion"))
except asyncio.TimeoutError:
return web.json_response({"detail": "Response generation timed out"}, status=408)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
default_model=args.default_model
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
lambda req_id, token, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode([token])) if topology_viz and hasattr(inference_engine, "tokenizer") else None
)

def preemptively_start_download(request_id: str, opaque_status: str):
Expand Down
16 changes: 3 additions & 13 deletions exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,6 @@ async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional

return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)

async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
response = await self.stub.GetInferenceResult(request)
if response.tensor is None:
return None, response.is_finished
return (
np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
response.is_finished,
)

async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
response = await self.stub.CollectTopology(request)
Expand All @@ -174,9 +164,9 @@ async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
topology.add_edge(node_id, conn.to_id, conn.description)
return topology

async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
await self.stub.SendResult(request)
async def send_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
request = node_service_pb2.SendNewTokenRequest(request_id=request_id, token=token, is_finished=is_finished)
await self.stub.SendNewToken(request)

async def send_opaque_status(self, request_id: str, status: str) -> None:
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
Expand Down
8 changes: 4 additions & 4 deletions exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ async def CollectTopology(self, request, context):
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)

async def SendResult(self, request, context):
async def SendNewToken(self, request, context):
request_id = request.request_id
result = request.result
token = request.token
is_finished = request.is_finished
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
self.node.on_token.trigger_all(request_id, result, is_finished)
if DEBUG >= 5: print(f"Received SendNewToken request: {request_id=} {token=} {is_finished=}")
self.node.on_token.trigger_all(request_id, token, is_finished)
return node_service_pb2.Empty()

async def SendOpaqueStatus(self, request, context):
Expand Down
16 changes: 3 additions & 13 deletions exo/networking/grpc/node_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ service NodeService {
rpc SendPrompt (PromptRequest) returns (Tensor) {}
rpc SendTensor (TensorRequest) returns (Tensor) {}
rpc SendExample (ExampleRequest) returns (Loss) {}
rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
rpc SendResult (SendResultRequest) returns (Empty) {}
rpc SendNewToken (SendNewTokenRequest) returns (Empty) {}
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
}
Expand Down Expand Up @@ -45,15 +44,6 @@ message Loss {
float loss = 1;
optional Tensor grads = 2;
}

message GetInferenceResultRequest {
string request_id = 1;
}

message InferenceResult {
optional Tensor tensor = 1;
bool is_finished = 2;
}

message Tensor {
bytes tensor_data = 1;
Expand Down Expand Up @@ -93,9 +83,9 @@ message DeviceCapabilities {
DeviceFlops flops = 4;
}

message SendResultRequest {
message SendNewTokenRequest {
string request_id = 1;
repeated int32 result = 2;
int32 token = 2;
bool is_finished = 3;
}

Expand Down
66 changes: 31 additions & 35 deletions exo/networking/grpc/node_service_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading