Skip to content

Commit

Permalink
Add better logging
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam committed Nov 21, 2024
1 parent 9e965a2 commit 3089e61
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions sharktank/tests/evaluate/perplexity_vmfb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import json

from sharktank.evaluate import perplexity_vmfb
from sharktank.utils.export_artifacts import (
ExportMlirException,
IreeCompileException,
)

skipif_run_quick_llama_test = pytest.mark.skipif(
'config.getoption("run-quick-llama-test") and not config.getoption("run-nightly-llama-tests")',
Expand Down Expand Up @@ -46,6 +50,7 @@ def test_llama3_8B_f16_decomposed(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=decomposed",
f"--num-prompts=5",
]
)

Expand All @@ -61,10 +66,8 @@ def test_llama3_8B_f16_decomposed(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@pytest.mark.xfail(
reason="Non-decomposed attention is not supported yet",
)
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def test_llama3_8B_f16(self):

# Llama 3.1 8B non-decomposed
Expand Down Expand Up @@ -96,10 +99,8 @@ def test_llama3_8B_f16(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@pytest.mark.xfail(
reason="FP8 model is unsupported",
)
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def test_llama3_8B_fp8_decomposed(self):

# Llama 3.1 8B decomposed
Expand Down Expand Up @@ -131,10 +132,8 @@ def test_llama3_8B_fp8_decomposed(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@pytest.mark.xfail(
reason="FP8 model is unsupported",
)
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def test_llama3_8B_fp8(self):

# Llama 3.1 8B non-decomposed
Expand Down Expand Up @@ -166,10 +165,10 @@ def test_llama3_8B_fp8(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@skipif_run_quick_llama_test
@pytest.mark.xfail(
reason="Sharding is unsupported",
)
@skipif_run_quick_llama_test
def test_llama3_405B_f16_decomposed(self):

# Llama 3.1 405B decomposed
Expand Down Expand Up @@ -201,10 +200,8 @@ def test_llama3_405B_f16_decomposed(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@pytest.mark.xfail(
reason="Non-decomposed attention is not supported yet",
)
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def test_llama3_405B_f16(self):

# Llama 3.1 405B non-decomposed
Expand Down Expand Up @@ -236,10 +233,8 @@ def test_llama3_405B_f16(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@pytest.mark.xfail(
reason="FP8 model is unsupported",
)
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def test_llama3_405B_fp8_decomposed(self):

# Llama 3.1 405B decomposed
Expand Down Expand Up @@ -271,10 +266,8 @@ def test_llama3_405B_fp8_decomposed(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@pytest.mark.xfail(
reason="FP8 model is unsupported",
)
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def test_llama3_405B_fp8(self):

# Llama 3.1 405B non-decomposed
Expand Down

0 comments on commit 3089e61

Please sign in to comment.