Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Jan 30, 2025
1 parent 03dffd9 commit fc55c2d
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,19 @@ def prefill(self):
if self.dump_bins:
write_ndarray_to_bin(
token_ids.numpy(),
f"prefill_token_ids_{'_'.join([str(x) for x in token_ids.shape])}.bin",
f"prefill_token_ids_{'x'.join([str(x) for x in token_ids.shape])}xi64.bin",
)
write_ndarray_to_bin(
np.array(token_ids.shape[0], dtype=np.int64),
f"prefill_seq_lens_1.bin",
f"prefill_seq_lens_1xi64.bin",
)
write_ndarray_to_bin(
seq_block_ids_tensor.numpy(),
f"prefill_seq_block_ids_{'_'.join([str(x) for x in seq_block_ids_tensor.shape])}.bin",
f"prefill_seq_block_ids_{'x'.join([str(x) for x in seq_block_ids_tensor.shape])}xi64.bin",
)
write_ndarray_to_bin(
self.cache_state[0].to(torch.float8_e4m3fnuz).to(torch.uint8).numpy(),
f"prefill_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin",
f"prefill_cache_state_{'x'.join([str(x) for x in self.cache_state[0].shape])}xf8E4M3FNUZ.bin",
)
logits = model.prefill(
token_ids,
Expand Down Expand Up @@ -217,23 +217,23 @@ def decode(self):
if self.dump_bins:
write_ndarray_to_bin(
self.next_tokens.numpy(),
f"decode_next_tokens_{'_'.join([str(x)for x in self.next_tokens.shape])}.bin",
f"decode_next_tokens_{'x'.join([str(x)for x in self.next_tokens.shape])}xi64.bin",
)
write_ndarray_to_bin(
start_positions.numpy(),
f"decode_start_positions_{'_'.join([str(x)for x in start_positions.shape])}.bin",
f"decode_start_positions_{'x'.join([str(x)for x in start_positions.shape])}xi64.bin",
)
write_ndarray_to_bin(
seq_block_ids_tensor.numpy(),
f"decode_seq_block_ids_tensor_{'_'.join([str(x)for x in seq_block_ids_tensor.shape])}.bin",
f"decode_seq_block_ids_tensor_{'x'.join([str(x)for x in seq_block_ids_tensor.shape])}xi64.bin",
)
write_ndarray_to_bin(
torch.tensor(self.next_tokens.shape[0]).to(torch.int64).numpy(),
f"decode_seq_lens_1.bin",
f"decode_seq_lens_1xi64.bin",
)
write_ndarray_to_bin(
self.cache_state[0].to(torch.float8_e4m3fnuz).to(torch.uint8).numpy(),
f"decode_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin",
f"decode_cache_state_{'x'.join([str(x) for x in self.cache_state[0].shape])}xf8E4M3FNUZ.bin",
)
logits = model.decode(
self.next_tokens,
Expand Down

0 comments on commit fc55c2d

Please sign in to comment.