Skip to content

Commit

Permalink
Use --iree-hal-target-device flag in perplexity tests
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam committed Jan 6, 2025
1 parent 1a3c4cb commit d747d89
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
- name: Run perplexity test with IREE
run: |
source ${VENV_DIR}/bin/activate
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-device=hip --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html
- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_eval_short.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ jobs:
- name: Run perplexity test with vmfb
run: |
source ${VENV_DIR}/bin/activate
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-device=hip --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json
18 changes: 9 additions & 9 deletions sharktank/sharktank/evaluate/perplexity_iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
torch_device,
iree_device,
iree_hip_target,
iree_hal_target_backends,
iree_hal_target_device,
kv_cache_type,
tensor_parallelism_size,
attention_kernel,
Expand All @@ -73,7 +73,7 @@ def __init__(
self.torch_device = torch_device
self.iree_device = iree_device
self.iree_hip_target = iree_hip_target
self.iree_hal_target_backends = iree_hal_target_backends
self.iree_hal_target_device = iree_hal_target_device
self.kv_cache_type = kv_cache_type
self.block_seq_stride = block_seq_stride
self.activation_dtype = torch.float16
Expand Down Expand Up @@ -135,7 +135,7 @@ def compile_model(self, weight_path_str):
irpa_path=self.weight_path_str,
batch_size=self.bs,
iree_hip_target=self.iree_hip_target,
iree_hal_target_backends=self.iree_hal_target_backends,
iree_hal_target_device=self.iree_hal_target_device,
attention_kernel=self.attention_kernel,
tensor_parallelism_size=self.tensor_parallelism_size,
block_seq_stride=self.block_seq_stride,
Expand Down Expand Up @@ -392,7 +392,7 @@ def run_perplexity(
torch_device,
iree_device,
iree_hip_target,
iree_hal_target_backends,
iree_hal_target_device,
kv_cache_type,
tensor_parallelism_size,
attention_kernel,
Expand All @@ -404,7 +404,7 @@ def run_perplexity(
torch_device=torch_device,
iree_device=iree_device,
iree_hip_target=iree_hip_target,
iree_hal_target_backends=iree_hal_target_backends,
iree_hal_target_device=iree_hal_target_device,
kv_cache_type=kv_cache_type,
tensor_parallelism_size=tensor_parallelism_size,
attention_kernel=attention_kernel,
Expand Down Expand Up @@ -450,10 +450,10 @@ def main(argv):
help="Specify the iree-hip target version (e.g., gfx942)",
)
parser.add_argument(
"--iree-hal-target-backends",
"--iree-hal-target-device",
action="store",
default="rocm",
help="Specify the iree-hal target backends (e.g., rocm)",
default="hip",
help="Specify the iree-hal target device (e.g., hip, cpu)",
)
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument(
Expand Down Expand Up @@ -485,7 +485,7 @@ def main(argv):
torch_device=torch_device,
iree_device=args.iree_device,
iree_hip_target=args.iree_hip_target,
iree_hal_target_backends=args.iree_hal_target_backends,
iree_hal_target_device=args.iree_hal_target_device,
kv_cache_type=args.kv_cache_type,
tensor_parallelism_size=args.tensor_parallelism_size,
attention_kernel=args.attention_kernel,
Expand Down
16 changes: 8 additions & 8 deletions sharktank/tests/evaluate/perplexity_iree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_llama3_8B_f16_decomposed(self):
f"--irpa-file={self.llama3_8b_f16_model}",
f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
f"--iree-device={self.iree_device}",
f"--iree-hal-target-backends={self.iree_hal_target_backends}",
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=decomposed",
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_llama3_8B_f16(self):
f"--irpa-file={self.llama3_8b_f16_model}",
f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
f"--iree-device={self.iree_device}",
f"--iree-hal-target-backends={self.iree_hal_target_backends}",
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=torch_sdpa",
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_llama3_8B_fp8_decomposed(self):
f"--irpa-file={self.llama3_8b_fp8_model}",
f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
f"--iree-device={self.iree_device}",
f"--iree-hal-target-backends={self.iree_hal_target_backends}",
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=decomposed",
Expand Down Expand Up @@ -154,7 +154,7 @@ def test_llama3_8B_fp8(self):
f"--irpa-file={self.llama3_8b_fp8_model}",
f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
f"--iree-device={self.iree_device}",
f"--iree-hal-target-backends={self.iree_hal_target_backends}",
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=torch_sdpa",
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_llama3_405B_f16_decomposed(self):
f"--irpa-file={self.llama3_405b_f16_model}",
f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
f"--iree-device={self.iree_device}",
f"--iree-hal-target-backends={self.iree_hal_target_backends}",
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=decomposed",
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_llama3_405B_f16(self):
f"--irpa-file={self.llama3_405b_f16_model}",
f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
f"--iree-device={self.iree_device}",
f"--iree-hal-target-backends={self.iree_hal_target_backends}",
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=torch_sdpa",
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_llama3_405B_fp8_decomposed(self):
f"--irpa-file={self.llama3_405b_fp8_model}",
f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
f"--iree-device={self.iree_device}",
f"--iree-hal-target-backends={self.iree_hal_target_backends}",
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=decomposed",
Expand Down Expand Up @@ -300,7 +300,7 @@ def test_llama3_405B_fp8(self):
f"--irpa-file={self.llama3_405b_fp8_model}",
f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
f"--iree-device={self.iree_device}",
f"--iree-hal-target-backends={self.iree_hal_target_backends}",
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=torch_sdpa",
Expand Down

0 comments on commit d747d89

Please sign in to comment.