diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index ef5824c957e8..3310b6b2ed69 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -486,8 +486,7 @@ def before(): beta = relay.var("beta") y = relay.nn.batch_norm(y, gamma, beta, mean, var, axis=3) y = y[0] - y = relay.Function(analysis.free_vars(y), y) - return y + return relay.Function(analysis.free_vars(y), y) def alter_conv2d(attrs, inputs, tinfos, out_type): data, weight = inputs @@ -509,9 +508,8 @@ def expected(): bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c") add = relay.add(y, bias) y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW") - y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NHWC") - mean = relay.mean(y, axis=3, exclude=True) - var = relay.variance(y, axis=3, exclude=True) + mean = relay.mean(y, axis=1, exclude=True) + var = relay.variance(y, axis=1, exclude=True) denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05)) gamma = relay.var("gamma", shape=(16,)) denom = denom * gamma