-
Notifications
You must be signed in to change notification settings - Fork 403
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable composable benchmark configs for flexible model+device+optimiz…
…ation scheduling
- Loading branch information
Github Executorch
committed
Dec 18, 2024
1 parent
72bb7b7
commit a7dc617
Showing
2 changed files
with
220 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import re | ||
import json | ||
import os | ||
import logging | ||
from typing import Any, Dict | ||
|
||
from examples.models import MODEL_NAME_TO_MODEL | ||
|
||
|
||
# Device pools for AWS Device Farm | ||
DEVICE_POOLS = { | ||
"apple_iphone_15": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/3b5acd2e-92e2-4778-b651-7726bafe129d", | ||
"samsung_galaxy_s22": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa", | ||
"samsung_galaxy_s24": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98f8788c-2e25-4a3c-8bb2-0d1e8897c0db", | ||
"google_pixel_8_pro": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d65096ab-900b-4521-be8b-a3619b69236a", | ||
} | ||
|
||
# Predefined benchmark configurations | ||
BENCHMARK_CONFIGS = { | ||
"xplat": [ | ||
"xnnpack_q8", | ||
"hf_xnnpack_fp32", | ||
"llama3_fb16", | ||
"llama3_spinquant", | ||
"llama3_qlora", | ||
], | ||
"android": [ | ||
"qnn_q8", | ||
], | ||
"ios": [ | ||
"coreml_fp16", | ||
"mps", | ||
], | ||
} | ||
|
||
|
||
def parse_args() -> Any: | ||
""" | ||
Parse command-line arguments. | ||
Returns: | ||
argparse.Namespace: Parsed command-line arguments. | ||
Example: | ||
parse_args() -> Namespace(models=['mv3', 'meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8'], | ||
os='android', | ||
devices=['samsung_galaxy_s22']) | ||
""" | ||
from argparse import ArgumentParser | ||
|
||
parser = ArgumentParser("Gather all benchmark configs.") | ||
parser.add_argument( | ||
"--os", | ||
type=str, | ||
choices=["android", "ios"], | ||
help="the target OS", | ||
) | ||
parser.add_argument( | ||
"--models", | ||
nargs='+', # Accept one or more space-separated model names | ||
help=f"either HuggingFace model IDs or names in [{MODEL_NAME_TO_MODEL}]", | ||
) | ||
parser.add_argument( | ||
"--devices", | ||
nargs='+', # Accept one or more space-separated devices | ||
choices=list(DEVICE_POOLS.keys()), # Convert dict_keys to a list | ||
help=f"devices to run the benchmark on. Pass as space-separated values. Available devices: {list(DEVICE_POOLS.keys())}", | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def set_output(name: str, val: Any) -> None: | ||
""" | ||
Set the output value to be used by other GitHub jobs. | ||
Args: | ||
name (str): The name of the output variable. | ||
val (Any): The value to set for the output variable. | ||
Example: | ||
set_output("benchmark_configs", {"include": [...]}) | ||
""" | ||
logging.info(f"Setting {val} to GitHub output") | ||
|
||
if os.getenv("GITHUB_OUTPUT"): | ||
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: | ||
print(f"{name}={val}", file=env) | ||
else: | ||
print(f"::set-output name={name}::{val}") | ||
|
||
|
||
def is_valid_huggingface_model_id(model_name: str) -> bool: | ||
""" | ||
Validate if the model name matches the pattern for HuggingFace model IDs. | ||
Args: | ||
model_name (str): The model name to validate. | ||
Returns: | ||
bool: True if the model name matches the valid pattern, False otherwise. | ||
Example: | ||
is_valid_huggingface_model_id('meta-llama/Llama-3.2') -> True | ||
""" | ||
pattern = r'^[a-zA-Z0-9-_]+/[a-zA-Z0-9-_.]+$' | ||
return bool(re.match(pattern, model_name)) | ||
|
||
|
||
def get_benchmark_configs() -> Dict[str, Dict]: | ||
""" | ||
Gather benchmark configurations for a given set of models on the target operating system and devices. | ||
Args: | ||
None | ||
Returns: | ||
Dict[str, Dict]: A dictionary containing the benchmark configurations. | ||
Example: | ||
get_benchmark_configs() -> { | ||
"include": [ | ||
{"model": "meta-llama/Llama-3.2-1B", "benchmark_config": "hf_xnnpack_fp32", "device": "arn:aws:..."}, | ||
{"model": "mv3", "benchmark_config": "xnnpack_q8", "device": "arn:aws:..."}, | ||
... | ||
] | ||
} | ||
""" | ||
args = parse_args() | ||
target_os = args.os | ||
devices = args.devices | ||
models = args.models | ||
|
||
benchmark_configs = {"include": []} | ||
|
||
for model_name in models: | ||
configs = [] | ||
if is_valid_huggingface_model_id(model_name): | ||
if model_name.startswith("meta-llama/"): | ||
# LLaMA models | ||
repo_name = model_name.split("meta-llama/")[1] | ||
if "qlora" in repo_name.lower(): | ||
configs.append("llama3_qlora") | ||
elif "spinquant" in repo_name.lower(): | ||
configs.append("llama3_spinquant") | ||
configs.append("llama3_fb16") | ||
else: | ||
# Non-LLaMA models | ||
configs.append("hf_xnnpack_fp32") | ||
elif model_name in MODEL_NAME_TO_MODEL: | ||
# ExecuTorch in-tree models | ||
configs.append("xnnpack_q8") | ||
configs.extend(BENCHMARK_CONFIGS[target_os]) | ||
else: | ||
# Skip unknown models with a warning | ||
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.") | ||
continue | ||
|
||
# Add configurations for each valid device | ||
for device in devices: | ||
if device not in DEVICE_POOLS: | ||
logging.warning(f"Unsupported device '{device}'. Skipping.") | ||
continue | ||
for config in configs: | ||
record = { | ||
"model": model_name, | ||
"config": config, | ||
"device": DEVICE_POOLS[device], | ||
} | ||
benchmark_configs["include"].append(record) | ||
|
||
set_output("benchmark_configs", json.dumps(benchmark_configs)) | ||
|
||
|
||
if __name__ == "__main__": | ||
get_benchmark_configs() |
Oops, something went wrong.