Skip to content

Commit

Permalink
[CUDA] Fix codegen for warp shuffle intrinsics (apache#5606)
Browse files Browse the repository at this point in the history
* fix shfl intrin

* improve test_lower_warp_memory_cuda_half_a_warp
  • Loading branch information
roastduck authored and Trevor Morris committed Jun 9, 2020
1 parent 1c60e71 commit a7e6a08
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}};
const char* name = T()(call->dtype, call->name);
*rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern);
}
Expand Down
26 changes: 14 additions & 12 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,30 +136,32 @@ def check_cuda(dtype):
print("Skip because gpu does not have fp16 support")
return

m = 16
A = te.placeholder((m,), name='A', dtype=dtype)
B = te.compute((m,), lambda i: A[(i + 1) % m], name='B')
n, m = 16, 16
A = te.placeholder((n, m,), name='A', dtype=dtype)
B = te.compute((n, m,), lambda j, i: A[j, (i + 1) % m], name='B')

cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 2 * m
with cuda_target:
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
ty = te.thread_axis("threadIdx.y")
bx = te.thread_axis("blockIdx.x")

AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], nparts=1)
s[B].bind(xi, tx)
s[B].bind(xo, bx)
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
s[AA].bind(xo, bx)
s[AA].bind(xi, tx)
y, x = B.op.axis
z, y = s[B].split(y, nparts=2)
s[B].bind(x, tx)
s[B].bind(y, ty)
s[B].bind(z, bx)
s[AA].compute_at(s[B], y)
_, x = AA.op.axis
s[AA].bind(x, tx)

ctx = tvm.gpu(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.array(list(range(m)), dtype=dtype)
B_np = np.array(list(range(1, m)) + [0], dtype=dtype)
A_np = np.array([list(range(i, m + i)) for i in range(n)], dtype=dtype)
B_np = np.array([list(range(1 + i, m + i)) + [i] for i in range(n)], dtype=dtype)
A_nd = tvm.nd.array(A_np, ctx)
B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
func(A_nd, B_nd)
Expand Down

0 comments on commit a7e6a08

Please sign in to comment.