Skip to content

Commit

Permalink
[onert] Correct the shape when _keep_dims is false in backward of Mea…
Browse files Browse the repository at this point in the history
…nLayer (#13394)

This commit corrects the shape by creating a temporary shape having the
same rank as the input when _keep_dims is false because MeanGrad does not
support other rank cases.

- Revert #13351
- Fix `GenModelTrain.NonTrainableOps_FC_Mean` test case

ONE-DCO-1.0-Signed-off-by: Jiyoung Yun <[email protected]>
  • Loading branch information
jyoungyun authored Jul 11, 2024
1 parent 4e328ce commit 98ed924
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion runtime/onert/backend/train/ops/MeanLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,29 @@ void MeanLayer::forward(bool) { cpu::ops::MeanLayer::run(); }

void MeanLayer::backward()
{
nnfw::cker::Shape keep_dim_shape;
// If _keep_dims is false, the input rank and the output rank can be different.
// MeanGrad does not support other ranking cases. This code corrects the shape
// by creating a temporary shape having the same rank as the input.
if (_keep_dims == false)
{
keep_dim_shape.ReplaceWith(getShape(_input));
auto axes_vec = cpu::ops::getReducerAxes(_axes);
for (const auto &axis : axes_vec)
{
keep_dim_shape.SetDim(axis, 1);
}
}
else
{
keep_dim_shape.ReplaceWith(getShape(_back_prop_output));
}

switch (_back_prop_output->data_type())
{
case OperandType::FLOAT32:
{
nnfw::cker::train::MeanGrad(getShape(_back_prop_output), getBuffer<float>(_back_prop_output),
nnfw::cker::train::MeanGrad(keep_dim_shape, getBuffer<float>(_back_prop_output),
getShape(_back_prop_input), getBuffer<float>(_back_prop_input));
break;
}
Expand Down

0 comments on commit 98ed924

Please sign in to comment.