Skip to content

Commit a7dc617

Browse files
author
Github Executorch
committed
Enable composable benchmark configs for flexible model+device+optimization scheduling
1 parent 72bb7b7 commit a7dc617

File tree

2 files changed

+220
-69
lines changed

2 files changed

+220
-69
lines changed

Diff for: .ci/scripts/gather_benchmark_configs.py

+183
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import re
9+
import json
10+
import os
11+
import logging
12+
from typing import Any, Dict
13+
14+
from examples.models import MODEL_NAME_TO_MODEL
15+
16+
17+
# Device pools for AWS Device Farm
18+
DEVICE_POOLS = {
19+
"apple_iphone_15": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/3b5acd2e-92e2-4778-b651-7726bafe129d",
20+
"samsung_galaxy_s22": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa",
21+
"samsung_galaxy_s24": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98f8788c-2e25-4a3c-8bb2-0d1e8897c0db",
22+
"google_pixel_8_pro": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d65096ab-900b-4521-be8b-a3619b69236a",
23+
}
24+
25+
# Predefined benchmark configurations
26+
BENCHMARK_CONFIGS = {
27+
"xplat": [
28+
"xnnpack_q8",
29+
"hf_xnnpack_fp32",
30+
"llama3_fb16",
31+
"llama3_spinquant",
32+
"llama3_qlora",
33+
],
34+
"android": [
35+
"qnn_q8",
36+
],
37+
"ios": [
38+
"coreml_fp16",
39+
"mps",
40+
],
41+
}
42+
43+
44+
def parse_args() -> Any:
45+
"""
46+
Parse command-line arguments.
47+
48+
Returns:
49+
argparse.Namespace: Parsed command-line arguments.
50+
51+
Example:
52+
parse_args() -> Namespace(models=['mv3', 'meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8'],
53+
os='android',
54+
devices=['samsung_galaxy_s22'])
55+
"""
56+
from argparse import ArgumentParser
57+
58+
parser = ArgumentParser("Gather all benchmark configs.")
59+
parser.add_argument(
60+
"--os",
61+
type=str,
62+
choices=["android", "ios"],
63+
help="the target OS",
64+
)
65+
parser.add_argument(
66+
"--models",
67+
nargs='+', # Accept one or more space-separated model names
68+
help=f"either HuggingFace model IDs or names in [{MODEL_NAME_TO_MODEL}]",
69+
)
70+
parser.add_argument(
71+
"--devices",
72+
nargs='+', # Accept one or more space-separated devices
73+
choices=list(DEVICE_POOLS.keys()), # Convert dict_keys to a list
74+
help=f"devices to run the benchmark on. Pass as space-separated values. Available devices: {list(DEVICE_POOLS.keys())}",
75+
)
76+
77+
return parser.parse_args()
78+
79+
80+
def set_output(name: str, val: Any) -> None:
81+
"""
82+
Set the output value to be used by other GitHub jobs.
83+
84+
Args:
85+
name (str): The name of the output variable.
86+
val (Any): The value to set for the output variable.
87+
88+
Example:
89+
set_output("benchmark_configs", {"include": [...]})
90+
"""
91+
logging.info(f"Setting {val} to GitHub output")
92+
93+
if os.getenv("GITHUB_OUTPUT"):
94+
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
95+
print(f"{name}={val}", file=env)
96+
else:
97+
print(f"::set-output name={name}::{val}")
98+
99+
100+
def is_valid_huggingface_model_id(model_name: str) -> bool:
101+
"""
102+
Validate if the model name matches the pattern for HuggingFace model IDs.
103+
104+
Args:
105+
model_name (str): The model name to validate.
106+
107+
Returns:
108+
bool: True if the model name matches the valid pattern, False otherwise.
109+
110+
Example:
111+
is_valid_huggingface_model_id('meta-llama/Llama-3.2') -> True
112+
"""
113+
pattern = r'^[a-zA-Z0-9-_]+/[a-zA-Z0-9-_.]+$'
114+
return bool(re.match(pattern, model_name))
115+
116+
117+
def get_benchmark_configs() -> Dict[str, Dict]:
118+
"""
119+
Gather benchmark configurations for a given set of models on the target operating system and devices.
120+
121+
Args:
122+
None
123+
124+
Returns:
125+
Dict[str, Dict]: A dictionary containing the benchmark configurations.
126+
127+
Example:
128+
get_benchmark_configs() -> {
129+
"include": [
130+
{"model": "meta-llama/Llama-3.2-1B", "benchmark_config": "hf_xnnpack_fp32", "device": "arn:aws:..."},
131+
{"model": "mv3", "benchmark_config": "xnnpack_q8", "device": "arn:aws:..."},
132+
...
133+
]
134+
}
135+
"""
136+
args = parse_args()
137+
target_os = args.os
138+
devices = args.devices
139+
models = args.models
140+
141+
benchmark_configs = {"include": []}
142+
143+
for model_name in models:
144+
configs = []
145+
if is_valid_huggingface_model_id(model_name):
146+
if model_name.startswith("meta-llama/"):
147+
# LLaMA models
148+
repo_name = model_name.split("meta-llama/")[1]
149+
if "qlora" in repo_name.lower():
150+
configs.append("llama3_qlora")
151+
elif "spinquant" in repo_name.lower():
152+
configs.append("llama3_spinquant")
153+
configs.append("llama3_fb16")
154+
else:
155+
# Non-LLaMA models
156+
configs.append("hf_xnnpack_fp32")
157+
elif model_name in MODEL_NAME_TO_MODEL:
158+
# ExecuTorch in-tree models
159+
configs.append("xnnpack_q8")
160+
configs.extend(BENCHMARK_CONFIGS[target_os])
161+
else:
162+
# Skip unknown models with a warning
163+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
164+
continue
165+
166+
# Add configurations for each valid device
167+
for device in devices:
168+
if device not in DEVICE_POOLS:
169+
logging.warning(f"Unsupported device '{device}'. Skipping.")
170+
continue
171+
for config in configs:
172+
record = {
173+
"model": model_name,
174+
"config": config,
175+
"device": DEVICE_POOLS[device],
176+
}
177+
benchmark_configs["include"].append(record)
178+
179+
set_output("benchmark_configs", json.dumps(benchmark_configs))
180+
181+
182+
if __name__ == "__main__":
183+
get_benchmark_configs()

0 commit comments

Comments
 (0)