From d747d891d1962a45f434939cf42eb071baba1ad5 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Mon, 6 Jan 2025 19:30:28 +0000 Subject: [PATCH] Use --iree-hal-target-device flag in perplexity tests --- .github/workflows/ci_eval.yaml | 2 +- .github/workflows/ci_eval_short.yaml | 2 +- .../sharktank/evaluate/perplexity_iree.py | 18 +++++++++--------- .../tests/evaluate/perplexity_iree_test.py | 16 ++++++++-------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index f2db697d7..fe29f54d5 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -70,7 +70,7 @@ jobs: - name: Run perplexity test with IREE run: | source ${VENV_DIR}/bin/activate - pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html + pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-device=hip --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index 385a54261..05c7fa415 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -69,4 +69,4 @@ jobs: - name: Run perplexity test with vmfb run: | source ${VENV_DIR}/bin/activate - pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-device=hip --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json diff --git a/sharktank/sharktank/evaluate/perplexity_iree.py b/sharktank/sharktank/evaluate/perplexity_iree.py index c47726f0e..f42a4cf4a 100644 --- a/sharktank/sharktank/evaluate/perplexity_iree.py +++ b/sharktank/sharktank/evaluate/perplexity_iree.py @@ -64,7 +64,7 @@ def __init__( torch_device, iree_device, iree_hip_target, - iree_hal_target_backends, + iree_hal_target_device, kv_cache_type, tensor_parallelism_size, attention_kernel, @@ -73,7 +73,7 @@ def __init__( self.torch_device = torch_device self.iree_device = iree_device self.iree_hip_target = iree_hip_target - self.iree_hal_target_backends = iree_hal_target_backends + self.iree_hal_target_device = iree_hal_target_device self.kv_cache_type = kv_cache_type self.block_seq_stride = block_seq_stride self.activation_dtype = torch.float16 @@ -135,7 +135,7 @@ def compile_model(self, weight_path_str): irpa_path=self.weight_path_str, batch_size=self.bs, iree_hip_target=self.iree_hip_target, - iree_hal_target_backends=self.iree_hal_target_backends, + iree_hal_target_device=self.iree_hal_target_device, attention_kernel=self.attention_kernel, tensor_parallelism_size=self.tensor_parallelism_size, block_seq_stride=self.block_seq_stride, @@ -392,7 +392,7 @@ def run_perplexity( torch_device, iree_device, iree_hip_target, - iree_hal_target_backends, + iree_hal_target_device, kv_cache_type, tensor_parallelism_size, attention_kernel, @@ -404,7 +404,7 @@ def run_perplexity( torch_device=torch_device, iree_device=iree_device, iree_hip_target=iree_hip_target, - iree_hal_target_backends=iree_hal_target_backends, + iree_hal_target_device=iree_hal_target_device, kv_cache_type=kv_cache_type, tensor_parallelism_size=tensor_parallelism_size, attention_kernel=attention_kernel, @@ -450,10 +450,10 @@ def main(argv): help="Specify the iree-hip target version (e.g., gfx942)", ) parser.add_argument( - "--iree-hal-target-backends", + "--iree-hal-target-device", action="store", - default="rocm", - help="Specify the iree-hal target backends (e.g., rocm)", + default="hip", + help="Specify the iree-hal target device (e.g., hip, cpu)", ) parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") parser.add_argument( @@ -485,7 +485,7 @@ def main(argv): torch_device=torch_device, iree_device=args.iree_device, iree_hip_target=args.iree_hip_target, - iree_hal_target_backends=args.iree_hal_target_backends, + iree_hal_target_device=args.iree_hal_target_device, kv_cache_type=args.kv_cache_type, tensor_parallelism_size=args.tensor_parallelism_size, attention_kernel=args.attention_kernel, diff --git a/sharktank/tests/evaluate/perplexity_iree_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py index d10d9f5db..1e42bde9c 100644 --- a/sharktank/tests/evaluate/perplexity_iree_test.py +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -46,7 +46,7 @@ def test_llama3_8B_f16_decomposed(self): f"--irpa-file={self.llama3_8b_f16_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", f"--iree-device={self.iree_device}", - f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", @@ -82,7 +82,7 @@ def test_llama3_8B_f16(self): f"--irpa-file={self.llama3_8b_f16_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", f"--iree-device={self.iree_device}", - f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=torch_sdpa", @@ -118,7 +118,7 @@ def test_llama3_8B_fp8_decomposed(self): f"--irpa-file={self.llama3_8b_fp8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", f"--iree-device={self.iree_device}", - f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", @@ -154,7 +154,7 @@ def test_llama3_8B_fp8(self): f"--irpa-file={self.llama3_8b_fp8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", f"--iree-device={self.iree_device}", - f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=torch_sdpa", @@ -192,7 +192,7 @@ def test_llama3_405B_f16_decomposed(self): f"--irpa-file={self.llama3_405b_f16_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--iree-device={self.iree_device}", - f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=decomposed", @@ -228,7 +228,7 @@ def test_llama3_405B_f16(self): f"--irpa-file={self.llama3_405b_f16_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--iree-device={self.iree_device}", - f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", @@ -264,7 +264,7 @@ def test_llama3_405B_fp8_decomposed(self): f"--irpa-file={self.llama3_405b_fp8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--iree-device={self.iree_device}", - f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=decomposed", @@ -300,7 +300,7 @@ def test_llama3_405B_fp8(self): f"--irpa-file={self.llama3_405b_fp8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--iree-device={self.iree_device}", - f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa",