From fc55c2d3f25a804d25d2048c9d1bf2f7d6794595 Mon Sep 17 00:00:00 2001 From: dan Date: Wed, 29 Jan 2025 16:13:09 -0800 Subject: [PATCH] address comments --- sharktank/sharktank/examples/paged_llm_v1.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 056e9ce33..780f7bc13 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -156,19 +156,19 @@ def prefill(self): if self.dump_bins: write_ndarray_to_bin( token_ids.numpy(), - f"prefill_token_ids_{'_'.join([str(x) for x in token_ids.shape])}.bin", + f"prefill_token_ids_{'x'.join([str(x) for x in token_ids.shape])}xi64.bin", ) write_ndarray_to_bin( np.array(token_ids.shape[0], dtype=np.int64), - f"prefill_seq_lens_1.bin", + f"prefill_seq_lens_1xi64.bin", ) write_ndarray_to_bin( seq_block_ids_tensor.numpy(), - f"prefill_seq_block_ids_{'_'.join([str(x) for x in seq_block_ids_tensor.shape])}.bin", + f"prefill_seq_block_ids_{'x'.join([str(x) for x in seq_block_ids_tensor.shape])}xi64.bin", ) write_ndarray_to_bin( self.cache_state[0].to(torch.float8_e4m3fnuz).to(torch.uint8).numpy(), - f"prefill_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin", + f"prefill_cache_state_{'x'.join([str(x) for x in self.cache_state[0].shape])}xf8E4M3FNUZ.bin", ) logits = model.prefill( token_ids, @@ -217,23 +217,23 @@ def decode(self): if self.dump_bins: write_ndarray_to_bin( self.next_tokens.numpy(), - f"decode_next_tokens_{'_'.join([str(x)for x in self.next_tokens.shape])}.bin", + f"decode_next_tokens_{'x'.join([str(x)for x in self.next_tokens.shape])}xi64.bin", ) write_ndarray_to_bin( start_positions.numpy(), - f"decode_start_positions_{'_'.join([str(x)for x in start_positions.shape])}.bin", + f"decode_start_positions_{'x'.join([str(x)for x in start_positions.shape])}xi64.bin", ) write_ndarray_to_bin( seq_block_ids_tensor.numpy(), - f"decode_seq_block_ids_tensor_{'_'.join([str(x)for x in seq_block_ids_tensor.shape])}.bin", + f"decode_seq_block_ids_tensor_{'x'.join([str(x)for x in seq_block_ids_tensor.shape])}xi64.bin", ) write_ndarray_to_bin( torch.tensor(self.next_tokens.shape[0]).to(torch.int64).numpy(), - f"decode_seq_lens_1.bin", + f"decode_seq_lens_1xi64.bin", ) write_ndarray_to_bin( self.cache_state[0].to(torch.float8_e4m3fnuz).to(torch.uint8).numpy(), - f"decode_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin", + f"decode_cache_state_{'x'.join([str(x) for x in self.cache_state[0].shape])}xf8E4M3FNUZ.bin", ) logits = model.decode( self.next_tokens,