diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index 2252e34dff38..11c8f4a929ae 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -97,7 +97,7 @@ def _module_lowering( # Lower from ONNX to Torch run_pipeline_with_repro_report( torch_mod, - f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", + f"builtin.module(inline, func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", "Lowering Onnx backend contract to Linalg-on-Tensors backend contract", ) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index e0d3529d942e..238d8fd94873 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -35,7 +35,7 @@ from typing import Optional, List, Dict, Tuple -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np import re @@ -90,6 +90,47 @@ class Config: # making an assumption. elide_initialized_inputs: bool = True + # Some ONNX operators are defined by ONNX functions and will be + # automatically expanded (see get_operator_function() below) to MLIR + # functions by the importer. This option allows denylisting functions that + # should not be expanded. + function_expansion_denylists_by_domain: Dict[str, set[str]] = field( + default_factory=lambda: { + # Default domain (ONNX built-in ops) + "": { + # CastLike's second input `target_type` is used only for its + # type (T2), from which its output's type is inferred, but + # because its value is unused, ONNX's shape inference doesn't + # annotate the input value with a type, so looking up the + # function by the provided input types will fail. + "CastLike", + # ONNX errors when trying to infer the type of the Loop op + # within this function: "[ShapeInferenceError] Inferred shape + # and existing shape differ in rank: (1) vs (0)" + "Range", + # e2e test failures for this ONNX function in + # BernoulliOnesModule_basic, BernoulliZerosModule_basic suggest + # the ONNX function has the opposite behavior to what it is + # documented to have! Disabling until some decision is made + # about how to deal with this. + # FIXME: make and link issue for follow-up + "Bernoulli", + # e2e tests CrossEntropyLossModule_basic and + # CrossEntropyLossNoReductionModule_basic fail because the + # expansion of this function uses Unsqueeze with non-constant + # values for tensor axes. Disabling pending fix to converter + # (if appropriate). + # FIXME: make and link issue for follow-up + "SoftmaxCrossEntropyLoss", + # e2e tests HardswishModule_basic, HardswishRandomModule_basic + # fail with some golden value mismatch when this function is + # expanded. Disabling pending investigation of why this is. + # FIXME: make and link issue for follow-up + "HardSwish", + } + } + ) + class ModelInfo: """Top-level accounting and accessors for an ONNX model.""" @@ -111,7 +152,12 @@ def create_module(self, context: Optional[Context] = None) -> Module: class GraphInfo: """Information about a Graph within a model.""" - def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): + def __init__( + self, + model_info: ModelInfo, + graph_proto: onnx.GraphProto, + is_subgraph: bool = False, + ): self.model_info = model_info self.graph_proto = graph_proto self.initializer_map: Dict[str, onnx.TensorProto] = { @@ -129,7 +175,11 @@ def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): # Generate the effective input map, which for old models can be a # subset of the input map. - if model_info and model_info.config.elide_initialized_inputs: + if ( + not is_subgraph + and model_info + and model_info.config.elide_initialized_inputs + ): self.input_map = { k: v for k, v in self.declared_input_map.items() @@ -149,8 +199,18 @@ def find_type_proto_for_name(self, name: str) -> onnx.TypeProto: # Node outputs don't typically have type information, but shape inference # will associate them in the value_info. If not there, it may be a # graph output, which must have type information. - value_info = self.value_info_map.get(name) or self.output_map.get(name) - if value_info is not None: + value_info = ( + self.value_info_map.get(name) + or self.output_map.get(name) + or self.declared_input_map.get(name) + ) + if value_info is None: + tensor_proto = self.initializer_map.get(name) + if tensor_proto is not None: + return onnx.helper.make_tensor_type_proto( + tensor_proto.data_type, tensor_proto.dims + ) + else: return value_info.type # No type information is associated, this can occur when the value is unused: return "" @@ -172,6 +232,8 @@ class NodeImporter: __slots__ = [ "_c", "_cc", + "_m", + "_mc", "_gi", "_p", "_b", @@ -185,9 +247,13 @@ def __init__( parent_op: Operation, block: Block, context_cache: "ContextCache", + module_op: Operation, + module_cache: "ModuleCache", ): self._c = parent_op.context self._cc = context_cache + self._m = module_op + self._mc = module_cache self._gi = graph_info self._p = parent_op self._b = block @@ -195,9 +261,19 @@ def __init__( @classmethod def define_function( - cls, graph_info: GraphInfo, module_op: Operation + cls, + graph_info: GraphInfo, + module_op: Operation, + context_cache: Optional["ContextCache"] = None, + module_cache: Optional["ModuleCache"] = None, + public: bool = True, ) -> "NodeImporter": - cc = ContextCache(module_op.context) + cc = ( + context_cache + if context_cache is not None + else ContextCache(module_op.context) + ) + mc = module_cache if module_cache is not None else ModuleCache(module_op, cc) with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"): body = module_op.regions[0].blocks[0] func_name = graph_info.graph_proto.name @@ -209,11 +285,23 @@ def define_function( for out in graph_info.output_map.values() ] ftype = FunctionType.get(input_types, output_types) - func_op = func_dialect.FuncOp(func_name, ftype, ip=InsertionPoint(body)) + func_op = func_dialect.FuncOp( + func_name, + ftype, + ip=InsertionPoint(body), + visibility="public" if public else "private", + ) block = func_op.add_entry_block( [Location.name(k) for k in graph_info.input_map.keys()] ) - imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc) + imp = NodeImporter( + graph_info, + parent_op=func_op, + block=block, + context_cache=cc, + module_op=module_op, + module_cache=mc, + ) for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): imp._nv_map[node_name] = input_value imp._populate_graph_attrs(func_op) @@ -293,6 +381,8 @@ def get_none(self): def import_node(self, node: onnx.NodeProto): with InsertionPoint(self._b), Location.name(node.name): op_type = node.op_type + op_domain = node.domain + # Handle special op types that materialize to non-op IR constructs. # Handlers return True if the op was handled, else this function # should process it as a general node. @@ -303,33 +393,58 @@ def import_node(self, node: onnx.NodeProto): return # General node import. input_values = [] + input_type_protos = [] for input_name in node.input: try: input_values.append(self._nv_map[input_name]) + # Missing optional arguments will have empty types + input_type_protos.append( + self._gi.find_type_proto_for_name(input_name) + or onnx.TypeProto() + ) except KeyError: raise OnnxImportError( f"Non topologically produced ONNX node input '{input_name}': {node}" ) - output_names = list(node.output) - output_types = [ - self._cc.type_proto_to_type(self._gi.find_type_proto_for_name(n)) - for n in output_names - ] - - attrs = self.import_attributes(node.attribute) - attrs["name"] = StringAttr.get(f"onnx.{op_type}") - regions = self.count_regions(node.attribute) - - custom_op = Operation.create( - name="torch.operator", - results=output_types, - operands=input_values, - attributes=attrs, - regions=regions, + output_names = [] + output_type_protos = [] + output_types = [] + for output_name in node.output: + output_names.append(output_name) + type_proto = self._gi.find_type_proto_for_name(output_name) + output_type_protos.append(type_proto) + output_types.append(self._cc.type_proto_to_type(type_proto)) + + for opset_import in self._gi.model_info.model_proto.opset_import: + if opset_import.domain == op_domain: + opset_version = opset_import.version + break + operator_func_op = self._mc.get_operator_function( + op_type, + op_domain, + opset_version, + input_type_protos, + output_type_protos, + node, + self._gi.model_info.config, ) - self.import_regions(node.attribute, custom_op) + if operator_func_op is not None: + custom_op = func_dialect.CallOp(operator_func_op, input_values) + else: + attrs = self.import_attributes(node.attribute) + attrs["name"] = StringAttr.get(f"onnx.{op_type}") + regions = self.count_regions(node.attribute) + custom_op = Operation.create( + name="torch.operator", + results=output_types, + operands=input_values, + attributes=attrs, + regions=regions, + ) + self.import_regions(node.attribute, custom_op) + for output_name, output_value in zip(output_names, custom_op.results): self._nv_map[output_name] = output_value @@ -387,9 +502,14 @@ def import_regions(self, onnx_attrs: List[onnx.AttributeProto], op): *block_types, arg_locs=[op.location] * len(block_types) ) block = region.blocks[0] - graph_info = GraphInfo(None, attr.g) + graph_info = GraphInfo(self._gi.model_info, attr.g, is_subgraph=True) imp = NodeImporter( - graph_info, parent_op=op, block=block, context_cache=self._cc + graph_info, + parent_op=op, + block=block, + context_cache=self._cc, + module_op=self._m, + module_cache=self._mc, ) for node_name, input_value in zip(block_names, block.arguments): @@ -603,6 +723,13 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: element_type = self.get_optional_element_type(ot.elem_type) return self.get_optional_type(element_type) + # Check if TypeProto is empty (sometimes happens for unused function + # arguments) + if tp.SerializeToString( + deterministic=True + ) == onnx.TypeProto().SerializeToString(deterministic=True): + return self.get_none_type() + # TODO: Others if ever needed. Or we consider ourselves DNN-only. # See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type. raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}") @@ -631,6 +758,240 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: return handler(tp) +def _specialize_function_and_create_model( + function_proto: onnx.FunctionProto, + op_schema: onnx.defs.OpSchema, + name_to_give_model: str, + input_type_protos: list[onnx.TypeProto], + output_type_protos: list[onnx.TypeProto], + caller_node: onnx.NodeProto, +) -> onnx.ModelProto: + """ + Helper for ModuleCache::get_operator_function() that specializes a function + and coverts it to a model. + + An ONNX function may be polymorphic, parameterized over the types of its + inputs and values of its attributes (~= compile-time constants). We need to + monomorphize it for importing into MLIR. It seems like the only practical + way to do this is by turning it into a model: + - models can have types on their inputs and outputs, unlike functions + - ONNX provides a function to do shape inference (providing concrete + types for everything in the body) for models, but not for functions + - the rest of the code in this importer can only handle models, not + functions + """ + + graph_proto = onnx.GraphProto() + + for input_name, input_type_proto in zip(function_proto.input, input_type_protos): + input_proto = onnx.ValueInfoProto() + input_proto.name = input_name + input_proto.type.CopyFrom(input_type_proto) + graph_proto.input.append(input_proto) + output_proto = onnx.ValueInfoProto() + + for output_name, output_type_proto in zip( + function_proto.output, output_type_protos + ): + output_proto.name = output_name + output_proto.type.CopyFrom(output_type_proto) + graph_proto.output.append(output_proto) + + call_attributes = caller_node.attribute + for node in function_proto.node: + # Import referenced attributes from call-site or default values + new_node = onnx.NodeProto() + new_node.CopyFrom(node) + old_attributes = list(node.attribute) + # .clear() isn't available on protobuf lists for some reason + while len(new_node.attribute) > 0: + new_node.attribute.pop() + for node_attribute in old_attributes: + ref_name = node_attribute.ref_attr_name + if not ref_name: + new_node.attribute.append(node_attribute) + continue + + for call_attribute in call_attributes: + if call_attribute.name == ref_name: + new_attribute = onnx.AttributeProto() + new_attribute.CopyFrom(call_attribute) + new_attribute.name = node_attribute.name + new_node.attribute.append(new_attribute) + break + else: + # The default value is sometimes empty for optional attributes + # that don't have a default, in which case it is dropped. + # cf. https://github.com/onnx/onnx/blob/88f8ef15cfaa3138d336f3502aed5018d802bf43/onnx/shape_inference/attribute_binder.h#L21-L23 + if ( + op_schema.attributes[ref_name].default_value + and op_schema.attributes[ref_name].default_value.type + ): + new_attribute = onnx.AttributeProto() + new_attribute.CopyFrom(op_schema.attributes[ref_name].default_value) + new_attribute.name = node_attribute.name + new_node.attribute.append(new_attribute) + graph_proto.node.append(new_node) + + graph_proto.name = name_to_give_model + + model_proto = onnx.ModelProto() + model_proto.opset_import.extend(function_proto.opset_import) + # FIXME: is this the correct IR version, or should it be the latest, or the + # one used by the actual model, or something else? + model_proto.ir_version = onnx.helper.find_min_ir_version_for( + function_proto.opset_import + ) + model_proto.graph.CopyFrom(graph_proto) + + model_proto = onnx.shape_inference.infer_shapes( + model_proto, check_type=True, strict_mode=True, data_prop=True + ) + graph_proto = model_proto.graph + + # Useful for debugging. + # onnx.checker.check_model(model_proto, full_check=True) + + return model_proto + + +class ModuleCache: + """Caches per-module lookups of various things.""" + + __slots__ = [ + "_m", + "_cc", + "_operator_function_map", + ] + + def __init__(self, module_op: Operation, context_cache: ContextCache): + self._m = module_op + self._cc = context_cache + self._operator_function_map: Dict[str, func_dialect.FuncOp] = {} + + def get_operator_function( + self, + op_name: str, + op_domain: str, + opset_version: int, + input_type_protos: list[onnx.TypeProto], + output_type_protos: list[onnx.TypeProto], + caller_node: onnx.NodeProto, + config: Config, + ) -> Optional[func_dialect.FuncOp]: + """ + Get or create MLIR function corresponding to an ONNX operator. + + Returns None for ONNX operators that aren't functions. + """ + + if ( + op_domain in config.function_expansion_denylists_by_domain + and op_name in config.function_expansion_denylists_by_domain[op_domain] + ): + return None + + op_schema = onnx.defs.get_schema( + op_name, domain=op_domain, max_inclusive_version=opset_version + ) + + # The get_schema() lookup above should get the right version of the + # operator definition, but the function body can change slightly + # within a single operator version, as explained in + # https://github.com/onnx/onnx/blob/093a8d335a66ea136eb1f16b3a1ce6237ee353ab/onnx/defs/schema.h#L1070-L1086 + # There also seem to be cases where a function goes from being not + # context-dependent to context-dependent. + f = lambda ver: ver <= opset_version + ncd_function_version = max( + filter(f, op_schema.function_opset_versions), + default=None, + ) + cd_function_version = max( + filter(f, op_schema.context_dependent_function_opset_versions), + default=None, + ) + if ncd_function_version is None and cd_function_version is None: + # No relevant function definition + return None + elif ncd_function_version is not None and ( + cd_function_version is None or cd_function_version < ncd_function_version + ): + specific_version = ncd_function_version + is_context_dependent = False + else: + specific_version = cd_function_version + is_context_dependent = True + + # This is both a key for memoization of function importing and also a + # name mangling scheme, so it must include all information needed to + # uniquely identify a function and anything it might be parameterized + # over. + key = repr( + ( + op_name, + op_domain, + opset_version, + input_type_protos, + # Though output types can be inferred from input types, it does + # not seem to be the case that there's only one legal set of + # outputs for a given set of inputs. When attemtping to always + # use onnx.shape_inference.infer_function_output_types instead + # of the caller-provided types, sometimes IR verification fails + output_type_protos, + # Avoid including the attributes twice (once on their own and + # once as part of the node) for context-dependent functions, + # avoid including unused parts of the node for other functions. + caller_node if is_context_dependent else caller_node.attribute, + ) + ) + + existing = self._operator_function_map.get(key) + if existing is not None: + return existing + + if is_context_dependent: + function_proto_str = ( + op_schema.get_context_dependent_function_with_opset_version( + specific_version, + caller_node.SerializeToString(), + [ + t.SerializeToString() if not isinstance(t, bytes) else t + for t in input_type_protos + ], + ) + ) + else: + function_proto_str = op_schema.get_function_with_opset_version( + specific_version + ) + if not function_proto_str: + raise OnnxImportError( + f"Function lookup for {op_name}/{op_domain}/{specific_version}/{is_context_dependent} failed unexpectedly. This probably indicates a bug." + ) + function_proto = onnx.onnx_pb.FunctionProto() + function_proto.ParseFromString(function_proto_str) + + tmp_model_proto = _specialize_function_and_create_model( + function_proto, + op_schema, + key, + input_type_protos, + output_type_protos, + caller_node, + ) + + tmp_model_info = ModelInfo(tmp_model_proto) + tmp_graph_info = GraphInfo(tmp_model_info, tmp_model_proto.graph) + imp = NodeImporter.define_function( + tmp_graph_info, self._m, self._cc, self, public=False + ) + imp.import_all() + func_op = imp._p + + self._operator_function_map[key] = func_op + return func_op + + ELEM_TYPE_TO_IR_TYPE_CB = { onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(), onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8),