Skip to content

Commit

Permalink
Support for sdxl pipeline (testing) (#152)
Browse files Browse the repository at this point in the history
This commit adds support for testing all sdxl submodels
  • Loading branch information
saienduri authored Apr 4, 2024
1 parent 763f9d1 commit b8e771c
Show file tree
Hide file tree
Showing 36 changed files with 185 additions and 4 deletions.
42 changes: 38 additions & 4 deletions .github/workflows/test_iree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ concurrency:
cancel-in-progress: true

jobs:
linux_x86_64:
name: Linux (x86_64)
linux_x86_64_onnx:
name: Linux (x86_64) Onnx
runs-on: ubuntu-latest
env:
VENV_DIR: ${{ github.workspace }}/.venv
Expand Down Expand Up @@ -67,12 +67,46 @@ jobs:
source ${VENV_DIR}/bin/activate
pytest iree_tests/onnx/node/generated -n auto -rpfE --timeout=30 --retries 2 --retry-delay 5 --durations=10
linux_x86_64_w7900_gpu_models:
name: Linux (x86_64 w7900) Models GPU
runs-on: nodai-amdgpu-w7900-x86-64
env:
VENV_DIR: ${{ github.workspace }}/.venv
steps:
- name: "Checking out repository"
uses: actions/checkout@v4
with:
submodules: false
lfs: true
- name: "Setting up Python"
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: "Setup Python venv"
run: python3 -m venv ${VENV_DIR}

- name: "Installing IREE nightly release Python packages"
run: |
source ${VENV_DIR}/bin/activate
python3 -m pip install \
--find-links https://iree.dev/pip-release-links.html \
--upgrade \
iree-compiler \
iree-runtime
- name: "Installing other Python requirements"
run: |
source ${VENV_DIR}/bin/activate
python3 -m pip install -r iree_tests/requirements.txt
# TODO(scotttodd): add a local cache for these large files to a persistent runner
- name: "Downloading remote files for real weight model tests"
run: |
source ${VENV_DIR}/bin/activate
python3 iree_tests/download_remote_files.py
python3 iree_tests/download_remote_files.py --root-dir pytorch/models
- name: "Running real weight model tests"
env:
IREE_TEST_CONFIG_FILES: iree_tests/configs/config_pytorch_models_cpu_llvm_task.json
run: |
source ${VENV_DIR}/bin/activate
pytest iree_tests -n auto -k real_weights -rpfE --timeout=600 --retries 2 --retry-delay 5 --durations=0
pytest iree_tests/pytorch/models -s -n 4 -k real_weights -rpfE --timeout=1200 --retries 2 --retry-delay 5 --durations=0
15 changes: 15 additions & 0 deletions iree_tests/configs/config_pytorch_models_cpu_llvm_task.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"config_name": "sdxl_cpu_llvm_task",
"iree_compile_flags" : [
"--iree-hal-target-backends=llvm-cpu",
"--iree-llvmcpu-target-cpu-features=host",
"--iree-llvmcpu-distribution-size=32"
],
"iree_run_module_flags": [
"--device=local-task"
],
"skip_compile_tests": [],
"skip_run_tests": [],
"expected_compile_failures": [],
"expected_run_failures": []
}
Git LFS file not shown
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
--parameters=model=real_weights.irpa
--input=1x64xi64=@inference_input.0.bin
--input=1x64xi64=@inference_input.1.bin
--input=1x64xi64=@inference_input.2.bin
--input=1x64xi64=@inference_input.3.bin
--expected_output=2x64x2048xf16=@inference_output.0.bin
--expected_output=2x1280xf16=@inference_output.1.bin
--expected_f16_threshold=1.0f
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
--input="1x64xi64"
--input="1x64xi64"
--input="1x64xi64"
--input="1x64xi64"
--parameters=splats.irpa
Git LFS file not shown
29 changes: 29 additions & 0 deletions iree_tests/pytorch/models/sdxl-prompt-encoder-tank/test_cases.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"test_cases": [
{
"name": "splats",
"runtime_flagfile": "splat_data_flags.txt",
"remote_file_groups": []
},
{
"name": "real_weights",
"runtime_flagfile": "real_weights_data_flags.txt",
"remote_file_groups": [
{
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
"azure_container_name": "sharkpublic",
"azure_base_blob_name": "sai/sdxl-prompt-encoder/",
"files": [
"inference_input.0.bin",
"inference_input.1.bin",
"inference_input.2.bin",
"inference_input.3.bin",
"inference_output.0.bin",
"inference_output.1.bin",
"real_weights.irpa"
]
}
]
}
]
}
Git LFS file not shown
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
--parameters=model=real_weights.irpa
--module=sdxl_scheduled_unet_pipeline_fp16_.vmfb
--input=1x4x128x128xf16=@inference_input.0.bin
--input=2x64x2048xf16=@inference_input.1.bin
--input=2x1280xf16=@inference_input.2.bin
--input=1xf16=@inference_input.3.bin
--expected_output=1x4x128x128xf16=@inference_output.0.bin
--expected_f16_threshold=0.8f
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
--input="1x4x128x128xf16"
--input="2x64x2048xf16"
--input="2x1280xf16"
--input="1xf16"
--parameters=splats.irpa
Git LFS file not shown
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"test_cases": [
{
"name": "splats",
"runtime_flagfile": "splat_data_flags.txt",
"remote_file_groups": []
},
{
"name": "real_weights",
"runtime_flagfile": "real_weights_data_flags.txt",
"remote_file_groups": [
{
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
"azure_container_name": "sharkpublic",
"azure_base_blob_name": "sai/sdxl-scheduled-unet/",
"files": [
"inference_input.0.bin",
"inference_input.1.bin",
"inference_input.2.bin",
"inference_input.3.bin",
"inference_output.0.bin",
"real_weights.irpa"
]
}
]
}
]
}
3 changes: 3 additions & 0 deletions iree_tests/pytorch/models/sdxl-vae-decode-tank/model.mlirbc
Git LFS file not shown
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
--parameters=model=real_weights.irpa
--input=1x4x128x128xf16=@inference_input.0.bin
--expected_output=1x3x1024x1024xf16=@inference_output.0.bin
--expected_f16_threshold=0.02f
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
--input="1x4x128x128xf16"
--parameters=splats.irpa
3 changes: 3 additions & 0 deletions iree_tests/pytorch/models/sdxl-vae-decode-tank/splats.irpa
Git LFS file not shown
25 changes: 25 additions & 0 deletions iree_tests/pytorch/models/sdxl-vae-decode-tank/test_cases.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"test_cases": [
{
"name": "splats",
"runtime_flagfile": "splat_data_flags.txt",
"remote_file_groups": []
},
{
"name": "real_weights",
"runtime_flagfile": "real_weights_data_flags.txt",
"remote_file_groups": [
{
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
"azure_container_name": "sharkpublic",
"azure_base_blob_name": "sai/sdxl-vae-decode/",
"files": [
"inference_input.0.bin",
"inference_output.0.bin",
"real_weights.irpa"
]
}
]
}
]
}

0 comments on commit b8e771c

Please sign in to comment.