Skip to content

Commit

Permalink
Add atomic intrinsic for output nonzero inference. (#25)
Browse files Browse the repository at this point in the history
* upd

* upd
  • Loading branch information
yzh119 committed Jan 21, 2022
1 parent a91a93c commit 79e5bc5
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 6 deletions.
5 changes: 5 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,11 @@ TVM_DLL const Op& tvm_lower_bound();
*/
TVM_DLL const Op& tvm_upper_bound();

/*!
* \brief Atomic add function.
*/
TVM_DLL const Op& tvm_atomic_add();

/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,14 @@ TVM_DLL PrimExpr lower_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r,
TVM_DLL PrimExpr upper_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r,
Span span = Span());

/*!
* \brief Perform atomic add on ptr by val, and return the old value.
* \param ptr The address to perform atomic add.
* \param val The value to add.
* \return The old result stored in ptr.
*/
TVM_DLL PrimExpr atomic_add(tir::Var ptr, PrimExpr val, Span span = Span());

/*!
* \brief Calculate trunc(x)
* \param x The input expression.
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def upper_bound(arr, val, l, r, span):
return tvm.tir.upper_bound(arr, val, l, r, span)


@register
def atomic_add(ptr, val, span):
return tvm.tir.atomic_add(ptr, val, span)


@register
def floordiv(x, y, span):
return tvm.tir.floordiv(x, y, span)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .function import PrimFunc

from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace, lower_bound, upper_bound
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace, lower_bound, upper_bound, atomic_add
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,26 @@ def upper_bound(arr, val, l, r, span=None):
return _ffi_api.upper_bound(arr, val, l, r, span) # type: ignore


def atomic_add(ptr, val, span=None):
"""Perform an atomic add operation to ptr by the given val.
Parameters
----------
ptr : Var
The pointer to the address we perform atomic add.
val : PrimExpr
The value to add.
span : Optional[Span]
The location of this expression in the source code.
Returns
-------
PrimExpr
The value on pointer before we perform the atomic add.
"""
return _ffi_api.atomic_add(ptr, val, span) # type: ignore


def isnan(x, span=None):
"""Check if input value is Nan.
Expand Down
7 changes: 7 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,13 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
os << ", ";
this->PrintExpr(op->args[3], os);
os << ")";
} else if (op->op.same_as(builtin::tvm_atomic_add())) {
os << "atomicAdd(";
ICHECK_EQ(op->args.size(), 2U);
this->PrintExpr(op->args[0], os);
os << ", ";
this->PrintExpr(op->args[1], os);
os << ")";
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
3 changes: 3 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_lower_bound)
TIR_DEFINE_BUILTIN_FUNC(tvm_upper_bound)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(tvm_atomic_add)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

Expand Down
7 changes: 7 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,11 @@ PrimExpr upper_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span sp
return tir::Call({kDLInt, 32, 1}, builtin::tvm_upper_bound(), {arr, val, l, r}, span);
}

// atomic_add
PrimExpr atomic_add(tir::Var ptr, PrimExpr val, Span span) {
return tir::Call(val->dtype, builtin::tvm_atomic_add(), {ptr, val}, span);
}

// trunc
PrimExpr trunc(PrimExpr x, Span span) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
Expand Down Expand Up @@ -932,6 +937,8 @@ TVM_REGISTER_GLOBAL("tir.lower_bound").set_body_typed(tvm::lower_bound);

TVM_REGISTER_GLOBAL("tir.upper_bound").set_body_typed(tvm::upper_bound);

TVM_REGISTER_GLOBAL("tir.atomic_add").set_body_typed(tvm::atomic_add);

// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \
Expand Down
12 changes: 10 additions & 2 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,16 @@ Definition of a scope that is a stage pipeline:
if (!IsCompleteBlock(self, block_sref, scope_root_sref) &&
!IsReductionBlock(self, block_sref, scope_root_sref)) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(scope_root_subtree->stmt),
GetRef<Block>(block));
// NOTE(Zihao): check if the block has atomic attribute.
auto&& it = block->annotations.find("atomic");
bool is_atomic = false;
if (it != block->annotations.end()) {
is_atomic = ((*it).second).as<IntImmNode>()->value;
}
if (!is_atomic) {
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(scope_root_subtree->stmt),
GetRef<Block>(block));
}
}
}
}
Expand Down
38 changes: 35 additions & 3 deletions tests/python/unittest/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,20 @@ def binary_search(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
D[vi] = T.upper_bound(A.data, B[vi], 0, n)


@T.prim_func
def global_add(a: T.handle) -> None:
A = T.match_buffer(a, (1,), dtype='int32')
for i in T.serial(0, 1024):
with T.block('global_add'):
T.block_attr({
"atomic": True
})
T.reads([A[0:1]])
T.writes([A[0:1]])
vi = T.axis.S(1024, i)
T.evaluate(T.atomic_add(A.data, vi))


def test_binary_search():
sch = tir.Schedule(binary_search)
b = sch.get_block('search')
Expand All @@ -281,7 +295,7 @@ def test_binary_search():
# print(f.imported_modules[0].get_source())

x = np.arange(-128, 128).astype(np.int32)
y = np.random.randint(-200, 200, size=1024).astype(np.int32)
y = np.random.randint(-200, 200, size=1024).astype(np.int32)
a = np.zeros((1024,)).astype(np.int32)
b = np.zeros((1024,)).astype(np.int32)

Expand All @@ -293,7 +307,7 @@ def test_binary_search():
dev = tvm.cuda(0)
x_array = tvm.nd.array(x, device=dev)
y_array = tvm.nd.array(y, device=dev)
a_array = tvm.nd.array(a, device=dev)
a_array = tvm.nd.array(a, device=dev)
b_array = tvm.nd.array(b, device=dev)
f(x_array, y_array, a_array, b_array)
tvm_a = a_array.numpy()
Expand All @@ -304,12 +318,30 @@ def test_binary_search():
tvm.testing.assert_allclose(np_b, tvm_b)


def test_global_add():
sch = tir.Schedule(global_add)
b = sch.get_block('global_add')
i, = sch.get_loops(b)
sch.bind(i, 'blockIdx.x')
f = tvm.build(sch.mod['main'], target='cuda')

# create input and run kernel
dev = tvm.cuda(0)
a = np.zeros((1,)).astype(np.int32)
a_gpu = tvm.nd.array(a, device=dev)
f(a_gpu)

# check output
tvm.testing.assert_allclose(a_gpu.numpy(), np.array([1024 * 1023 / 2]).astype(np.int32))


if __name__ == "__main__":
test_nearbyint()
test_unary_intrin()
test_round_intrinsics_on_int()
test_binary_intrin()
test_ldexp()
test_clz()
# test_clz()
test_fma()
test_binary_search()
test_global_add()

0 comments on commit 79e5bc5

Please sign in to comment.