Skip to content

Commit

Permalink
Add a design alternative for e2eshark (#214)
Browse files Browse the repository at this point in the history
The broad goal of this design rework is to make it easier to add
functionality and new tests to e2eshark. Towards this, there are
sub-goals:

1. break up the complexity of run.py into a more coherent file
structure.
2. remove the need to generate excessive boilerplate for model.py files,
and generally aim to make adding tests as simple as possible.
3. avoid running command-line scripts through python. This includes
transitioning to using python bindings for torch-mlir and iree, and
developing an import-friendly alternative to model.py files.
4. avoid concatenating python files (again, by building a better import
structure).
5. minimize redundancy in the log (test-run) directories. For example, I
believe the input and output tensors are actually stored three different
ways (as a .bin, .pt, and .pkl), and this should not need to be the
case.

Abstractly there seem to be four input pieces to manage:

1. The model
2. The choice of frontend importer (e.g. iree-import-onnx, torch_mlir's
fx importer, etc)
3. the choice of backend (e.g. compiling with iree from torch)
4. the input

Individual tests should just need to determine how to access or generate
the model and how to generate a sample input.
Command line arguments should determine which frontend/backend
combination is to be used during the test.

The desired output pieces seem to be:

1. input/output/golden_output tensors
2. mlir/vmfb files used to generate output
3. stage and config-specific debugging info for tests which fail

### Current high-priority TODOs:

1. look into using pytest or another testing framework.
2. generate command-line reproducers for ease of debugging individual
stages
3. come up with a better name for the directory than "alt_e2eshark".
4. identify functionality present in e2eshark that is currently lacking
and should be migrated over. E.g. Chi's comment about the --report flag.
5. add more type hints to make it easier to track down definitions.
6. add (much) more to the README about adding tests and running them
with various command line options.
7. look into removing torch-mlir as a python dependency entirely
(perhaps optionally, so torch-mlir developers can still use the test
suite for debugging op conversions).
8. Consider alternative options for storing cache directory paths, log
directory paths, onnx model paths, etc. OnnxModelInfo is currently a bit
clunky in needing to sometimes infer the log directory from the model
path, etc.
  • Loading branch information
zjgarvey authored Jul 25, 2024
1 parent f02ed80 commit 3a2dff8
Show file tree
Hide file tree
Showing 25 changed files with 1,953 additions and 0 deletions.
110 changes: 110 additions & 0 deletions alt_e2eshark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# e2eshark framework-to-iree-to-inference tests

This test suite enables developers to add small (operator level) to large (full model)
end-2-end tests that compare output of running a model in a Framework
(e.g. Pytorch, ONNX) to the output of running the IREE-compiled artefact of
the same model on a target backend (e.g. CPU, AIE). If the difference in outputs
is within a tolerable limit, then the test is reported as have passed, else the
test is reported as have failed. In case of a failing test, the stage of the
failure is reported.

The test suite is organized starting with a framework name: pytorch, tensorflow, onnx.
For each framework category, multiple modes are tested.

- pytorch : starting model is a pytorch model (planned for later)
- tensorflow : starting model is a tensorflow model (planned for later)
- onnx : starting model is an onnx model generated using onnx python API or an existing onnx model

The target backend can be any IREE supported backend: llvm-cpu, amd-aie etc.

## Contents
The contents are as below.
- e2e_testing/azutils.py : util functions for interfacing with azure
- e2e_testing/backends.py : where test backends are defined. Add other backends here.
- e2e_testing/framework.py : contains two types of classes: framework-specific base classes for storing model info, and generic classes for testing infrastructure.
- e2e_testing/onnx_utils.py : onnx related util functions. These either infer information from an onnx model or modify an onnx model.
- e2e_testing/registry.py : this contains the GLOBAL_TEST_REGISTRY, which gets updated when importing files with instances of `register_test(TestInfoClass, 'testname')`.
- e2e_testing/storage.py : contains helper functions and classes for managing the storage of tensors.
- e2e_testing/test_configs/onnxconfig.py : defines the onnx frontend test config. Other configs (e.g. pytorch, tensorflow) should be created in sibling files.
- onnx_tests/ : contains files that define OnnxModelInfo child classes, which customize model/input generation for various kinds of tests. Individual tests are also registered here together with their corresponding OnnxModelInfo child class.
- dev_requirements.txt : `pip install -r dev_requirements.txt` to install additional packages if you are using local builds of torch-mlir and iree.
- requirements.txt : `pip install -r requirements.txt` to install packages for getting started immediately. This is mostly useful if you aren't trying to test local builds of IREE or torch-mlir.
- run.py : Run `python run.py --help` to learn about the script. This is the script to run tests.

The logs are created as .log files in the test-run sub directory. Examine the logs to find and fix
cause of any failure. You can specify -r 'your dir name' to the run.py to name your test run directory
as per your choice. The default name for the run directory is 'test-run'.

Note that, you will be required to pass --cachedir argument to the run.py to point to a directory where
model weights etc. from external model serving repositories such as from Torch Vision, Hugging Face etc.
will be downloaded. The downloaded data can be large, so set it to other than your home,
preferably with 100 GB or more free space.

## Setting up (Quick Start)

By default, a nightly build of torch_mlir and IREE is installed when you run:

```bash
python -m venv test_suite.venv /
source test_suite.venv/bin/activate /
pip install --upgrade pip /
pip install -r ./requirements.txt
```

Therefore, you are not required to have a local build of either torch mlir or iree.

## Setting up (using local build of torch-mlir or iree)

If you want to use a custom build of torch-mlir or iree, you need to build those projects with python bindings enabled.

If you already installed `requirements.txt` to your venv, you can uninstall whatever package you want to replace, then activate the appropriate `.env` file for the project you want to use. For example,

```bash
# if starting with dev_requirements, this line is uneccessary:
pip uninstall iree-compiler iree-runtime
# set up python to find iree compiler and iree runtime
source /path/to/iree-build/.env && export PYTHONPATH
```

If you installed `dev_requirements.txt`, you won't need to uninstall iree-compiler, iree-runtime, or torch-mlir, since these aren't included there.

Unfortunately, the `.env` files in torch-mlir and iree completely replace the pythonpath instead of adding to it. So if you want to use a local build of both torch-mlir and iree, you could do something like:

```bash
export IREE_BUILD_DIR="<path to iree build dir>"
export TORCH_MLIR_BUILD_DIR="<path to torch-mlir build dir>"
source ${IREE_BUILD_DIR}/.env && export PYTHONPATH="${TORCH_MLIR_BUILD_DIR}/tools/torch-mlir/python_packages/torch_mlir/:${PYTHONPATH}"
```

If you are just a torch-mlir developer and don't want a custom IREE build, you can either make an `.env` file for torch-mlir with `torch-mlir/build_tools/write_env_file.sh`and use that to set your python path, or just use:

```bash
export PYTHONPATH="${TORCH_MLIR_BUILD_DIR}/tools/torch-mlir/python_packages/torch_mlir/"
```

## Adding a test

For onnx framework tests, you add a test in one of the model.py files contained in '/e2eshark/onnx_tests/'.

The OnnxModelInfo class simply requires that you define a function called "construct_model", which should define how to build the model.onnx file (be sure that the model.onnx file gets saved to your class' self.model, which should store the filepath to the model).

We provide a conveninece function for generating inputs by default, but to override this for an individual test, you can redefine "construct_inputs" for your test class.

Once a test class is generated, register the test with the test suite with:

```python
register_test(YourTestClassName,"name_of_test")
```

## Running a test

Here is an example of running the test we made in the previous section with some commonly used flags:

```bash
python run.py --torchtolinalg --cachedir="../cache_dir" -t name_of_test
```

This will generate a new folder './test-run/name_of_test/' which contains some artifacts generated during the test. These artifacts can be used to run command line scripts to debug various failures.



14 changes: 14 additions & 0 deletions alt_e2eshark/dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-r https://raw.githubusercontent.com/llvm/torch-mlir/main/requirements.txt
-r https://raw.githubusercontent.com/llvm/torch-mlir/main/torchvision-requirements.txt
tabulate
simplejson
ml_dtypes
onnx
onnxruntime
transformers
huggingface-hub
sentencepiece
accelerate
auto-gptq
optimum
azure-storage-blob
Empty file.
113 changes: 113 additions & 0 deletions alt_e2eshark/e2e_testing/azutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os, sys, argparse, shutil, zipfile
from azure.storage.blob import ContainerClient
from azure.core.exceptions import ResourceNotFoundError
from pathlib import Path
from zipfile import ZipFile

PRIVATE_CONN_STRING = os.environ.get("AZ_PRIVATE_CONNECTION", default="")
priv_container_name = "onnxprivatestorage"


def pre_test_onnx_model_azure_download(name, cache_dir, model_path):
# This util helps setting up the e2eshark/onnx/models tests by ensuring
# all the models-tests in the testsList have the required model.onnx file
# testsList: expected to contain only onnx tests

# if cache directory doesn't exist, then make it
if not os.path.exists(cache_dir):
os.mkdir(cache_dir)

# for the testList download all the onnx/models in cache_path
download_and_setup_onnxmodel(cache_dir, name)

model_dir = model_path.rstrip("model.onnx")
# if the the model exists for the test in the test dir, do nothing.
# if it doesn't exist in the test directory but exists in cache dir, simply unzip cached model
dest_file = cache_dir + "model.onnx.zip"
print(f"Unzipping - {dest_file}...","\t")
# model_file_path_cache may not exist for models which were not correctly downloaded,
# skip unzipping such model files, only extract existing models
if os.path.exists(dest_file):
# Unzip the model in the model test dir
with ZipFile(dest_file, "r") as zf:
# onnx/model/testname already present in the zip file structure
zf.extractall(model_dir)
print(f'Unzipping succeded. Look for extracted contents in {model_dir}')
else:
print(f'Failed: path {dest_file} does not exist!')


def download_and_setup_onnxmodel(cache_dir, name):
# Utility to download specified models (zip files) to cache dir
# Download failure should not stop tests running entirely.
# So downloads will be allowed to fail and corressponding
# tests will fail with No model.onnx file found error

# Azure Storage Creds for Public Onnx Models
account_url = "https://onnxstorage.blob.core.windows.net"
container_name = "onnxstorage"

# Azure Storage Creds for Private Onnx Models - AZURE Login Required for access
priv_account_url = "https://onnxprivatestorage.blob.core.windows.net"
priv_container_name = "onnxprivatestorage"

blob_dir = "e2eshark/onnx/models/" + name
blob_name = blob_dir + "/model.onnx.zip"
dest_file = cache_dir + "model.onnx.zip"
if os.path.exists(dest_file):
# model already in cache dir, skip download.
# TODO: skip model downloading based on some comparison / update flag
return
# TODO: better organisation of models in tank and cache
print(f"Begin download for {blob_name} to {dest_file}")

try_private = False
try:
download_azure_blob(account_url, container_name, blob_name, dest_file)
except Exception as e:
try_private = True
print(
f"Unable to download model from public for {name}.\nError - {type(e).__name__}"
)

if try_private:
print("Trying download from private storage")
try:
download_azure_blob(
priv_account_url, priv_container_name, blob_name, dest_file
)
except Exception as e:
print(f"Unable to download model for {name}.\nError - {type(e).__name__}")


def download_azure_blob(account_url, container_name, blob_name, dest_file):
if container_name == priv_container_name:
if PRIVATE_CONN_STRING == "":
print(
"Please set AZ_PRIVATE_CONNECTION environment variable with connection string for private azure storage account"
)
with ContainerClient.from_connection_string(
conn_str=PRIVATE_CONN_STRING,
container_name=container_name,
) as container_client:
download_stream = container_client.download_blob(blob_name)
with open(dest_file, mode="wb") as local_blob:
local_blob.write(download_stream.readall())
else:
with ContainerClient(
account_url,
container_name,
max_chunk_get_size=1024 * 1024 * 32, # 32 MiB
max_single_get_size=1024 * 1024 * 32, # 32 MiB
) as container_client:
download_stream = container_client.download_blob(
blob_name, max_concurrency=4
)
with open(dest_file, mode="wb") as local_blob:
local_blob.write(download_stream.readall())
66 changes: 66 additions & 0 deletions alt_e2eshark/e2e_testing/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import abc
from typing import TypeVar
from e2e_testing.storage import TestTensors

CompiledArtifact = TypeVar("CompiledArtifact")
Invoker = TypeVar("Invoker")

# This file should contain customizations for how to compile mlir from various entrypoints


class BackendBase(abc.ABC):

@abc.abstractmethod
def compile(self, module) -> CompiledArtifact:
"""specifies how to compile an MLIR Module"""

@abc.abstractmethod
def load(self, artifact: CompiledArtifact, func_name: str) -> Invoker:
"""loads the function with name func_name from compiled artifact. This method should return a function callable from python."""


from iree import compiler as ireec
from iree import runtime as ireert


class SimpleIREEBackend(BackendBase):
'''This backend uses iree to compile and run MLIR modules for a specified hal_target_backend'''
def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu"):
self.device = device
self.hal_target_backend = hal_target_backend

def compile(self, module, *, save_to: str = None):
# compile to a vmfb for llvm-cpu
b = ireec.tools.compile_str(
str(module),
target_backends=[self.hal_target_backend],
extra_args=["--iree-input-demote-i64-to-i32"],
)
# log the vmfb
if save_to:
with open(save_to + "compiled_model.vmfb", "wb") as f:
f.write(b)
return b

def load(self, artifact, *, func_name="main"):
config = ireert.Config(self.device)
ctx = ireert.SystemContext(config=config)
vm_module = ireert.VmModule.copy_buffer(ctx.instance, artifact)
ctx.add_vm_module(vm_module)

def func(x):
x = x.data
device_array = ctx.modules.module[func_name](*x)
if isinstance(device_array, tuple):
np_array = []
for d in device_array:
np_array.append(d.to_host())
return TestTensors(np_array)
return TestTensors((device_array.to_host(),))

return func
Loading

0 comments on commit 3a2dff8

Please sign in to comment.