Skip to content

Commit 2e2ab00

Browse files
author
Github Executorch
committed
Enable composable benchmark configs for flexible model+device+optimization scheduling
1 parent 72bb7b7 commit 2e2ab00

File tree

2 files changed

+254
-67
lines changed

2 files changed

+254
-67
lines changed

Diff for: .ci/scripts/gather_benchmark_configs.py

+219
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import itertools
9+
import json
10+
import logging
11+
import os
12+
import re
13+
import sys
14+
from typing import Any, Dict, List
15+
16+
from executorch.examples.models import MODEL_NAME_TO_MODEL
17+
18+
19+
# Device pools for AWS Device Farm
20+
DEVICE_POOLS = {
21+
"apple_iphone_15": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/3b5acd2e-92e2-4778-b651-7726bafe129d",
22+
"samsung_galaxy_s22": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa",
23+
"samsung_galaxy_s24": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98f8788c-2e25-4a3c-8bb2-0d1e8897c0db",
24+
"google_pixel_8_pro": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d65096ab-900b-4521-be8b-a3619b69236a",
25+
}
26+
27+
# Predefined benchmark configurations
28+
BENCHMARK_CONFIGS = {
29+
"xplat": {
30+
"xnnpack_q8": {},
31+
"hf_xnnpack_fp32": {},
32+
"llama3_fb16": {
33+
"--model": "llama3_2",
34+
"--use_kv_cache": True,
35+
"--use_sdpa_with_kv_cache": True,
36+
"--xnnpack": True,
37+
"--dtype-override": "bf16",
38+
"--metadata": '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}',
39+
},
40+
"llama3_spinquant": {
41+
"--model": "llama3_2",
42+
"--use_kv_cache": True,
43+
"--use_sdpa_with_kv_cache": True,
44+
"--xnnpack": True,
45+
"--xnnpack-extended-ops": True,
46+
"--dtype-override": "fp32",
47+
"--max_seq_length": 2048,
48+
"--preq_mode": "8da4w_output_8da8w",
49+
"--preq_group_size": 32,
50+
"--preq_embedding_quantize": "8,0",
51+
"--use_spin_quant": "native",
52+
"--metadata": '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}',
53+
},
54+
"llama3_qlora": {
55+
"--model": "llama3_2",
56+
"--use_kv_cache": True,
57+
"--use_sdpa_with_kv_cache": True,
58+
"--xnnpack": True,
59+
"--xnnpack-extended-ops": True,
60+
"--dtype-override": "fp32",
61+
"--max_seq_length": 2048,
62+
"--preq_mode": "8da4w_output_8da8w",
63+
"--preq_group_size": 32,
64+
"--preq_embedding_quantize": "8,0",
65+
"--use_qat": True,
66+
"--use_lora": 16,
67+
"--metadata": '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}',
68+
},
69+
},
70+
"android": {
71+
"qnn_q8": {},
72+
},
73+
"ios": {
74+
"coreml_fp16": {},
75+
"mps": {},
76+
},
77+
}
78+
79+
80+
def parse_args() -> Any:
81+
"""
82+
Parse command-line arguments.
83+
84+
Returns:
85+
argparse.Namespace: Parsed command-line arguments.
86+
87+
Example:
88+
parse_args() -> Namespace(models=['mv3', 'meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8'],
89+
os='android',
90+
devices=['samsung_galaxy_s22'])
91+
"""
92+
from argparse import ArgumentParser
93+
94+
parser = ArgumentParser("Gather all benchmark configs.")
95+
parser.add_argument(
96+
"--os",
97+
type=str,
98+
choices=["android", "ios"],
99+
help="the target OS",
100+
)
101+
parser.add_argument(
102+
"--models",
103+
nargs="+", # Accept one or more space-separated model names
104+
help=f"either HuggingFace model IDs or names in [{MODEL_NAME_TO_MODEL}]",
105+
)
106+
parser.add_argument(
107+
"--devices",
108+
nargs="+", # Accept one or more space-separated devices
109+
choices=list(DEVICE_POOLS.keys()), # Convert dict_keys to a list
110+
help=f"devices to run the benchmark on. Pass as space-separated values. Available devices: {list(DEVICE_POOLS.keys())}",
111+
)
112+
113+
return parser.parse_args()
114+
115+
116+
def set_output(name: str, val: Any) -> None:
117+
"""
118+
Set the output value to be used by other GitHub jobs.
119+
120+
Args:
121+
name (str): The name of the output variable.
122+
val (Any): The value to set for the output variable.
123+
124+
Example:
125+
set_output("benchmark_configs", {"include": [...]})
126+
"""
127+
logging.info(f"Setting {val} to GitHub output")
128+
129+
if os.getenv("GITHUB_OUTPUT"):
130+
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
131+
print(f"{name}={val}", file=env)
132+
else:
133+
print(f"::set-output name={name}::{val}")
134+
135+
136+
def is_valid_huggingface_model_id(model_name: str) -> bool:
137+
"""
138+
Validate if the model name matches the pattern for HuggingFace model IDs.
139+
140+
Args:
141+
model_name (str): The model name to validate.
142+
143+
Returns:
144+
bool: True if the model name matches the valid pattern, False otherwise.
145+
146+
Example:
147+
is_valid_huggingface_model_id('meta-llama/Llama-3.2') -> True
148+
"""
149+
pattern = r"^[a-zA-Z0-9-_]+/[a-zA-Z0-9-_.]+$"
150+
return bool(re.match(pattern, model_name))
151+
152+
153+
def get_benchmark_configs() -> Dict[str, Dict]:
154+
"""
155+
Gather benchmark configurations for a given set of models on the target operating system and devices.
156+
157+
Args:
158+
None
159+
160+
Returns:
161+
Dict[str, Dict]: A dictionary containing the benchmark configurations.
162+
163+
Example:
164+
get_benchmark_configs() -> {
165+
"include": [
166+
{"model": "meta-llama/Llama-3.2-1B", "benchmark_config": "hf_xnnpack_fp32", "device": "arn:aws:..."},
167+
{"model": "mv3", "benchmark_config": "xnnpack_q8", "device": "arn:aws:..."},
168+
...
169+
]
170+
}
171+
"""
172+
args = parse_args()
173+
target_os = args.os
174+
devices = args.devices
175+
models = args.models
176+
177+
benchmark_configs = {"include": []}
178+
179+
for model_name in models:
180+
configs = []
181+
if is_valid_huggingface_model_id(model_name):
182+
if model_name.startswith("meta-llama/"):
183+
# LLaMA models
184+
repo_name = model_name.split("meta-llama/")[1]
185+
if "qlora" in repo_name.lower():
186+
configs.append("llama3_qlora")
187+
elif "spinquant" in repo_name.lower():
188+
configs.append("llama3_spinquant")
189+
configs.append("llama3_fb16")
190+
else:
191+
# Non-LLaMA models
192+
configs.append("hf_xnnpack_fp32")
193+
elif model_name in MODEL_NAME_TO_MODEL:
194+
# ExecuTorch in-tree models
195+
configs.append("xnnpack_q8")
196+
configs.extend(BENCHMARK_CONFIGS[target_os])
197+
else:
198+
# Skip unknown models with a warning
199+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
200+
continue
201+
202+
# Add configurations for each valid device
203+
for device in devices:
204+
if device not in DEVICE_POOLS:
205+
logging.warning(f"Unsupported device '{device}'. Skipping.")
206+
continue
207+
for config in configs:
208+
record = {
209+
"model": model_name,
210+
"config": config,
211+
"device": DEVICE_POOLS[device],
212+
}
213+
benchmark_configs["include"].append(record)
214+
215+
set_output("benchmark_configs", json.dumps(benchmark_configs))
216+
217+
218+
if __name__ == "__main__":
219+
get_benchmark_configs()

0 commit comments

Comments
 (0)