Skip to content

Commit

Permalink
Quality of life improvements for model tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottTodd committed Mar 15, 2024
1 parent 9104b9b commit 131a51b
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 60 deletions.
1 change: 1 addition & 0 deletions iree_tests/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ config_*.txt
# Large files fetched on-demand
*.bin
real_weights.irpa
pytorch/models/**/*.npy
40 changes: 39 additions & 1 deletion iree_tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ $ pytest iree_tests -n auto
Run tests using custom config files:

```bash
$ pytest iree_tests --config-files ./configs/config_gpu_vulkan.json

# OR set an environment variable
$ export IREE_TEST_CONFIG_FILES=/iree/config_cpu_llvm_sync.json;/iree/config_gpu_vulkan.json
$ pytest iree_tests
```
Expand Down Expand Up @@ -124,7 +127,23 @@ collected 1047 items
======================== 1047 tests collected in 4.34s ========================
```

Run a subset of tests (see
Run tests from a specific subdirectory:

```bash
$ pytest iree_tests/simple

======================================= test session starts ======================================= platform win32 -- Python 3.11.2, pytest-8.0.2, pluggy-1.4.0
rootdir: D:\dev\projects\SHARK-TestSuite\iree_tests
configfile: pytest.ini
plugins: retry-1.6.2, timeout-2.2.0, xdist-3.5.0
collected 2 items

simple\abs\simple_abs.mlir . [ 50%] simple\abs_bc\simple_abs.mlirbc . [100%]

======================================== 2 passed in 2.48s ========================================
```

Run a filtered subset of tests (see
[Specifying which tests to run](https://docs.pytest.org/en/8.0.x/how-to/usage.html#specifying-which-tests-to-run)):

```bash
Expand Down Expand Up @@ -295,3 +314,22 @@ To simply strip weights:
```bash
iree-ir-tool strip-data model.mlir -o model_stripped.mlir
```

### Working with parameter files

To convert from .safetensors to .irpa (real weights):

```bash
iree-convert-parameters \
--parameters=path/to/file.safetensors \
--output=path/to/output.irpa
```

To strip constants and replace them with splats:

```bash
iree-convert-parameters \
--parameters=path/to/parameters.[safetensors,irpa] \
--strip \
--output=path/to/output.irpa
```
137 changes: 79 additions & 58 deletions iree_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,84 @@
import subprocess


# --------------------------------------------------------------------------- #
# pytest hooks
# https://docs.pytest.org/en/stable/reference/reference.html#initialization-hooks
# https://docs.pytest.org/en/stable/reference/reference.html#collection-hooks


def pytest_addoption(parser):
# List of configuration files following this schema:
# {
# "config_name": str,
# "iree_compile_flags": list of str,
# "iree_run_module_flags": list of str,
# "skip_compile_tests": list of str,
# "skip_run_tests": list of str,
# "expected_compile_failures": list of str,
# "expected_run_failures": list of str
# }
#
# For example, to test on CPU with the `llvm-cpu` backend and `local-task` device:
# {
# "config_name": "cpu_llvm_task",
# "iree_compile_flags": ["--iree-hal-target-backends=llvm-cpu"],
# "iree_run_module_flags": ["--device=local-task"],
# "skip_compile_tests": [],
# "skip_run_tests": [],
# "expected_compile_failures": ["test_abs"],
# "expected_run_failures": ["test_add"],
# }
#
# The list of files can be specified in (by order of preference):
# 1. The `--config-files` argument
# e.g. `pytest ... --config-files foo.json bar.json`
# 2. The `IREE_TEST_CONFIG_FILES` environment variable
# e.g. `set IREE_TEST_CONFIG_FILES=foo.json;bar.json`
# 3. A default config file used for testing the test suite itself
default_config_files = [
f for f in os.getenv("IREE_TEST_CONFIG_FILES", "").split(";") if f
]
if not default_config_files:
this_dir = Path(__file__).parent
repo_root = this_dir.parent
default_config_files = [
repo_root / "iree_tests/configs/config_cpu_llvm_sync.json",
# repo_root / "iree_tests/configs/config_gpu_vulkan.json",
]
parser.addoption(
"--config-files",
action="store",
nargs="*",
default=default_config_files,
help="List of config JSON files used to build test cases",
)


def pytest_sessionstart(session):
session.config.iree_test_configs = []
for config_file in session.config.getoption("config_files"):
with open(config_file) as f:
session.config.iree_test_configs.append(pyjson5.load(f))


def pytest_collect_file(parent, file_path):
if file_path.name.endswith(".mlir") or file_path.name.endswith(".mlirbc"):
return MlirFile.from_parent(parent, path=file_path)


# TODO(scotttodd): other hooks hook may help with updating XFAIL sets
#
# * load config file(s) and lists of tests requested
# `pytest_collection_finish`
# * after each test finishes, record result
# `pytest_runtest_logfinish`
# * after all tests finish, join results back into a config file
# `pytest_sessionfinish`
# * let the user accept the new config file in place of their original

# --------------------------------------------------------------------------- #

@dataclass(frozen=True)
class IreeCompileAndRunTestSpec:
"""Specification for an IREE "compile and run" test."""
Expand Down Expand Up @@ -56,11 +134,6 @@ class IreeCompileAndRunTestSpec:
skip_test: bool


def pytest_collect_file(parent, file_path):
if file_path.name.endswith(".mlir") or file_path.name.endswith(".mlirbc"):
return MlirFile.from_parent(parent, path=file_path)


class MlirFile(pytest.File):
"""Collector for MLIR files accompanied by input/output."""

Expand Down Expand Up @@ -134,8 +207,7 @@ def collect(self):
print(f"No test cases for '{test_name}'")
return []

global _iree_test_configs
for config in _iree_test_configs:
for config in self.config.iree_test_configs:
if test_name in config["skip_compile_tests"]:
continue

Expand Down Expand Up @@ -283,54 +355,3 @@ def __init__(
f"Run with:\n"
f" cd {cwd} && {' '.join(process.args)}\n\n"
)


# TODO(scotttodd): move this setup code into a (scoped) function?
# Is there some way to share state across pytest functions?

# Load a list of configuration files following this schema:
# {
# "config_name": str,
# "iree_compile_flags": list of str,
# "iree_run_module_flags": list of str,
# "skip_compile_tests": list of str,
# "skip_run_tests": list of str,
# "expected_compile_failures": list of str,
# "expected_run_failures": list of str
# }
#
# For example, to test the on CPU with the `llvm-cpu`` backend on the `local-task` device:
# {
# "config_name": "cpu",
# "iree_compile_flags": ["--iree-hal-target-backends=llvm-cpu"],
# "iree_run_module_flags": ["--device=local-task"],
# "skip_compile_tests": [],
# "skip_run_tests": [],
# "expected_compile_failures": ["test_abs"],
# "expected_run_failures": ["test_add"],
# }
#
# TODO(scotttodd): expand schema with more flexible include_tests/exclude_tests fields.
# * One use case is wanting to run only a small, controlled subset of tests, without needing to
# manually exclude any new tests that might be added in the future.
#
# First check for the `IREE_TEST_CONFIG_FILES` environment variable. If defined,
# this should point to a semicolon-delimited list of config file paths, e.g.
# `export IREE_TEST_CONFIG_FILES=/iree/config_cpu.json;/iree/config_gpu.json`.
_iree_test_configs = []
_iree_test_config_files = [
config for config in os.getenv("IREE_TEST_CONFIG_FILES", "").split(";") if config
]

# If no config files were specified via the environment variable, default to in-tree config files.
if not _iree_test_config_files:
THIS_DIR = Path(__file__).parent
REPO_ROOT = THIS_DIR.parent
_iree_test_config_files = [
REPO_ROOT / "iree_tests/configs/config_cpu_llvm_sync.json",
# REPO_ROOT / "iree_tests/configs/config_gpu_vulkan.json",
]

for config_file in _iree_test_config_files:
with open(config_file) as f:
_iree_test_configs.append(pyjson5.load(f))
11 changes: 10 additions & 1 deletion iree_tests/download_remote_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from azure.storage.blob import ContainerClient
from pathlib import Path
import argparse
import pyjson5

THIS_DIR = Path(__file__).parent
Expand Down Expand Up @@ -45,7 +46,15 @@ def download_for_test_case(test_dir: Path, test_case_json: dict):


if __name__ == "__main__":
for test_cases_path in Path(THIS_DIR).rglob("test_cases.json"):
parser = argparse.ArgumentParser(description="Remote file downloader.")
parser.add_argument(
"--root-dir",
default="",
help="Root directory to search for files to download from (e.g. 'pytorch/models/resnet50')",
)
args = parser.parse_args()

for test_cases_path in (THIS_DIR / args.root_dir).rglob("test_cases.json"):
print(f"Processing {test_cases_path.relative_to(THIS_DIR)}")

test_dir = test_cases_path.parent
Expand Down

0 comments on commit 131a51b

Please sign in to comment.