Skip to content

Commit

Permalink
[DevX] Skip disabled configs for benchmarking (pytorch#7868)
Browse files Browse the repository at this point in the history
Skip disabled configs for benchmarking

Co-authored-by: Github Executorch <[email protected]>
  • Loading branch information
2 people authored and Zonglin Peng committed Jan 30, 2025
1 parent 93e6525 commit 1936f0c
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 22 deletions.
50 changes: 49 additions & 1 deletion .ci/scripts/gather_benchmark_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import re
import sys
from typing import Any, Dict, List
from typing import Any, Dict, List, NamedTuple

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from examples.models import MODEL_NAME_TO_MODEL
Expand Down Expand Up @@ -47,6 +47,46 @@
}


class DisabledConfig(NamedTuple):
config_name: str
github_issue: str # Link to the GitHub issue


# Updated DISABLED_CONFIGS
DISABLED_CONFIGS: Dict[str, List[DisabledConfig]] = {
"resnet50": [
DisabledConfig(
config_name="qnn_q8",
github_issue="https://github.com/pytorch/executorch/issues/7892",
),
],
"w2l": [
DisabledConfig(
config_name="qnn_q8",
github_issue="https://github.com/pytorch/executorch/issues/7634",
),
],
"mobilebert": [
DisabledConfig(
config_name="mps",
github_issue="https://github.com/pytorch/executorch/issues/7904",
),
],
"edsr": [
DisabledConfig(
config_name="mps",
github_issue="https://github.com/pytorch/executorch/issues/7905",
),
],
"llama": [
DisabledConfig(
config_name="mps",
github_issue="https://github.com/pytorch/executorch/issues/7907",
),
],
}


def extract_all_configs(data, target_os=None):
if isinstance(data, dict):
# If target_os is specified, include "xplat" and the specified branch
Expand Down Expand Up @@ -117,6 +157,14 @@ def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
# Skip unknown models with a warning
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")

# Remove disabled configs for the given model
disabled_configs = DISABLED_CONFIGS.get(model_name, [])
disabled_config_names = {disabled.config_name for disabled in disabled_configs}
for disabled in disabled_configs:
print(
f"Excluding disabled config: '{disabled.config_name}' for model '{model_name}' on '{target_os}'. Linked GitHub issue: {disabled.github_issue}"
)
configs = [config for config in configs if config not in disabled_config_names]
return configs


Expand Down
110 changes: 89 additions & 21 deletions .ci/scripts/tests/test_gather_benchmark_configs.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,41 @@
import importlib.util
import os
import re
import subprocess
import sys
import unittest
from unittest.mock import mock_open, patch

import pytest

# Dynamically import the script
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
gather_benchmark_configs = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gather_benchmark_configs)


@pytest.mark.skipif(
sys.platform != "linux", reason="The script under test runs on Linux runners only"
)
class TestGatehrBenchmarkConfigs(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Dynamically import the script
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
spec = importlib.util.spec_from_file_location(
"gather_benchmark_configs", script_path
)
cls.gather_benchmark_configs = importlib.util.module_from_spec(spec)
spec.loader.exec_module(cls.gather_benchmark_configs)

def test_extract_all_configs_android(self):
android_configs = gather_benchmark_configs.extract_all_configs(
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
android_configs = self.gather_benchmark_configs.extract_all_configs(
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
)
self.assertIn("xnnpack_q8", android_configs)
self.assertIn("qnn_q8", android_configs)
self.assertIn("llama3_spinquant", android_configs)
self.assertIn("llama3_qlora", android_configs)

def test_extract_all_configs_ios(self):
ios_configs = gather_benchmark_configs.extract_all_configs(
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
ios_configs = self.gather_benchmark_configs.extract_all_configs(
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
)

self.assertIn("xnnpack_q8", ios_configs)
Expand All @@ -40,51 +45,114 @@ def test_extract_all_configs_ios(self):
self.assertIn("llama3_spinquant", ios_configs)
self.assertIn("llama3_qlora", ios_configs)

def test_skip_disabled_configs(self):
# Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS
with patch.dict(
self.gather_benchmark_configs.DISABLED_CONFIGS,
{
"mv3": [
self.gather_benchmark_configs.DisabledConfig(
config_name="disabled_config1",
github_issue="https://github.com/org/repo/issues/123",
),
self.gather_benchmark_configs.DisabledConfig(
config_name="disabled_config2",
github_issue="https://github.com/org/repo/issues/124",
),
]
},
), patch.dict(
self.gather_benchmark_configs.BENCHMARK_CONFIGS,
{
"ios": [
"disabled_config1",
"disabled_config2",
"enabled_config1",
"enabled_config2",
]
},
):
result = self.gather_benchmark_configs.generate_compatible_configs(
"mv3", target_os="ios"
)

# Assert that disabled configs are excluded
self.assertNotIn("disabled_config1", result)
self.assertNotIn("disabled_config2", result)
# Assert enabled configs are included
self.assertIn("enabled_config1", result)
self.assertIn("enabled_config2", result)

def test_disabled_configs_have_github_links(self):
github_issue_regex = re.compile(r"https://github\.com/.+/.+/issues/\d+")

for (
model_name,
disabled_configs,
) in self.gather_benchmark_configs.DISABLED_CONFIGS.items():
for disabled in disabled_configs:
with self.subTest(model_name=model_name, config=disabled.config_name):
# Assert that disabled is an instance of DisabledConfig
self.assertIsInstance(
disabled, self.gather_benchmark_configs.DisabledConfig
)

# Assert that github_issue is provided and matches the expected pattern
self.assertTrue(
disabled.github_issue
and github_issue_regex.match(disabled.github_issue),
f"Invalid or missing GitHub issue link for '{disabled.config_name}' in model '{model_name}'.",
)

def test_generate_compatible_configs_llama_model(self):
model_name = "meta-llama/Llama-3.2-1B"
target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["llama3_fb16", "llama3_coreml_ane"]
self.assertEqual(result, expected)

target_os = "android"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["llama3_fb16"]
self.assertEqual(result, expected)

def test_generate_compatible_configs_quantized_llama_model(self):
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, None
)
expected = ["llama3_spinquant"]
self.assertEqual(result, expected)

model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, None
)
expected = ["llama3_qlora"]
self.assertEqual(result, expected)

def test_generate_compatible_configs_non_genai_model(self):
model_name = "mv2"
target_os = "xplat"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8"]
self.assertEqual(result, expected)

target_os = "android"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8", "qnn_q8"]
self.assertEqual(result, expected)

target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
Expand All @@ -93,22 +161,22 @@ def test_generate_compatible_configs_non_genai_model(self):
def test_generate_compatible_configs_unknown_model(self):
model_name = "unknown_model"
target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
self.assertEqual(result, [])

def test_is_valid_huggingface_model_id_valid(self):
valid_model = "meta-llama/Llama-3.2-1B"
self.assertTrue(
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
self.gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
)

@patch("builtins.open", new_callable=mock_open)
@patch("os.getenv", return_value=None)
def test_set_output_no_github_env(self, mock_getenv, mock_file):
with patch("builtins.print") as mock_print:
gather_benchmark_configs.set_output("test_name", "test_value")
self.gather_benchmark_configs.set_output("test_name", "test_value")
mock_print.assert_called_with("::set-output name=test_name::test_value")

def test_device_pools_contains_all_devices(self):
Expand All @@ -120,7 +188,7 @@ def test_device_pools_contains_all_devices(self):
"google_pixel_8_pro",
]
for device in expected_devices:
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
self.assertIn(device, self.gather_benchmark_configs.DEVICE_POOLS)

def test_gather_benchmark_configs_cli(self):
args = {
Expand Down

0 comments on commit 1936f0c

Please sign in to comment.