Skip to content

Commit

Permalink
Truncated gpt model examples for iree debug
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Oct 16, 2024
1 parent 649314a commit e0f4df0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
4 changes: 3 additions & 1 deletion alt_e2eshark/e2e_testing/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ class TestTensors:

def __init__(self, data: Tuple):
self.data = data
self.type = type(self.data[0])
self.type = None
if len(data) > 0 :
self.type = type(self.data[0])
if not all([type(d) == self.type for d in data]):
self.type == None

Expand Down
1 change: 1 addition & 0 deletions alt_e2eshark/onnx_tests/helper_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, og_model_info_class: type, og_name: str, *args, **kwargs):
run_dir = Path(self.model).parents[1]
og_model_path = os.path.join(run_dir, og_name)
self.sibling_inst = og_model_info_class(og_name, og_model_path)
self.opset_version = self.sibling_inst.opset_version

def construct_model(self):
if not os.path.exists(self.sibling_inst.model):
Expand Down
40 changes: 40 additions & 0 deletions alt_e2eshark/onnx_tests/models/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,43 @@ def construct_inputs(self):
register_test(dim_param_constructor(dim_params), model_name)

# You can add more customizations or specific handling for certain models here
from ..helper_classes import TruncatedModel, get_trucated_constructor
# download model.onnx to /proj/gdba/shark/chi/src/SHARK-TestSuite/alt_e2eshark/tmp/mygpt4
# export CACHE_DIR=/proj/gdba/shark/chi/src/SHARK-TestSuite/alt_e2eshark/tmp
model_name = "mygpt4"
mygpt4_nlp_params = {
"batch_size": 1,
"seq_len": 128,
"encoder_sequence_length": 128,
"decoder_sequence_length": 128,
"unk__2829": 1,
"unk__2828": 1,
"unk__2827": 1,
"unk__2826": 1,
"unk__2825": 1,
"unk__2824": 1,
}
t_model_constructor = get_trucated_constructor(TruncatedModel, dim_param_constructor(mygpt4_nlp_params), model_name)
# Run to the last onnx ops / the whole model with:
# python ./run.py --mode=cl-onnx-iree -v -t mygpt4
# Stages to be run: ['setup', 'import_model', 'preprocessing', 'compilation', 'construct_inputs', 'native_inference', 'compiled_inference', 'postprocessing']
# Test list: ['mygpt4']
# running test mygpt4...
# Unzipping - /proj/gdba/shark/chi/src/SHARK-TestSuite/alt_e2eshark/tmp/mygpt4/model.onnx.zip...
# Unzipping succeded. Look for extracted contents in /proj/gdba/shark/chi/src/SHARK-TestSuite/alt_e2eshark/test-run/mygpt4
# {'Shape': 63, 'Cast': 141, 'Slice': 136, 'Squeeze': 25, 'Range': 25, 'Concat': 65, 'Reshape': 161, 'Unsqueeze': 38, 'Sub': 85, 'Mul': 169, 'Add': 159, 'Gather': 3, 'MatMul': 72, 'Split': 12, 'Transpose': 48, 'GreaterOrEqual': 12, 'Softmax': 12, 'GlobalAveragePool': 48, 'Sqrt': 24, 'Reciprocal': 24, 'Erf': 12}#
# Running stage 'import_model'...
register_test(t_model_constructor(1, ""), "mygpt4")
# Run to the last 2/7/12/17 onnx op with:
# run with python ./run.py --mode=cl-onnx-iree -v -t mygpt4_trunc_
# for n in range(2, 20, 5):
# register_test(t_model_constructor(n,""), f"mygpt4_trunc_{n}")

# Run to 5/55/105/155 onnx.Add ops with:
# run with python ./run.py --mode=cl-onnx-iree -v -t mygpt4_trunc_add_
for n in range(5, 160, 50):
register_test(t_model_constructor(n,"Add"), f"mygpt4_trunc_add_{n}")

# python ./run.py --mode=cl-onnx-iree -v -t mygpt4_trunc_shape_1
register_test(t_model_constructor(1, "Shape"), "mygpt4_trunc_shape_1")
register_test(t_model_constructor(161, "Reshape"), "mygpt4_trunc_reshape_161")

0 comments on commit e0f4df0

Please sign in to comment.