Skip to content

Commit

Permalink
[alt] add basic optimizations to some tests (#364)
Browse files Browse the repository at this point in the history
Adds basic optimizations to tests failing due to
<nod-ai/SHARK-ModelDev#813> and
<nod-ai/SHARK-ModelDev#825>
  • Loading branch information
zjgarvey authored Oct 7, 2024
1 parent 11158e6 commit 0f57bb2
Showing 1 changed file with 88 additions and 2 deletions.
90 changes: 88 additions & 2 deletions alt_e2eshark/onnx_tests/models/azure_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from ..helper_classes import AzureDownloadableModel
from e2e_testing.registry import register_test
from e2e_testing.storage import load_test_txt_file
import onnxruntime

this_file = Path(__file__)
lists_dir = (this_file.parent).joinpath("external_lists")

model_names = load_test_txt_file(lists_dir.joinpath("shark-test-suite.txt"))
for i in range(1,4):
model_names += load_test_txt_file(lists_dir.joinpath(f"vai-hf-cnn-fp32-shard{i}.txt"))
for i in range(1, 4):
model_names += load_test_txt_file(
lists_dir.joinpath(f"vai-hf-cnn-fp32-shard{i}.txt")
)
model_names += load_test_txt_file(lists_dir.joinpath(f"vai-int8-p0p1-shard{i}.txt"))
model_names += load_test_txt_file(lists_dir.joinpath("vai-vision-int8.txt"))

Expand All @@ -23,7 +26,90 @@
"DeepLabV3_resnet50_vaiq_int8",
]

# if the model has significant shape issues, consider applying basic optimizations before import by adding to this list:
basic_opt = [
"gcvit_base",
"gcvit_small",
"gcvit_tiny",
"gcvit_xtiny",
"gcvit_xxtiny",
"swinv2_base_window12to16_192to256.ms_in22k_ft_in1k",
"swinv2_base_window12to24_192to384.ms_in22k_ft_in1k",
"swinv2_base_window16_256.ms_in1k",
"swinv2_base_window8_256.ms_in1k",
"swinv2_large_window12to16_192to256.ms_in22k_ft_in1k",
"swinv2_large_window12to24_192to384.ms_in22k_ft_in1k",
"swinv2_small_window16_256.ms_in1k",
"swinv2_small_window8_256.ms_in1k",
"swinv2_tiny_window16_256.ms_in1k",
"swinv2_tiny_window8_256.ms_in1k",
"xcit_large_24_p16_224",
"xcit_large_24_p16_224_dist",
"xcit_large_24_p16_384_dist",
"xcit_large_24_p8_224",
"xcit_large_24_p8_224_dist",
"xcit_large_24_p8_384_dist",
"xcit_medium_24_p16_224",
"xcit_medium_24_p16_224_dist",
"xcit_medium_24_p16_384_dist",
"xcit_medium_24_p8_224",
"xcit_medium_24_p8_224_dist",
"xcit_medium_24_p8_384_dist",
"xcit_nano_12_p16_224",
"xcit_nano_12_p16_224_dist",
"xcit_nano_12_p16_384_dist",
"xcit_nano_12_p8_224",
"xcit_nano_12_p8_224_dist",
"xcit_nano_12_p8_384_dist",
"xcit_small_12_p16_224",
"xcit_small_12_p16_224_dist",
"xcit_small_12_p16_384_dist",
"xcit_small_12_p8_224",
"xcit_small_12_p8_224_dist",
"xcit_small_12_p8_384_dist",
"xcit_small_24_p16_224",
"xcit_small_24_p16_224_dist",
"xcit_small_24_p16_384_dist",
"xcit_small_24_p8_224",
"xcit_small_24_p8_224_dist",
"xcit_small_24_p8_384_dist",
"xcit_tiny_12_p16_224",
"xcit_tiny_12_p16_224_dist",
"xcit_tiny_12_p16_384_dist",
"xcit_tiny_12_p8_224",
"xcit_tiny_12_p8_224_dist",
"xcit_tiny_12_p8_384_dist",
"xcit_tiny_24_p16_224",
"xcit_tiny_24_p16_224_dist",
"xcit_tiny_24_p16_384_dist",
"xcit_tiny_24_p8_224",
"xcit_tiny_24_p8_224_dist",
"xcit_tiny_24_p8_384_dist",
]

custom_registry += basic_opt
# for simple models without dim params or additional customization, we should be able to register them directly with AzureDownloadableModel
# TODO: many of the models in the text files loaded from above will likely need to be registered with an alternative test info class.
for t in set(model_names).difference(custom_registry):
register_test(AzureDownloadableModel, t)


class AzureWithOpt(AzureDownloadableModel):
def apply_ort_basic_optimizations(self):
opt = onnxruntime.SessionOptions()
opt.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
)
optimized_model = str(Path(self.model).parent.joinpath("model.optimized.onnx"))
opt.optimized_model_filepath = optimized_model
session = onnxruntime.InferenceSession(self.model, opt)
self.model = optimized_model
del session

def construct_model(self):
super().construct_model()
self.apply_ort_basic_optimizations()


for t in basic_opt:
register_test(AzureWithOpt, t)

0 comments on commit 0f57bb2

Please sign in to comment.