Skip to content

Commit

Permalink
Most API pieces in place
Browse files Browse the repository at this point in the history
  • Loading branch information
skottmckay committed Dec 21, 2024
1 parent 96f05d0 commit e71505f
Show file tree
Hide file tree
Showing 17 changed files with 705 additions and 312 deletions.
4 changes: 3 additions & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
const logging::Logger& logger,
std::unique_ptr<Graph>& graph);

Status UpdateUsingGraphApiModel(const OrtModel& api_model);

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const {
return runtime_optimizations_;
Expand Down Expand Up @@ -1796,7 +1798,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
std::unordered_map<std::string, std::unordered_set<NodeIndex>> node_arg_to_consumer_nodes_;
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

const std::unordered_map<std::string, int> domain_to_version_;
std::unordered_map<std::string, int> domain_to_version_;

// Model IR version.
Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
Expand Down
132 changes: 97 additions & 35 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,8 @@ struct OrtApi {
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length,
ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env,
_In_ const void* model_data, size_t model_data_length,
_In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out);

/** \brief Run the model in an ::OrtSession
Expand Down Expand Up @@ -4785,19 +4786,39 @@ struct OrtApi {
// Get the OrtGraphApi instance for creating a model dynamically.
const OrtGraphApi*(ORT_API_CALL* GetGraphApi)();

// Create TypeInfo for editing.
// User provides custom OrtAllocator. We only call the OrtAllocator::Info and OrtAllocator::Free functions.
// The OrtMemoryInfo returned by the allocator must match where `p_data` is.
ORT_API2_STATUS(CreateTensorWithDataAndDeleterAsOrtValue, _In_ const OrtAllocator* deleter,
_In_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len,
ONNXTensorElementDataType type,
_Outptr_ OrtValue** out);

//
ORT_API2_STATUS(CreateTypeInfo, _In_ enum ONNXType onnx_type, _Out_ OrtTypeInfo** type_info);

// Mutable getter. Can we used for Tensor or SparseTensor
ORT_API2_STATUS(GetMutableTensorInfoFromTypeInfo, _In_ OrtTypeInfo* type_info,
_Outptr_result_maybenull_ OrtTensorTypeAndShapeInfo** tensor_info);
ORT_API2_STATUS(GetMutableMapInfoFromTypeInfo, _In_ OrtTypeInfo* type_info,
_Outptr_result_maybenull_ OrtMapTypeInfo** map_info);
ORT_API2_STATUS(GetMutableSequenceInfoFromTypeInfo, _In_ OrtTypeInfo* type_info,
_Outptr_result_maybenull_ OrtSequenceTypeInfo** tensor_info);
ORT_API2_STATUS(GetMutableOptionalInfoFromTypeInfo, _In_ OrtTypeInfo* type_info,
_Outptr_result_maybenull_ OrtOptionalTypeInfo** tensor_info);
// Additions for creating/editing graph inputs/outputs in model builder API
//

// Get the opset for the given domain. And nodes added must conform to this opset.
ORT_API2_STATUS(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, _Out_ int* opset);

// Create TypeInfo instance to populate for graph inputs/outputs.
// We have to create the TypeInfo and the internal type, and need args for the latter.
// We _could_ just return OrtTypeInfo, but the user is going to need the internal type to actually do something
// meaningful with it, so return both instead of making the user call another function to get a mutable instance
// of the internal type.

// Create Tensor TypeInfo. User owns type_info. type_info owns tensor_info.
ORT_API2_STATUS(CreateTensorTypeInfo, ONNXTensorElementDataType element_type,
_Out_ OrtTypeInfo** type_info, _Out_ OrtTensorTypeAndShapeInfo** tensor_info);
// Create SparseTensor TypeInfo. User owns type_info. type_info owns tensor_info.
ORT_API2_STATUS(CreateSparseTensorTypeInfo, ONNXTensorElementDataType element_type,
_Out_ OrtTypeInfo** type_info, _Out_ OrtTensorTypeAndShapeInfo** tensor_info);
// Create Map TypeInfo. User owns type_info. type_info owns map_info.
ORT_API2_STATUS(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type,
_Out_ OrtTypeInfo** type_info, _Out_ OrtMapTypeInfo** map_info);
ORT_API2_STATUS(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type,
_Out_ OrtTypeInfo** type_info, _Out_ OrtSequenceTypeInfo** sequence_info);
ORT_API2_STATUS(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type,
_Out_ OrtTypeInfo** type_info, _Out_ OrtOptionalTypeInfo** optional_info);
};

/*
Expand Down Expand Up @@ -4923,10 +4944,10 @@ ORT_RUNTIME_CLASS(ValueInfo);
struct OrtGraphApi {
/*** New usage ***
Create OrtTypeInfo for Tensor input:
CreateTypeInfo
GetMutableTensorInfoFromTypeInfo
SetTensorElementType
SetDimensions + SetSymbolicDimensions to define shape (these are complimentary and overlap)
CreateTensorTypeInfo
SetDimensions + SetSymbolicDimensions to define shape. These are 1:1.
If dim is -1 there should be a non-nullptr symbolic value. May be empty string but must not be nullptr.
If dim >= 0 there should be a nullptr symbolic value.
Read from OrtTypeInfo:
GetOnnxTypeFromTypeInfo
Expand All @@ -4939,12 +4960,11 @@ struct OrtGraphApi {
Call CreateValueInfo to combine value name with OrtTypeInfo
***/

// OrtValueInfo takes ownership of the OrtTypeInfo
ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ OrtTypeInfo* type_info,
// user can release type_info.
ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info,
_Outptr_ OrtValueInfo** value_info);
ORT_API2_STATUS(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name);
ORT_API2_STATUS(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info);
ORT_API2_STATUS(CloneValueInfo, _In_ const OrtValueInfo* src_value_info, _Out_ OrtValueInfo** new_value_info);
ORT_CLASS_RELEASE(ValueInfo); // call if not added to Graph

//
Expand All @@ -4968,18 +4988,11 @@ struct OrtGraphApi {

// Set the Inputs/Outputs. Replaces and frees any existing inputs/outputs.
// Graph takes ownership of the OrtTypeInfo instances.
ORT_API2_STATUS(SetInputs, _In_ OrtGraph* graph,
ORT_API2_STATUS(SetGraphInputs, _In_ OrtGraph* graph,
_In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len);
ORT_API2_STATUS(SetOutputs, _In_ OrtGraph* graph,
ORT_API2_STATUS(SetGraphOutputs, _In_ OrtGraph* graph,
_In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len);

// Get current inputs/outputs. Expected usage is augmenting an existing graph.
// To change the inputs/outputs, use CloneValueInfo for any OrtValueInfo instances you want to keep,
// and CreateValueInfo to create new inputs/outputs.
// Call SetInputs/SetOutputs with the new instances to replace all existing inputs/outputs.
ORT_API2_STATUS(GetInputs, _In_ OrtGraph* graph, _Inout_ const OrtValueInfo** inputs, _Out_ size_t* inputs_len);
ORT_API2_STATUS(GetOutputs, _In_ OrtGraph* graph, _Inout_ const OrtValueInfo** outputs, _Out_ size_t* outputs_len);

// 2 use cases. User is free to choose either approach but the suggested usage would be:
//
// 1: Weights:
Expand All @@ -4991,8 +5004,11 @@ struct OrtGraphApi {
// e.g. min/max input of Clip, indices for Gather.
// Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the actual data.
// We will copy the data when converting to TensorProto so user doesn't need to keep the data alive.
ORT_API2_STATUS(AddInitializer, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue** tensor);
ORT_API2_STATUS(AddNode, _In_ OrtGraph* graph, _In_ OrtNode** node);
//
// Graph takes ownership of initializer.
ORT_API2_STATUS(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor);
// Graph takes ownership of node
ORT_API2_STATUS(AddNodeToGraph, _In_ OrtGraph* graph, _In_ OrtNode* node);
ORT_CLASS_RELEASE(Graph); // call if not added to Model

//
Expand All @@ -5007,7 +5023,7 @@ struct OrtGraphApi {
size_t opset_entries_len,
_Outptr_ OrtModel** model);

ORT_API2_STATUS(AddGraph, _In_ OrtModel* model, _Inout_ OrtGraph** graph);
ORT_API2_STATUS(AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph);
ORT_CLASS_RELEASE(Model);

//
Expand All @@ -5022,6 +5038,10 @@ struct OrtGraphApi {
ORT_API2_STATUS(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model,
_In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out);

//
// Model editing
//

// TODO: How best to support creating a session with an existing model and allowing updates to it?
//
// We need to be able to update the opsets
Expand All @@ -5035,12 +5055,54 @@ struct OrtGraphApi {
// If we add something to OrtSessionOptions the existing APIs could be used
// - based on flag in session options, skip the call to InitializeSession in CreateSession
// - that will leave the OrtSession in an invalid state
// - could use SessionGetInput*/SessionGetOutput* functions to get input/output info but that's just the TypeInfo
// and our API currently deals in ValueInfo for those
// - how does that work with changing the inputs/outputs?
// - do we skip populating the Graph API inputs/outputs until there's a call to SetInputs/SetOutputs and simply
// return info from the existing model up until that point?
// - how do we manage ownership of existing vs. new inputs/outputs?
// - add function to get OrtModel from session
// - add function to augment the onnxruntime::Graph and finalize the session
// - can share implementation details with Graph::LoadFromGraphApiModel
ORT_API2_STATUS(GetModelFromSession, _In_ OrtSession* session, _Outptr_ OrtModel** model);
ORT_API2_STATUS(GetOpsetFromModel, _In_ const OrtModel* model, _In_ const char* domain_name, int* opset_version);
ORT_API2_STATUS(UpdateSessionWithModel, _In_ OrtSession* session, _In_ OrtModel* model,

ORT_API2_STATUS(CreateModelBuilderSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
_In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out, _Outptr_ OrtModel** model);

ORT_API2_STATUS(CreateModelBuilderSessionFromArray, _In_ const OrtEnv* env,
_In_ const void* model_data, size_t model_data_length,
_In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out, _Outptr_ OrtModel** model);

// To get current inputs/outputs to augment an existing graph use SessionGetInputCount/SessionGetOutputCount,
// SessionGetInputTypeInfo/SessionGetOutputTypeInfo to get the TypeInfo, and SessionGetInputName/SessionGetOutputName
// to get the names.
//
// Call SetInputs/SetOutputs with the new instances to replace all existing inputs/outputs.
// Get OrtModel for editing. This can be used to add nodes before or after the existing Graph.
// The model inputs/outputs MUST be updated to reflect the changes made to the Graph.
// Any new operator domains must be added.
// Existing operator domains must be honoured. i.e. you must use the same opset the model currently uses.
// ORT_API2_STATUS(GetModelFromSession, _In_ OrtSession* session, _Outptr_ OrtModel** model);

// Get OrtGraph from OrtModel. OrtModel owns OrtGraph so use should NOT call ReleaseGraph.
//
// User can call AddNode to add nodes at the start of end of the OrtGraph.
// User can call AddInitializer.
// User must adjust the Graph inputs/outputs to be valid for the changes made.
//
// e.g. Original: InputA -> NodeA ->.
// User adds a new Node 'InputPreA -> NodePreA -> InputA'.
// NodeA will now consume the output of NodePreA.
// The original InputA graph input must be replaced with a new graph input for InputPreA.
//
// This is done by using SessionGetInputCount/SessionGetInputName/SessionGetInputTypeInfo to get the existing
// input info. Create new OrtValueInfo instances for the required existing and new inputs.
// Call SetInputs
ORT_API2_STATUS(GetGraphFromModel, _In_ OrtModel* model, _Outptr_ OrtGraph** graph);

// Update the model in the session.
// Existing opsets cannot be changed, so `additional_domain_names` must NOT match any existing domain names.
// Session will be finalized and ready for inferencing on completion.
ORT_API2_STATUS(ApplyModelToSession, _In_ OrtSession* session, _In_ OrtModel* model,
_In_reads_(additional_opset_entries_len) const char* const* additional_domain_names,
_In_reads_(additional_opset_entries_len) const int* additional_opset_versions,
_In_ size_t additional_opset_entries_len);
Expand Down
45 changes: 6 additions & 39 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,6 @@ ORT_DEFINE_RELEASE(KernelInfo);
#define ORT_DEFINE_GRAPH_API_RELEASE(NAME) \
inline void OrtRelease(Ort##NAME* ptr) { GetGraphApi().Release##NAME(ptr); }

ORT_DEFINE_GRAPH_API_RELEASE(Shape);
ORT_DEFINE_GRAPH_API_RELEASE(ValueInfo);
ORT_DEFINE_GRAPH_API_RELEASE(Node);
ORT_DEFINE_GRAPH_API_RELEASE(Graph);
Expand Down Expand Up @@ -600,14 +599,6 @@ struct Base {
return p;
}

/// \brief Allows relinquishing ownership of the contained C object pointer if the API call is successful
/// The underlying object is not destroyed
// TODO: Refine this. Method name is based on usage not what it does which is ugly.
// Should we use `friend` instead to allow ownership transfer when building a model at runtime?
contained_type** release_on_success() {
return &p_;
}

protected:
contained_type* p_{};
};
Expand Down Expand Up @@ -2509,39 +2500,15 @@ struct CustomOpBase : OrtCustomOp {
//
namespace GraphApi {

namespace detail {
template <typename T>
struct ShapeImpl : Ort::detail::Base<T> {
using B = Ort::detail::Base<T>;
using B::B;

template <typename U>
bool operator==(const ShapeImpl<U>& o) const;
};
} // namespace detail

// Const object holder that does not own the underlying object
using ConstShape = detail::ShapeImpl<Ort::detail::Unowned<const OrtShape>>;

/** \brief Wrapper around ::OrtShape
*
*/
struct Shape : detail::ShapeImpl<OrtShape> {
using Dimension = std::variant<int64_t, std::string>;
explicit Shape(std::nullptr_t) {} ///< No instance is created
explicit Shape(OrtShape* p) : ShapeImpl<OrtShape>{p} {} ///< Take ownership of a pointer created by C Api
Shape(const std::vector<int64_t>& dims); ///< Wraps CreateFixedShape. All dims must be >= 0.
Shape(const std::vector<Dimension>& dims); /// <Wraps CreateShape + AddDimension/AddDynamicDimension

ConstShape GetConst() const { return ConstShape{this->p_}; }
};

namespace detail {
template <typename T>
struct ValueInfoImpl : Ort::detail::Base<T> {
using B = Ort::detail::Base<T>;
using B::B;

std::string Name() const;
ConstTypeInfo TypeInfo() const;

template <typename U>
bool operator==(const ValueInfoImpl<U>& o) const;
};
Expand All @@ -2558,7 +2525,7 @@ struct ValueInfo : detail::ValueInfoImpl<OrtValueInfo> {
explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl<OrtValueInfo>{p} {} ///< Take ownership of a pointer created by C Api

// Create ValueInfo for a tensor
static ValueInfo CreateTensorValueInfo(const std::string& name, ONNXTensorElementDataType type, Shape& shape);
explicit ValueInfo(const std::string& name, ConstTypeInfo& type_info);

ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; }
};
Expand Down Expand Up @@ -2615,8 +2582,8 @@ struct GraphImpl : Ort::detail::Base<T> {
using B = Ort::detail::Base<T>;
using B::B;

void AddInput(ValueInfo& input);
void AddOutput(ValueInfo& output);
void SetInputs(std::vector<ValueInfo>& inputs);
void SetOutputs(std::vector<ValueInfo>& outputs);
void AddInitializer(const std::string& name, Value& initializer); // Graph takes ownership of Value
void AddNode(Node& node); // Graph takes ownership of Node

Expand Down
Loading

0 comments on commit e71505f

Please sign in to comment.