Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pytest collection to anchor on test case files, not .mlir files. #282

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 62 additions & 38 deletions iree_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ireers import *

IREE_TESTS_ROOT = Path(__file__).parent
TEST_DATA_FLAGFILE_NAME = "test_data_flags.txt"

# --------------------------------------------------------------------------- #
# pytest hooks
Expand Down Expand Up @@ -112,10 +113,14 @@ def check_field(field_name):


def pytest_collect_file(parent, file_path):
if not file_path.name.endswith("_spec.mlir") and (
file_path.name.endswith(".mlir") or file_path.name.endswith(".mlirbc")
):
return MlirFile.from_parent(parent, path=file_path)
if file_path.name == TEST_DATA_FLAGFILE_NAME:
return MlirCompileRunTest.from_parent(parent, path=file_path)

if file_path.suffix == ".json":
with open(file_path) as f:
test_cases_json = pyjson5.load(f)
if test_cases_json.get("file_format", "") == "test_cases_v0":
return MlirCompileRunTest.from_parent(parent, path=file_path)


# --------------------------------------------------------------------------- #
Expand Down Expand Up @@ -164,12 +169,13 @@ class IreeCompileAndRunTestSpec:
skip_test: bool


class MlirFile(pytest.File):
"""Collector for MLIR files accompanied by input/output."""
class MlirCompileRunTest(pytest.File):
"""Collector for MLIR -> compile -> run tests anchored on a file."""

@dataclass(frozen=True)
class TestCase:
name: str
mlir_file: str
runtime_flagfile: str
enabled: bool

Expand Down Expand Up @@ -206,45 +212,54 @@ def discover_test_cases(self):
skip_missing = self.config.getoption("skip_tests_missing_files")
have_lfs_files = self.check_for_lfs_files()

test_data_flagfile_name = "test_data_flags.txt"
if (self.path.parent / test_data_flagfile_name).exists():
mlir_files = sorted(self.path.parent.glob("*.mlir*"))
if len(mlir_files) == 0:
if not skip_missing:
raise FileNotFoundError(
f"Missing .mlir file for test {self.path.parent.name}"
)
return test_cases
assert len(mlir_files) <= 1, "Test directories may only contain one .mlir file"
mlir_file = mlir_files[0]

if self.path.name == TEST_DATA_FLAGFILE_NAME:
test_cases.append(
MlirFile.TestCase(
name="test",
runtime_flagfile=test_data_flagfile_name,
MlirCompileRunTest.TestCase(
name="",
mlir_file=mlir_file,
runtime_flagfile=TEST_DATA_FLAGFILE_NAME,
enabled=have_lfs_files,
)
)
Comment on lines +225 to 233
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could remove the need for this branch by generating boilerplate test_cases.json files in the ONNX test suite:

{
  "file_format": "test_cases_v0",
  "test_cases": [
    {
      "name": "",
      "runtime_flagfile": "test_data_flags.txt",
      "remote_files": []
    }
  ]
}

Trying to keep this conftest file somewhat simple and a bit concerned this is sliding the wrong way 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, would probably be best to have test_cases.json files for all the onnx tests too. That way all our tests can follow a unified format and the conftest won't get confusing. We should probably be anchoring on a test configuration file rather than the runtime flag file for onnx tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I can make that change too. How does a follow-up PR sound? That will require touching lots of files.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might also be able to replace test_cases.json with tests.py and then further simplify (or remove) the conftest.py file... 🤔 Mainly trying to keep the large test suites autogenerated, concise, and directly compatible with iree-compile -> iree-run-module.


for test_cases_path in self.path.parent.glob("*.json"):
with open(test_cases_path) as f:
elif self.path.suffix == ".json":
with open(self.path) as f:
test_cases_json = pyjson5.load(f)
if test_cases_json.get("file_format", "") != "test_cases_v0":
continue
for test_case_json in test_cases_json["test_cases"]:
test_case_name = test_case_json["name"]
have_remote_files = self.check_for_remote_files(test_case_json)
have_all_files = have_lfs_files and have_remote_files

if not skip_missing and not have_all_files:
raise FileNotFoundError(
f"Missing files for test {self.path.parent.name}::{test_case_name}"
if test_cases_json.get("file_format", "") == "test_cases_v0":
for test_case_json in test_cases_json["test_cases"]:
test_case_name = test_case_json["name"]
have_remote_files = self.check_for_remote_files(test_case_json)
have_all_files = have_lfs_files and have_remote_files

if not skip_missing and not have_all_files:
raise FileNotFoundError(
f"Missing files for test {self.path.parent.name}::{test_case_name}"
)
test_cases.append(
MlirCompileRunTest.TestCase(
name=test_case_name,
mlir_file=mlir_file,
runtime_flagfile=test_case_json["runtime_flagfile"],
enabled=have_all_files,
)
)
test_cases.append(
MlirFile.TestCase(
name=test_case_name,
runtime_flagfile=test_case_json["runtime_flagfile"],
enabled=have_all_files,
)
)

return test_cases

def collect(self):
# Expected directory structure:
# path/to/test_some_ml_operator/
# - *.mlir[bc]
# - test_data_flags.txt OR test_cases.json
# - test_data_flags.txt OR *.json using "test cases" schema
# path/to/test_some_ml_model/
# ...

Expand Down Expand Up @@ -274,14 +289,21 @@ def collect(self):
) or relative_test_directory in config.get("skip_run_tests", [])
config_name = config["config_name"]

# TODO(scotttodd): don't compile once per test case?
# try pytest-dependency or pytest-depends
for test_case in test_cases:
# Generate test item names like 'model.mlir::cpu_llvm_sync::splats'.
# These show up in pytest output.
mlir_file = test_case.mlir_file
name_parts = [
e for e in [mlir_file.name, config_name, test_case.name] if e
]
item_name = "::".join(name_parts)
# Generate simpler test names to use in filenames.
test_name = config_name + "_" + test_case.name

spec = IreeCompileAndRunTestSpec(
test_directory=test_directory,
input_mlir_name=self.path.name,
input_mlir_stem=self.path.stem,
input_mlir_name=mlir_file.name,
input_mlir_stem=mlir_file.stem,
data_flagfile_name=test_case.runtime_flagfile,
test_name=test_name,
iree_compile_flags=config["iree_compile_flags"],
Expand All @@ -291,7 +313,7 @@ def collect(self):
skip_run=skip_run,
skip_test=not test_case.enabled,
)
yield IreeCompileRunItem.from_parent(self, name=test_name, spec=spec)
yield IreeCompileRunItem.from_parent(self, name=item_name, spec=spec)


class IreeCompileRunItem(pytest.Item):
Expand Down Expand Up @@ -409,7 +431,9 @@ def repr_failure(self, excinfo):
return super().repr_failure(excinfo)

def reportinfo(self):
display_name = f"{self.path.parent.name}::{self.name}"
display_name = (
f"{self.path.parent.name}::{self.spec.input_mlir_name}::{self.name}"
)
return self.path, 0, f"IREE compile and run: {display_name}"

# Defining this for pytest-retry to avoid an AttributeError.
Expand Down