Skip to content

Commit

Permalink
Use partial conversion in TritonToLinalg (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
nhat-nguyen authored Nov 16, 2023
1 parent eb06d57 commit c2fb18d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
21 changes: 10 additions & 11 deletions lib/Conversion/TritonToLinalg/TritonToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,16 +574,14 @@ struct LoadConverter : public OpConversionPattern<triton::LoadOp> {
// The workaround is to broadcast the pointers early in the address
// calculation. A proper fix is complicated, but at least we can provide a
// better error message.
op.emitError("LoadOp expects a memref, not a memref of pointers");
return failure();
return rewriter.notifyMatchFailure(
op, "LoadOp expects a memref, not a memref of pointers");
}

DictionaryAttr attrs;
auto tensorType =
RankedTensorType::get(type.getShape(), type.getElementType(), attrs);
RankedTensorType::get(type.getShape(), type.getElementType());
auto alloc = rewriter.create<memref::AllocOp>(
loc, MemRefType::get(type.getShape(), type.getElementType(),
AffineMap(), attrs));
loc, MemRefType::get(type.getShape(), type.getElementType()));

if (!mask) {
assert(!other && "other value used in non-masked load");
Expand Down Expand Up @@ -626,7 +624,8 @@ struct LoadConverter : public OpConversionPattern<triton::LoadOp> {
auto isContMask = mstate.parse(mask, loc, rewriter);

if (isContMask.failed()) {
return op.emitError("Cannot lower continuous masked loads");
return rewriter.notifyMatchFailure(
op, "Cannot lower continuous masked loads");
}

// fill load destination with other value
Expand Down Expand Up @@ -930,8 +929,9 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
// subview that skips over each first element.
if (reductionOps.size() != 1 ||
!isReductionOpSupported(reductionOps.front())) {
return op.emitError("Only support lowering reduction with body "
"containing one max(i/f) or add(i/f).");
return rewriter.notifyMatchFailure(
op, "Only support lowering reduction with body "
"containing 1 max(i/f) or addf.");
}

auto rop = reductionOps.front();
Expand Down Expand Up @@ -1044,8 +1044,6 @@ struct GetProgramIDConverter
getMaxEnumValForProgramIDDim() + 1;

public:
GetProgramIDConverter(MLIRContext *context) : OpConversionPattern(context) {}

LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1311,6 +1309,7 @@ void mlir::triton::populateTritonToLinalgConversionPatterns(
unsigned int launchGridRank) {
populateFunctionOpInterfaceTypeConversionPattern<triton::FuncOp>(
patterns, typeConverter);

patterns.add<MetaOpConverter>(patterns.getContext());
patterns.add<StoreConverter>(patterns.getContext());
patterns.add<AddPtrConverter>(patterns.getContext());
Expand Down
15 changes: 2 additions & 13 deletions lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TritonTypeConverter : public TypeConverter {
}
};

struct TritonToLinalgPass : public TritonToLinalgBase<TritonToLinalgPass> {
class TritonToLinalgPass : public TritonToLinalgBase<TritonToLinalgPass> {

static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1;
static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT =
Expand Down Expand Up @@ -124,17 +124,6 @@ struct TritonToLinalgPass : public TritonToLinalgBase<TritonToLinalgPass> {

target.addLegalOp<ModuleOp>();

target.addIllegalDialect<triton::TritonDialect>();

// triton.reduce will be lowered to linalg.reduce. Unfortunately, mlir
// inserts the ops inside triton.reduce's region BEFORE triton.reduce
// itself, so the conversion algorithm will visit triton.reduce_return
// first. Without marking this op as legal, the conversion process will fail
// because there's no legalization pattern for triton.reduce_return.
target.addLegalOp<triton::ReduceReturnOp>();

target.addLegalOp<triton::ReturnOp>();

// Update function signature to use memrefs
target.addDynamicallyLegalOp<triton::FuncOp>([&](triton::FuncOp op) {
return tritonTypeConverter.isSignatureLegal(op.getFunctionType());
Expand Down Expand Up @@ -192,7 +181,7 @@ struct TritonToLinalgPass : public TritonToLinalgBase<TritonToLinalgPass> {
for (auto func : getOperation().getOps<triton::FuncOp>())
addProgramInfo(func);

if (failed(applyFullConversion(moduleOp, target, std::move(patterns))))
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
signalPassFailure();

// Convert tt.func and tt.return into func's counterparts
Expand Down
35 changes: 35 additions & 0 deletions test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s

module {
tt.func public @rand(%arg0: !tt.ptr<i32, 1>, %arg1: !tt.ptr<i32, 1>) attributes {noinline = false} {
%0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32>
%1 = tt.splat %arg0 : (!tt.ptr<i32, 1>) -> tensor<8x!tt.ptr<i32, 1>>
%2 = tt.addptr %1, %0 : tensor<8x!tt.ptr<i32, 1>>, tensor<8xi32>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xi32>
%4 = tt.extern_elementwise %3, %0 {libname = "libdevice", libpath = "/path/to/something", pure = true, symbol = "some_symbol"} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
%5 = tt.splat %arg1 : (!tt.ptr<i32, 1>) -> tensor<8x!tt.ptr<i32, 1>>
%6 = tt.addptr %5, %0 : tensor<8x!tt.ptr<i32, 1>>, tensor<8xi32>
tt.store %6, %4 {cache = 1 : i32, evict = 1 : i32} : tensor<8xi32>
tt.return
}
}

// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @rand
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) {
// CHECK: [[VAR_0_:%.+]] = tensor.empty() : tensor<8xi32>
// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<8xi32>) {
// CHECK: ^bb0([[out:.+]]: i32):
// CHECK: [[VAR_4_:%.+]] = linalg.index 0 : index
// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : index to i32
// CHECK: linalg.yield [[VAR_5_]] : i32
// CHECK: } -> tensor<8xi32>
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [8], strides: [1] : memref<*xi32> to memref<8xi32, strided<[1]>>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<8xi32>
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<8xi32, strided<[1]>> to memref<8xi32>
// CHECK: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<8xi32>
// CHECK-DAG: [[VAR_3_:%.+]] = tt.extern_elementwise [[VAR_2_]], [[VAR_1_]] {libname = "libdevice", libpath = "/path/to/something", pure = true, symbol = "some_symbol"} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [8], strides: [1] : memref<*xi32> to memref<8xi32, strided<[1]>>
// CHECK: memref.tensor_store [[VAR_3_]], [[VAR_reinterpret_cast_0_]] : memref<8xi32, strided<[1]>>
// CHECK: return
// CHECK: }

0 comments on commit c2fb18d

Please sign in to comment.