diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 624483708d3a6..51696f2378e48 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1457,6 +1457,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const logging::Logger& logger, std::unique_ptr& graph); + Status UpdateUsingGraphApiModel(const OrtModel& api_model); + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const { return runtime_optimizations_; @@ -1796,7 +1798,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::unordered_map> node_arg_to_consumer_nodes_; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - const std::unordered_map domain_to_version_; + std::unordered_map domain_to_version_; // Model IR version. Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION}; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 3f3e088dd7474..6d686e6ba6cca 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -850,7 +850,8 @@ struct OrtApi { * * \snippet{doc} snippets.dox OrtStatus Return Value */ - ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, + ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); /** \brief Run the model in an ::OrtSession @@ -4785,19 +4786,39 @@ struct OrtApi { // Get the OrtGraphApi instance for creating a model dynamically. const OrtGraphApi*(ORT_API_CALL* GetGraphApi)(); - // Create TypeInfo for editing. + // User provides custom OrtAllocator. We only call the OrtAllocator::Info and OrtAllocator::Free functions. + // The OrtMemoryInfo returned by the allocator must match where `p_data` is. + ORT_API2_STATUS(CreateTensorWithDataAndDeleterAsOrtValue, _In_ const OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + // - ORT_API2_STATUS(CreateTypeInfo, _In_ enum ONNXType onnx_type, _Out_ OrtTypeInfo** type_info); - - // Mutable getter. Can we used for Tensor or SparseTensor - ORT_API2_STATUS(GetMutableTensorInfoFromTypeInfo, _In_ OrtTypeInfo* type_info, - _Outptr_result_maybenull_ OrtTensorTypeAndShapeInfo** tensor_info); - ORT_API2_STATUS(GetMutableMapInfoFromTypeInfo, _In_ OrtTypeInfo* type_info, - _Outptr_result_maybenull_ OrtMapTypeInfo** map_info); - ORT_API2_STATUS(GetMutableSequenceInfoFromTypeInfo, _In_ OrtTypeInfo* type_info, - _Outptr_result_maybenull_ OrtSequenceTypeInfo** tensor_info); - ORT_API2_STATUS(GetMutableOptionalInfoFromTypeInfo, _In_ OrtTypeInfo* type_info, - _Outptr_result_maybenull_ OrtOptionalTypeInfo** tensor_info); + // Additions for creating/editing graph inputs/outputs in model builder API + // + + // Get the opset for the given domain. And nodes added must conform to this opset. + ORT_API2_STATUS(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, _Out_ int* opset); + + // Create TypeInfo instance to populate for graph inputs/outputs. + // We have to create the TypeInfo and the internal type, and need args for the latter. + // We _could_ just return OrtTypeInfo, but the user is going to need the internal type to actually do something + // meaningful with it, so return both instead of making the user call another function to get a mutable instance + // of the internal type. + + // Create Tensor TypeInfo. User owns type_info. type_info owns tensor_info. + ORT_API2_STATUS(CreateTensorTypeInfo, ONNXTensorElementDataType element_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtTensorTypeAndShapeInfo** tensor_info); + // Create SparseTensor TypeInfo. User owns type_info. type_info owns tensor_info. + ORT_API2_STATUS(CreateSparseTensorTypeInfo, ONNXTensorElementDataType element_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtTensorTypeAndShapeInfo** tensor_info); + // Create Map TypeInfo. User owns type_info. type_info owns map_info. + ORT_API2_STATUS(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtMapTypeInfo** map_info); + ORT_API2_STATUS(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtSequenceTypeInfo** sequence_info); + ORT_API2_STATUS(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtOptionalTypeInfo** optional_info); }; /* @@ -4923,10 +4944,10 @@ ORT_RUNTIME_CLASS(ValueInfo); struct OrtGraphApi { /*** New usage *** Create OrtTypeInfo for Tensor input: - CreateTypeInfo - GetMutableTensorInfoFromTypeInfo - SetTensorElementType - SetDimensions + SetSymbolicDimensions to define shape (these are complimentary and overlap) + CreateTensorTypeInfo + SetDimensions + SetSymbolicDimensions to define shape. These are 1:1. + If dim is -1 there should be a non-nullptr symbolic value. May be empty string but must not be nullptr. + If dim >= 0 there should be a nullptr symbolic value. Read from OrtTypeInfo: GetOnnxTypeFromTypeInfo @@ -4939,12 +4960,11 @@ struct OrtGraphApi { Call CreateValueInfo to combine value name with OrtTypeInfo ***/ - // OrtValueInfo takes ownership of the OrtTypeInfo - ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ OrtTypeInfo* type_info, + // user can release type_info. + ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, _Outptr_ OrtValueInfo** value_info); ORT_API2_STATUS(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); ORT_API2_STATUS(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); - ORT_API2_STATUS(CloneValueInfo, _In_ const OrtValueInfo* src_value_info, _Out_ OrtValueInfo** new_value_info); ORT_CLASS_RELEASE(ValueInfo); // call if not added to Graph // @@ -4968,18 +4988,11 @@ struct OrtGraphApi { // Set the Inputs/Outputs. Replaces and frees any existing inputs/outputs. // Graph takes ownership of the OrtTypeInfo instances. - ORT_API2_STATUS(SetInputs, _In_ OrtGraph* graph, + ORT_API2_STATUS(SetGraphInputs, _In_ OrtGraph* graph, _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); - ORT_API2_STATUS(SetOutputs, _In_ OrtGraph* graph, + ORT_API2_STATUS(SetGraphOutputs, _In_ OrtGraph* graph, _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); - // Get current inputs/outputs. Expected usage is augmenting an existing graph. - // To change the inputs/outputs, use CloneValueInfo for any OrtValueInfo instances you want to keep, - // and CreateValueInfo to create new inputs/outputs. - // Call SetInputs/SetOutputs with the new instances to replace all existing inputs/outputs. - ORT_API2_STATUS(GetInputs, _In_ OrtGraph* graph, _Inout_ const OrtValueInfo** inputs, _Out_ size_t* inputs_len); - ORT_API2_STATUS(GetOutputs, _In_ OrtGraph* graph, _Inout_ const OrtValueInfo** outputs, _Out_ size_t* outputs_len); - // 2 use cases. User is free to choose either approach but the suggested usage would be: // // 1: Weights: @@ -4991,8 +5004,11 @@ struct OrtGraphApi { // e.g. min/max input of Clip, indices for Gather. // Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the actual data. // We will copy the data when converting to TensorProto so user doesn't need to keep the data alive. - ORT_API2_STATUS(AddInitializer, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue** tensor); - ORT_API2_STATUS(AddNode, _In_ OrtGraph* graph, _In_ OrtNode** node); + // + // Graph takes ownership of initializer. + ORT_API2_STATUS(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor); + // Graph takes ownership of node + ORT_API2_STATUS(AddNodeToGraph, _In_ OrtGraph* graph, _In_ OrtNode* node); ORT_CLASS_RELEASE(Graph); // call if not added to Model // @@ -5007,7 +5023,7 @@ struct OrtGraphApi { size_t opset_entries_len, _Outptr_ OrtModel** model); - ORT_API2_STATUS(AddGraph, _In_ OrtModel* model, _Inout_ OrtGraph** graph); + ORT_API2_STATUS(AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph); ORT_CLASS_RELEASE(Model); // @@ -5022,6 +5038,10 @@ struct OrtGraphApi { ORT_API2_STATUS(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + // + // Model editing + // + // TODO: How best to support creating a session with an existing model and allowing updates to it? // // We need to be able to update the opsets @@ -5035,12 +5055,54 @@ struct OrtGraphApi { // If we add something to OrtSessionOptions the existing APIs could be used // - based on flag in session options, skip the call to InitializeSession in CreateSession // - that will leave the OrtSession in an invalid state + // - could use SessionGetInput*/SessionGetOutput* functions to get input/output info but that's just the TypeInfo + // and our API currently deals in ValueInfo for those + // - how does that work with changing the inputs/outputs? + // - do we skip populating the Graph API inputs/outputs until there's a call to SetInputs/SetOutputs and simply + // return info from the existing model up until that point? + // - how do we manage ownership of existing vs. new inputs/outputs? // - add function to get OrtModel from session // - add function to augment the onnxruntime::Graph and finalize the session // - can share implementation details with Graph::LoadFromGraphApiModel - ORT_API2_STATUS(GetModelFromSession, _In_ OrtSession* session, _Outptr_ OrtModel** model); - ORT_API2_STATUS(GetOpsetFromModel, _In_ const OrtModel* model, _In_ const char* domain_name, int* opset_version); - ORT_API2_STATUS(UpdateSessionWithModel, _In_ OrtSession* session, _In_ OrtModel* model, + + ORT_API2_STATUS(CreateModelBuilderSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out, _Outptr_ OrtModel** model); + + ORT_API2_STATUS(CreateModelBuilderSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out, _Outptr_ OrtModel** model); + + // To get current inputs/outputs to augment an existing graph use SessionGetInputCount/SessionGetOutputCount, + // SessionGetInputTypeInfo/SessionGetOutputTypeInfo to get the TypeInfo, and SessionGetInputName/SessionGetOutputName + // to get the names. + // + // Call SetInputs/SetOutputs with the new instances to replace all existing inputs/outputs. + // Get OrtModel for editing. This can be used to add nodes before or after the existing Graph. + // The model inputs/outputs MUST be updated to reflect the changes made to the Graph. + // Any new operator domains must be added. + // Existing operator domains must be honoured. i.e. you must use the same opset the model currently uses. + // ORT_API2_STATUS(GetModelFromSession, _In_ OrtSession* session, _Outptr_ OrtModel** model); + + // Get OrtGraph from OrtModel. OrtModel owns OrtGraph so use should NOT call ReleaseGraph. + // + // User can call AddNode to add nodes at the start of end of the OrtGraph. + // User can call AddInitializer. + // User must adjust the Graph inputs/outputs to be valid for the changes made. + // + // e.g. Original: InputA -> NodeA ->. + // User adds a new Node 'InputPreA -> NodePreA -> InputA'. + // NodeA will now consume the output of NodePreA. + // The original InputA graph input must be replaced with a new graph input for InputPreA. + // + // This is done by using SessionGetInputCount/SessionGetInputName/SessionGetInputTypeInfo to get the existing + // input info. Create new OrtValueInfo instances for the required existing and new inputs. + // Call SetInputs + ORT_API2_STATUS(GetGraphFromModel, _In_ OrtModel* model, _Outptr_ OrtGraph** graph); + + // Update the model in the session. + // Existing opsets cannot be changed, so `additional_domain_names` must NOT match any existing domain names. + // Session will be finalized and ready for inferencing on completion. + ORT_API2_STATUS(ApplyModelToSession, _In_ OrtSession* session, _In_ OrtModel* model, _In_reads_(additional_opset_entries_len) const char* const* additional_domain_names, _In_reads_(additional_opset_entries_len) const int* additional_opset_versions, _In_ size_t additional_opset_entries_len); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d4b1e7031c661..36daf535907a2 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -536,7 +536,6 @@ ORT_DEFINE_RELEASE(KernelInfo); #define ORT_DEFINE_GRAPH_API_RELEASE(NAME) \ inline void OrtRelease(Ort##NAME* ptr) { GetGraphApi().Release##NAME(ptr); } -ORT_DEFINE_GRAPH_API_RELEASE(Shape); ORT_DEFINE_GRAPH_API_RELEASE(ValueInfo); ORT_DEFINE_GRAPH_API_RELEASE(Node); ORT_DEFINE_GRAPH_API_RELEASE(Graph); @@ -600,14 +599,6 @@ struct Base { return p; } - /// \brief Allows relinquishing ownership of the contained C object pointer if the API call is successful - /// The underlying object is not destroyed - // TODO: Refine this. Method name is based on usage not what it does which is ugly. - // Should we use `friend` instead to allow ownership transfer when building a model at runtime? - contained_type** release_on_success() { - return &p_; - } - protected: contained_type* p_{}; }; @@ -2509,39 +2500,15 @@ struct CustomOpBase : OrtCustomOp { // namespace GraphApi { -namespace detail { -template -struct ShapeImpl : Ort::detail::Base { - using B = Ort::detail::Base; - using B::B; - - template - bool operator==(const ShapeImpl& o) const; -}; -} // namespace detail - -// Const object holder that does not own the underlying object -using ConstShape = detail::ShapeImpl>; - -/** \brief Wrapper around ::OrtShape - * - */ -struct Shape : detail::ShapeImpl { - using Dimension = std::variant; - explicit Shape(std::nullptr_t) {} ///< No instance is created - explicit Shape(OrtShape* p) : ShapeImpl{p} {} ///< Take ownership of a pointer created by C Api - Shape(const std::vector& dims); ///< Wraps CreateFixedShape. All dims must be >= 0. - Shape(const std::vector& dims); /// p_}; } -}; - namespace detail { template struct ValueInfoImpl : Ort::detail::Base { using B = Ort::detail::Base; using B::B; + std::string Name() const; + ConstTypeInfo TypeInfo() const; + template bool operator==(const ValueInfoImpl& o) const; }; @@ -2558,7 +2525,7 @@ struct ValueInfo : detail::ValueInfoImpl { explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl{p} {} ///< Take ownership of a pointer created by C Api // Create ValueInfo for a tensor - static ValueInfo CreateTensorValueInfo(const std::string& name, ONNXTensorElementDataType type, Shape& shape); + explicit ValueInfo(const std::string& name, ConstTypeInfo& type_info); ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; } }; @@ -2615,8 +2582,8 @@ struct GraphImpl : Ort::detail::Base { using B = Ort::detail::Base; using B::B; - void AddInput(ValueInfo& input); - void AddOutput(ValueInfo& output); + void SetInputs(std::vector& inputs); + void SetOutputs(std::vector& outputs); void AddInitializer(const std::string& name, Value& initializer); // Graph takes ownership of Value void AddNode(Node& node); // Graph takes ownership of Node diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 688808f69ab48..83d950994dbd1 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2181,30 +2181,6 @@ inline std::vector StringsToCharPtrs(const std::vector return ptrs; } -inline Shape::Shape(const std::vector& dims) { - ThrowOnError(GetGraphApi().CreateFixedShape(dims.data(), dims.size(), &p_)); -} - -inline Shape::Shape(const std::vector& dims) { - ThrowOnError(GetGraphApi().CreateShape(&p_)); - for (const auto& dim : dims) { - if (std::holds_alternative(dim)) { - ThrowOnError(GetGraphApi().AddDimension(p_, std::get(dim))); - } else { - ThrowOnError(GetGraphApi().AddDynamicDimension(p_, std::get(dim).c_str())); - } - } -} - -// static -inline ValueInfo ValueInfo::CreateTensorValueInfo(const std::string& name, ONNXTensorElementDataType type, - Shape& shape) { - // ValueInfo takes ownership of `shape` - ValueInfo vi(nullptr); - ThrowOnError(GetGraphApi().CreateTensorValueInfo(name.c_str(), type, shape.release_on_success(), &vi.p_)); - return vi; -} - // static inline void Node::Init(const std::string& operator_name, const std::string& operator_domain, const std::string& node_name, @@ -2215,17 +2191,19 @@ inline void Node::Init(const std::string& operator_name, const std::string& oper auto inputs = StringsToCharPtrs(input_names); auto outputs = StringsToCharPtrs(output_names); - std::vector attributes_ptrs; + std::vector attributes_ptrs; attributes_ptrs.reserve(attributes.size()); std::transform(attributes.begin(), attributes.end(), std::back_inserter(attributes_ptrs), - [](OpAttr& attr) { return attr.release_on_success(); }); + [](OpAttr& attr) -> OrtOpAttr* { return attr; }); - // Node takes ownership of `attributes` ThrowOnError(GetGraphApi().CreateNode(operator_name.c_str(), operator_domain.c_str(), node_name.c_str(), inputs.data(), inputs.size(), outputs.data(), outputs.size(), attributes_ptrs.data(), attributes_ptrs.size(), &node)); + + // Node now owns the attributes + std::for_each(attributes.begin(), attributes.end(), [](OpAttr& attr) { attr.release(); }); } inline Node::Node(const std::string& operator_name, const std::string& operator_domain, @@ -2262,30 +2240,61 @@ inline Model::Model(const std::vector& opsets) { ThrowOnError(GetGraphApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); } +inline ValueInfo::ValueInfo(const std::string& name, ConstTypeInfo& type_info) { + ThrowOnError(GetGraphApi().CreateValueInfo(name.c_str(), type_info, &p_)); +} namespace detail { -inline void GraphImpl::AddInput(ValueInfo& input) { - // Graph takes ownership of `input` - ThrowOnError(GetGraphApi().AddInput(p_, input.release_on_success())); +inline std::string ValueInfoImpl::Name() const { + const char* name = nullptr; + ThrowOnError(GetGraphApi().GetValueInfoName(p_, &name)); + return name; +} + +inline ConstTypeInfo ValueInfoImpl::TypeInfo() const { + const OrtTypeInfo* type_info = nullptr; + ThrowOnError(GetGraphApi().GetValueInfoTypeInfo(p_, &type_info)); + return ConstTypeInfo{type_info}; } -inline void GraphImpl::AddOutput(ValueInfo& output) { - // Graph takes ownership of `output` - ThrowOnError(GetGraphApi().AddOutput(p_, output.release_on_success())); +inline void GraphImpl::SetInputs(std::vector& inputs) { + std::vector inputs_ptrs; + inputs_ptrs.reserve(inputs.size()); + + // Graph takes ownership. + std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs), + [](ValueInfo& vi) -> OrtValueInfo* { return vi.release(); }); + + ThrowOnError(GetGraphApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size())); + + // Graph now owns the inputs + std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); }); +} + +inline void GraphImpl::SetOutputs(std::vector& outputs) { + std::vector outputs_ptrs; + outputs_ptrs.reserve(outputs.size()); + std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs), + [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); + + ThrowOnError(GetGraphApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size())); + + // Graph now owns the outputs + std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); }); } inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer) { // Graph takes ownership of `initializer` - ThrowOnError(GetGraphApi().AddInitializer(p_, name.c_str(), initializer.release_on_success())); + ThrowOnError(GetGraphApi().AddInitializerToGraph(p_, name.c_str(), initializer.release())); } inline void GraphImpl::AddNode(Node& node) { // Graph takes ownership of `node` - ThrowOnError(GetGraphApi().AddNode(p_, node.release_on_success())); + ThrowOnError(GetGraphApi().AddNodeToGraph(p_, node.release())); } inline void ModelImpl::AddGraph(Graph& graph) { // Model takes ownership of `graph` - ThrowOnError(GetGraphApi().AddGraph(p_, graph.release_on_success())); + ThrowOnError(GetGraphApi().AddGraphToModel(p_, graph.release())); } } // namespace detail } // namespace GraphApi diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 8f1bc98ce7b49..d6b38f8cbccfe 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -289,3 +289,13 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio // “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] // “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type"; + +// Create an Inference Session that will use the Model Builder API to create/update the model. +// This flag will create the session but not fully initialize it. A model, if provided, will be loaded. +// A session logger will be created, and execution providers will be registered. +// Any device specific allocators and IDataTransfer objects will be registered. +// This allows CreateAllocator to return device specific allocators registered by EPs. +// FUTURE: This will also allow CopyTensors to utilize the IDataTransfer objects +// "0": Disabled. [DEFAULT] +// "1": Enable Model Builder Session +static const char* const kOrtSessionOptionsEnableModelBuilder = "session.model_builder_session"; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index a884927abddb7..4e8ac674206f8 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -40,7 +40,7 @@ OrtTypeInfo::OrtTypeInfo(std::unique_ptr optional_type_info : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info)) {} OrtTypeInfo::OrtTypeInfo(ONNXType type, std::unique_ptr data) noexcept - : type(type), data(std::move(data)) { + : type(type), tensor_type_info(std::move(data)) { } OrtTypeInfo::~OrtTypeInfo() = default; @@ -55,7 +55,9 @@ ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeI ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input, _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN - *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data.get() : nullptr; + *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) + ? input->tensor_type_info.get() + : nullptr; return nullptr; API_IMPL_END } @@ -93,6 +95,69 @@ ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* API_IMPL_END } +ORT_API_STATUS_IMPL(CreateTensorTypeInfo, ONNXTensorElementDataType element_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtTensorTypeAndShapeInfo** tensor_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_TENSOR); + ti->tensor_type_info = std::make_unique(); + ti->tensor_type_info->type = element_type; + + *tensor_info = ti->tensor_type_info.get(); + *type_info = ti.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(CreateSparseTensorTypeInfo, ONNXTensorElementDataType element_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtTensorTypeAndShapeInfo** tensor_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_SPARSETENSOR); + ti->tensor_type_info = std::make_unique(); + ti->tensor_type_info->type = element_type; + + *tensor_info = ti->tensor_type_info.get(); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtMapTypeInfo** map_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_MAP); + ti->map_type_info = std::make_unique(map_key_type, map_value_type->Clone()); + *map_info = ti->map_type_info.get(); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtSequenceTypeInfo** sequence_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_SEQUENCE); + ti->sequence_type_info = std::make_unique(sequence_type->Clone()); + *sequence_info = ti->sequence_type_info.get(); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtOptionalTypeInfo** optional_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_OPTIONAL); + ti->optional_type_info = std::make_unique(contained_type->Clone()); + *optional_info = ti->optional_type_info.get(); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { std::unique_ptr p(ptr); } @@ -298,8 +363,8 @@ std::unique_ptr OrtTypeInfo::Clone() const { #endif case ONNX_TYPE_TENSOR: { std::unique_ptr info; - if (data) { - info = data->Clone(); + if (tensor_type_info) { + info = tensor_type_info->Clone(); } result = MakePtr(type, std::move(info)); result->denotation = denotation; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 72d263d5fa442..54bb946e0d36b 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -31,7 +31,7 @@ struct OrtTypeInfo { ONNXType type; std::string denotation; - std::unique_ptr data; + std::unique_ptr tensor_type_info; std::unique_ptr map_type_info; std::unique_ptr sequence_type_info; std::unique_ptr optional_type_info; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 41ce01d2627a5..a37221ba225b2 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -16,6 +16,7 @@ #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" +#include "core/framework/tensor_type_and_shape.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/tensor_shape.h" #include "core/framework/tensorprotoutils.h" @@ -5233,6 +5234,9 @@ Status Graph::InlineFunction(Node& callnode) { } void Graph::SetInputs(gsl::span inputs) { + graph_inputs_including_initializers_.clear(); + graph_inputs_excluding_initializers_.clear(); + // creating graph from scratch // rely on SetGraphInputsOutputs() to fix up graph_inputs_excluding_initializers_ // if is_loaded_from_model_file_ == false @@ -5241,7 +5245,6 @@ void Graph::SetInputs(gsl::span inputs) { if (is_loaded_from_model_file_) { // graph loaded from model file - graph_inputs_excluding_initializers_.clear(); for (const auto* input : inputs) { ORT_ENFORCE(input->Exists(), "Input to set must exist."); if (name_to_initial_tensor_.find(input->Name()) == name_to_initial_tensor_.end()) { @@ -5258,6 +5261,7 @@ void Graph::SetInputs(gsl::span inputs) { } void Graph::SetOutputs(gsl::span outputs) { + graph_outputs_.clear(); graph_outputs_.reserve(outputs.size()); graph_outputs_.assign(outputs.begin(), outputs.end()); @@ -5576,6 +5580,41 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph return Status::OK(); } +namespace { +ValueInfoProto OrtValueInfoToOnnx(const OrtValueInfo& vi) { + // the model builder API checks that the OrtValueInfo has a complete and valid OrtTypeInfo instance and that the + // name is not null/empty. + ORT_ENFORCE(vi.type_info->type == ONNX_TYPE_TENSOR, + "Internal error. Model Builder API should only allow OrtValueInfo for tensor to be created."); + + ValueInfoProto value_info_proto; + value_info_proto.set_name(vi.name); + + auto* tensor = value_info_proto.mutable_type()->mutable_tensor_type(); + const OrtTensorTypeAndShapeInfo& tensor_info = *vi.type_info->tensor_type_info.get(); + tensor->set_elem_type(tensor_info.type); + + auto& shape = *tensor->mutable_shape(); + + size_t idx = 0; + for (auto dim : tensor_info.shape.GetDims()) { + auto& dim_proto = *shape.add_dim(); + if (dim >= 0) { + dim_proto.set_dim_value(dim); + } else { + const std::string& dim_param = tensor_info.dim_params[idx]; + // if empty leave the new dim_proto with neither dim_value nor dim_param set. this represents an 'unknown' dim + if (!dim_param.empty()) { + dim_proto.set_dim_param(dim_param); + } + } + } + + return value_info_proto; +} + +} // namespace + Status Graph::LoadFromGraphApiModel(const OrtGraph& api_graph) { ArgNameToTypeMap name_to_type_map; @@ -5589,9 +5628,7 @@ Status Graph::LoadFromGraphApiModel(const OrtGraph& api_graph) { std::vector node_args; node_args.reserve(graph_inputs_or_outputs.size()); for (auto& ort_value_info : graph_inputs_or_outputs) { - const ValueInfoProto& value_info = ort_value_info->value_info_proto; - // assuming input has name and type. this needs to be enforced either here or in the graph api implementation - ORT_ENFORCE(utils::HasName(value_info) && utils::HasType(value_info), "Graph input must have name and type."); + ValueInfoProto value_info = OrtValueInfoToOnnx(*ort_value_info); name_to_type_map[value_info.name()] = value_info.type(); node_args.push_back(&GetOrCreateNodeArg(value_info.name(), &value_info.type())); @@ -5728,4 +5765,21 @@ Status Graph::LoadFromGraphApiModel(const OrtGraph& api_graph, return graph->LoadFromGraphApiModel(api_graph); } +Status Graph::UpdateUsingGraphApiModel(const OrtModel& api_model) { + for (auto& entry : api_model.domain_to_version) { + if (auto it = domain_to_version_.find(entry.first); it != domain_to_version_.end()) { + if (it->second != entry.second) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Domain version can not be changed for '", entry.first, + "'. Current version: ", it->second); + } + } else { + domain_to_version_.insert(entry); + } + } + + // this will replace all inputs/outputs and add nodes. + return LoadFromGraphApiModel(*api_model.graph); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_api_types.h b/onnxruntime/core/graph/graph_api_types.h index 820da2b5b7824..f1a634a591513 100644 --- a/onnxruntime/core/graph/graph_api_types.h +++ b/onnxruntime/core/graph/graph_api_types.h @@ -2,17 +2,15 @@ // Licensed under the MIT License. #include "core/framework/ort_value.h" +#include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/onnx_protobuf.h" // ORT C interface types for OrtGraphApi can't be in a namespace. // We need to define them here so onnxruntime::Model can be created from OrtModel. -struct OrtShape { - ONNX_NAMESPACE::TensorShapeProto shape_proto; -}; - struct OrtValueInfo { - ONNX_NAMESPACE::ValueInfoProto value_info_proto; + std::string name; + std::unique_ptr type_info; }; struct OrtOpAttr { diff --git a/onnxruntime/core/session/graph_apis.h b/onnxruntime/core/session/graph_apis.h index 6593536d77703..06e046c3c437c 100644 --- a/onnxruntime/core/session/graph_apis.h +++ b/onnxruntime/core/session/graph_apis.h @@ -3,14 +3,10 @@ namespace OrtGraphApis { // implementation that returns the API struct ORT_API(const OrtGraphApi*, GetGraphApi); -ORT_API_STATUS_IMPL(CreateFixedShape, _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtShape** shape); -ORT_API_STATUS_IMPL(CreateShape, _Outptr_ OrtShape** shape); -ORT_API_STATUS_IMPL(AddDimension, _In_ OrtShape* shape, int64_t dim_value); -ORT_API_STATUS_IMPL(AddDynamicDimension, _In_ OrtShape* shape, const char* dimension_name); -ORT_API(void, ReleaseShape, _Frees_ptr_opt_ OrtShape* shape); - -ORT_API_STATUS_IMPL(CreateTensorValueInfo, _In_ const char* name, _In_ ONNXTensorElementDataType type, - _Inout_ OrtShape** shape, _Outptr_ OrtValueInfo** value_info); +ORT_API_STATUS_IMPL(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info); +ORT_API_STATUS_IMPL(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); +ORT_API_STATUS_IMPL(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); ORT_API(void, ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo* value_info); ORT_API_STATUS_IMPL(CreateNode, const char* operator_name, const char* domain_name, _In_ const char* node_name, @@ -21,10 +17,12 @@ ORT_API_STATUS_IMPL(CreateNode, const char* operator_name, const char* domain_na ORT_API(void, ReleaseNode, _Frees_ptr_opt_ OrtNode* node); ORT_API_STATUS_IMPL(CreateGraph, _Outptr_ OrtGraph** graph); -ORT_API_STATUS_IMPL(AddInput, _In_ OrtGraph* graph, _Inout_ OrtValueInfo** value_info); -ORT_API_STATUS_IMPL(AddOutput, _In_ OrtGraph* graph, _Inout_ OrtValueInfo** value_info); -ORT_API_STATUS_IMPL(AddInitializer, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue** tensor); -ORT_API_STATUS_IMPL(AddNode, _In_ OrtGraph* graph, _Inout_ OrtNode** node); +ORT_API_STATUS_IMPL(SetGraphInputs, _In_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); +ORT_API_STATUS_IMPL(SetGraphOutputs, _In_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); +ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor); +ORT_API_STATUS_IMPL(AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node); ORT_API(void, ReleaseGraph, _Frees_ptr_opt_ OrtGraph* graph); ORT_API_STATUS_IMPL(CreateModel, @@ -32,10 +30,26 @@ ORT_API_STATUS_IMPL(CreateModel, _In_reads_(opset_entries_len) const int* opset_versions, size_t opset_entries_len, _Outptr_ OrtModel** model); -ORT_API_STATUS_IMPL(AddGraph, _In_ OrtModel* model, _Inout_ OrtGraph** graph); +ORT_API_STATUS_IMPL(AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph); ORT_API(void, ReleaseModel, _Frees_ptr_opt_ OrtModel* model); ORT_API_STATUS_IMPL(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); +// +// Model editing APIs for updating existing model +// +ORT_API_STATUS_IMPL(CreateModelBuilderSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out, _Outptr_ OrtModel** model); + +ORT_API_STATUS_IMPL(CreateModelBuilderSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out, _Outptr_ OrtModel** model); + +ORT_API_STATUS_IMPL(GetGraphFromModel, _In_ OrtModel* model, _Outptr_ OrtGraph** graph); + +ORT_API_STATUS_IMPL(ApplyModelToSession, _In_ OrtSession* session, _In_ OrtModel* model, + _In_reads_(additional_opset_entries_len) const char* const* additional_domain_names, + _In_reads_(additional_opset_entries_len) const int* additional_opset_versions, + _In_ size_t additional_opset_entries_len); } // namespace OrtGraphApis diff --git a/onnxruntime/core/session/graph_c_api.cc b/onnxruntime/core/session/graph_c_api.cc index a47c7db0cf033..3b8ee1449857d 100644 --- a/onnxruntime/core/session/graph_c_api.cc +++ b/onnxruntime/core/session/graph_c_api.cc @@ -5,6 +5,8 @@ #include "core/framework/error_code_helper.h" #include "core/framework/ort_value.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/framework/tensor_type_and_shape.h" #include "core/graph/constants.h" #include "core/graph/graph_api_types.h" #include "core/graph/onnx_protobuf.h" @@ -19,60 +21,56 @@ using namespace onnxruntime; namespace OrtGraphApis { -ORT_API_STATUS_IMPL(CreateShape, _Outptr_ OrtShape** shape) { - API_IMPL_BEGIN - *shape = new OrtShape(); - return nullptr; - API_IMPL_END -} +namespace { +std::unique_ptr CreateOrtModelFromSession(InferenceSession& session) { + auto model = std::make_unique(); -ORT_API_STATUS_IMPL(AddDimension, _In_ OrtShape* shape, int64_t dim_value) { - API_IMPL_BEGIN - shape->shape_proto.add_dim()->set_dim_value(dim_value); - return nullptr; - API_IMPL_END + // start with the minimal amount of info. + // we need opsets and a Graph instance + // user can query session inputs/outputs using existing API + // they can add Node and Initializer instances to the Graph, and provide additional domain:opset info if needed + model->graph = std::make_unique(); } -ORT_API_STATUS_IMPL(AddDynamicDimension, _In_ OrtShape* shape, const char* dimension_name) { +ORT_API_STATUS_IMPL(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info) { API_IMPL_BEGIN - if (dimension_name == nullptr || *dimension_name == '\0') { - shape->shape_proto.add_dim(); // 'unknown'dimension exists but has neither dim_value nor dim_param - } else { - shape->shape_proto.add_dim()->set_dim_param(dimension_name); + if (name == nullptr || *name == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "name cannot be null or empty string"); } - return nullptr; - API_IMPL_END -} + if (type_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_info cannot be null"); + } -ORT_API_STATUS_IMPL(CreateFixedShape, _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtShape** shape) { - API_IMPL_BEGIN - auto s = std::make_unique(); - for (size_t i = 0; i < dim_count; ++i) { - s->shape_proto.add_dim()->set_dim_value(dim_values[i]); + if (type_info->type != ONNX_TYPE_TENSOR) { + return OrtApis::CreateStatus(ORT_FAIL, "Only tensor types are supported currently"); + } + + if (type_info->tensor_type_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tensor_type_info cannot be null"); } - *shape = s.release(); + auto vi = std::make_unique(); + vi->name = name; + vi->type_info = type_info->Clone(); + + *value_info = vi.release(); + return nullptr; API_IMPL_END } -ORT_API(void, ReleaseShape, _Frees_ptr_opt_ OrtShape* shape) { - delete shape; +ORT_API_STATUS_IMPL(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name) { + API_IMPL_BEGIN + *name = value_info->name.c_str(); + return nullptr; + API_IMPL_END } - -ORT_API_STATUS_IMPL(CreateTensorValueInfo, _In_ const char* name, _In_ ONNXTensorElementDataType type, - _Inout_ OrtShape** shape, _Outptr_ OrtValueInfo** value_info) { +ORT_API_STATUS_IMPL(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info) { API_IMPL_BEGIN - auto vi = std::make_unique(); - vi->value_info_proto.set_name(name); - auto* tensor = vi->value_info_proto.mutable_type()->mutable_tensor_type(); - tensor->set_elem_type(type); - *tensor->mutable_shape() = (*shape)->shape_proto; - *value_info = vi.release(); - delete *shape; // take ownership of the OrtShape - *shape = nullptr; + *type_info = value_info->type_info.get(); return nullptr; API_IMPL_END @@ -140,33 +138,48 @@ ORT_API_STATUS_IMPL(CreateGraph, _Outptr_ OrtGraph** graph) { API_IMPL_END } -ORT_API_STATUS_IMPL(AddInput, _In_ OrtGraph* graph, _Inout_ OrtValueInfo** value_info) { +ORT_API_STATUS_IMPL(SetGraphInputs, _In_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len) { API_IMPL_BEGIN - graph->inputs.push_back(std::unique_ptr(*value_info)); // take ownership - *value_info = nullptr; + for (size_t i = 0; i < inputs_len; ++i) { + if (inputs[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "inputs cannot contain null entries"); + } + + graph->inputs.push_back(std::unique_ptr(inputs[i])); // take ownership + inputs[i] = nullptr; + } + return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(AddOutput, _In_ OrtGraph* graph, _Inout_ OrtValueInfo** value_info) { + +ORT_API_STATUS_IMPL(SetGraphOutputs, _In_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len) { API_IMPL_BEGIN - graph->outputs.push_back(std::unique_ptr(*value_info)); // take ownership - *value_info = nullptr; + for (size_t i = 0; i < outputs_len; ++i) { + if (outputs[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "outputs cannot contain null entries"); + } + + graph->outputs.push_back(std::unique_ptr(outputs[i])); // take ownership + outputs[i] = nullptr; + } + return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(AddInitializer, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue** tensor) { +ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor) { API_IMPL_BEGIN - graph->initializers[name] = std::unique_ptr(*tensor); // take ownership - *tensor = nullptr; + graph->initializers[name] = std::unique_ptr(tensor); // take ownership return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(AddNode, _In_ OrtGraph* graph, _Inout_ OrtNode** node) { +ORT_API_STATUS_IMPL(AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node) { API_IMPL_BEGIN - graph->nodes.push_back(std::unique_ptr(*node)); // take ownership - *node = nullptr; + graph->nodes.push_back(std::unique_ptr(node)); // take ownership return nullptr; API_IMPL_END } @@ -191,16 +204,18 @@ ORT_API_STATUS_IMPL(CreateModel, API_IMPL_END } -ORT_API_STATUS_IMPL(AddGraph, _In_ OrtModel* model, _Inout_ OrtGraph** graph) { +ORT_API_STATUS_IMPL(AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph) { API_IMPL_BEGIN - // TODO: High level validation - // Has inputs - // Has outputs - // Nodes are not necessarily required in a subgraph as a branch of an If may just pass through a value + if (graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } - model->graph = std::unique_ptr(*graph); // take ownership - *graph = nullptr; + if (graph->inputs.empty() || graph->outputs.empty()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph must have at least one input and one output"); + } + + model->graph = std::unique_ptr(graph); // take ownership return nullptr; API_IMPL_END } @@ -239,39 +254,115 @@ ORT_API_STATUS_IMPL(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const O API_IMPL_END } +ORT_API_STATUS_IMPL(CreateModelBuilderSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out, _Outptr_ OrtModel** model) { + API_IMPL_BEGIN + std::unique_ptr session; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, session)); + // No call to InitializeSession. We do that in UpdateSessionWithModel. + // ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); + + auto session_model = CreateOrtModelFromSession(*session); + *out = reinterpret_cast(session.release()); + *model = session_model.release(); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(CreateModelBuilderSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out, _Outptr_ OrtModel** model) { + API_IMPL_BEGIN + std::unique_ptr session; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, session)); + // No call to InitializeSession. We do that in UpdateSessionWithModel + // ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); + + auto session_model = CreateOrtModelFromSession(*session); + *out = reinterpret_cast(session.release()); + *model = session_model.release(); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(GetGraphFromModel, _In_ OrtModel* model, _Outptr_ OrtGraph** graph) { + API_IMPL_BEGIN + *graph = model->graph.get(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(ApplyModelToSession, _In_ OrtSession* session, _In_ OrtModel* model, + _In_reads_(additional_opset_entries_len) const char* const* additional_domain_names, + _In_reads_(additional_opset_entries_len) const int* additional_opset_versions, + _In_ size_t additional_opset_entries_len) { + API_IMPL_BEGIN + for (size_t i = 0; i < additional_opset_entries_len; ++i) { + model->domain_to_version[additional_domain_names[i]] = additional_opset_versions[i]; + } + + auto sess = reinterpret_cast(session); + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->ApplyUpdates(*model)); + + return nullptr; + API_IMPL_END + } // namespace OrtGraphApis static constexpr OrtGraphApi ort_graph_api = { // NOTE: The C# bindings depend on the API order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). - &OrtGraphApis::CreateFixedShape, - &OrtGraphApis::CreateShape, - &OrtGraphApis::AddDimension, - &OrtGraphApis::AddDynamicDimension, - &OrtGraphApis::ReleaseShape, - - &OrtGraphApis::CreateTensorValueInfo, + &OrtGraphApis::CreateValueInfo, + &OrtGraphApis::GetValueInfoName, + &OrtGraphApis::GetValueInfoTypeInfo, &OrtGraphApis::ReleaseValueInfo, &OrtGraphApis::CreateNode, &OrtGraphApis::ReleaseNode, &OrtGraphApis::CreateGraph, - &OrtGraphApis::AddInput, - &OrtGraphApis::AddOutput, - &OrtGraphApis::AddInitializer, - &OrtGraphApis::AddNode, + &OrtGraphApis::SetGraphInputs, + &OrtGraphApis::SetGraphOutputs, + &OrtGraphApis::AddInitializerToGraph, + &OrtGraphApis::AddNodeToGraph, &OrtGraphApis::ReleaseGraph, &OrtGraphApis::CreateModel, - &OrtGraphApis::AddGraph, + &OrtGraphApis::AddGraphToModel, &OrtGraphApis::ReleaseModel, &OrtGraphApis::CreateSessionFromModel, + + &OrtGraphApis::CreateModelBuilderSession, + &OrtGraphApis::CreateModelBuilderSessionFromArray, + &OrtGraphApis::GetGraphFromModel, + &OrtGraphApis::ApplyModelToSession, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned -static_assert(offsetof(OrtGraphApi, CreateSessionFromModel) / sizeof(void*) == 18, +static_assert(offsetof(OrtGraphApi, ApplyModelToSession) / sizeof(void*) == 19, "Size of version 21 API cannot change"); // initial version in ORT 1.21 ORT_API(const OrtGraphApi*, OrtGraphApis::GetGraphApi) { diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 287e3d7e1a5c7..6b2e4415282c6 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -36,6 +36,7 @@ #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/transform_layout_functions.h" #include "core/framework/utils.h" +#include "core/graph/graph_api_types.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/optimizer/graph_transformer_utils.h" @@ -1226,6 +1227,26 @@ common::Status InferenceSession::Load(const OrtModel& graph_api_model) { return Status::OK(); } +common::Status InferenceSession::ApplyUpdates(const OrtModel& graph_api_model) { + std::lock_guard l(session_mutex_); + + if (!is_model_loaded_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session does not contain a loaded model."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + if (is_inited_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session has already been initialized."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + auto& graph = model_->MainGraph().UpdateUsingGraphApiModel(graph_api_model); + + return Status::OK(); +} + common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format) { // The transformer order: // 1. Ensure we inline as many functions as possible. We refer to it as Ahead Of Time (AOT) function inlining. @@ -3318,6 +3339,10 @@ common::Status InferenceSession::WaitForNotification(Notification* p_executor_do return Status::OK(); } +const Model& InferenceSession::GetModel() const { + return *model_; +} + SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) { ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK()); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 02d76663e66d6..63ce0dd891c99 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -330,6 +330,19 @@ class InferenceSession { * @return OK if success. */ [[nodiscard]] common::Status Load(const OrtModel& graph_api_model); + + /** + * Apply updates from an OrtModel that was created via OrtGraphApi. + * This can: + * - add nodes at the start and end of the model + * - add initializers + * - update the graph inputs/outputs + * + * @param graph_api_model OrtModel from OrtGraphApi + * @return OK if success. + */ + [[nodiscard]] common::Status ApplyUpdates(const OrtModel& graph_api_model); + #endif // !defined(ORT_MINIMAL_BUILD) /** @@ -556,6 +569,8 @@ class InferenceSession { */ Status AddPrePackedWeightsContainer(PrepackedWeightsContainer* prepacked_weights_container); + const Model& GetModel() const; + protected: #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 5446118b3fec4..7ad257e0d5c30 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -28,6 +28,7 @@ #include "core/framework/utils.h" #include "core/graph/constants.h" #include "core/graph/graph.h" +#include "core/graph/model.h" #include "core/providers/get_execution_providers.h" #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" @@ -115,6 +116,70 @@ using namespace onnxruntime; auto v = (value); \ auto tensor = v->GetMutable(); +namespace { +// Create tensor. Allocates memory. Tensor owns memory. Allocator is wrapped and stored in a shared_ptr in Tensor. +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, + OrtAllocator* allocator, OrtValue& value) { + TensorShape tensor_shape(shape, shape_len); + AllocatorPtr alloc_ptr = std::make_shared(allocator); + Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value); + return nullptr; +} + +ORT_STATUS_PTR CreateTensorImplForSeq(MLDataType elem_type, const int64_t* shape, size_t shape_len, Tensor& out) { + OrtAllocator* allocator; + // TODO(pranav): what allocator should be used to create the tensor here? + // for the sake of simplicity of the API using the default one here + ORT_API_RETURN_IF_ERROR(OrtApis::GetAllocatorWithDefaultOptions(&allocator)); + AllocatorPtr alloc_ptr = std::make_shared(allocator); + TensorShape tensor_shape(shape, shape_len); + out = Tensor(elem_type, tensor_shape, std::move(alloc_ptr)); + return nullptr; +} + +// Create Tensor with existing data. Tensor does not own memory. +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, + const int64_t* shape, size_t shape_len, + const OrtMemoryInfo* info, + void* p_data, size_t p_data_len, + OrtValue& ort_value) { + TensorShape tensor_shape(shape, shape_len); + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); + } + + size_t size_to_allocate = 0; + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); + if (!status.IsOK()) { + return ToOrtStatus(status); + } + if (size_to_allocate > p_data_len) { + std::ostringstream oss; + oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + + Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value); + return nullptr; +} + +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, + const int64_t* shape, size_t shape_len, + OrtAllocator* deleter, + void* p_data, size_t p_data_len, + OrtValue& ort_value) { + TensorShape tensor_shape(shape, shape_len); + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); + } + + AllocatorPtr alloc_ptr = std::make_shared(deleter); + Tensor::InitOrtValue(ml_type, tensor_shape, p_data, std::move(alloc_ptr), ort_value); + return nullptr; +} + +} // namespace + ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out) { @@ -188,58 +253,26 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env, API_IMPL_END } -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, - _Inout_ OrtAllocator* allocator, OrtValue& value) { - TensorShape tensor_shape(shape, shape_len); - AllocatorPtr alloc_ptr = std::make_shared(allocator); - Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value); - return nullptr; -} - -ORT_STATUS_PTR CreateTensorImplForSeq(MLDataType elem_type, const int64_t* shape, size_t shape_len, Tensor& out) { - OrtAllocator* allocator; - // TODO(pranav): what allocator should be used to create the tensor here? - // for the sake of simplicity of the API using the default one here - ORT_API_RETURN_IF_ERROR(OrtApis::GetAllocatorWithDefaultOptions(&allocator)); - AllocatorPtr alloc_ptr = std::make_shared(allocator); - TensorShape tensor_shape(shape, shape_len); - out = Tensor(elem_type, tensor_shape, std::move(alloc_ptr)); - return nullptr; -} - -/** - * - * this function will create a copy of the allocator info - */ -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, const OrtMemoryInfo* info, - void* p_data, size_t p_data_len, OrtValue& ort_value) { - TensorShape tensor_shape(shape, shape_len); - if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); - } - - size_t size_to_allocate = 0; - Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); - if (!status.IsOK()) { - return ToOrtStatus(status); - } - if (size_to_allocate > p_data_len) { - std::ostringstream oss; - oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len; - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); - } - - Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value); +ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, + _Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { + API_IMPL_BEGIN + auto ml_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType(); + auto value = std::make_unique(); + ORT_API_RETURN_IF_ERROR(CreateTensorImpl(ml_type, shape, shape_len, info, p_data, p_data_len, *value)); + *out = value.release(); return nullptr; + API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, +ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, + _In_ OrtAllocator* deleter, _Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { API_IMPL_BEGIN auto ml_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType(); auto value = std::make_unique(); - ORT_API_RETURN_IF_ERROR(CreateTensorImpl(ml_type, shape, shape_len, info, p_data, p_data_len, *value)); + ORT_API_RETURN_IF_ERROR(CreateTensorImpl(ml_type, shape, shape_len, deleter, p_data, p_data_len, *value)); *out = value.release(); return nullptr; API_IMPL_END @@ -680,66 +713,6 @@ ORT_API_STATUS_IMPL(OrtApis::EnableOrtCustomOps, _Inout_ OrtSessionOptions* opti API_IMPL_END } -namespace { -// provider either model_path, or modal_data + model_data_length. -static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, - _In_ const OrtEnv* env, - _In_opt_z_ const ORTCHAR_T* model_path, - _In_opt_ const void* model_data, - size_t model_data_length, - std::unique_ptr& sess) { - // quick check here to decide load path. InferenceSession will provide error message for invalid values. - // TODO: Could move to a helper - const Env& os_env = Env::Default(); // OS environment (!= ORT environment) - bool load_config_from_model = - os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1"; - - if (load_config_from_model) { -#if !defined(ORT_MINIMAL_BUILD) - if (model_path != nullptr) { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), - model_path); - } else { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), - model_data, static_cast(model_data_length)); - } -#else - return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build."); -#endif - } else { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment()); - } - -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - // Add custom domains - if (options && !options->custom_op_domains_.empty()) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); - } -#endif - - // Finish load - if (load_config_from_model) { -#if !defined(ORT_MINIMAL_BUILD) - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load()); -#endif - } else { - if (model_path != nullptr) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path)); - } else { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast(model_data_length))); - } - } - - return nullptr; -} -} // namespace - ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN @@ -1346,6 +1319,20 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerCount, _In_ const O return GetNodeDefListCountHelper(sess, get_overridable_initializers_fn, out); } +ORT_API_STATUS_IMPL(SessionGetOpsetForDomain, _In_ const OrtSession* ort_session, _In_ const char* domain, + _Out_ int* opset) { + const auto& session = *reinterpret_cast(ort_session); + const auto& domain_opset_map = session.Model().MainGraph().DomainToVersionMap(); + + auto it = domain_opset_map.find(domain); + if (it == domain_opset_map.cend()) { + return OrtApis::CreateStatus(ORT_FAIL, "Domain not used by model."); + } + + *opset = it->second; + return nullptr; +} + static ORT_STATUS_PTR GetNodeDefTypeInfoHelper(const OrtSession* sess, GetDefListFn get_fn, size_t index, _Outptr_ struct OrtTypeInfo** out) { API_IMPL_BEGIN @@ -2786,6 +2773,15 @@ static constexpr OrtApi ort_api_1_to_21 = { // End of Version 20 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::GetGraphApi, + + &OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, + + // APIs to create/edit type info when building/modifying a model using the Graph API + &OrtApis::CreateTensorTypeInfo, + &OrtApis::CreateSparseTensorTypeInfo, + &OrtApis::CreateMapTypeInfo, + &OrtApis::CreateSequenceTypeInfo, + &OrtApis::CreateOptionalTypeInfo, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 27ff21260f292..b75cc1b8a40de 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -536,4 +536,23 @@ ORT_API_STATUS_IMPL(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv ORT_API(const OrtGraphApi*, GetGraphApi); +ORT_API_STATUS_IMPL(CreateTensorWithDataAndDeleterAsOrtValue, _In_ const OrtAllocator* deleter, + _Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + +ORT_API_STATUS_IMPL(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, + _Out_ int* opset); + +// APIs to create/edit type info +ORT_API_STATUS_IMPL(CreateTensorTypeInfo, ONNXTensorElementDataType element_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtTensorTypeAndShapeInfo** tensor_info); +ORT_API_STATUS_IMPL(CreateSparseTensorTypeInfo, ONNXTensorElementDataType element_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtTensorTypeAndShapeInfo** tensor_info); +ORT_API_STATUS_IMPL(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtMapTypeInfo** map_info); +ORT_API_STATUS_IMPL(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtSequenceTypeInfo** sequence_info); +ORT_API_STATUS_IMPL(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, + _Out_ OrtTypeInfo** type_info, _Out_ OrtOptionalTypeInfo** optional_info); } // namespace OrtApis diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 4063e9d8f20d3..57ed887c9906d 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -7,6 +7,7 @@ #include "core/framework/execution_provider.h" #include "core/session/abi_session_options_impl.h" #include "core/session/inference_session.h" +#include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_c_api.h" using namespace onnxruntime; @@ -32,6 +33,64 @@ common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); } +// provider either model_path, or modal_data + model_data_length. +OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess) { + // quick check here to decide load path. InferenceSession will provide error message for invalid values. + // TODO: Could move to a helper + const Env& os_env = Env::Default(); // OS environment (!= ORT environment) + bool load_config_from_model = + os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1"; + + if (load_config_from_model) { +#if !defined(ORT_MINIMAL_BUILD) + if (model_path != nullptr) { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment(), + model_path); + } else { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment(), + model_data, static_cast(model_data_length)); + } +#else + return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build."); +#endif + } else { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment()); + } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + // Add custom domains + if (options && !options->custom_op_domains_.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); + } +#endif + + // Finish load + if (load_config_from_model) { +#if !defined(ORT_MINIMAL_BUILD) + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load()); +#endif + } else { + if (model_path != nullptr) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path)); + } else { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast(model_data_length))); + } + } + + return nullptr; +} + OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, _In_ std::unique_ptr<::onnxruntime::InferenceSession>& sess, _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) { diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index d1a2386732b4b..2ed138913e466 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -19,3 +19,10 @@ class InferenceSession; OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, _In_ std::unique_ptr<::onnxruntime::InferenceSession>& sess, _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); + +OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess);