Skip to content

Commit d030e94

Browse files
author
Github Executorch
committed
Enable composable benchmark configs for flexible model+device+optimization scheduling
1 parent 72bb7b7 commit d030e94

File tree

2 files changed

+226
-72
lines changed

2 files changed

+226
-72
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import json
9+
import logging
10+
import os
11+
import re
12+
from typing import Any, Dict
13+
14+
from examples.models import MODEL_NAME_TO_MODEL
15+
16+
17+
# Device pools for AWS Device Farm
18+
DEVICE_POOLS = {
19+
"apple_iphone_15": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/3b5acd2e-92e2-4778-b651-7726bafe129d",
20+
"samsung_galaxy_s22": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa",
21+
"samsung_galaxy_s24": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98f8788c-2e25-4a3c-8bb2-0d1e8897c0db",
22+
"google_pixel_8_pro": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d65096ab-900b-4521-be8b-a3619b69236a",
23+
}
24+
25+
# Predefined benchmark configurations
26+
BENCHMARK_CONFIGS = {
27+
"xplat": [
28+
"xnnpack_q8",
29+
"hf_xnnpack_fp32",
30+
"llama3_fb16",
31+
"llama3_spinquant",
32+
"llama3_qlora",
33+
],
34+
"android": [
35+
"qnn_q8",
36+
],
37+
"ios": [
38+
"coreml_fp16",
39+
"mps",
40+
],
41+
}
42+
43+
44+
def parse_args() -> Any:
45+
"""
46+
Parse command-line arguments.
47+
48+
Returns:
49+
argparse.Namespace: Parsed command-line arguments.
50+
51+
Example:
52+
parse_args() -> Namespace(models=['mv3', 'meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8'],
53+
os='android',
54+
devices=['samsung_galaxy_s22'])
55+
"""
56+
from argparse import ArgumentParser
57+
58+
def comma_separated(value: str):
59+
"""
60+
Parse a comma-separated string into a list.
61+
"""
62+
return value.split(",")
63+
64+
parser = ArgumentParser("Gather all benchmark configs.")
65+
parser.add_argument(
66+
"--os",
67+
type=str,
68+
choices=["android", "ios"],
69+
help="The target OS.",
70+
)
71+
parser.add_argument(
72+
"--models",
73+
type=comma_separated, # Use the custom parser for comma-separated values
74+
help=f"Comma-separated model IDs or names. Valid values include {MODEL_NAME_TO_MODEL}.",
75+
)
76+
parser.add_argument(
77+
"--devices",
78+
type=comma_separated, # Use the custom parser for comma-separated values
79+
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
80+
)
81+
82+
return parser.parse_args()
83+
84+
85+
def set_output(name: str, val: Any) -> None:
86+
"""
87+
Set the output value to be used by other GitHub jobs.
88+
89+
Args:
90+
name (str): The name of the output variable.
91+
val (Any): The value to set for the output variable.
92+
93+
Example:
94+
set_output("benchmark_configs", {"include": [...]})
95+
"""
96+
print(f"Setting {val} to GitHub output")
97+
98+
if os.getenv("GITHUB_OUTPUT"):
99+
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
100+
print(f"{name}={val}", file=env)
101+
else:
102+
print(f"::set-output name={name}::{val}")
103+
104+
105+
def is_valid_huggingface_model_id(model_name: str) -> bool:
106+
"""
107+
Validate if the model name matches the pattern for HuggingFace model IDs.
108+
109+
Args:
110+
model_name (str): The model name to validate.
111+
112+
Returns:
113+
bool: True if the model name matches the valid pattern, False otherwise.
114+
115+
Example:
116+
is_valid_huggingface_model_id('meta-llama/Llama-3.2-1B') -> True
117+
"""
118+
pattern = r"^[a-zA-Z0-9-_]+/[a-zA-Z0-9-_.]+$"
119+
return bool(re.match(pattern, model_name))
120+
121+
122+
def get_benchmark_configs() -> Dict[str, Dict]:
123+
"""
124+
Gather benchmark configurations for a given set of models on the target operating system and devices.
125+
126+
Args:
127+
None
128+
129+
Returns:
130+
Dict[str, Dict]: A dictionary containing the benchmark configurations.
131+
132+
Example:
133+
get_benchmark_configs() -> {
134+
"include": [
135+
{"model": "meta-llama/Llama-3.2-1B", "benchmark_config": "hf_xnnpack_fp32", "device": "arn:aws:..."},
136+
{"model": "mv3", "benchmark_config": "xnnpack_q8", "device": "arn:aws:..."},
137+
...
138+
]
139+
}
140+
"""
141+
args = parse_args()
142+
target_os = args.os
143+
devices = args.devices
144+
models = args.models
145+
146+
benchmark_configs = {"include": []}
147+
148+
for model_name in models:
149+
configs = []
150+
if is_valid_huggingface_model_id(model_name):
151+
if model_name.startswith("meta-llama/"):
152+
# LLaMA models
153+
repo_name = model_name.split("meta-llama/")[1]
154+
if "qlora" in repo_name.lower():
155+
configs.append("llama3_qlora")
156+
elif "spinquant" in repo_name.lower():
157+
configs.append("llama3_spinquant")
158+
else:
159+
configs.append("llama3_fb16")
160+
else:
161+
# Non-LLaMA models
162+
configs.append("hf_xnnpack_fp32")
163+
elif model_name in MODEL_NAME_TO_MODEL:
164+
# ExecuTorch in-tree models
165+
configs.append("xnnpack_q8")
166+
configs.extend(BENCHMARK_CONFIGS[target_os])
167+
else:
168+
# Skip unknown models with a warning
169+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
170+
continue
171+
172+
# Add configurations for each valid device
173+
for device in devices:
174+
if device not in DEVICE_POOLS:
175+
logging.warning(f"Unsupported device '{device}'. Skipping.")
176+
continue
177+
for config in configs:
178+
record = {
179+
"model": model_name,
180+
"config": config,
181+
"device": DEVICE_POOLS[device],
182+
}
183+
benchmark_configs["include"].append(record)
184+
185+
set_output("benchmark_configs", json.dumps(benchmark_configs))
186+
187+
188+
if __name__ == "__main__":
189+
get_benchmark_configs()

0 commit comments

Comments
 (0)