Skip to content

Commit

Permalink
Merge commit '25324a7af5504213c6c6f9c43ab3b1b6bbc2a280'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Sep 9, 2024
2 parents 31bd963 + 25324a7 commit 5297206
Show file tree
Hide file tree
Showing 11 changed files with 439 additions and 273 deletions.
9 changes: 6 additions & 3 deletions docs/python-api/triton-semantics.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
Triton Semantics
================

Triton mostly follows the semantics of NumPy with minor exceptions. In this document, we go over some of the array computing features supported in Triton, and we cover the exceptions where Triton's semantics deviate from that NumPy.

Type Promotion
==============
--------------

**Type Promotion** occurs when tensors of different data types are used in an operation. For binary operations associated to `dunder methods <https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types>`_ and the ternary function ``tl.where`` on its last two arguments, Triton automatically converts the input tensors to a common data type following a hierarchy of kinds (sets of dtypes): ``{bool} < {integral dypes} < {floating point dtypes}``.

Expand All @@ -27,7 +30,7 @@ When an operation involves a tensor and a scalar:


Broadcasting
============
------------

**Broadcasting** allows operations on tensors of different shapes by automatically expanding their shapes to a compatible size without copying the data. This follows the following rules:

Expand All @@ -37,7 +40,7 @@ Broadcasting


Differences with NumPy
======================
----------------------

**C rounding in integer division** Operators in Triton follow C semantics rather than Python semantics for efficiency. As such, ``int // int`` implements `rounding towards zero as in C <https://en.wikipedia.org/wiki/Modulo#In_programming_languages>`_ for integers of mixed signs, rather than rounding towards minus infinity as in Python. For the same reason, the modulus operator ``int % int`` (which is defined as ``a % b = a - b * (a // b)``) also follows C semantics rather than Python semantics.

Expand Down
55 changes: 36 additions & 19 deletions include/triton/Analysis/Membar.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,38 @@ namespace mlir {

class OpBuilder;

/// Callback to allow backend to provide more information on whether a barrier
/// is needed between two operations. Even though two operations access the same
/// shared memory thay may not require a barrier in between them.
using MembarFilterFn = std::function<bool(Operation *, Operation *)>;

struct BlockInfo {
using IntervalSetT = std::set<Interval<size_t>>;
using IntervalMapT = std::map<Interval<size_t>, std::set<Operation *>>;

IntervalSetT syncReadIntervals;
IntervalSetT syncWriteIntervals;
IntervalMapT syncReadIntervals;
IntervalMapT syncWriteIntervals;

BlockInfo() = default;

/// Unions two BlockInfo objects.
BlockInfo &join(const BlockInfo &other) {
syncReadIntervals.insert(other.syncReadIntervals.begin(),
other.syncReadIntervals.end());
syncWriteIntervals.insert(other.syncWriteIntervals.begin(),
other.syncWriteIntervals.end());
for (auto &interval : other.syncReadIntervals)
syncReadIntervals[interval.first].insert(interval.second.begin(),
interval.second.end());
for (auto &interval : other.syncWriteIntervals)
syncWriteIntervals[interval.first].insert(interval.second.begin(),
interval.second.end());
return *this;
}

/// Returns true if intervals in two BlockInfo objects are intersected.
bool isIntersected(const BlockInfo &other) const {
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) ||
bool isIntersected(const BlockInfo &other, MembarFilterFn filter) const {
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals,
filter) ||
/*WAR*/
isIntersected(syncReadIntervals, other.syncWriteIntervals) ||
isIntersected(syncReadIntervals, other.syncWriteIntervals, filter) ||
/*WAW*/
isIntersected(syncWriteIntervals, other.syncWriteIntervals);
isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter);
}

/// Clears the intervals because a barrier is inserted.
Expand All @@ -51,12 +59,17 @@ struct BlockInfo {
bool operator!=(const BlockInfo &other) const { return !(*this == other); }

private:
bool isIntersected(const IntervalSetT &lhsIntervalSet,
const IntervalSetT &rhsIntervalSet) const {
bool isIntersected(const IntervalMapT &lhsIntervalSet,
const IntervalMapT &rhsIntervalSet,
MembarFilterFn filter) const {
for (auto &lhs : lhsIntervalSet)
for (auto &rhs : rhsIntervalSet)
if (lhs.intersects(rhs))
return true;
if (lhs.first.intersects(rhs.first))
for (auto lhsOp : lhs.second)
for (auto rhsOp : rhs.second)
if (!filter || !filter(lhsOp, rhsOp))
return true;

return false;
}
};
Expand All @@ -81,7 +94,8 @@ class MembarAnalysis {
/// it is considered as the problem of the operation itself but not the membar
/// analysis.
MembarAnalysis() = default;
explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter)
: allocation(allocation), filter(filter) {}

/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
Expand Down Expand Up @@ -116,6 +130,7 @@ class MembarAnalysis {

private:
Allocation *allocation = nullptr;
MembarFilterFn filter = nullptr;
};

/// Postorder traversal on the callgraph to insert membar instructions
Expand All @@ -125,9 +140,10 @@ class MembarAnalysis {
/// before and after function calls, but might be a bit conservative.
class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
public:
ModuleMembarAnalysis(ModuleAllocation *moduleAllocation)
ModuleMembarAnalysis(ModuleAllocation *moduleAllocation,
MembarFilterFn filter = nullptr)
: CallGraph<BlockInfo>(moduleAllocation->getModuleOp()),
moduleAllocation(moduleAllocation) {}
moduleAllocation(moduleAllocation), filter(filter) {}

void run() {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
Expand All @@ -138,14 +154,15 @@ class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
auto *allocation = moduleAllocation->getFuncData(funcOp);
auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo());
if (inserted) {
MembarAnalysis analysis(allocation);
MembarAnalysis analysis(allocation, filter);
analysis.run(funcMap);
}
});
}

private:
ModuleAllocation *moduleAllocation;
MembarFilterFn filter;
};

} // namespace mlir
Expand Down
20 changes: 12 additions & 8 deletions lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,15 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
if (isa<MemoryEffects::Write>(effectInstance.getEffect()))
curBlockInfo.syncWriteIntervals.insert(
allocation->getAllocatedInterval(bufferId));
curBlockInfo
.syncWriteIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
else if (isa<MemoryEffects::Read>(effectInstance.getEffect()))
curBlockInfo.syncReadIntervals.insert(
allocation->getAllocatedInterval(bufferId));
curBlockInfo
.syncReadIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
}
}
}
Expand All @@ -161,15 +165,15 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
"dependencies");
}
auto interval = allocation->getAllocatedInterval(scratchBufferId);
curBlockInfo.syncWriteIntervals.insert(interval);
if (blockInfo->isIntersected(curBlockInfo)) {
curBlockInfo.syncWriteIntervals[interval].insert(op);
if (blockInfo->isIntersected(curBlockInfo, filter)) {
builder->setInsertionPoint(op);
insertBarrier(op, builder);
}
// Ops with a scratch buffer internally syncs read/write on shared memory
blockInfo->sync();
curBlockInfo.syncReadIntervals.insert(interval);
} else if (blockInfo->isIntersected(curBlockInfo)) {
curBlockInfo.syncReadIntervals[interval].insert(op);
} else if (blockInfo->isIntersected(curBlockInfo, filter)) {
builder->setInsertionPoint(op);
insertBarrier(op, builder);
blockInfo->sync();
Expand Down
18 changes: 13 additions & 5 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,18 +331,26 @@ void init_triton_llvm(py::module &&m) {
// regressions with some scheduling solution.
tuningOptions.SLPVectorization = true;

std::string pluginFile =
mlir::triton::tools::getStrEnv("LLVM_PASS_PLUGIN_PATH");

// We don't pass the targetMachine to the LLVM-IR pass builder, unless
// `arch` is specified
// `arch` is specified.
//
// Don't set target machine in LLVM pass builder when using LLVM IR
// level plugins. LLVM IR level plugin passes typically want to insert
// calls to externally generated code (i.e. precompile a Cuda/Hip kernel
// with Clang and then insert a call to it within an instrumentation
// pass) setting the targetMachine value here can can cause a mis-match
// in the target machine between the MLIR and Clang generated kernels
// and break the lowering of some target specific intrinsics.
std::unique_ptr<TargetMachine> targetMachine = nullptr;
if (!arch.empty())
if (!arch.empty() && pluginFile.empty())
targetMachine = std::move(
createTargetMachine(mod, arch, enable_fp_fusion, features));
PassBuilder pb(/*targetMachine=*/targetMachine.get(), tuningOptions,
std::nullopt, instrCbPtr);

std::string pluginFile =
mlir::triton::tools::getStrEnv("LLVM_PASS_PLUGIN_PATH");

if (!pluginFile.empty()) {
// TODO: Add some logging here that we inserted a pass into the LLVM
// pass pipeline
Expand Down
92 changes: 92 additions & 0 deletions test/Analysis/test-membar.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -713,3 +713,95 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} {
// CHECK-LABEL: tma_special_cases
tt.func @tma_special_cases(%arg1: !tt.ptr<i8, 0>) -> (tensor<256x64xf16, #blocked>){
%true = arith.constant 1 : i1
%c0 = arith.constant 0 : i32
%barrier = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable>
%alloc = triton_gpu.local_alloc : () -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
// CHECK: triton_nvidia_gpu.init_barrier
// CHECK-NEXT: triton_nvidia_gpu.init_barrier
triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>

// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: triton_nvidia_gpu.barrier_expect
// CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local
// CHECK-NEXT: triton_nvidia_gpu.wait_barrier
triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : <i8, 0>, <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>

// CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local
// CHECK-NEXT: triton_nvidia_gpu.barrier_expect
// CHECK-NEXT: triton_nvidia_gpu.wait_barrier
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : <i8, 0>, <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>

// CHECK-NEXT: triton_gpu.local_load
%t = triton_gpu.local_load %alloc : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked>

// CHECK-NEXT: triton_nvidia_gpu.barrier_expect
// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local
// CHECK-NEXT: triton_nvidia_gpu.wait_barrier
triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : <i8, 0>, <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>

// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: triton_nvidia_gpu.inval_barrier
// CHECK-NEXT: triton_nvidia_gpu.inval_barrier
triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>

tt.return %t : tensor<256x64xf16, #blocked>
}
}

// -----

#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} {
// CHECK-LABEL: tma_special_cases_cf
tt.func @tma_special_cases_cf(%arg1: !tt.ptr<i8, 0>, %i1 : i1, %arg2: tensor<256x64xf16, #blocked>) -> (tensor<256x64xf16, #blocked>){
%true = arith.constant 1 : i1
%c0 = arith.constant 0 : i32
%barrier = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable>
%alloc = triton_gpu.local_alloc : () -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
// CHECK: cf.cond_br
scf.if %i1 {
// CHECK-NOT: gpu.barrier
// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local
// CHECK-NEXT: triton_nvidia_gpu.barrier_expect
// CHECK-NEXT: triton_nvidia_gpu.wait_barrier
// CHECK-NEXT: cf.br
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : <i8, 0>, <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable>
scf.yield
} else {
// CHECK-NOT: gpu.barrier
// CHECK: triton_gpu.local_store
// CHECK-NEXT: cf.br
triton_gpu.local_store %arg2, %alloc : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
scf.yield
}
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.local_load
%t = triton_gpu.local_load %alloc : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked>
tt.return %t : tensor<256x64xf16, #blocked>
}
}
Loading

0 comments on commit 5297206

Please sign in to comment.