Skip to content

Commit

Permalink
rocm: fix miopen convolutions (#5179)
Browse files Browse the repository at this point in the history
* fix miopen convolutions

* fix overly long lines
  • Loading branch information
t-vi authored Mar 30, 2020
1 parent b776ff3 commit 8412196
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
11 changes: 5 additions & 6 deletions tests/python/contrib/test_miopen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def test_conv2d():

yshape = [x.value for x in Y.shape]
import topi
with tvm.target.create("rocm -libs=miopen"):
s = topi.generic.schedule_extern(Y)
s = te.create_schedule(Y.op)

def verify():
ctx = tvm.rocm(0)
Expand All @@ -67,10 +66,10 @@ def verify():
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f(x, w, y)

Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w))
with tvm.target.rocm():
s_ref = topi.generic.schedule_conv2d_nchw([Y_ref])
f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm")
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w),
(dilation_h, dilation_w))
s_ref = te.create_schedule(Y_ref.op)
f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm", target_host="llvm")
y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f_ref(x, w, y_ref)
print("Max abs diff:", np.max(np.abs(y.asnumpy() - y_ref.asnumpy())))
Expand Down
5 changes: 4 additions & 1 deletion topi/python/topi/rocm/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from ..nn.util import get_pad_tuple

@autotvm.register_topi_compute("conv2d_nchw_miopen.rocm")
def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
layout='NCHW', out_dtype='float32'):
"""Conv2D operator for rocm backend.
Parameters
Expand Down Expand Up @@ -58,6 +59,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, out_dtype=
CO, CI, KH, KW = get_const_tuple(kernel.shape)
N, _, H, W = get_const_tuple(data.shape)

assert layout == 'NCHW'

# handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
Expand Down

0 comments on commit 8412196

Please sign in to comment.