diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index 2713e745e272..06852b0fec9f 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -3,13 +3,15 @@ include "mlir/IR/EnumAttr.td" -// Attributes for LoadOp +// Attributes for LoadOp and StoreOp def TT_CacheModifierAttr : I32EnumAttr< "CacheModifier", "", [ I32EnumAttrCase<"NONE", 1, "none">, I32EnumAttrCase<"CA", 2, "ca">, I32EnumAttrCase<"CG", 3, "cg">, + I32EnumAttrCase<"WB", 4, "wb">, + I32EnumAttrCase<"CS", 5, "cs">, ]> { let cppNamespace = "::mlir::triton"; } diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index df7ca527107b..37e336480800 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -354,6 +354,9 @@ struct StoreOpConversion auto &ptxStoreInstr = ptxBuilder.create<>("st") ->global() + .o("wb", op.getCache() == triton::CacheModifier::WB) + .o("cg", op.getCache() == triton::CacheModifier::CG) + .o("cs", op.getCache() == triton::CacheModifier::CS) .o("L1::evict_first", op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) .o("L1::evict_last", diff --git a/python/src/triton.cc b/python/src/triton.cc index 0f351cae84dd..be93cf4ba1dc 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -89,6 +89,8 @@ void init_triton_ir(py::module &&m) { .value("NONE", mlir::triton::CacheModifier::NONE) .value("CA", mlir::triton::CacheModifier::CA) .value("CG", mlir::triton::CacheModifier::CG) + .value("WB", mlir::triton::CacheModifier::WB) + .value("CS", mlir::triton::CacheModifier::CS) .export_values(); py::enum_(m, "MEM_SEMANTIC") diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 2b1c68cbb7b7..b4e5f328f65d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2266,6 +2266,37 @@ def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): # test store # --------------- + +@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs"]) +def test_store_cache_modifier(cache): + src = torch.empty(128, device='cuda') + dst = torch.empty(128, device='cuda') + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, cache_modifier=CACHE) + + pgm = _kernel[(1,)](dst, src, CACHE=cache) + ptx = pgm.asm['ptx'] + if cache == '': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + if cache == '.wb': + assert 'st.global.wb' in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + if cache == '.cg': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' in ptx + assert 'st.global.cs' not in ptx + if cache == '.cs': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' in ptx + # --------------- # test if # --------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 19c409dfa54e..33bff321dd61 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -775,7 +775,7 @@ def cast(input: tl.tensor, # ===----------------------------------------------------------------------===// -def _str_to_cache_modifier(cache_modifier): +def _str_to_load_cache_modifier(cache_modifier): cache = ir.CACHE_MODIFIER.NONE # default if cache_modifier: if cache_modifier == ".ca": @@ -787,6 +787,20 @@ def _str_to_cache_modifier(cache_modifier): return cache +def _str_to_store_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".wb": + cache = ir.CACHE_MODIFIER.WB + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + def _str_to_eviction_policy(eviction_policy): eviction = ir.EVICTION_POLICY.NORMAL # default if eviction_policy: @@ -929,7 +943,7 @@ def load(ptr: tl.tensor, is_volatile: bool, builder: ir.builder) -> tl.tensor: # Cache, eviction and padding options - cache = _str_to_cache_modifier(cache_modifier) + cache = _str_to_load_cache_modifier(cache_modifier) eviction = _str_to_eviction_policy(eviction_policy) padding = _str_to_padding_option(padding_option) @@ -1018,7 +1032,7 @@ def store(ptr: tl.tensor, eviction_policy: str, builder: ir.builder) -> tl.tensor: # Cache and eviction options - cache = _str_to_cache_modifier(cache_modifier) + cache = _str_to_store_cache_modifier(cache_modifier) eviction = _str_to_eviction_policy(eviction_policy) if ptr.type.is_ptr() and ptr.type.element_ty.is_block():