Skip to content

Commit

Permalink
Merge pull request microsoft#22 from NonStatic2014/bohu/fix_random_re…
Browse files Browse the repository at this point in the history
…sult

Fix random prediction results and replace camel-case names with snake-case names
  • Loading branch information
NonStatic2014 authored Apr 2, 2019
2 parents 19747b1 + 74c2f8e commit 094aca2
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 72 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/hosting/converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ onnx::TensorProto_DataType MLDataTypeToTensorProtoDataType(const onnxruntime::Da
}
}

common::Status MLValue2TensorProto(onnxruntime::MLValue& ml_value, bool using_raw_data,
std::shared_ptr<onnxruntime::logging::Logger> logger,
/* out */ onnx::TensorProto& tensor_proto) {
common::Status MLValueToTensorProto(onnxruntime::MLValue& ml_value, bool using_raw_data,
std::shared_ptr<onnxruntime::logging::Logger> logger,
/* out */ onnx::TensorProto& tensor_proto) {
// Tensor in MLValue
const auto& tensor = ml_value.Get<onnxruntime::Tensor>();

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/hosting/converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ onnx::TensorProto_DataType MLDataTypeToTensorProtoDataType(const onnxruntime::Da
// * segment field: we do not expect very large tensors in the prediction output
// * external_data field: we do not expect very large tensors in the prediction output
// Note: If any input data is in raw_data field, all outputs tensor data will be put into raw_data field.
common::Status MLValue2TensorProto(onnxruntime::MLValue& ml_value, bool using_raw_data,
std::shared_ptr<onnxruntime::logging::Logger> logger,
/* out */ onnx::TensorProto& tensor_proto);
common::Status MLValueToTensorProto(onnxruntime::MLValue& ml_value, bool using_raw_data,
std::shared_ptr<onnxruntime::logging::Logger> logger,
/* out */ onnx::TensorProto& tensor_proto);

} // namespace hosting
} // namespace onnxruntime
Expand Down
30 changes: 18 additions & 12 deletions onnxruntime/hosting/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ protobufutil::Status Executor::Predict(const std::string& model_name, const std:
auto logger = env_->GetLogger(request_id);

// Create the input NameMLValMap
onnxruntime::NameMLValMap nameMlValMap{};
onnxruntime::NameMLValMap name_ml_value_map{};
common::Status status{};
for (const auto& input : request.inputs()) {
std::string input_name = input.first;
onnx::TensorProto input_tensor = input.second;
using_raw_data = using_raw_data && input_tensor.has_raw_data();

// Prepare the MLValue object
OrtAllocatorInfo* cpuAllocatorInfo = nullptr;
auto ort_status = OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo);
if (ort_status != nullptr || cpuAllocatorInfo == nullptr) {
OrtAllocatorInfo* cpu_allocator_info = nullptr;
auto ort_status = OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpu_allocator_info);
if (ort_status != nullptr || cpu_allocator_info == nullptr) {
LOGS(*logger, ERROR) << "OrtCreateAllocatorInfo FAILED! Input name: " << input_name;
return protobufutil::Status(protobufutil::error::Code::RESOURCE_EXHAUSTED, "OrtCreateAllocatorInfo() FAILED!");
}
Expand All @@ -58,17 +58,17 @@ protobufutil::Status Executor::Predict(const std::string& model_name, const std:
}

std::unique_ptr<char[]> data(new char[cpu_tensor_length]);
memset(data.get(), 0, cpu_tensor_length);
if (nullptr == data) {
LOGS(*logger, ERROR) << "Run out memory. Input name: " << input_name;
return protobufutil::Status(protobufutil::error::Code::RESOURCE_EXHAUSTED, "Run out of memory");
}
memset(data.get(), 0, cpu_tensor_length);

// TensorProto -> MLValue
MLValue ml_value;
OrtCallback deleter;
status = onnxruntime::utils::TensorProtoToMLValue(onnxruntime::Env::Default(), nullptr, input_tensor,
onnxruntime::MemBuffer(data.get(), cpu_tensor_length, *cpuAllocatorInfo),
onnxruntime::MemBuffer(data.release(), cpu_tensor_length, *cpu_allocator_info),
ml_value, deleter);
if (!status.IsOK()) {
LOGS(*logger, ERROR) << "TensorProtoToMLValue() FAILED! Input name: " << input_name
Expand All @@ -78,7 +78,13 @@ protobufutil::Status Executor::Predict(const std::string& model_name, const std:
"TensorProtoToMLValue() FAILED: " + status.ErrorMessage());
}

nameMlValMap[input_name] = ml_value;
auto insertion_result = name_ml_value_map.insert(std::make_pair(input_name, ml_value));
if (!insertion_result.second) {
LOGS(*logger, ERROR) << "Predict() FAILED! Input name: " << input_name
<< " Trying to overwrite existing input value";
return protobufutil::Status(protobufutil::error::Code::ALREADY_EXISTS,
"Predict() FAILED: Trying to overwrite existing input value");
}
} // for(const auto& input : request.inputs())

// Prepare the output names and vector
Expand All @@ -89,11 +95,11 @@ protobufutil::Status Executor::Predict(const std::string& model_name, const std:
std::vector<onnxruntime::MLValue> outputs(output_names.size());

// Run()!
OrtRunOptions runOptions{};
runOptions.run_log_verbosity_level = 4; // TODO: respect user selected log level
runOptions.run_tag = request_id;
OrtRunOptions run_options{};
run_options.run_log_verbosity_level = 4; // TODO: respect user selected log level
run_options.run_tag = request_id;

status = env_->GetSession()->Run(runOptions, nameMlValMap, output_names, &outputs);
status = env_->GetSession()->Run(run_options, name_ml_value_map, output_names, &outputs);
if (!status.IsOK()) {
LOGS(*logger, ERROR) << "Run() FAILED!"
<< " Error code: " << status.Code()
Expand All @@ -105,7 +111,7 @@ protobufutil::Status Executor::Predict(const std::string& model_name, const std:
// Build the response
for (size_t i = 0; i < outputs.size(); ++i) {
onnx::TensorProto output_tensor{};
status = MLValue2TensorProto(outputs[i], using_raw_data, std::move(logger), output_tensor);
status = MLValueToTensorProto(outputs[i], using_raw_data, std::move(logger), output_tensor);
if (!status.IsOK()) {
LOGS(*logger, ERROR) << "MLValue2TensorProto() FAILED! Output name: " << output_names[i]
<< " Error code: " << status.Code()
Expand Down
Loading

0 comments on commit 094aca2

Please sign in to comment.