Skip to content

Commit

Permalink
Enable usage of block pointer semantics for AMD gpus (#301)
Browse files Browse the repository at this point in the history
* Enable usage of block pointer semantics for AMD gpus

This commit enables usage of block pointer semantics by enabling
rewrite_tensor_pointer_pass that rewrites block pointer loads/stores
to legacy loads/stores.

* Update FA fwd in tutorial to use the block pointers

* use 90 compute capability for amd gpus in python/triton/compiler/compiler.py

Co-authored-by: Alexander Efimov <[email protected]>

---------

Co-authored-by: Ognjen Plavsic <[email protected]>
Co-authored-by: Lixun Zhang <[email protected]>
Co-authored-by: Aleksandr Efimov <[email protected]>
Co-authored-by: Alexander Efimov <[email protected]>
  • Loading branch information
5 people authored Aug 24, 2023
1 parent fa42931 commit ff7e707
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 123 deletions.
4 changes: 2 additions & 2 deletions include/triton/Dialect/Triton/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ namespace triton {

std::unique_ptr<Pass> createCombineOpsPass();

std::unique_ptr<Pass>
createRewriteTensorPointerPass(int computeCapability = 80);
std::unique_ptr<Pass> createRewriteTensorPointerPass(int computeCapability = 80,
bool isROCM = false);

} // namespace triton

Expand Down
11 changes: 6 additions & 5 deletions lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,12 @@ class RewriteTensorPointerPass
: public TritonRewriteTensorPointerBase<RewriteTensorPointerPass> {
private:
int computeCapability;
bool isROCM;
DenseMap<Value, RewritedInfo> rewritedInfo;

public:
explicit RewriteTensorPointerPass(int computeCapability)
: computeCapability(computeCapability) {}
explicit RewriteTensorPointerPass(int computeCapability, bool isROCM)
: computeCapability(computeCapability), isROCM(isROCM) {}

static bool needRewrite(Operation *op) {
return std::any_of(op->getOperands().begin(), op->getOperands().end(),
Expand Down Expand Up @@ -470,7 +471,7 @@ class RewriteTensorPointerPass

void runOnOperation() override {
// Only rewrite if the hardware does not support
if (computeCapability >= 90)
if (!isROCM && computeCapability >= 90)
return;

// NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because
Expand Down Expand Up @@ -499,6 +500,6 @@ class RewriteTensorPointerPass
};

std::unique_ptr<Pass>
triton::createRewriteTensorPointerPass(int computeCapability) {
return std::make_unique<RewriteTensorPointerPass>(computeCapability);
triton::createRewriteTensorPointerPass(int computeCapability, bool isROCM) {
return std::make_unique<RewriteTensorPointerPass>(computeCapability, isROCM);
}
4 changes: 2 additions & 2 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1637,9 +1637,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::triton::createCombineOpsPass());
})
.def("add_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
[](mlir::PassManager &self, int computeCapability, bool isROCM) {
self.addPass(mlir::triton::createRewriteTensorPointerPass(
computeCapability));
computeCapability, isROCM));
})
.def(
"add_convert_triton_to_tritongpu_pass",
Expand Down
4 changes: 2 additions & 2 deletions python/test/unit/language/test_block_pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option:
for padding in ("zero", "nan")])
def test_block_copy(dtype_str, n, padding_option):
capability = torch.cuda.get_device_capability()
if capability[0] >= 9:
if torch.version.hip is None and capability[0] >= 9:
pytest.skip("Hopper support is working in progress")

dtype = getattr(torch, dtype_str)
Expand Down Expand Up @@ -82,7 +82,7 @@ def matmul_no_scf_with_advance_kernel(
])
def test_block_ptr_matmul_no_scf(shape, num_warps):
capability = torch.cuda.get_device_capability()
if capability[0] >= 9:
if torch.version.hip is None and capability[0] >= 9:
pytest.skip("Hopper support is working in progress")

m, n, k = shape
Expand Down
7 changes: 6 additions & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ def ttir_compute_capability_rewrite(mod, arch):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
if _is_cuda(arch):
pm.add_rewrite_tensor_pointer_pass(arch)
pm.add_rewrite_tensor_pointer_pass(arch, False)
elif is_hip():
capability = 90
pm.add_rewrite_tensor_pointer_pass(capability, True)
else:
assert(False, "unsupported target")
pm.run(mod)
return mod

Expand Down
Loading

0 comments on commit ff7e707

Please sign in to comment.