Skip to content

Commit

Permalink
[PASS][ConvertLayout] Fixes AttributeError during ConvertLayout to NHWC
Browse files Browse the repository at this point in the history
Fixes an issue described in apache#6410. In order to retrieve the shape a tensor `checked_type` should be used.

Change-Id: I991d194d9cc15ee20464ff2e239fd05c035000c8
  • Loading branch information
lhutton1 committed Sep 8, 2020
1 parent 4b48d89 commit 6c829f4
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
6 changes: 4 additions & 2 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,10 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
return relay.nn.conv2d(data, weight, **new_attrs)
elif desired_data_layout == 'NHWC':
# Check for depthwise convolution.
if is_depthwise_conv2d(data.shape, attrs['data_layout'], weight.shape,
attrs['kernel_layout'], attrs['groups']):
data_info, weight_info = tinfos
if is_depthwise_conv2d(data_info.shape, attrs['data_layout'],
weight_info.shape, attrs['kernel_layout'],
attrs['groups']):
new_attrs['kernel_layout'] = 'HWOI'
else:
new_attrs['kernel_layout'] = 'HWIO'
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/qnn/op/layout_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts):
return relay.qnn.op.conv2d(*inputs, **new_attrs)
if desired_data_layout == 'NHWC':
# Check for depthwise convolution.
if is_depthwise_conv2d(inputs[0].shape, attrs['data_layout'], inputs[1].shape,
attrs['kernel_layout'], attrs['groups']):
data_info, weight_info = tinfos
if is_depthwise_conv2d(data_info.shape, attrs['data_layout'],
weight_info.shape, attrs['kernel_layout'],
attrs['groups']):
new_attrs['kernel_layout'] = 'HWOI'
else:
new_attrs['kernel_layout'] = 'HWIO'
Expand Down
38 changes: 38 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,43 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_conv_nhwc_convert_layout():
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
y = relay.nn.conv2d(x, weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NCHW',
kernel_layout='OIHW')
y = relay.nn.relu(y)
y = relay.Function([x, weight], y)
return y

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
x = relay.layout_transform(x, 'NCHW', 'NHWC')
weight = relay.layout_transform(weight, 'OIHW', 'HWIO')
y = relay.nn.conv2d(x, weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO")
y = relay.nn.relu(y)
y = relay.layout_transform(y, 'NHWC', 'NCHW')
y = relay.Function(relay.analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NHWC', 'default']}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_conv_transpose_convert_layout():
def before():
x = relay.var("x", shape=(1, 56, 56, 64))
Expand Down Expand Up @@ -795,6 +832,7 @@ def expected():
if __name__ == "__main__":
test_no_convert_layout()
test_conv_convert_layout()
test_conv_nhwc_convert_layout()
test_conv_bias_pool_convert_layout()
test_conv_concat_convert_layout()
test_dual_path_convert_layout()
Expand Down

0 comments on commit 6c829f4

Please sign in to comment.