From 280a125780f38af7ffdbc35289f31ef68431cf9c Mon Sep 17 00:00:00 2001 From: 103yiran <1039105206@qq.com> Date: Fri, 29 Jul 2022 14:25:12 +0800 Subject: [PATCH] check all inputs Signed-off-by: 103yiran <1039105206@qq.com> Signed-off-by: 103yiran <1039105206@qq.com> --- builtin_op_importers.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp index 9854f9b3..e4648998 100644 --- a/builtin_op_importers.cpp +++ b/builtin_op_importers.cpp @@ -230,9 +230,9 @@ DEFINE_BUILTIN_OP_IMPORTER(BatchNormalization) { ASSERT( (inputs.at(1).shape().nbDims == 1) && "The shape of the scale input must be (C, )", ErrorCode::kINVALID_NODE); - ASSERT((inputs.at(1).shape().nbDims == 1) && "The shape of the bias input must be (C, )", ErrorCode::kINVALID_NODE); - ASSERT((inputs.at(1).shape().nbDims == 1) && "The shape of the mean input must be (C, )", ErrorCode::kINVALID_NODE); - ASSERT((inputs.at(1).shape().nbDims == 1) && "The shape of the var input must be (C, )", ErrorCode::kINVALID_NODE); + ASSERT((inputs.at(2).shape().nbDims == 1) && "The shape of the bias input must be (C, )", ErrorCode::kINVALID_NODE); + ASSERT((inputs.at(3).shape().nbDims == 1) && "The shape of the mean input must be (C, )", ErrorCode::kINVALID_NODE); + ASSERT((inputs.at(4).shape().nbDims == 1) && "The shape of the var input must be (C, )", ErrorCode::kINVALID_NODE); const bool allInputsWeights = inputs.at(1).is_weights() && inputs.at(2).is_weights() && inputs.at(3).is_weights() && inputs.at(4).is_weights();