diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 38c33b45936e..0ea71de367fa 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -61,10 +61,10 @@ bool BiasAddRel(const Array& types, int num_inputs, const Attrs& attrs, if (axis < 0) { axis = data->shape.size() + axis; } - if (axis >= static_cast(data->shape.size())) { + if (axis >= static_cast(data->shape.size()) || axis < 0) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << "The axis in bias_add must be in range for the shape; " - << "attempted to access index " << axis << " of " + << "attempted to access index " << param->axis << " of " << PrettyPrint(data->shape)); return false; } diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index ea5dd6948b11..dfd350486c3b 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -202,14 +202,16 @@ def test_bias_add(): def test_bias_add_type_failure(): - # the axis is out of range - try: - b_add = relay.nn.bias_add(relay.const(1), relay.const(2), axis=0) - run_infer_type(b_add) - except tvm._ffi.base.TVMError: - pass - else: - assert False + def assert_failure(expr): + try: + run_infer_type(expr) + except tvm._ffi.base.TVMError: + return + else: + assert False + + for axis in (0, -1, -3, 1): + assert_failure(relay.nn.bias_add(relay.const(1), relay.const(2), axis=axis)) def test_expand_dims_infer_type():