Skip to content

Commit

Permalink
[BACKEND] Add store cache modifiers (#1826)
Browse files Browse the repository at this point in the history
Plumb through store cache modifiers.
  • Loading branch information
ThomasRaoux committed Jun 23, 2023
1 parent 2eb7bc4 commit 3d1cd89
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 4 deletions.
4 changes: 3 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

include "mlir/IR/EnumAttr.td"

// Attributes for LoadOp
// Attributes for LoadOp and StoreOp
def TT_CacheModifierAttr : I32EnumAttr<
"CacheModifier", "",
[
I32EnumAttrCase<"NONE", 1, "none">,
I32EnumAttrCase<"CA", 2, "ca">,
I32EnumAttrCase<"CG", 3, "cg">,
I32EnumAttrCase<"WB", 4, "wb">,
I32EnumAttrCase<"CS", 5, "cs">,
]> {
let cppNamespace = "::mlir::triton";
}
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ struct StoreOpConversion
auto &ptxStoreInstr =
ptxBuilder.create<>("st")
->global()
.o("wb", op.getCache() == triton::CacheModifier::WB)
.o("cg", op.getCache() == triton::CacheModifier::CG)
.o("cs", op.getCache() == triton::CacheModifier::CS)
.o("L1::evict_first",
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
Expand Down
2 changes: 2 additions & 0 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ void init_triton_ir(py::module &&m) {
.value("NONE", mlir::triton::CacheModifier::NONE)
.value("CA", mlir::triton::CacheModifier::CA)
.value("CG", mlir::triton::CacheModifier::CG)
.value("WB", mlir::triton::CacheModifier::WB)
.value("CS", mlir::triton::CacheModifier::CS)
.export_values();

py::enum_<mlir::triton::MemSemantic>(m, "MEM_SEMANTIC")
Expand Down
31 changes: 31 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2266,6 +2266,37 @@ def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr):
# test store
# ---------------


@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs"])
def test_store_cache_modifier(cache):
src = torch.empty(128, device='cuda')
dst = torch.empty(128, device='cuda')

@triton.jit
def _kernel(dst, src, CACHE: tl.constexpr):
offsets = tl.arange(0, 128)
x = tl.load(src + offsets)
tl.store(dst + offsets, x, cache_modifier=CACHE)

pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if cache == '':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
if cache == '.wb':
assert 'st.global.wb' in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
if cache == '.cg':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' in ptx
assert 'st.global.cs' not in ptx
if cache == '.cs':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' in ptx

# ---------------
# test if
# ---------------
Expand Down
20 changes: 17 additions & 3 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def cast(input: tl.tensor,
# ===----------------------------------------------------------------------===//


def _str_to_cache_modifier(cache_modifier):
def _str_to_load_cache_modifier(cache_modifier):
cache = ir.CACHE_MODIFIER.NONE # default
if cache_modifier:
if cache_modifier == ".ca":
Expand All @@ -787,6 +787,20 @@ def _str_to_cache_modifier(cache_modifier):
return cache


def _str_to_store_cache_modifier(cache_modifier):
cache = ir.CACHE_MODIFIER.NONE # default
if cache_modifier:
if cache_modifier == ".wb":
cache = ir.CACHE_MODIFIER.WB
elif cache_modifier == ".cg":
cache = ir.CACHE_MODIFIER.CG
elif cache_modifier == ".cs":
cache = ir.CACHE_MODIFIER.CS
else:
raise ValueError(f"Cache modifier {cache_modifier} not supported")
return cache


def _str_to_eviction_policy(eviction_policy):
eviction = ir.EVICTION_POLICY.NORMAL # default
if eviction_policy:
Expand Down Expand Up @@ -929,7 +943,7 @@ def load(ptr: tl.tensor,
is_volatile: bool,
builder: ir.builder) -> tl.tensor:
# Cache, eviction and padding options
cache = _str_to_cache_modifier(cache_modifier)
cache = _str_to_load_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)
padding = _str_to_padding_option(padding_option)

Expand Down Expand Up @@ -1018,7 +1032,7 @@ def store(ptr: tl.tensor,
eviction_policy: str,
builder: ir.builder) -> tl.tensor:
# Cache and eviction options
cache = _str_to_cache_modifier(cache_modifier)
cache = _str_to_store_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)

if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
Expand Down

0 comments on commit 3d1cd89

Please sign in to comment.