Skip to content

Commit

Permalink
Fixed bugs of Gemm op
Browse files Browse the repository at this point in the history
 <!-- Added opsupport check of gemm
  Refine HandleBuildOp of gemm -->

  <!-- 1.Vsinpu cannot support Gemm where transA and transB both are true

  2.The beta param was used incorrectly in previous patch

  3.Make the code more readable -->

Signed-off-by: Feiyue Chen <[email protected]>
  • Loading branch information
chenfeiyue-cfy committed Jan 18, 2024
1 parent 46c90e3 commit 3ca67cc
Showing 1 changed file with 71 additions and 31 deletions.
102 changes: 71 additions & 31 deletions onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ namespace onnxruntime {
namespace vsi {
namespace npu {
class GemmOpBuilder : public BaseOpBuilder {
bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer,
const Node* node) const override {
NodeAttrHelper helper(*node);
auto trans_A = helper.Get("transA", 0);
auto trans_B = helper.Get("transB", 0);
if (trans_A == trans_B && trans_A == 1) {
LOGS_DEFAULT(WARNING) << "Cannot support Gemm Op with transA && transB both be true.";
return false;
}
return true;
}
bool HandleBuildOp(vsi::npu::GraphEP* graph_ep,
std::vector<std::shared_ptr<tim::vx::Tensor>>& inputs,
std::vector<std::shared_ptr<tim::vx::Tensor>>& outputs,
Expand All @@ -52,41 +63,70 @@ class GemmOpBuilder : public BaseOpBuilder {
auto beta_tensor = graph_ep->GetGraph()->CreateTensor(CoefSpec);
beta_tensor->CopyDataToTensor(&beta);

auto updatedA = input_A;
auto updatedB = input_B;
if (has_alpha) {
updatedA = graph_ep->GetGraph()->CreateTensor(
input_A->GetSpec().AsTransientSpec());
auto mul1 =
graph_ep->GetGraph()->CreateOperation<tim::vx::ops::Multiply>();
(*mul1).BindInput(input_A).BindInput(alpha_tensor).BindOutput(updatedA);
graph_ep->GetOps().push_back(std::move(mul1));
}
if (has_beta) {
updatedB = graph_ep->GetGraph()->CreateTensor(
input_B->GetSpec().AsTransientSpec());
auto mul2 =
graph_ep->GetGraph()->CreateOperation<tim::vx::ops::Multiply>();
(*mul2).BindInput(input_B).BindInput(beta_tensor).BindOutput(updatedA);
graph_ep->GetOps().push_back(std::move(mul2));
}
auto matmul_impl = [&](std::shared_ptr<tim::vx::Tensor> input_A,
std::shared_ptr<tim::vx::Tensor> input_B,
std::shared_ptr<tim::vx::Tensor> output) {
auto matmul_op = graph_ep->GetGraph()->CreateOperation<tim::vx::ops::Matmul>(
trans_A, trans_B);
(*matmul_op).BindInput(input_A).BindInput(input_B).BindOutput(output);
graph_ep->GetOps().push_back(std::move(matmul_op));
};

auto multiply_impl = [&](std::shared_ptr<tim::vx::Tensor> input,
std::shared_ptr<tim::vx::Tensor> coef,
std::shared_ptr<tim::vx::Tensor> output) {
auto multiply_op = graph_ep->GetGraph()->CreateOperation<tim::vx::ops::Multiply>();
(*multiply_op).BindInput(input).BindInput(coef).BindOutput(output);
graph_ep->GetOps().push_back(std::move(multiply_op));
};

if (has_C) {
auto AB_output = graph_ep->GetGraph()->CreateTensor(
auto add_impl = [&](std::shared_ptr<tim::vx::Tensor> input_A,
std::shared_ptr<tim::vx::Tensor> input_B,
std::shared_ptr<tim::vx::Tensor> output) {
auto add_op = graph_ep->GetGraph()->CreateOperation<tim::vx::ops::Add>();
(*add_op).BindInput(input_A).BindInput(input_B).BindOutput(output);
graph_ep->GetOps().push_back(std::move(add_op));
};

auto AB_output = outputs[0];
if (has_alpha) {
AB_output = graph_ep->GetGraph()->CreateTensor(
outputs[0]->GetSpec().AsTransientSpec());
auto matmul = graph_ep->GetGraph()->CreateOperation<tim::vx::ops::Matmul>(
trans_A, trans_B);
(*matmul).BindInput(updatedA).BindInput(updatedB).BindOutput(AB_output);
graph_ep->GetOps().push_back((std::move(matmul)));
auto add = graph_ep->GetGraph()->CreateOperation<tim::vx::ops::Add>();
(*add).BindInput(AB_output).BindInput(inputs[2]).BindOutput(outputs[0]);
graph_ep->GetOps().push_back((std::move(add)));
matmul_impl(input_A, input_B, AB_output);

if (has_C) {
auto mul1_output = graph_ep->GetGraph()->CreateTensor(
outputs[0]->GetSpec().AsTransientSpec());
multiply_impl(AB_output, alpha_tensor, mul1_output);
if (has_beta) {
auto multiplied_C = graph_ep->GetGraph()->CreateTensor(
inputs[2]->GetSpec().AsTransientSpec());
multiply_impl(inputs[2], beta_tensor, multiplied_C);
add_impl(mul1_output, multiplied_C, outputs[0]);
} else {
add_impl(mul1_output, inputs[2], outputs[0]);
}
} else {
multiply_impl(AB_output, alpha_tensor, outputs[0]);
}
} else {
auto op = graph_ep->GetGraph()->CreateOperation<tim::vx::ops::Matmul>(
trans_A, trans_B);
(*op).BindInput(updatedA).BindInput(updatedB).BindOutput(outputs[0]);
graph_ep->GetOps().push_back((std::move(op)));
if (has_C) {
AB_output = graph_ep->GetGraph()->CreateTensor(
outputs[0]->GetSpec().AsTransientSpec());
matmul_impl(input_A, input_B, AB_output);
if (has_beta) {
auto multiplied_C = graph_ep->GetGraph()->CreateTensor(
inputs[2]->GetSpec().AsTransientSpec());
multiply_impl(inputs[2], beta_tensor, multiplied_C);
add_impl(AB_output, multiplied_C, outputs[0]);
} else {
add_impl(AB_output, inputs[2], outputs[0]);
}
} else {
matmul_impl(input_A, input_B, outputs[0]);
}
}

return true;
}
};
Expand Down

0 comments on commit 3ca67cc

Please sign in to comment.