diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp index 9854f9b3..e4648998 100644 --- a/builtin_op_importers.cpp +++ b/builtin_op_importers.cpp @@ -230,9 +230,9 @@ DEFINE_BUILTIN_OP_IMPORTER(BatchNormalization) { ASSERT( (inputs.at(1).shape().nbDims == 1) && "The shape of the scale input must be (C, )", ErrorCode::kINVALID_NODE); - ASSERT((inputs.at(1).shape().nbDims == 1) && "The shape of the bias input must be (C, )", ErrorCode::kINVALID_NODE); - ASSERT((inputs.at(1).shape().nbDims == 1) && "The shape of the mean input must be (C, )", ErrorCode::kINVALID_NODE); - ASSERT((inputs.at(1).shape().nbDims == 1) && "The shape of the var input must be (C, )", ErrorCode::kINVALID_NODE); + ASSERT((inputs.at(2).shape().nbDims == 1) && "The shape of the bias input must be (C, )", ErrorCode::kINVALID_NODE); + ASSERT((inputs.at(3).shape().nbDims == 1) && "The shape of the mean input must be (C, )", ErrorCode::kINVALID_NODE); + ASSERT((inputs.at(4).shape().nbDims == 1) && "The shape of the var input must be (C, )", ErrorCode::kINVALID_NODE); const bool allInputsWeights = inputs.at(1).is_weights() && inputs.at(2).is_weights() && inputs.at(3).is_weights() && inputs.at(4).is_weights();