Skip to content

Commit

Permalink
[AMD] Add load CV cache modifier (#4746)
Browse files Browse the repository at this point in the history
- Supported load CV cache modifier. It corresponds to `flat load sc0 sc1`
- Added unit test for `gfx94` arch

Signed-off-by: Ilya Veselov <[email protected]>
  • Loading branch information
joviliast committed Sep 18, 2024
1 parent f4c48a9 commit f808819
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 2 deletions.
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def TT_CacheModifierAttr : I32EnumAttr<
I32EnumAttrCase<"WB", 4, "wb">,
I32EnumAttrCase<"CS", 5, "cs">,
I32EnumAttrCase<"WT", 6, "wt">,
I32EnumAttrCase<"CV", 7, "cv">,
]> {
let cppNamespace = "::mlir::triton";
}
Expand Down
1 change: 1 addition & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ void init_triton_ir(py::module &&m) {
.value("WB", CacheModifier::WB)
.value("CS", CacheModifier::CS)
.value("WT", CacheModifier::WT)
.value("CV", CacheModifier::CV)
.export_values();

py::enum_<MemSemantic>(m, "MEM_SEMANTIC", py::module_local())
Expand Down
6 changes: 5 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3782,7 +3782,7 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_


@pytest.mark.interpreter
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"])
def test_load_cache_modifier(cache, device):
src = torch.empty(128, device=device)
dst = torch.empty(128, device=device)
Expand All @@ -3802,11 +3802,15 @@ def _kernel(dst, src, CACHE: tl.constexpr):
return
amdgcn = pgm.asm['amdgcn']
cg_cache_modifier_str = 'nt'
cv_cache_modifier_str = 'sc0 sc1'
global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line]
flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line]
if cache == '' or cache == '.ca':
assert cg_cache_modifier_str not in global_load_line[0]
if cache == '.cg':
assert cg_cache_modifier_str in global_load_line[0]
if cache == '.cv':
assert cv_cache_modifier_str in flat_load_line[0]

if is_cuda():
ptx = pgm.asm['ptx']
Expand Down
2 changes: 2 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,8 @@ def _str_to_load_cache_modifier(cache_modifier):
cache = ir.CACHE_MODIFIER.CA
elif cache_modifier == ".cg":
cache = ir.CACHE_MODIFIER.CG
elif cache_modifier == ".cv":
cache = ir.CACHE_MODIFIER.CV
else:
raise ValueError(f"Cache modifier {cache_modifier} not supported")
return cache
Expand Down
8 changes: 7 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class CallOpConversion : public mlir::RewritePattern {
mlir::LLVM::AMD::predicatedLoadCG);
}

bool isPredicatedLoadCV(LLVM::CallOp callOp) const {
return callOp.getCallee().value().contains(
mlir::LLVM::AMD::predicatedLoadCV);
}

bool isPredicatedStore(LLVM::CallOp callOp) const {
return callOp.getCallee().value().contains(
mlir::LLVM::AMD::predicatedStore);
Expand Down Expand Up @@ -137,8 +142,9 @@ class CallOpConversion : public mlir::RewritePattern {
| vialatile | non-tmp | gcn instr gfx94
LLVM::LoadOp | 0 | 0 | (ca) global load
| 0/1 | 1 | (cg) global load nt
| 1 | 0 | (cv) flat load sc0 sc1
*/
bool vialatileFlag = false;
bool vialatileFlag = isPredicatedLoadCV(callOp);
bool nonTmpFlag = isPredicatedLoadCG(callOp);
auto loadOp = rewriter.create<LLVM::LoadOp>(
loc, elemTy, ptr, /*alignment=*/0, vialatileFlag, nonTmpFlag);
Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
return predicatedLoadCA;
case triton::CacheModifier::CG:
return predicatedLoadCG;
case triton::CacheModifier::CV:
return predicatedLoadCV;
default:
// Do not fail in compile time in the case of unsupported modifier.
// Just apply default config.
Expand Down
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace mlir::LLVM::AMD {
const char predicatedLoad[] = "__predicated_load";
const char predicatedLoadCA[] = "__predicated_load_CA";
const char predicatedLoadCG[] = "__predicated_load_CG";
const char predicatedLoadCV[] = "__predicated_load_CV";
const char predicatedStore[] = "__predicated_store";
const char predicatedStoreCG[] = "__predicated_store_CG";
const char predicatedStoreCS[] = "__predicated_store_CS";
Expand Down

0 comments on commit f808819

Please sign in to comment.