diff --git a/iree_tests/conftest.py b/iree_tests/conftest.py index 697ca5323..21dc9c78d 100644 --- a/iree_tests/conftest.py +++ b/iree_tests/conftest.py @@ -16,6 +16,7 @@ from ireers import * IREE_TESTS_ROOT = Path(__file__).parent +TEST_DATA_FLAGFILE_NAME = "test_data_flags.txt" # --------------------------------------------------------------------------- # # pytest hooks @@ -112,10 +113,14 @@ def check_field(field_name): def pytest_collect_file(parent, file_path): - if not file_path.name.endswith("_spec.mlir") and ( - file_path.name.endswith(".mlir") or file_path.name.endswith(".mlirbc") - ): - return MlirFile.from_parent(parent, path=file_path) + if file_path.name == TEST_DATA_FLAGFILE_NAME: + return MlirCompileRunTest.from_parent(parent, path=file_path) + + if file_path.suffix == ".json": + with open(file_path) as f: + test_cases_json = pyjson5.load(f) + if test_cases_json.get("file_format", "") == "test_cases_v0": + return MlirCompileRunTest.from_parent(parent, path=file_path) # --------------------------------------------------------------------------- # @@ -164,12 +169,13 @@ class IreeCompileAndRunTestSpec: skip_test: bool -class MlirFile(pytest.File): - """Collector for MLIR files accompanied by input/output.""" +class MlirCompileRunTest(pytest.File): + """Collector for MLIR -> compile -> run tests anchored on a file.""" @dataclass(frozen=True) class TestCase: name: str + mlir_file: str runtime_flagfile: str enabled: bool @@ -206,37 +212,46 @@ def discover_test_cases(self): skip_missing = self.config.getoption("skip_tests_missing_files") have_lfs_files = self.check_for_lfs_files() - test_data_flagfile_name = "test_data_flags.txt" - if (self.path.parent / test_data_flagfile_name).exists(): + mlir_files = sorted(self.path.parent.glob("*.mlir*")) + if len(mlir_files) == 0: + if not skip_missing: + raise FileNotFoundError( + f"Missing .mlir file for test {self.path.parent.name}" + ) + return test_cases + assert len(mlir_files) <= 1, "Test directories may only contain one .mlir file" + mlir_file = mlir_files[0] + + if self.path.name == TEST_DATA_FLAGFILE_NAME: test_cases.append( - MlirFile.TestCase( - name="test", - runtime_flagfile=test_data_flagfile_name, + MlirCompileRunTest.TestCase( + name="", + mlir_file=mlir_file, + runtime_flagfile=TEST_DATA_FLAGFILE_NAME, enabled=have_lfs_files, ) ) - - for test_cases_path in self.path.parent.glob("*.json"): - with open(test_cases_path) as f: + elif self.path.suffix == ".json": + with open(self.path) as f: test_cases_json = pyjson5.load(f) - if test_cases_json.get("file_format", "") != "test_cases_v0": - continue - for test_case_json in test_cases_json["test_cases"]: - test_case_name = test_case_json["name"] - have_remote_files = self.check_for_remote_files(test_case_json) - have_all_files = have_lfs_files and have_remote_files - - if not skip_missing and not have_all_files: - raise FileNotFoundError( - f"Missing files for test {self.path.parent.name}::{test_case_name}" + if test_cases_json.get("file_format", "") == "test_cases_v0": + for test_case_json in test_cases_json["test_cases"]: + test_case_name = test_case_json["name"] + have_remote_files = self.check_for_remote_files(test_case_json) + have_all_files = have_lfs_files and have_remote_files + + if not skip_missing and not have_all_files: + raise FileNotFoundError( + f"Missing files for test {self.path.parent.name}::{test_case_name}" + ) + test_cases.append( + MlirCompileRunTest.TestCase( + name=test_case_name, + mlir_file=mlir_file, + runtime_flagfile=test_case_json["runtime_flagfile"], + enabled=have_all_files, + ) ) - test_cases.append( - MlirFile.TestCase( - name=test_case_name, - runtime_flagfile=test_case_json["runtime_flagfile"], - enabled=have_all_files, - ) - ) return test_cases @@ -244,7 +259,7 @@ def collect(self): # Expected directory structure: # path/to/test_some_ml_operator/ # - *.mlir[bc] - # - test_data_flags.txt OR test_cases.json + # - test_data_flags.txt OR *.json using "test cases" schema # path/to/test_some_ml_model/ # ... @@ -274,14 +289,21 @@ def collect(self): ) or relative_test_directory in config.get("skip_run_tests", []) config_name = config["config_name"] - # TODO(scotttodd): don't compile once per test case? - # try pytest-dependency or pytest-depends for test_case in test_cases: + # Generate test item names like 'model.mlir::cpu_llvm_sync::splats'. + # These show up in pytest output. + mlir_file = test_case.mlir_file + name_parts = [ + e for e in [mlir_file.name, config_name, test_case.name] if e + ] + item_name = "::".join(name_parts) + # Generate simpler test names to use in filenames. test_name = config_name + "_" + test_case.name + spec = IreeCompileAndRunTestSpec( test_directory=test_directory, - input_mlir_name=self.path.name, - input_mlir_stem=self.path.stem, + input_mlir_name=mlir_file.name, + input_mlir_stem=mlir_file.stem, data_flagfile_name=test_case.runtime_flagfile, test_name=test_name, iree_compile_flags=config["iree_compile_flags"], @@ -291,7 +313,7 @@ def collect(self): skip_run=skip_run, skip_test=not test_case.enabled, ) - yield IreeCompileRunItem.from_parent(self, name=test_name, spec=spec) + yield IreeCompileRunItem.from_parent(self, name=item_name, spec=spec) class IreeCompileRunItem(pytest.Item): @@ -409,7 +431,9 @@ def repr_failure(self, excinfo): return super().repr_failure(excinfo) def reportinfo(self): - display_name = f"{self.path.parent.name}::{self.name}" + display_name = ( + f"{self.path.parent.name}::{self.spec.input_mlir_name}::{self.name}" + ) return self.path, 0, f"IREE compile and run: {display_name}" # Defining this for pytest-retry to avoid an AttributeError.