Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Make lower_warp_memory support extent(threadIdx.x) < warp_size #5307

Merged
merged 2 commits into from
Apr 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1228,9 +1228,17 @@ constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*!
* \brief See pseudo code
*
* Type tvm_warp_shuffle(Type value, warp_id) {
* Type tvm_warp_shuffle(Type value, warp_id, width, warp_size) {
* return (value passed in by warp indicated by warp_id);
* }
*
* Parameter warp_id indicates the source thread ID in a warp.
*
* Parameter width indicates the number of threads involved in one
* shuffle. See CUDA document for __shfl.
*
* Parameter warp_size is the size of a warp, which helps a backend
* to determine wheter the width paramter is legal.
*/
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
/*!
Expand Down
16 changes: 10 additions & 6 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,15 @@ struct CUDAPopcount {
}
};

struct CUDAShuffle {
std::string operator()(DataType t, std::string name) const {
return "__shfl";
}
};
static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
*rv = CallNode::make(
call->dtype, "__shfl", cuda_args, CallNode::PureExtern);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
.set_body(DispatchExtern<CUDAMath>);
Expand Down Expand Up @@ -154,7 +158,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchExtern<CUDAShuffle>);
.set_body(DispatchCUDAShuffle);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
.set_body(DispatchExtern<CUDAMath>);
Expand Down
20 changes: 14 additions & 6 deletions src/target/source/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file intrin_rule_opencl.cc
* \brief OpenCL intrinsic rules.
*/
#include <tvm/arith/analyzer.h>
#include "../intrin_rule.h"

namespace tvm {
Expand Down Expand Up @@ -89,14 +90,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh")

// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle {
std::string operator()(DataType t, std::string name) const {
return "intel_sub_group_shuffle";
}
};
static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
arith::Analyzer analyzer;
CHECK(analyzer.CanProve(call->args[2] == call->args[3]))
<< "Intel warp shuffle dose not support width != warp_size";
Array<PrimExpr> cuda_args{{call->args[0], call->args[1]}};
*rv = CallNode::make(
call->dtype, "intel_sub_group_shuffle", cuda_args, CallNode::PureExtern);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
.set_body(DispatchExtern<IntelShuffle>);
.set_body(DispatchIntelShuffle);

} // namespace intrin
} // namespace codegen
Expand Down
65 changes: 44 additions & 21 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,45 @@ namespace tir {
//
// Before rewrite,
//
// alloc warp warp_mem[n * warp_size * m]
// store warp_mem[m * warp_index + (warp_size * m) * y + x]
// load warp_mem[m * z + (warp_size * m) * y + x]
// alloc warp warp_mem[n * width * m]
// store warp_mem[m * warp_index + (width * m) * y + x]
// load warp_mem[m * z + (width * m) * y + x]
// subject to x \in [0, m), y \in [0, n)
//
// where width equals to the extent of threadIdx.x, which should
// be no larger than the warp size
//
// After rewrite:
//
// alloc local local_mem[n * m]
// store warp_mem[m * y + x]
// warp_shuffle(load warp_mem[m * y + x], z)
// subject to (m * y + x) is invariant to warp_index
//
// If width == warp size, we are shuffling on full warps.
// Otherwise, we are virtually shuffling on sub-warps,
// whose size equals to width. In this case, you can imagine
// a warp only consists of `width` threads. Width is passed
// as an argument to the shuffle primitive, and will be
// lowered to the device code if the target supports.
//
// A limitation of this sub-warp approach is that users
// cannot shuffle across the sub-warp boundary (i.e. shuffle
// with threadIdx.y or threadIdx.z indices). It can be solved
// via fusing threadIdx.x to the warp size, or improving the
// analyzer to detect both 3 thread axes, which is left for
// future improvements.

// Algorithm
//
// To implement this rewrite rule, we can do the follow step:
// For each warp memory alloc
// - Use linear pattern detector on load index to find m
// - Deduce n given warp_size and alloc size
// - Now that we have m, n, warp_size, we can proceed with the rewrite
// - Deduce n given width and alloc size
// - Now that we have m, n, width, we can proceed with the rewrite

// Visitor to find m in pattern
// store warp_mem[m * warp_index + (warp_size * m) * y + x]
// store warp_mem[m * warp_index + (width * m) * y + x]
class WarpStoreCoeffFinder : private StmtVisitor {
public:
WarpStoreCoeffFinder(const VarNode* buffer,
Expand Down Expand Up @@ -153,12 +170,12 @@ class WarpIndexFinder : private StmtVisitor {
explicit WarpIndexFinder(int warp_size)
: warp_size_(warp_size) {
}
// find the warp co-efficient in the statement given the warp size
IterVar Find(const Stmt& stmt) {
// find the warp co-efficient and the shuffle width in the statement
std::pair<Var, int> Find(const Stmt& stmt) {
this->VisitStmt(stmt);
CHECK(warp_index_.defined())
<< "Cannot find warp index(threadIdx.x) within the scope of warp memory";
return warp_index_;
return std::make_pair(warp_index_->var, width_);
}

private:
Expand All @@ -167,11 +184,12 @@ class WarpIndexFinder : private StmtVisitor {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
int value;
int value = 0;
CHECK(arith::GetConstInt(op->value, &value) &&
value == warp_size_)
<< "Expect threadIdx.x 's size to be equal to warp size("
<< warp_size_ << ")" << " to enable warp memory"
value <= warp_size_ &&
warp_size_ % value == 0)
<< "Expect threadIdx.x 's size to be no larger than, and a factor of"
<< " warp size(" << warp_size_ << ")" << " to enable warp memory"
<< " but get " << op->value << " instead";
if (warp_index_.defined()) {
CHECK(warp_index_.same_as(iv))
Expand All @@ -180,6 +198,7 @@ class WarpIndexFinder : private StmtVisitor {
<< "Please create it using thread_axis once and reuse the axis "
<< "across multiple binds in the same kernel";
} else {
width_ = value;
warp_index_ = iv;
}
}
Expand All @@ -188,6 +207,8 @@ class WarpIndexFinder : private StmtVisitor {
}
// warp size
int warp_size_{0};
// number of threads involved in one shuffle
int width_{0};
// the warp index
IterVar warp_index_{nullptr};
};
Expand All @@ -204,16 +225,16 @@ class WarpAccessRewriter : protected StmtExprMutator {
CHECK_GT(alloc_size, 0)
<< "warp memory only support constant alloc size";
alloc_size *= op->dtype.lanes();
warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var;
std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
warp_coeff_ = WarpStoreCoeffFinder(
buffer_, warp_index_, analyzer_).Find(op->body);
CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0)
<< "Warp memory must be multiple of warp size";
warp_group_ = alloc_size / (warp_size_ * warp_coeff_);
CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0)
<< "Warp memory must be multiple of the extent of threadIdx.x";
warp_group_ = alloc_size / (width_ * warp_coeff_);
return AllocateNode::make(
op->buffer_var,
op->dtype,
{make_const(DataType::Int(32), alloc_size / warp_size_)},
{make_const(DataType::Int(32), alloc_size / width_)},
op->condition,
this->VisitStmt(op->body));
}
Expand Down Expand Up @@ -247,7 +268,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
op->dtype, op->buffer_var, local_index, op->predicate);
return CallNode::make(load_value.dtype(),
intrinsic::tvm_warp_shuffle,
{load_value, group},
{load_value, group, width_, warp_size_},
CallNode::Intrinsic);
} else {
return StmtExprMutator::VisitExpr_(op);
Expand Down Expand Up @@ -276,9 +297,9 @@ class WarpAccessRewriter : protected StmtExprMutator {
return std::make_pair(x, z);
} else {
PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_);
PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_);
y = y * m + x;
PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)),
PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)),
m);
return std::make_pair(analyzer_->canonical_simplify(y),
analyzer_->canonical_simplify(z));
Expand All @@ -290,6 +311,8 @@ class WarpAccessRewriter : protected StmtExprMutator {
int warp_size_{0};
// The buffer variable
const VarNode* buffer_;
// number of threads involved in one shuffle
int width_{0};
// Warp index
Var warp_index_;
// the coefficient m
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,48 @@ def check_cuda(dtype):
check_cuda("float32")
check_cuda("float16")

def test_lower_warp_memory_cuda_half_a_warp():
def check_cuda(dtype):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return

m = 16
A = te.placeholder((m,), name='A', dtype=dtype)
B = te.compute((m,), lambda i: A[(i + 1) % m], name='B')

cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 2 * m
with cuda_target:
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")

AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], nparts=1)
s[B].bind(xi, tx)
s[B].bind(xo, bx)
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
s[AA].bind(xo, bx)
s[AA].bind(xi, tx)

ctx = tvm.gpu(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.array(list(range(m)), dtype=dtype)
B_np = np.array(list(range(1, m)) + [0], dtype=dtype)
A_nd = tvm.nd.array(A_np, ctx)
B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
func(A_nd, B_nd)
tvm.testing.assert_allclose(B_nd.asnumpy(), B_np, rtol=1e-3)

check_cuda("float32")
check_cuda("float16")

if __name__ == "__main__":
test_lower_warp_memory_local_scope()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()