Skip to content

Commit

Permalink
Enable composable benchmark configs for flexible model+device+optimiz…
Browse files Browse the repository at this point in the history
…ation scheduling
  • Loading branch information
Github Executorch committed Dec 18, 2024
1 parent 72bb7b7 commit a7dc617
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 69 deletions.
183 changes: 183 additions & 0 deletions .ci/scripts/gather_benchmark_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import re
import json
import os
import logging
from typing import Any, Dict

from examples.models import MODEL_NAME_TO_MODEL


# Device pools for AWS Device Farm
DEVICE_POOLS = {
"apple_iphone_15": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/3b5acd2e-92e2-4778-b651-7726bafe129d",
"samsung_galaxy_s22": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa",
"samsung_galaxy_s24": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98f8788c-2e25-4a3c-8bb2-0d1e8897c0db",
"google_pixel_8_pro": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d65096ab-900b-4521-be8b-a3619b69236a",
}

# Predefined benchmark configurations
BENCHMARK_CONFIGS = {
"xplat": [
"xnnpack_q8",
"hf_xnnpack_fp32",
"llama3_fb16",
"llama3_spinquant",
"llama3_qlora",
],
"android": [
"qnn_q8",
],
"ios": [
"coreml_fp16",
"mps",
],
}


def parse_args() -> Any:
"""
Parse command-line arguments.
Returns:
argparse.Namespace: Parsed command-line arguments.
Example:
parse_args() -> Namespace(models=['mv3', 'meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8'],
os='android',
devices=['samsung_galaxy_s22'])
"""
from argparse import ArgumentParser

parser = ArgumentParser("Gather all benchmark configs.")
parser.add_argument(
"--os",
type=str,
choices=["android", "ios"],
help="the target OS",
)
parser.add_argument(
"--models",
nargs='+', # Accept one or more space-separated model names
help=f"either HuggingFace model IDs or names in [{MODEL_NAME_TO_MODEL}]",
)
parser.add_argument(
"--devices",
nargs='+', # Accept one or more space-separated devices
choices=list(DEVICE_POOLS.keys()), # Convert dict_keys to a list
help=f"devices to run the benchmark on. Pass as space-separated values. Available devices: {list(DEVICE_POOLS.keys())}",
)

return parser.parse_args()


def set_output(name: str, val: Any) -> None:
"""
Set the output value to be used by other GitHub jobs.
Args:
name (str): The name of the output variable.
val (Any): The value to set for the output variable.
Example:
set_output("benchmark_configs", {"include": [...]})
"""
logging.info(f"Setting {val} to GitHub output")

if os.getenv("GITHUB_OUTPUT"):
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
print(f"{name}={val}", file=env)
else:
print(f"::set-output name={name}::{val}")


def is_valid_huggingface_model_id(model_name: str) -> bool:
"""
Validate if the model name matches the pattern for HuggingFace model IDs.
Args:
model_name (str): The model name to validate.
Returns:
bool: True if the model name matches the valid pattern, False otherwise.
Example:
is_valid_huggingface_model_id('meta-llama/Llama-3.2') -> True
"""
pattern = r'^[a-zA-Z0-9-_]+/[a-zA-Z0-9-_.]+$'
return bool(re.match(pattern, model_name))


def get_benchmark_configs() -> Dict[str, Dict]:
"""
Gather benchmark configurations for a given set of models on the target operating system and devices.
Args:
None
Returns:
Dict[str, Dict]: A dictionary containing the benchmark configurations.
Example:
get_benchmark_configs() -> {
"include": [
{"model": "meta-llama/Llama-3.2-1B", "benchmark_config": "hf_xnnpack_fp32", "device": "arn:aws:..."},
{"model": "mv3", "benchmark_config": "xnnpack_q8", "device": "arn:aws:..."},
...
]
}
"""
args = parse_args()
target_os = args.os
devices = args.devices
models = args.models

benchmark_configs = {"include": []}

for model_name in models:
configs = []
if is_valid_huggingface_model_id(model_name):
if model_name.startswith("meta-llama/"):
# LLaMA models
repo_name = model_name.split("meta-llama/")[1]
if "qlora" in repo_name.lower():
configs.append("llama3_qlora")
elif "spinquant" in repo_name.lower():
configs.append("llama3_spinquant")
configs.append("llama3_fb16")
else:
# Non-LLaMA models
configs.append("hf_xnnpack_fp32")
elif model_name in MODEL_NAME_TO_MODEL:
# ExecuTorch in-tree models
configs.append("xnnpack_q8")
configs.extend(BENCHMARK_CONFIGS[target_os])
else:
# Skip unknown models with a warning
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
continue

# Add configurations for each valid device
for device in devices:
if device not in DEVICE_POOLS:
logging.warning(f"Unsupported device '{device}'. Skipping.")
continue
for config in configs:
record = {
"model": model_name,
"config": config,
"device": DEVICE_POOLS[device],
}
benchmark_configs["include"].append(record)

set_output("benchmark_configs", json.dumps(benchmark_configs))


if __name__ == "__main__":
get_benchmark_configs()
Loading

0 comments on commit a7dc617

Please sign in to comment.