Skip to content

Commit

Permalink
Refactor how config and test_cases.json files are handled in iree_tes…
Browse files Browse the repository at this point in the history
…ts. (#164)

Assorted quality of life improvements to iree_tests before we scale up
to having more tests with real weights downloaded from Azure. These
changes should also help as we move towards having files cached on CI
test runners.

* Collapse/simplify how remote files are listed from
  ```json
        "remote_file_groups": [
          {
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
            "azure_container_name": "sharkpublic",
            "azure_base_blob_name": "sai/sd-unet-tank/",
            "files": [
              "inference_input.0.bin",
              "inference_input.1.bin",
  ```
  to
  ```json
        "remote_files": [

"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd-unet-tank/inference_input.0.bin",

"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd-unet-tank/inference_input.1.bin",
  ```
(That was premature optimization,
[KISS](https://en.wikipedia.org/wiki/KISS_principle), etc. etc.)
* Skip downloading remote files if they exist on disk with a matching
MD5 hash (requires the files to have MD5 hashes computed in Azure --
some files are missing them)
* Relax JSON parsing code to tolerate omitted fields (e.g. instead of
requiring `"skip_run_tests": [],`, you can now omit it)
* Add a `"file_format": "test_cases_v0",` marker to JSON files and look
for that instead of a specific file name
* Print file sizes while downloading, e.g.
  ```bash
  $ python download_remote_files.py

  Processing future-pytorch-models\llama-tank\test_cases.json
Downloading 'llama-tank.mlirbc' (128.44 MiB) to
'future-pytorch-models\llama-tank'
  Processing future-pytorch-models\sd-clip-tank\test_cases.json
    Skipping 'inference_input.0.bin' download (local MD5 hash matches)
    Skipping 'inference_output.0.bin' download (local MD5 hash matches)
    Skipping 'inference_output.1.bin' download (local MD5 hash matches)
Downloading 'real_weights.irpa' (469.46 MiB) to
'future-pytorch-models\sd-clip-tank'
  ```
  • Loading branch information
ScottTodd authored Apr 16, 2024
1 parent cbca7b9 commit 07d8938
Show file tree
Hide file tree
Showing 15 changed files with 268 additions and 224 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/test_iree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ jobs:
submodules: false
lfs: true

# not supported for Ubuntu 23.1
# - name: "Setting up Python"
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'

- name: "Setup Python venv"
run: python3 -m venv ${VENV_DIR}

Expand Down
2 changes: 0 additions & 2 deletions iree_tests/configs/config_onnx_cpu_llvm_sync.json
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@
"test_eyelike_with_dtype",
"test_eyelike_without_dtype",
"test_gathernd_example_int32_batch_dim1",
"test_gelu_tanh_1",
"test_gelu_tanh_2",
"test_globalmaxpool",
"test_globalmaxpool_precomputed",
"test_gridsample",
Expand Down
91 changes: 51 additions & 40 deletions iree_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,20 @@ def pytest_sessionstart(session):
session.config.iree_test_configs = []
for config_file in session.config.getoption("config_files"):
with open(config_file) as f:
session.config.iree_test_configs.append(pyjson5.load(f))
test_config = pyjson5.load(f)

# Sanity check the config file structure before going any further.
def check_field(field_name):
if field_name not in test_config:
raise ValueError(
f"config file '{config_file}' is missing a '{field_name}' field"
)

check_field("config_name")
check_field("iree_compile_flags")
check_field("iree_run_module_flags")

session.config.iree_test_configs.append(test_config)


def pytest_collect_file(parent, file_path):
Expand Down Expand Up @@ -169,20 +182,23 @@ def check_for_lfs_files(self):

def check_for_remote_files(self, test_case_json):
"""Checks if all remote_files in a JSON test case exist on disk."""
if "remote_files" not in test_case_json:
return True

have_all_files = True
for remote_file_group in test_case_json["remote_file_groups"]:
for remote_file in remote_file_group["files"]:
if not (self.path.parent / remote_file).exists():
test_case_name = test_case_json["name"]
print(
f"Missing file '{remote_file}' for test {self.path.parent.name}::{test_case_name}"
)
have_all_files = False
break
for remote_file_url in test_case_json["remote_files"]:
remote_file = remote_file_url.rsplit("/", 1)[-1]
if not (self.path.parent / remote_file).exists():
test_case_name = test_case_json["name"]
print(
f"Missing file '{remote_file}' for test {self.path.parent.name}::{test_case_name}"
)
have_all_files = False
break
return have_all_files

def discover_test_cases(self):
"""Discovers test cases in either test_data_flags.txt or test_cases.json."""
"""Discovers test cases in either test_data_flags.txt or *.json files."""
test_cases = []

have_lfs_files = self.check_for_lfs_files()
Expand All @@ -197,23 +213,21 @@ def discover_test_cases(self):
)
)

test_cases_name = "test_cases.json"
test_cases_path = self.path.parent / test_cases_name
if not test_cases_path.exists():
return test_cases

with open(test_cases_path) as f:
test_cases_json = pyjson5.load(f)
for test_case_json in test_cases_json["test_cases"]:
test_case_name = test_case_json["name"]
have_all_files = self.check_for_remote_files(test_case_json)
test_cases.append(
MlirFile.TestCase(
name=test_case_name,
runtime_flagfile=test_case_json["runtime_flagfile"],
enabled=have_lfs_files and have_all_files,
for test_cases_path in self.path.parent.glob("*.json"):
with open(test_cases_path) as f:
test_cases_json = pyjson5.load(f)
if test_cases_json.get("file_format", "") != "test_cases_v0":
continue
for test_case_json in test_cases_json["test_cases"]:
test_case_name = test_case_json["name"]
have_all_files = self.check_for_remote_files(test_case_json)
test_cases.append(
MlirFile.TestCase(
name=test_case_name,
runtime_flagfile=test_case_json["runtime_flagfile"],
enabled=have_lfs_files and have_all_files,
)
)
)

return test_cases

Expand All @@ -234,21 +248,18 @@ def collect(self):
return []

for config in self.config.iree_test_configs:
if test_name in config["skip_compile_tests"]:
if test_name in config.get("skip_compile_tests", []):
continue

expect_compile_success = (
self.config.getoption("ignore_xfails")
or test_name not in config["expected_compile_failures"]
)
expect_run_success = (
self.config.getoption("ignore_xfails")
or test_name not in config["expected_run_failures"]
)
skip_run = (
self.config.getoption("skip_all_runs")
or test_name in config["skip_run_tests"]
)
expect_compile_success = self.config.getoption(
"ignore_xfails"
) or test_name not in config.get("expected_compile_failures", [])
expect_run_success = self.config.getoption(
"ignore_xfails"
) or test_name not in config.get("expected_run_failures", [])
skip_run = self.config.getoption(
"skip_all_runs"
) or test_name in config.get("skip_run_tests", [])
config_name = config["config_name"]

# TODO(scotttodd): don't compile once per test case?
Expand Down
154 changes: 127 additions & 27 deletions iree_tests/download_remote_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,133 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from azure.storage.blob import ContainerClient
from azure.storage.blob import BlobClient, BlobProperties
from pathlib import Path
import argparse
import hashlib
import logging
import mmap
import pyjson5
import re

THIS_DIR = Path(__file__).parent
logger = logging.getLogger(__name__)


# TODO(scotttodd): multithread? async?
# TODO(scotttodd): skip download if already exists? check some metadata
def download_azure_remote_files(
test_dir: Path, container_client: ContainerClient, remote_file_group: dict
):
base_blob_name = remote_file_group["azure_base_blob_name"]
def human_readable_size(size, decimal_places=2):
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
if size < 1024.0 or unit == "PiB":
break
size /= 1024.0
return f"{size:.{decimal_places}f} {unit}"

for remote_file in remote_file_group["files"]:
print(f" Downloading {remote_file} to {test_dir.relative_to(THIS_DIR)}")
blob_name = base_blob_name + remote_file
dest = test_dir / remote_file

with open(dest, mode="wb") as local_blob:
download_stream = container_client.download_blob(
blob_name, max_concurrency=4
def get_remote_md5(remote_file: str, blob_properties: BlobProperties):
content_settings = blob_properties.get("content_settings")
if not content_settings:
return None
remote_md5 = content_settings.get("content_md5")
if not remote_md5:
logger.warning(
f" Remote file '{remote_file}' on Azure is missing the "
"'content_md5' property, can't check if local matches remote"
)
return remote_md5


def get_local_md5(local_file_path: Path):
if not local_file_path.exists() or local_file_path.stat().st_size == 0:
return None

with open(local_file_path) as file, mmap.mmap(
file.fileno(), 0, access=mmap.ACCESS_READ
) as file:
return hashlib.md5(file).digest()


def download_azure_remote_file(test_dir: Path, remote_file: str):
remote_file_name = remote_file.rsplit("/", 1)[-1]
relative_dir = test_dir.relative_to(THIS_DIR)

# Extract path components from Azure URL to use with the Azure Storage Blobs
# client library for Python (https://pypi.org/project/azure-storage-blob/).
#
# For example:
# https://sharkpublic.blob.core.windows.net/sharkpublic/path/to/blob.txt
# ^ ^
# account_url: https://sharkpublic.blob.core.windows.net
# container_name: sharkpublic
# blob_name: path/to/blob.txt
#
# Note: we could also use the generic handler (e.g. wget, 'requests'), but
# the client library offers other APIs.

result = re.search(r"(https.+\.net)/([^/]+)/(.+)", remote_file)
account_url = result.groups()[0]
container_name = result.groups()[1]
blob_name = result.groups()[2]

with BlobClient(
account_url,
container_name,
blob_name,
max_chunk_get_size=1024 * 1024 * 32, # 32 MiB
max_single_get_size=1024 * 1024 * 32, # 32 MiB
) as blob_client:
blob_properties = blob_client.get_blob_properties()
blob_size_str = human_readable_size(blob_properties.size)
remote_md5 = get_remote_md5(remote_file, blob_properties)

local_file_path = test_dir / remote_file_name
local_md5 = get_local_md5(local_file_path)

if remote_md5 and remote_md5 == local_md5:
logger.info(
f" Skipping '{remote_file_name}' download ({blob_size_str}) "
"- local MD5 hash matches"
)
return

if not local_md5:
logger.info(
f" Downloading '{remote_file_name}' ({blob_size_str}) "
f"to '{relative_dir}'"
)
else:
logger.info(
f" Downloading '{remote_file_name}' ({blob_size_str}) "
f"to '{relative_dir}' (local MD5 does not match)"
)

with open(local_file_path, mode="wb") as local_blob:
download_stream = blob_client.download_blob(max_concurrency=4)
local_blob.write(download_stream.readall())


def download_generic_remote_file(test_dir: Path, remote_file: str):
# TODO(scotttodd): use https://pypi.org/project/requests/
raise NotImplementedError("generic remote file downloads not implemented yet")


def download_for_test_case(test_dir: Path, test_case_json: dict):
for remote_file_group in test_case_json["remote_file_groups"]:
account_url = remote_file_group["azure_account_url"]
container_name = remote_file_group["azure_container_name"]
# This is naive (greedy, serial) for now. We could batch downloads that
# share a source:
# * Iterate over all files (across all included paths), building a list
# of files to download (checking hashes / local references before
# adding to the list)
# * (Optionally) Determine disk space needed/available and ask before
# continuing
# * Group files based on source (e.g. Azure container)
# * Start batched/parallel downloads

if "remote_files" not in test_case_json:
return

with ContainerClient(
account_url,
container_name,
max_chunk_get_size=1024 * 1024 * 32, # 32 MiB
max_single_get_size=1024 * 1024 * 32, # 32 MiB
) as container_client:
download_azure_remote_files(test_dir, container_client, remote_file_group)
for remote_file in test_case_json["remote_files"]:
if "blob.core.windows.net" in remote_file:
download_azure_remote_file(test_dir, remote_file)
else:
download_generic_remote_file(test_dir, remote_file)


if __name__ == "__main__":
Expand All @@ -54,11 +142,23 @@ def download_for_test_case(test_dir: Path, test_case_json: dict):
)
args = parser.parse_args()

for test_cases_path in (THIS_DIR / args.root_dir).rglob("test_cases.json"):
print(f"Processing {test_cases_path.relative_to(THIS_DIR)}")
# Adjust logging levels.
logging.basicConfig(level=logging.INFO)
for log_name, log_obj in logging.Logger.manager.loggerDict.items():
if log_name.startswith("azure"):
logging.getLogger(log_name).setLevel(logging.WARNING)

# TODO(scotttodd): build list of files _then_ download
# TODO(scotttodd): report size needed for requested files and size available on disk

test_dir = test_cases_path.parent
for test_cases_path in (THIS_DIR / args.root_dir).rglob("*.json"):
with open(test_cases_path) as f:
test_cases_json = pyjson5.load(f)
if test_cases_json.get("file_format", "") != "test_cases_v0":
continue

logger.info(f"Processing {test_cases_path.relative_to(THIS_DIR)}")

test_dir = test_cases_path.parent
for test_case_json in test_cases_json["test_cases"]:
download_for_test_case(test_dir, test_case_json)
2 changes: 2 additions & 0 deletions iree_tests/future-pytorch-models/llama-tank/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Not stored in Git (128MB file, too large for Git LFS), fetch on demand
llama-tank.mlirbc
12 changes: 3 additions & 9 deletions iree_tests/future-pytorch-models/llama-tank/test_cases.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
{
"file_format": "test_cases_v0",
"test_cases": [
{
"name": "splats",
"runtime_flagfile": "splat_data_flags.txt",
"remote_file_groups": [
{
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
"azure_container_name": "sharkpublic",
"azure_base_blob_name": "sai/llama-tank/",
"files": [
"llama-tank.mlirbc"
]
}
"remote_files": [
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/llama-tank/llama-tank.mlirbc"
]
}
]
Expand Down
20 changes: 7 additions & 13 deletions iree_tests/future-pytorch-models/sd-clip-tank/test_cases.json
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
{
"file_format": "test_cases_v0",
"test_cases": [
{
"name": "splats",
"runtime_flagfile": "splat_data_flags.txt",
"remote_file_groups": []
"remote_files": []
},
{
"name": "real_weights",
"runtime_flagfile": "real_weights_data_flags.txt",
"remote_file_groups": [
{
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
"azure_container_name": "sharkpublic",
"azure_base_blob_name": "sai/sd-clip-tank/",
"files": [
"inference_input.0.bin",
"inference_output.0.bin",
"inference_output.1.bin",
"real_weights.irpa"
]
}
"remote_files": [
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd-clip-tank/inference_input.0.bin",
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd-clip-tank/inference_output.0.bin",
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd-clip-tank/inference_output.1.bin",
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd-clip-tank/real_weights.irpa"
]
}
]
Expand Down
Loading

0 comments on commit 07d8938

Please sign in to comment.