Skip to content

Commit

Permalink
[sharktank][llama] fix dump bins
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Jan 29, 2025
1 parent 1392a2e commit 03dffd9
Showing 1 changed file with 20 additions and 24 deletions.
44 changes: 20 additions & 24 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@

"""Inference support for the PagedLLMV1 protocol of models."""

from typing import Optional

from safetensors import safe_open

import math
import sys

from ..models.llama.tools.data_utils import write_ndarray_to_bin
import torch

import numpy as np
from ..layers import *
from ..types import *

Expand Down Expand Up @@ -158,20 +154,20 @@ def prefill(self):
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)

if self.dump_bins:
torch.save(
token_ids,
write_ndarray_to_bin(
token_ids.numpy(),
f"prefill_token_ids_{'_'.join([str(x) for x in token_ids.shape])}.bin",
)
torch.save(
torch.tensor(token_ids.shape[0]).to(torch.int64),
write_ndarray_to_bin(
np.array(token_ids.shape[0], dtype=np.int64),
f"prefill_seq_lens_1.bin",
)
torch.save(
seq_block_ids_tensor,
write_ndarray_to_bin(
seq_block_ids_tensor.numpy(),
f"prefill_seq_block_ids_{'_'.join([str(x) for x in seq_block_ids_tensor.shape])}.bin",
)
torch.save(
self.cache_state[0].to(torch.float8_e4m3fnuz),
write_ndarray_to_bin(
self.cache_state[0].to(torch.float8_e4m3fnuz).to(torch.uint8).numpy(),
f"prefill_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin",
)
logits = model.prefill(
Expand Down Expand Up @@ -219,24 +215,24 @@ def decode(self):
decode_attention_mask = replicate(decode_attention_mask, tp)

if self.dump_bins:
torch.save(
self.next_tokens,
write_ndarray_to_bin(
self.next_tokens.numpy(),
f"decode_next_tokens_{'_'.join([str(x)for x in self.next_tokens.shape])}.bin",
)
torch.save(
start_positions,
write_ndarray_to_bin(
start_positions.numpy(),
f"decode_start_positions_{'_'.join([str(x)for x in start_positions.shape])}.bin",
)
torch.save(
seq_block_ids_tensor,
write_ndarray_to_bin(
seq_block_ids_tensor.numpy(),
f"decode_seq_block_ids_tensor_{'_'.join([str(x)for x in seq_block_ids_tensor.shape])}.bin",
)
torch.save(
torch.tensor(self.next_tokens.shape[0]).to(torch.int64),
write_ndarray_to_bin(
torch.tensor(self.next_tokens.shape[0]).to(torch.int64).numpy(),
f"decode_seq_lens_1.bin",
)
torch.save(
self.cache_state[0].to(torch.float8_e4m3fnuz),
write_ndarray_to_bin(
self.cache_state[0].to(torch.float8_e4m3fnuz).to(torch.uint8).numpy(),
f"decode_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin",
)
logits = model.decode(
Expand Down

0 comments on commit 03dffd9

Please sign in to comment.