Skip to content

Commit

Permalink
fix issues after LLVM update (#22)
Browse files Browse the repository at this point in the history
* fix issues after LLVM update

* update LKG for PRs

* update tests
  • Loading branch information
manbearian authored Oct 17, 2023
1 parent 07ea842 commit 450e6be
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ jobs:
call-workflow:
uses: ./.github/workflows/test-plugin.yml
with:
triton-ref: '215b2e77a1d92f907bc4addfa3c3ba204028c32e' # known good commit "Add Shared Middle Layer to Triton via Plug-In"
triton-ref: '05dc28be0e72dd496300a31b99a21a5a5118f8e9' # known good commit "[CI] refactor workflows (#2504)"
triton-shared-ref: ${{ github.ref }}
2 changes: 2 additions & 0 deletions include/triton-shared/Analysis/UseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class UseAnalysis : public dataflow::SparseBackwardDataFlowAnalysis<UseInfo> {

void visitBranchOperand(OpOperand &operand) override { return; }

void visitCallOperand(OpOperand &operand) override { return; }

void setToExitState(UseInfo *lattice) override {
lattice->type = UseType::Undefined;
}
Expand Down
4 changes: 2 additions & 2 deletions lib/Analysis/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ void PtrAnalysis::rewriteForOp(
mapping.map(op.getInitArgs(), newInitArgs);
mapping.map(op.getRegionIterArgs(), args);

for (auto &bodyOp : op.getLoopBody().getOps()) {
for (auto &bodyOp : op.getRegion().getOps()) {
b.clone(bodyOp, mapping);
}

Expand Down Expand Up @@ -1309,7 +1309,7 @@ void PtrAnalysis::rewriteForOp(

// Update the loop body. Manually invoke the rewrite logic on addptr and yield
// in the loop body, so we can take advantage of the states we built up
for (auto &bodyOp : newOp.getLoopBody().getOps()) {
for (auto &bodyOp : newOp.getRegion().getOps()) {
if (auto addptrOp = dyn_cast<triton::AddPtrOp>(bodyOp)) {
rewriteAddptrOp(addptrOp, rewriter, knownPtrs);
} else if (auto advanceOp = dyn_cast<triton::AdvanceOp>(bodyOp)) {
Expand Down
12 changes: 6 additions & 6 deletions lib/Conversion/TritonToLinalg/TritonToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
}

bool isReductionOpSupported(Operation *redOp) const {
return isa<arith::AddFOp, arith::MaxFOp, arith::MinSIOp, arith::MinUIOp,
return isa<arith::AddFOp, arith::MaximumFOp, arith::MinSIOp, arith::MinUIOp,
arith::MaxSIOp, arith::MaxUIOp>(redOp);
}

Expand All @@ -852,7 +852,7 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
.Case([&](arith::AddFOp) {
return rewriter.getFloatAttr(constantType, 0.f);
})
.Case([&](arith::MaxFOp) {
.Case([&](arith::MaximumFOp) {
return rewriter.getFloatAttr(
constantType, -std::numeric_limits<float>::infinity());
})
Expand Down Expand Up @@ -899,7 +899,7 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
}
return b.create<arith::AddFOp>(loc, lhs, rhs);
})
.Case<arith::MaxFOp, arith::MinSIOp, arith::MinUIOp, arith::MaxSIOp,
.Case<arith::MaximumFOp, arith::MinSIOp, arith::MinUIOp, arith::MaxSIOp,
arith::MaxUIOp>([&](auto redOp) {
return b.create<decltype(redOp)>(loc, lhs, rhs);
})
Expand Down Expand Up @@ -1156,11 +1156,11 @@ struct MinMaxConverter : public OpRewritePattern<CmpOp> {
arith::CmpFPredicate pred) const {
switch (pred) {
case arith::CmpFPredicate::OGT:
rewriter.replaceOpWithNewOp<arith::MaxFOp>(selectOp, cmpOp.getLhs(),
cmpOp.getRhs());
rewriter.replaceOpWithNewOp<arith::MaximumFOp>(selectOp, cmpOp.getLhs(),
cmpOp.getRhs());
break;
case arith::CmpFPredicate::OLT:
rewriter.replaceOpWithNewOp<arith::MinFOp>(selectOp, cmpOp.getLhs(),
rewriter.replaceOpWithNewOp<arith::MinimumFOp>(selectOp, cmpOp.getLhs(),
cmpOp.getRhs());
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ module {
// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_]] into [[VAR_6_]][] : tensor<f32>
// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_5_]] : tensor<128xf32>) outs([[VAR_inserted_]] : tensor<f32>) dimensions = [0]
// CHECK: ([[in_1:%.+]]: f32, [[init_1:%.+]]: f32) {
// CHECK: [[VAR_19_:%.+]] = arith.maxf [[in_1]], [[init_1]] : f32
// CHECK: [[VAR_19_:%.+]] = arith.maximumf [[in_1]], [[init_1]] : f32
// CHECK: linalg.yield [[VAR_19_]] : f32
// CHECK: }
// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor<f32>
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ module {
// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : tensor<256x16xbf16>) -> tensor<256x16xbf16>
// CHECK: %[[VAL_12:.*]] = linalg.reduce ins(%[[VAL_9]] : tensor<32x256x16xbf16>) outs(%[[VAL_11]] : tensor<256x16xbf16>) dimensions = [0]
// CHECK: (%[[VAL_13:.*]]: bf16, %[[VAL_14:.*]]: bf16) {
// CHECK: %[[VAL_15:.*]] = arith.maxf %[[VAL_13]], %[[VAL_14]] : bf16
// CHECK: %[[VAL_15:.*]] = arith.maximumf %[[VAL_13]], %[[VAL_14]] : bf16
// CHECK: linalg.yield %[[VAL_15]] : bf16
// CHECK: }
// CHECK: memref.tensor_store %[[VAL_12]], %[[VAL_1]] : memref<256x16xbf16>
Expand Down

0 comments on commit 450e6be

Please sign in to comment.