diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_vmfb_test.py index 6c84c0fff..1fb8f3a6e 100644 --- a/sharktank/tests/evaluate/perplexity_vmfb_test.py +++ b/sharktank/tests/evaluate/perplexity_vmfb_test.py @@ -9,6 +9,10 @@ import json from sharktank.evaluate import perplexity_vmfb +from sharktank.utils.export_artifacts import ( + ExportMlirException, + IreeCompileException, +) skipif_run_quick_llama_test = pytest.mark.skipif( 'config.getoption("run-quick-llama-test") and not config.getoption("run-nightly-llama-tests")', @@ -46,6 +50,7 @@ def test_llama3_8B_f16_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", + f"--num-prompts=5", ] ) @@ -61,10 +66,8 @@ def test_llama3_8B_f16_decomposed(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="Non-decomposed attention is not supported yet", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed @@ -96,10 +99,8 @@ def test_llama3_8B_f16(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_8B_fp8_decomposed(self): # Llama 3.1 8B decomposed @@ -131,10 +132,8 @@ def test_llama3_8B_fp8_decomposed(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_8B_fp8(self): # Llama 3.1 8B non-decomposed @@ -166,10 +165,10 @@ def test_llama3_8B_fp8(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) + @skipif_run_quick_llama_test @pytest.mark.xfail( reason="Sharding is unsupported", ) - @skipif_run_quick_llama_test def test_llama3_405B_f16_decomposed(self): # Llama 3.1 405B decomposed @@ -201,10 +200,8 @@ def test_llama3_405B_f16_decomposed(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="Non-decomposed attention is not supported yet", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_405B_f16(self): # Llama 3.1 405B non-decomposed @@ -236,10 +233,8 @@ def test_llama3_405B_f16(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_405B_fp8_decomposed(self): # Llama 3.1 405B decomposed @@ -271,10 +266,8 @@ def test_llama3_405B_fp8_decomposed(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def test_llama3_405B_fp8(self): # Llama 3.1 405B non-decomposed