From 98ed924fefc8f498d5cf0f049adec1853de0d6cb Mon Sep 17 00:00:00 2001 From: Jiyoung Giuliana Yun Date: Thu, 11 Jul 2024 15:15:34 +0900 Subject: [PATCH] [onert] Correct the shape when _keep_dims is false in backward of MeanLayer (#13394) This commit corrects the shape by creating a temporary shape having the same rank as the input when _keep_dims is false because MeanGrad does not support other rank cases. - Revert https://github.com/Samsung/ONE/pull/13351 - Fix `GenModelTrain.NonTrainableOps_FC_Mean` test case ONE-DCO-1.0-Signed-off-by: Jiyoung Yun --- runtime/onert/backend/train/ops/MeanLayer.cc | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/runtime/onert/backend/train/ops/MeanLayer.cc b/runtime/onert/backend/train/ops/MeanLayer.cc index 2be83e82ce0..c6744e62cfa 100644 --- a/runtime/onert/backend/train/ops/MeanLayer.cc +++ b/runtime/onert/backend/train/ops/MeanLayer.cc @@ -48,11 +48,29 @@ void MeanLayer::forward(bool) { cpu::ops::MeanLayer::run(); } void MeanLayer::backward() { + nnfw::cker::Shape keep_dim_shape; + // If _keep_dims is false, the input rank and the output rank can be different. + // MeanGrad does not support other ranking cases. This code corrects the shape + // by creating a temporary shape having the same rank as the input. + if (_keep_dims == false) + { + keep_dim_shape.ReplaceWith(getShape(_input)); + auto axes_vec = cpu::ops::getReducerAxes(_axes); + for (const auto &axis : axes_vec) + { + keep_dim_shape.SetDim(axis, 1); + } + } + else + { + keep_dim_shape.ReplaceWith(getShape(_back_prop_output)); + } + switch (_back_prop_output->data_type()) { case OperandType::FLOAT32: { - nnfw::cker::train::MeanGrad(getShape(_back_prop_output), getBuffer(_back_prop_output), + nnfw::cker::train::MeanGrad(keep_dim_shape, getBuffer(_back_prop_output), getShape(_back_prop_input), getBuffer(_back_prop_input)); break; }