Skip to content

Commit

Permalink
[BACKEND] Remove ttg.cmp and ttg.select and replace by arith op (trit…
Browse files Browse the repository at this point in the history
…on-lang#2526)

Now that the bug related to attribute is fixed in MLIR we can use arith
ops for cmp and select ops.
  • Loading branch information
ThomasRaoux authored and zhanglx13 committed Nov 9, 2023
1 parent c070b98 commit ac4ee36
Show file tree
Hide file tree
Showing 20 changed files with 277 additions and 368 deletions.
48 changes: 0 additions & 48 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,54 +85,6 @@ def TTG_AsyncBulkCommitGroupOp : TTG_Op<"async_bulk_commit_group"> {
}];
}


// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
// This is needed because these ops don't
// handle encodings
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "integer comparison operation";

let description = [{}];

let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
TT_IntLike:$lhs,
TT_IntLike:$rhs);

let results = (outs TT_BoolLike:$result);
}

def TTG_CmpFOp : TTG_Op<"cmpf", [Pure, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "floating-point comparison operation";

let description = [{}];

let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
TT_FloatLike:$lhs,
TT_FloatLike:$rhs);

let results = (outs TT_BoolLike:$result);
}

// TODO: migrate to arith::SelectOp on LLVM16
def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "select operation";

let description = [{}];

let arguments = (ins TT_BoolLike:$condition,
TT_Tensor:$true_value,
TT_Tensor:$false_value);

let results = (outs TT_Type:$result);
}

// TODO[goostavz]: extract a base class for InsertSlice & InsertSliceAsync once the op definition is verified
def TTG_InsertSliceOp : TTG_Op<"insert_slice",
[AttrSizedOperandSegments,
Expand Down
10 changes: 2 additions & 8 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,10 +635,6 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
}

private:
static arith::CmpIPredicate getPredicate(triton::gpu::CmpIOp op) {
return op.getPredicate();
}

static arith::CmpIPredicate getPredicate(arith::CmpIOp op) {
return op.getPredicate();
}
Expand Down Expand Up @@ -917,13 +913,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
visitors.append<BroadcastOpAxisInfoVisitor>();
visitors.append<SplatOpAxisInfoVisitor>();
visitors.append<ExpandDimsOpAxisInfoVisitor>();
visitors.append<CmpOpAxisInfoVisitor<arith::CmpIOp>,
CmpOpAxisInfoVisitor<triton::gpu::CmpIOp>>();
visitors.append<CmpOpAxisInfoVisitor<arith::CmpIOp>>();
visitors.append<LogicalOpAxisInfoVisitor<arith::AndIOp>,
LogicalOpAxisInfoVisitor<arith::OrIOp>,
LogicalOpAxisInfoVisitor<arith::XOrIOp>>();
visitors.append<SelectOpAxisInfoVisitor<mlir::arith::SelectOp>,
SelectOpAxisInfoVisitor<triton::gpu::SelectOp>>();
visitors.append<SelectOpAxisInfoVisitor<mlir::arith::SelectOp>>();
visitors.append<ShLIOpAxisInfoVisitor, ShROpAxisInfoVisitor<arith::ShRUIOp>,
ShROpAxisInfoVisitor<arith::ShRSIOp>>();
visitors.append<MaxMinOpAxisInfoVisitor<arith::MaxSIOp>,
Expand Down
24 changes: 10 additions & 14 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1491,18 +1491,17 @@ Value EmitDualBF16ElementwiseOp(Location loc,
}

struct CmpIOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
CmpIOpConversion> {
using Base =
ElementwiseOpConversionBase<triton::gpu::CmpIOp, CmpIOpConversion>;
: public ElementwiseOpConversionBase<arith::CmpIOp, CmpIOpConversion> {
using Base = ElementwiseOpConversionBase<arith::CmpIOp, CmpIOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

// An interface to support variant DestOp builder.
SmallVector<LLVM::ICmpOp>
createDestOps(triton::gpu::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
MultipleOperandsRange operands, Location loc) const {
SmallVector<LLVM::ICmpOp> createDestOps(arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy,
MultipleOperandsRange operands,
Location loc) const {
return {rewriter.create<LLVM::ICmpOp>(
loc, elemTy, ArithCmpIPredicateToLLVM(op.getPredicate()),
operands[0][0], operands[0][1])};
Expand Down Expand Up @@ -1533,16 +1532,14 @@ struct CmpIOpConversion
};

struct CmpFOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp,
CmpFOpConversion> {
using Base =
ElementwiseOpConversionBase<triton::gpu::CmpFOp, CmpFOpConversion>;
: public ElementwiseOpConversionBase<arith::CmpFOp, CmpFOpConversion> {
using Base = ElementwiseOpConversionBase<arith::CmpFOp, CmpFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

// An interface to support variant DestOp builder.
static SmallVector<LLVM::FCmpOp>
createDestOps(triton::gpu::CmpFOp op, OpAdaptor adaptor,
createDestOps(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
MultipleOperandsRange operands, Location loc) {
return {rewriter.create<LLVM::FCmpOp>(
Expand Down Expand Up @@ -2101,7 +2098,6 @@ void populateElementwiseOpToLLVMPatterns(
int computeCapability, PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp)
#undef POPULATE_TERNARY_OP

Expand Down
62 changes: 3 additions & 59 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,6 @@ template <class Op> struct GenericOpPattern : public OpConversionPattern<Op> {
}
};

template <class SrcOp, class DstOp>
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
public:
using OpConversionPattern<SrcOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
addNamedAttrs(
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
adaptor.getLhs(), adaptor.getRhs()),
adaptor.getAttributes());
return success();
}
};

class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
public:
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
Expand Down Expand Up @@ -122,8 +105,9 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
GenericOpPattern<arith::RemFOp>,
// Cmp
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
GenericOpPattern<arith::CmpIOp>, GenericOpPattern<arith::CmpFOp>,
// Select
GenericOpPattern<arith::SelectOp>,
// Cast Ops
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
Expand All @@ -132,45 +116,6 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
}

// this shouldn't exist if mlir's SelectOp checked encodings properly
class StdSelectPattern : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());

Value cond = adaptor.getCondition();
if (llvm::isa<RankedTensorType>(retType) &&
!llvm::isa<TensorType>(cond.getType())) {
// triton_gpu.select doesn't support scalar condition values, so add a
// splat
auto retTypeTensor = llvm::cast<RankedTensorType>(retType);
auto retShape = retTypeTensor.getShape();
auto retEncoding = retTypeTensor.getEncoding();
Type condTy =
RankedTensorType::get(retShape, cond.getType(), retEncoding);
cond = rewriter.create<triton::SplatOp>(op.getLoc(), condTy, cond);
}

addNamedAttrs(
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, cond, adaptor.getTrueValue(), adaptor.getFalseValue()),
adaptor.getAttributes());
return success();
}
};

void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns,
TritonGPUConversionTarget &target) {
MLIRContext *context = patterns.getContext();
// Rewrite rule
patterns.add<StdSelectPattern>(typeConverter, context);
}

void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns,
TritonGPUConversionTarget &target) {
Expand Down Expand Up @@ -745,7 +690,6 @@ class ConvertTritonToTritonGPU
// rewrite patterns
RewritePatternSet patterns(context);
// add rules
populateStdPatternsAndLegality(typeConverter, patterns, target);
populateArithPatternsAndLegality(typeConverter, patterns, target);
populateMathPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns, numCTAs);
Expand Down
2 changes: 0 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
// We have custom versions of some arith operators
addIllegalOp<arith::CmpIOp, arith::CmpFOp>();

addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
triton::TritonDialect, cf::ControlFlowDialect,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
return true;
if (auto externElementwiseOp = dyn_cast<triton::ExternElementwiseOp>(op))
return externElementwiseOp.getPure();
if (llvm::isa<ttg::CmpIOp, ttg::CmpFOp, ttg::SelectOp>(op))
if (llvm::isa<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>(op))
return true;
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,8 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, int computeCapability) {
return !(boxDimSwizzle && strideDivisible && enableTMA);
}

// TODO: When encoding exists use triton::gpu::CmpIOp as arith::CmpIOp doesn't
// play well with encoding attributes. Move back to arith::CmpIOp when this pass
// moves back to triton IR level.
Value createCmpOp(OpBuilder &builder, Location loc, RankedTensorType type,
arith::CmpIPredicate pred, Value lhs, Value rhs) {
if (type.getEncoding())
return builder.create<ttg::CmpIOp>(loc, type, pred, lhs, rhs);
return builder.create<arith::CmpIOp>(loc, type, pred, lhs, rhs);
}

Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2149,7 +2149,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
tt.reduce.return %13 : i32"""
elif op == "max":
op_str = f"""
%13 = "{GPU_DIALECT}.cmpi"(%arg2, %arg3) <{{predicate = 4 : i64}}> : (i32, i32) -> i1
%13 = arith.cmpi "sgt", %arg2, %arg3 : i32
%14 = arith.select %13, %arg2, %arg3 : i32
tt.reduce.return %14 : i32"""
ir = f"""
Expand Down
2 changes: 1 addition & 1 deletion test/Analysis/test-alignment.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ tt.func @select() {
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
%5 = arith.select %4, %3, %7 : tensor<128xi1>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
%8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1>
%8 = arith.select %7, %3, %2 : tensor<128xi1>, tensor<128xi1>
tt.return
}

Expand Down
3 changes: 1 addition & 2 deletions test/Conversion/triton_to_tritongpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ tt.func public @select_op(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32>

// CHECK: %[[splat:.*]] = tt.splat %arg2 : (i1) -> tensor<128xi1, #blocked>
// CHECK-NEXT: %{{.*}} = "triton_gpu.select"(%[[splat]], %{{.*}}, %{{.*}}) : (tensor<128xi1, #blocked>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>) -> tensor<128xf32, #blocked>
// CHECK: %{{.*}} = arith.select %arg2, %{{.*}}, %{{.*}} : tensor<128xf32, #blocked>
%4 = arith.select %arg2, %cst, %3 : tensor<128xf32>

%5 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>, #blocked>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>, #blocked>, tensor<64xi32, #blocked>
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked>
%10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked>
%10 = arith.cmpi "slt", %4, %9 : tensor<64xi32, #blocked>
// load op has a vector width = 1 due to the %mask's alignment
// GCN-NOT: llvm.inline_asm
// GCN: llvm.addrspacecast {{.*}} : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
Expand Down
31 changes: 31 additions & 0 deletions test/Triton/print.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: triton-translate %s --mlir-print-ir-after-all -o %t 2>&1 | FileCheck %s

// CHECK: IR Dump After SCFToControlFlow (convert-scf-to-cf)
// CHECK: tt.func public @add_kernel_0d1d2d3de
// CHECK: IR Dump After ConvertIndexToLLVMPass (convert-index-to-llvm)
// CHECK: tt.func public @add_kernel_0d1d2d3de

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @add_kernel_0d1d2d3de(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked>
%6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked>
%7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%10 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
%12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%13 = arith.addf %9, %12 : tensor<1024xf32, #blocked>
%14 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf32, #blocked>
tt.return
}
}
2 changes: 1 addition & 1 deletion test/Triton/vecadd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ module {
// %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %6 = arith.cmpi "slt", %4, %5 : (tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32>
// %9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
Expand Down
4 changes: 2 additions & 2 deletions test/TritonGPU/coalesce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ tt.func public @load_tensors_two_types(%arg0: !tt.ptr<f32, 1> {tt.divisibility =
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked>
%6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked>
%6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked>
%7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
Expand Down Expand Up @@ -120,7 +120,7 @@ tt.func public @load_tensors_two_types(%arg0: !tt.ptr<f32, 1> {tt.divisibility =
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked>
%6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked>
%6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked>
%7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
Expand Down
Loading

0 comments on commit ac4ee36

Please sign in to comment.