Skip to content

Commit

Permalink
add llama + sd models to iree test suite (#126)
Browse files Browse the repository at this point in the history
This commit plugs in turbine tank exported mlir to the iree testing
framework for llama + sd models.
Couldn't add real weight test for llama. Real weight file is 20GB and
crashing the runner when trying.
On that note, looks like splat test cases aren't running at the moment?
Maybe we can enable the real weight test after we get the cluster.
Also, just in case it's useful for anyone adding models in future, for
the exported model mlir, I had to update real weight flag to
`--parameters=model=real_weights.irpa` to get it working.
  • Loading branch information
saienduri authored Mar 27, 2024
2 parents 28091b0 + d62d79d commit d097807
Show file tree
Hide file tree
Showing 19 changed files with 196 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/test_iree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ jobs:
- name: "Running real weight model tests"
run: |
source ${VENV_DIR}/bin/activate
pytest iree_tests -n auto -k real_weights -rpfE --timeout=60 --retries 2 --retry-delay 5 --durations=0
pytest iree_tests -n auto -k real_weights -rpfE --timeout=600 --retries 2 --retry-delay 5 --durations=0
67 changes: 67 additions & 0 deletions iree_tests/pytorch/models/llama-tank/real_weights_data_flags.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
--parameters=real_weights.irpa
--input=1x93xi64=@inference_input.0.bin
--expected_output=1x93x32000xf32=@inference_output.0.bin
--expected_output=1x32x93x128xf32=@inference_output.1.bin
--expected_output=1x32x93x128xf32=@inference_output.2.bin
--expected_output=1x32x93x128xf32=@inference_output.3.bin
--expected_output=1x32x93x128xf32=@inference_output.4.bin
--expected_output=1x32x93x128xf32=@inference_output.5.bin
--expected_output=1x32x93x128xf32=@inference_output.6.bin
--expected_output=1x32x93x128xf32=@inference_output.7.bin
--expected_output=1x32x93x128xf32=@inference_output.8.bin
--expected_output=1x32x93x128xf32=@inference_output.9.bin
--expected_output=1x32x93x128xf32=@inference_output.10.bin
--expected_output=1x32x93x128xf32=@inference_output.11.bin
--expected_output=1x32x93x128xf32=@inference_output.12.bin
--expected_output=1x32x93x128xf32=@inference_output.13.bin
--expected_output=1x32x93x128xf32=@inference_output.14.bin
--expected_output=1x32x93x128xf32=@inference_output.15.bin
--expected_output=1x32x93x128xf32=@inference_output.16.bin
--expected_output=1x32x93x128xf32=@inference_output.17.bin
--expected_output=1x32x93x128xf32=@inference_output.18.bin
--expected_output=1x32x93x128xf32=@inference_output.19.bin
--expected_output=1x32x93x128xf32=@inference_output.20.bin
--expected_output=1x32x93x128xf32=@inference_output.21.bin
--expected_output=1x32x93x128xf32=@inference_output.22.bin
--expected_output=1x32x93x128xf32=@inference_output.23.bin
--expected_output=1x32x93x128xf32=@inference_output.24.bin
--expected_output=1x32x93x128xf32=@inference_output.25.bin
--expected_output=1x32x93x128xf32=@inference_output.26.bin
--expected_output=1x32x93x128xf32=@inference_output.27.bin
--expected_output=1x32x93x128xf32=@inference_output.28.bin
--expected_output=1x32x93x128xf32=@inference_output.29.bin
--expected_output=1x32x93x128xf32=@inference_output.30.bin
--expected_output=1x32x93x128xf32=@inference_output.31.bin
--expected_output=1x32x93x128xf32=@inference_output.32.bin
--expected_output=1x32x93x128xf32=@inference_output.33.bin
--expected_output=1x32x93x128xf32=@inference_output.34.bin
--expected_output=1x32x93x128xf32=@inference_output.35.bin
--expected_output=1x32x93x128xf32=@inference_output.36.bin
--expected_output=1x32x93x128xf32=@inference_output.37.bin
--expected_output=1x32x93x128xf32=@inference_output.38.bin
--expected_output=1x32x93x128xf32=@inference_output.39.bin
--expected_output=1x32x93x128xf32=@inference_output.40.bin
--expected_output=1x32x93x128xf32=@inference_output.41.bin
--expected_output=1x32x93x128xf32=@inference_output.42.bin
--expected_output=1x32x93x128xf32=@inference_output.43.bin
--expected_output=1x32x93x128xf32=@inference_output.44.bin
--expected_output=1x32x93x128xf32=@inference_output.45.bin
--expected_output=1x32x93x128xf32=@inference_output.46.bin
--expected_output=1x32x93x128xf32=@inference_output.47.bin
--expected_output=1x32x93x128xf32=@inference_output.48.bin
--expected_output=1x32x93x128xf32=@inference_output.49.bin
--expected_output=1x32x93x128xf32=@inference_output.50.bin
--expected_output=1x32x93x128xf32=@inference_output.51.bin
--expected_output=1x32x93x128xf32=@inference_output.52.bin
--expected_output=1x32x93x128xf32=@inference_output.53.bin
--expected_output=1x32x93x128xf32=@inference_output.54.bin
--expected_output=1x32x93x128xf32=@inference_output.55.bin
--expected_output=1x32x93x128xf32=@inference_output.56.bin
--expected_output=1x32x93x128xf32=@inference_output.57.bin
--expected_output=1x32x93x128xf32=@inference_output.58.bin
--expected_output=1x32x93x128xf32=@inference_output.59.bin
--expected_output=1x32x93x128xf32=@inference_output.60.bin
--expected_output=1x32x93x128xf32=@inference_output.61.bin
--expected_output=1x32x93x128xf32=@inference_output.62.bin
--expected_output=1x32x93x128xf32=@inference_output.63.bin
--expected_output=1x32x93x128xf32=@inference_output.64.bin
2 changes: 2 additions & 0 deletions iree_tests/pytorch/models/llama-tank/splat_data_flags.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
--input="1x93xi64"
--parameters=splats.irpa
Binary file not shown.
18 changes: 18 additions & 0 deletions iree_tests/pytorch/models/llama-tank/test_cases.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"test_cases": [
{
"name": "splats",
"runtime_flagfile": "splat_data_flags.txt",
"remote_file_groups": [
{
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
"azure_container_name": "sharkpublic",
"azure_base_blob_name": "sai/llama-tank/",
"files": [
"llama-tank.mlirbc"
]
}
]
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
--parameters=model=real_weights.irpa
--input=1x77xi64=@inference_input.0.bin
--expected_output=1x77x768xf32=@inference_output.0.bin
--expected_output=1x768xf32=@inference_output.1.bin
Binary file not shown.
2 changes: 2 additions & 0 deletions iree_tests/pytorch/models/sd-clip-tank/splat_data_flags.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
--input="1x77xi64"
--parameters=splats.irpa
Binary file not shown.
26 changes: 26 additions & 0 deletions iree_tests/pytorch/models/sd-clip-tank/test_cases.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"test_cases": [
{
"name": "splats",
"runtime_flagfile": "splat_data_flags.txt",
"remote_file_groups": []
},
{
"name": "real_weights",
"runtime_flagfile": "real_weights_data_flags.txt",
"remote_file_groups": [
{
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
"azure_container_name": "sharkpublic",
"azure_base_blob_name": "sai/sd-clip-tank/",
"files": [
"inference_input.0.bin",
"inference_output.0.bin",
"inference_output.1.bin",
"real_weights.irpa"
]
}
]
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
--parameters=model=real_weights.irpa
--input=1x4x64x64xf32=@inference_input.0.bin
--input=1xf32=@inference_input.1.bin
--input=2x77x768xf32=@inference_input.2.bin
--expected_output=1x4x64x64xf32=@inference_output.0.bin
Binary file not shown.
4 changes: 4 additions & 0 deletions iree_tests/pytorch/models/sd-unet-tank/splat_data_flags.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
--input="1x4x64x64xf32"
--input="1xf32"
--input="2x77x768xf32"
--parameters=splats.irpa
Binary file not shown.
27 changes: 27 additions & 0 deletions iree_tests/pytorch/models/sd-unet-tank/test_cases.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"test_cases": [
{
"name": "splats",
"runtime_flagfile": "splat_data_flags.txt",
"remote_file_groups": []
},
{
"name": "real_weights",
"runtime_flagfile": "real_weights_data_flags.txt",
"remote_file_groups": [
{
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
"azure_container_name": "sharkpublic",
"azure_base_blob_name": "sai/sd-unet-tank/",
"files": [
"inference_input.0.bin",
"inference_input.1.bin",
"inference_input.2.bin",
"inference_output.0.bin",
"real_weights.irpa"
]
}
]
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
--parameters=model=real_weights.irpa
--input=1x4x64x64xf32=@inference_input.0.bin
--expected_output=1x3x512x512xf32=@inference_output.0.bin
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
--input="1x4x64x64xf32"
--parameters=splats.irpa
Binary file not shown.
35 changes: 35 additions & 0 deletions iree_tests/pytorch/models/sd-vae-decode-tank/test_cases.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"test_cases": [
{
"name": "splats",
"runtime_flagfile": "splat_data_flags.txt",
"remote_file_groups": [
{
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
"azure_container_name": "sharkpublic",
"azure_base_blob_name": "sai/sd-vae-decode-tank/",
"files": [
"sd-vae-decode-tank.mlirbc"
]
}
]
},
{
"name": "real_weights",
"runtime_flagfile": "real_weights_data_flags.txt",
"remote_file_groups": [
{
"azure_account_url": "https://sharkpublic.blob.core.windows.net",
"azure_container_name": "sharkpublic",
"azure_base_blob_name": "sai/sd-vae-decode-tank/",
"files": [
"inference_input.0.bin",
"inference_output.0.bin",
"real_weights.irpa",
"sd-vae-decode-tank.mlirbc"
]
}
]
}
]
}

0 comments on commit d097807

Please sign in to comment.