Skip to content

Commit

Permalink
remove previous support
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri committed Apr 4, 2024
1 parent 2fc3f62 commit 9eb7a62
Showing 1 changed file with 7 additions and 22 deletions.
29 changes: 7 additions & 22 deletions iree_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,6 @@ def __init__(self, spec, **kwargs):

# TODO(scotttodd): swap cwd for a temp path?
self.test_cwd = self.spec.test_directory
vae_decode_path = os.path.dirname(os.path.dirname(self.test_cwd)) + "/pytorch/models/sdxl-vae-decode-tank"
scheduled_unet_path = os.path.dirname(os.path.dirname(self.test_cwd)) + "/pytorch/models/sdxl-scheduled-unet-3-tank"
prompt_encoder_path = os.path.dirname(os.path.dirname(self.test_cwd)) + "/pytorch/models/sdxl-prompt-encoder-tank"
vmfb_name = f"{self.spec.input_mlir_stem}_{self.spec.test_name}.vmfb"

self.compile_args = ["iree-compile", self.spec.input_mlir_name]
Expand All @@ -295,8 +292,6 @@ def __init__(self, spec, **kwargs):
self.run_args.extend(self.spec.iree_run_module_flags)
self.run_args.append(f"--flagfile={self.spec.data_flagfile_name}")

self.benchmark_args = ["iree-benchmark-module", "--device=local-task", f"--module={prompt_encoder_path}/model_sdxl_cpu_llvm_task_real_weights.vmfb", f"--parameters=model={prompt_encoder_path}/real_weights.irpa", f"--module={scheduled_unet_path}/model_sdxl_cpu_llvm_task_real_weights.vmfb", f"--parameters=model={scheduled_unet_path}/real_weights.irpa", f"--module={vae_decode_path}/model_sdxl_cpu_llvm_task_real_weights.vmfb", f"--parameters=model={vae_decode_path}/real_weights.irpa", f"--module={vmfb_name}", "--function=tokens_to_image", "--input=1x4x128x128xf16", "--input=1xf16", "--input=1x64xi64", "--input=1x64xi64", "--input=1x64xi64", "--input=1x64xi64"]

def runtest(self):
if self.spec.skip_test:
pytest.skip()
Expand Down Expand Up @@ -347,16 +342,13 @@ def runtest(self):

if self.spec.skip_run:
return

if self.spec.test_directory.name == "sdxl-benchmark":
self.test_benchmark()
else:
try:
self.test_run()
except IreeRunException as e:
if not self.spec.expect_compile_success:
raise IreeXFailCompileRunException from e
raise e

try:
self.test_run()
except IreeRunException as e:
if not self.spec.expect_compile_success:
raise IreeXFailCompileRunException from e
raise e

def test_compile(self):
proc = subprocess.run(self.compile_args, capture_output=True, cwd=self.test_cwd)
Expand All @@ -367,13 +359,6 @@ def test_run(self):
proc = subprocess.run(self.run_args, capture_output=True, cwd=self.test_cwd)
if proc.returncode != 0:
raise IreeRunException(proc, self.test_cwd, self.compile_args)

def test_benchmark(self):
proc = subprocess.run(self.benchmark_args, capture_output=True, cwd=self.test_cwd)
if proc.returncode != 0:
raise IreeRunException(proc, self.test_cwd, self.compile_args)
outs = proc.stdout.decode("utf-8")
print(f"Stdout benchmark:\n{outs}\n")

def repr_failure(self, excinfo):
"""Called when self.runtest() raises an exception."""
Expand Down

0 comments on commit 9eb7a62

Please sign in to comment.