Skip to content

Commit

Permalink
remove redundant sample_logits, put back opaque status for process_pr…
Browse files Browse the repository at this point in the history
…ompt so we have a way of preemptively starting downloads
  • Loading branch information
Gary authored and Gary committed Dec 16, 2024
1 parent b1ebf59 commit 26071eb
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 25 deletions.
21 changes: 0 additions & 21 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,6 @@
from collections import OrderedDict
from mlx_lm.models.cache import make_prompt_cache

def sample_logits(
logits: mx.array,
temp: float = 0.0,
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None
) -> Tuple[mx.array, float]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values

if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
else:
token = mx.random.categorical(logits*(1/temp))

return token

class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
Expand Down
40 changes: 36 additions & 4 deletions exo/orchestration/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,28 @@ async def stop(self) -> None:
def on_node_status(self, request_id, opaque_status):
try:
status_data = json.loads(opaque_status)
if status_data.get("type", "") == "supported_inference_engines":
status_type = status_data.get("type", "")
if status_type == "supported_inference_engines":
node_id = status_data.get("node_id")
engines = status_data.get("engines", [])
self.topology_inference_engines_pool.append(engines)
if status_data.get("type", "") == "node_status":
elif status_type == "node_status":
if status_data.get("status", "").startswith("start_"):
self.current_topology.active_node_id = status_data.get("node_id")
elif status_data.get("status", "").startswith("end_"):
if status_data.get("node_id") == self.current_topology.active_node_id:
self.current_topology.active_node_id = None

download_progress = None
if status_data.get("type", "") == "download_progress":
if status_type == "download_progress":
if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
self.node_download_progress[status_data.get('node_id')] = download_progress

if self.topology_viz:
self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
except Exception as e:
if DEBUG >= 1: print(f"Error updating visualization: {e}")
if DEBUG >= 1: print(f"Error on_node_status: {e}")
if DEBUG >= 1: traceback.print_exc()

def get_supported_inference_engines(self):
Expand Down Expand Up @@ -153,10 +156,39 @@ async def process_prompt(
request_id: Optional[str] = None,
) -> None:
shard = self.get_current_shard(base_shard)
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
json.dumps({
"type": "node_status",
"node_id": self.id,
"status": "start_process_prompt",
"base_shard": base_shard.to_dict(),
"shard": shard.to_dict(),
"prompt": prompt,
"request_id": request_id,
}),
)
)
start_time = time.perf_counter_ns()
await self._process_prompt(base_shard, prompt, request_id)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
json.dumps({
"type": "node_status",
"node_id": self.id,
"status": "end_process_prompt",
"base_shard": base_shard.to_dict(),
"shard": shard.to_dict(),
"prompt": prompt,
"request_id": request_id,
"elapsed_time_ns": elapsed_time_ns,
}),
)
)
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")

async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
Expand Down

0 comments on commit 26071eb

Please sign in to comment.