Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Android NNAPI EP] Add QLinearAdd op Support, move some throw with return status #4607

Merged
merged 9 commits into from
Jul 30, 2020
28 changes: 28 additions & 0 deletions onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

#include "helper.h"

namespace onnxruntime {
namespace nnapi {

using std::string;
using std::vector;

Expand Down Expand Up @@ -40,6 +43,28 @@ std::string GetErrorCause(int error_code) {
}
}

QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) {
const auto& op_type = node.OpType();
if (op_type == "DequantizeLinear")
return QLinearOpType::DequantizeLinear;
else if (op_type == "QuantizeLinear")
return QLinearOpType::QuantizeLinear;
else if (op_type == "QLinearConv")
return QLinearOpType::QLinearConv;
else if (op_type == "QLinearMatMul")
return QLinearOpType::QLinearMatMul;
else if (op_type == "QLinearAdd")
return QLinearOpType::QLinearAdd;

return QLinearOpType::Unknown;
}

bool IsQLinearBinaryOp(QLinearOpType qlinear_op_type) {
return qlinear_op_type == QLinearOpType::QLinearConv ||
qlinear_op_type == QLinearOpType::QLinearMatMul ||
qlinear_op_type == QLinearOpType::QLinearAdd;
}

NodeAttrHelper::NodeAttrHelper(const onnxruntime::Node& node)
: node_attributes_(node.GetAttributes()) {}

Expand Down Expand Up @@ -97,3 +122,6 @@ vector<float> NodeAttrHelper::Get(const std::string& key, const vector<float>& d
bool NodeAttrHelper::HasAttr(const std::string& key) const {
return Contains(node_attributes_, key);
}

} // namespace nnapi
} // namespace onnxruntime
27 changes: 26 additions & 1 deletion onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksTypes.h"

namespace onnxruntime {
namespace nnapi {

#define THROW_ON_ERROR(val) \
{ \
const auto ret = (val); \
Expand Down Expand Up @@ -36,12 +39,31 @@ inline bool Contains(const Map& map, const Key& key) {

std::string GetErrorCause(int error_code);

enum class QLinearOpType : uint8_t {
Unknown, // Unknown or not a linear quantized op
DequantizeLinear,
QuantizeLinear,
QLinearConv,
QLinearMatMul,
QLinearAdd,
// Not yet supported
// QLinearAveragePool,
// QLinearMul,
// QLinearReduceMean,
};

QLinearOpType GetQLinearOpType(const onnxruntime::Node& node);

// This qlinear op is an operator takes 2 input and producce 1 output
// Such as QLinearConv, QLinearMatMul, QLinearAdd, ...
bool IsQLinearBinaryOp(QLinearOpType qlinear_op_type);

/**
* Wrapping onnxruntime::Node for retrieving attribute values
*/
class NodeAttrHelper {
public:
NodeAttrHelper(const onnxruntime::Node& proto);
NodeAttrHelper(const onnxruntime::Node& node);

float Get(const std::string& key, float def_val) const;
int32_t Get(const std::string& key, int32_t def_val) const;
Expand All @@ -54,3 +76,6 @@ class NodeAttrHelper {
private:
const onnxruntime::NodeAttributes& node_attributes_;
};

} // namespace nnapi
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,16 @@ std::unordered_map<std::string, vector<const Node*>> GetAllQuantizedOpInputs(con
const auto& node_indices = graph_view.GetNodesInTopologicalOrder();
for (const auto& node_idx : node_indices) {
const auto* node(graph_view.GetNode(node_idx));
const auto& op_type = node->OpType();
if (op_type == "DequantizeLinear" || op_type == "QLinearMatMul" || op_type == "QLinearConv") {
auto qlinear_op_type = GetQLinearOpType(*node);
if (qlinear_op_type == QLinearOpType::DequantizeLinear || IsQLinearBinaryOp(qlinear_op_type)) {
const auto& input_name = node->InputDefs()[0]->Name();
if (Contains(all_quantized_op_inputs, input_name))
all_quantized_op_inputs.at(input_name).push_back(node);
else
all_quantized_op_inputs.emplace(input_name, vector<const Node*>{node});
}

if (op_type == "QLinearMatMul" || op_type == "QLinearConv") {
if (IsQLinearBinaryOp(qlinear_op_type)) {
const auto& input_name = node->InputDefs()[3]->Name();
if (Contains(all_quantized_op_inputs, input_name))
all_quantized_op_inputs.at(input_name).push_back(node);
Expand Down Expand Up @@ -328,8 +328,8 @@ void ModelBuilder::RegisterModelInputs() {
}

// TODO, verify the scale and zero point match if there are multiple op using same input
std::tie(scale, zero_point) =
GetQuantizedInputScaleAndZeroPoint(*this, *all_quantized_op_inputs.at(input_name)[0], input_name);
ORT_THROW_IF_ERROR(GetQuantizedInputScaleAndZeroPoint(
*this, *all_quantized_op_inputs.at(input_name)[0], input_name, scale, zero_point));
break;
}
default:
Expand Down
Loading