From f808819c18e118e0df415a8caaee3a0ed126f4c8 Mon Sep 17 00:00:00 2001 From: Ilya V <152324710+joviliast@users.noreply.github.com> Date: Wed, 18 Sep 2024 23:01:18 +0200 Subject: [PATCH] [AMD] Add load CV cache modifier (#4746) - Supported load CV cache modifier. It corresponds to `flat load sc0 sc1` - Added unit test for `gfx94` arch Signed-off-by: Ilya Veselov --- include/triton/Dialect/Triton/IR/TritonAttrDefs.td | 1 + python/src/ir.cc | 1 + python/test/unit/language/test_core.py | 6 +++++- python/triton/language/semantic.py | 2 ++ .../amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp | 8 +++++++- third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 2 ++ third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h | 1 + 7 files changed, 19 insertions(+), 2 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index adfeaff6f788..e0d8c7ce35cb 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -13,6 +13,7 @@ def TT_CacheModifierAttr : I32EnumAttr< I32EnumAttrCase<"WB", 4, "wb">, I32EnumAttrCase<"CS", 5, "cs">, I32EnumAttrCase<"WT", 6, "wt">, + I32EnumAttrCase<"CV", 7, "cv">, ]> { let cppNamespace = "::mlir::triton"; } diff --git a/python/src/ir.cc b/python/src/ir.cc index 95e48a692041..abbf7ab41828 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -162,6 +162,7 @@ void init_triton_ir(py::module &&m) { .value("WB", CacheModifier::WB) .value("CS", CacheModifier::CS) .value("WT", CacheModifier::WT) + .value("CV", CacheModifier::CV) .export_values(); py::enum_(m, "MEM_SEMANTIC", py::module_local()) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3f88118e812e..5cec71b40175 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3782,7 +3782,7 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_ @pytest.mark.interpreter -@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) +@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"]) def test_load_cache_modifier(cache, device): src = torch.empty(128, device=device) dst = torch.empty(128, device=device) @@ -3802,11 +3802,15 @@ def _kernel(dst, src, CACHE: tl.constexpr): return amdgcn = pgm.asm['amdgcn'] cg_cache_modifier_str = 'nt' + cv_cache_modifier_str = 'sc0 sc1' global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line] + flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line] if cache == '' or cache == '.ca': assert cg_cache_modifier_str not in global_load_line[0] if cache == '.cg': assert cg_cache_modifier_str in global_load_line[0] + if cache == '.cv': + assert cv_cache_modifier_str in flat_load_line[0] if is_cuda(): ptx = pgm.asm['ptx'] diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 188e3279a80f..c153a44ae89e 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -922,6 +922,8 @@ def _str_to_load_cache_modifier(cache_modifier): cache = ir.CACHE_MODIFIER.CA elif cache_modifier == ".cg": cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cv": + cache = ir.CACHE_MODIFIER.CV else: raise ValueError(f"Cache modifier {cache_modifier} not supported") return cache diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp index 62a1c502145a..edf74cea4259 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -51,6 +51,11 @@ class CallOpConversion : public mlir::RewritePattern { mlir::LLVM::AMD::predicatedLoadCG); } + bool isPredicatedLoadCV(LLVM::CallOp callOp) const { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedLoadCV); + } + bool isPredicatedStore(LLVM::CallOp callOp) const { return callOp.getCallee().value().contains( mlir::LLVM::AMD::predicatedStore); @@ -137,8 +142,9 @@ class CallOpConversion : public mlir::RewritePattern { | vialatile | non-tmp | gcn instr gfx94 LLVM::LoadOp | 0 | 0 | (ca) global load | 0/1 | 1 | (cg) global load nt + | 1 | 0 | (cv) flat load sc0 sc1 */ - bool vialatileFlag = false; + bool vialatileFlag = isPredicatedLoadCV(callOp); bool nonTmpFlag = isPredicatedLoadCG(callOp); auto loadOp = rewriter.create( loc, elemTy, ptr, /*alignment=*/0, vialatileFlag, nonTmpFlag); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 01c9abe16e0a..262055d64508 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -217,6 +217,8 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, return predicatedLoadCA; case triton::CacheModifier::CG: return predicatedLoadCG; + case triton::CacheModifier::CV: + return predicatedLoadCV; default: // Do not fail in compile time in the case of unsupported modifier. // Just apply default config. diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index c7a23937e32e..631f9bcce45d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -13,6 +13,7 @@ namespace mlir::LLVM::AMD { const char predicatedLoad[] = "__predicated_load"; const char predicatedLoadCA[] = "__predicated_load_CA"; const char predicatedLoadCG[] = "__predicated_load_CG"; +const char predicatedLoadCV[] = "__predicated_load_CV"; const char predicatedStore[] = "__predicated_store"; const char predicatedStoreCG[] = "__predicated_store_CG"; const char predicatedStoreCS[] = "__predicated_store_CS";