Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wrapper all existing winml adapter apis with API_IMPL to try catch #2854

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions onnxruntime/core/framework/onnxruntime_typeinfo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/sparse_tensor.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/ort_apis.h"
#include "core/framework/error_code_helper.h"

#include "core/framework/tensor_type_and_shape.h"
#include "../../winml/adapter/winml_adapter_map_type_info.h"
Expand Down Expand Up @@ -61,19 +62,25 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtType
}

ORT_API_STATUS_IMPL(winmla::CastTypeInfoToMapTypeInfo, const OrtTypeInfo* type_info, const OrtMapTypeInfo** out) {
API_IMPL_BEGIN
*out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info_ : nullptr;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::CastTypeInfoToSequenceTypeInfo, const OrtTypeInfo* type_info, const OrtSequenceTypeInfo** out) {
API_IMPL_BEGIN
*out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info_ : nullptr;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::GetDenotationFromTypeInfo, const OrtTypeInfo* type_info, const char** const out, size_t* len) {
API_IMPL_BEGIN
*out = type_info->denotation_.c_str();
*len = type_info->denotation_.size();
return nullptr;
API_IMPL_END
}

ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) {
Expand Down
2 changes: 2 additions & 0 deletions winml/adapter/winml_adapter_dml.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void DmlConfigureProviderFactoryDefaultRoundingMode(onnxruntime::IExecutionProvi

ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options,
ID3D12Device* d3d_device, ID3D12CommandQueue* queue) {
API_IMPL_BEGIN
auto dml_device = CreateDmlDevice(d3d_device);
if (auto status = OrtSessionOptionsAppendExecutionProviderEx_DML(options, dml_device.Get(), queue)) {
return status;
Expand All @@ -61,6 +62,7 @@ ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_
// So we create the provider with rounding disabled, and expect the caller to enable it after.
onnxruntime::DmlConfigureProviderFactoryDefaultRoundingMode(factory, AllocatorRoundingMode::Disabled);
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled) {
Expand Down
6 changes: 5 additions & 1 deletion winml/adapter/winml_adapter_environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class WinmlAdapterLoggingWrapper : public LoggingWrapper {
ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* env, OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function,
_In_opt_ void* logger_param, OrtLoggingLevel default_warning_level,
_In_ const char* logid, _Outptr_ OrtEnv** out) {
API_IMPL_BEGIN
std::string name = logid;
std::unique_ptr<onnxruntime::logging::ISink> logger = onnxruntime::make_unique<WinmlAdapterLoggingWrapper>(logging_function, profiling_function, logger_param);

Expand All @@ -64,6 +65,7 @@ ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* en
// Set a new default logging manager
env->SetLoggingManager(std::move(winml_logging_manager));
return nullptr;
API_IMPL_END
}

// Override select shape inference functions which are incomplete in ONNX with versions that are complete,
Expand All @@ -72,11 +74,13 @@ ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* en
// registered schema are reachable only after upstream schema have been revised in a later OS release,
// which would be a compatibility risk.
ORT_API_STATUS_IMPL(winmla::OverrideSchema) {
API_IMPL_BEGIN
#ifdef USE_DML
static std::once_flag schema_override_once_flag;
std::call_once(schema_override_once_flag, []() {
SchemaInferenceOverrider::OverrideSchemaInferenceFunctions();
});
#endif USE_DML.
return nullptr;
#endif USE_DML
API_IMPL_END
}
4 changes: 4 additions & 0 deletions winml/adapter/winml_adapter_map_type_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,16 @@ OrtStatus* OrtMapTypeInfo::Clone(OrtMapTypeInfo** out) {

// OrtMapTypeInfo Accessors
ORT_API_STATUS_IMPL(winmla::GetMapKeyType, const OrtMapTypeInfo* map_type_info, enum ONNXTensorElementDataType* out) {
API_IMPL_BEGIN
*out = map_type_info->map_key_type_;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::GetMapValueType, const OrtMapTypeInfo* map_type_info, OrtTypeInfo** out) {
API_IMPL_BEGIN
return map_type_info->map_value_type_->Clone(out);
API_IMPL_END
}

ORT_API(void, winmla::ReleaseMapTypeInfo, OrtMapTypeInfo* ptr) {
Expand Down
36 changes: 36 additions & 0 deletions winml/adapter/winml_adapter_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,90 +195,118 @@ std::unique_ptr<onnx::ModelProto> OrtModel::DetachModelProto() {
}

ORT_API_STATUS_IMPL(winmla::CreateModelFromPath, const char* model_path, size_t size, OrtModel** out) {
API_IMPL_BEGIN
if (auto status = OrtModel::CreateOrtModelFromPath(model_path, size, out)) {
return status;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::CreateModelFromData, void* data, size_t size, OrtModel** out) {
API_IMPL_BEGIN
if (auto status = OrtModel::CreateOrtModelFromData(data, size, out)) {
return status;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::CloneModel, const OrtModel* in, OrtModel** out) {
API_IMPL_BEGIN
auto model_proto_copy = std::make_unique<onnx::ModelProto>(*in->UseModelProto());
if (auto status = OrtModel::CreateOrtModelFromProto(std::move(model_proto_copy), out)) {
return status;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetAuthor, const OrtModel* model, const char** const author, size_t* len) {
API_IMPL_BEGIN
*author = model->UseModelInfo()->author_.c_str();
*len = model->UseModelInfo()->author_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetName, const OrtModel* model, const char** const name, size_t* len) {
API_IMPL_BEGIN
*name = model->UseModelInfo()->name_.c_str();
*len = model->UseModelInfo()->name_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetDomain, const OrtModel* model, const char** const domain, size_t* len) {
API_IMPL_BEGIN
*domain = model->UseModelInfo()->domain_.c_str();
*len = model->UseModelInfo()->domain_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetDescription, const OrtModel* model, const char** const description, size_t* len) {
API_IMPL_BEGIN
*description = model->UseModelInfo()->description_.c_str();
*len = model->UseModelInfo()->description_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetVersion, const OrtModel* model, int64_t* version) {
API_IMPL_BEGIN
*version = model->UseModelInfo()->version_;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetMetadataCount, const OrtModel* model, size_t* count) {
API_IMPL_BEGIN
*count = model->UseModelInfo()->model_metadata_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetMetadata, const OrtModel* model, size_t count, const char** const key,
size_t* key_len, const char** const value, size_t* value_len) {
API_IMPL_BEGIN
*key = model->UseModelInfo()->model_metadata_[count].first.c_str();
*key_len = model->UseModelInfo()->model_metadata_[count].first.size();
*value = model->UseModelInfo()->model_metadata_[count].second.c_str();
*value_len = model->UseModelInfo()->model_metadata_[count].second.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetInputCount, const OrtModel* model, size_t* count) {
API_IMPL_BEGIN
*count = model->UseModelInfo()->input_features_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetOutputCount, const OrtModel* model, size_t* count) {
API_IMPL_BEGIN
*count = model->UseModelInfo()->output_features_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetInputName, const OrtModel* model, size_t index, const char** input_name, size_t* count) {
API_IMPL_BEGIN
*input_name = model->UseModelInfo()->input_features_[index]->name().c_str();
*count = model->UseModelInfo()->input_features_[index]->name().size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetOutputName, const OrtModel* model, size_t index, const char** output_name, size_t* count) {
API_IMPL_BEGIN
*output_name = model->UseModelInfo()->output_features_[index]->name().c_str();
*count = model->UseModelInfo()->output_features_[index]->name().size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, size_t index, const char** input_description, size_t* count) {
Expand All @@ -288,26 +316,33 @@ ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, siz
}

ORT_API_STATUS_IMPL(winmla::ModelGetOutputDescription, const OrtModel* model, size_t index, const char** output_description, size_t* count) {
API_IMPL_BEGIN
*output_description = model->UseModelInfo()->output_features_[index]->doc_string().c_str();
*count = model->UseModelInfo()->output_features_[index]->doc_string().size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetInputTypeInfo, const OrtModel* model, size_t index, OrtTypeInfo** type_info) {
API_IMPL_BEGIN
if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->input_features_[index]->type(), type_info)) {
return status;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetOutputTypeInfo, const OrtModel* model, size_t index, OrtTypeInfo** type_info) {
API_IMPL_BEGIN
if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->output_features_[index]->type(), type_info)) {
return status;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, const OrtModel* model) {
API_IMPL_BEGIN
auto model_info = model->UseModelInfo();
auto model_proto = model->UseModelProto();
auto& graph = model_proto->graph();
Expand Down Expand Up @@ -372,6 +407,7 @@ ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, const OrtModel* model) {
}
}
return nullptr;
API_IMPL_END
}

ORT_API(void, winmla::ReleaseModel, OrtModel* ptr) {
Expand Down
2 changes: 2 additions & 0 deletions winml/adapter/winml_adapter_sequence_type_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ OrtStatus* OrtSequenceTypeInfo::Clone(OrtSequenceTypeInfo** out) {
}

ORT_API_STATUS_IMPL(winmla::GetSequenceElementType, const OrtSequenceTypeInfo* sequence_type_info, OrtTypeInfo** out) {
API_IMPL_BEGIN
return sequence_type_info->sequence_key_type_->Clone(out);
API_IMPL_END
}

ORT_API(void, winmla::ReleaseSequenceTypeInfo, OrtSequenceTypeInfo* ptr) {
Expand Down