Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache onnxruntime inference session and input nodes #360

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions alt_e2eshark/e2e_testing/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,29 @@ def __init__(
self.sess_options = ort.SessionOptions()
self.dim_param_dict = None

def forward(self, input: Optional[TestTensors] = None) -> TestTensors:
"""Applies self.model to self.input. Only override if necessary for specific models"""
input = input.to_numpy().data
@property
def ort_session(self):
if hasattr(self, "_cached_ort_session") and self._cached_ort_session:
return self._cached_ort_session

if not os.path.exists(self.model):
self.construct_model()

self.update_sess_options()
session = ort.InferenceSession(self.model, self.sess_options)
self._cached_ort_session = ort.InferenceSession(self.model, self.sess_options)
self._ort_input_nodes = self._cached_ort_session.get_inputs()
self._ort_output_nodes = self._cached_ort_session.get_outputs()
return self._cached_ort_session

@ort_session.deleter
def ort_session(self):
if hasattr(self, "_cached_ort_session"):
del self._cached_ort_session

def forward(self, input: Optional[TestTensors] = None) -> TestTensors:
"""Applies self.model to self.input. Only override if necessary for specific models"""
input = input.to_numpy().data
session = self.ort_session
session_inputs = session.get_inputs()
session_outputs = session.get_outputs()

Expand Down Expand Up @@ -75,7 +91,12 @@ def construct_inputs(self) -> TestTensors:
self.update_dim_param_dict()
# print(self.get_signature())
# print(get_op_frequency(self.model))
return get_sample_inputs_for_onnx_model(self.model, self.dim_param_dict)
if hasattr(self, "_ort_input_nodes") and self._ort_input_nodes:
input_nodes = self._ort_input_nodes
else:
session = self.ort_session
input_nodes = session.get_inputs()
return get_sample_inputs_for_onnx_model(input_nodes, self.dim_param_dict)

def apply_postprocessing(self, output: TestTensors):
"""can be overridden to define post-processing methods for individual models"""
Expand Down
10 changes: 3 additions & 7 deletions alt_e2eshark/e2e_testing/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,10 @@ def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.N
raise NotImplementedError(f"Found an unhandled dtype: {node.type}.")


def get_sample_inputs_for_onnx_model(model_path, dim_param_dict = None) -> TestTensors:
"""A convenience function for generating sample inputs for an onnx model"""
opt = onnxruntime.SessionOptions()
opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
s = onnxruntime.InferenceSession(model_path, opt)
inputs = s.get_inputs()
def get_sample_inputs_for_onnx_model(input_nodes, dim_param_dict = None) -> TestTensors:
"""A convenience function for generating sample inputs for an onnx model"""
sample_inputs = TestTensors(
tuple([generate_input_from_node(node, dim_param_dict) for node in inputs])
tuple([generate_input_from_node(node, dim_param_dict) for node in input_nodes])
)
return sample_inputs

Expand Down