Skip to content

Commit

Permalink
[CIR][CIRGen][Builtin][Neon] Lower neon_vqadds_s32 (#1200)
Browse files Browse the repository at this point in the history
This can't be simply implemented by our CIR Add via LLVM::AddOp, as
i[t's saturated add.](https://godbolt.org/z/MxqGrj6fP)
  • Loading branch information
ghehg authored Dec 9, 2024
1 parent cfe7c63 commit 21e8647
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 40 deletions.
8 changes: 6 additions & 2 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,15 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
}

mlir::Value createSub(mlir::Value lhs, mlir::Value rhs, bool hasNUW = false,
bool hasNSW = false) {
bool hasNSW = false, bool saturated = false) {
auto op = create<cir::BinOp>(lhs.getLoc(), lhs.getType(),
cir::BinOpKind::Sub, lhs, rhs);
if (hasNUW)
op.setNoUnsignedWrap(true);
if (hasNSW)
op.setNoSignedWrap(true);
if (saturated)
op.setSaturated(true);
return op;
}

Expand All @@ -438,13 +440,15 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
}

mlir::Value createAdd(mlir::Value lhs, mlir::Value rhs, bool hasNUW = false,
bool hasNSW = false) {
bool hasNSW = false, bool saturated = false) {
auto op = create<cir::BinOp>(lhs.getLoc(), lhs.getType(),
cir::BinOpKind::Add, lhs, rhs);
if (hasNUW)
op.setNoUnsignedWrap(true);
if (hasNSW)
op.setNoSignedWrap(true);
if (saturated)
op.setSaturated(true);
return op;
}

Expand Down
4 changes: 3 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1192,12 +1192,14 @@ def BinOp : CIR_Op<"binop", [Pure,
let arguments = (ins Arg<BinOpKind, "binop kind">:$kind,
CIR_AnyType:$lhs, CIR_AnyType:$rhs,
UnitAttr:$no_unsigned_wrap,
UnitAttr:$no_signed_wrap);
UnitAttr:$no_signed_wrap,
UnitAttr:$saturated);

let assemblyFormat = [{
`(` $kind `,` $lhs `,` $rhs `)`
(`nsw` $no_signed_wrap^)?
(`nuw` $no_unsigned_wrap^)?
(`sat` $saturated^)?
`:` type($lhs) attr-dict
}];

Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2834,7 +2834,7 @@ static mlir::Value emitCommonNeonSISDBuiltinExpr(
case NEON::BI__builtin_neon_vqaddh_u16:
llvm_unreachable(" neon_vqaddh_u16 NYI ");
case NEON::BI__builtin_neon_vqadds_s32:
llvm_unreachable(" neon_vqadds_s32 NYI ");
return builder.createAdd(ops[0], ops[1], false, false, true);
case NEON::BI__builtin_neon_vqadds_u32:
llvm_unreachable(" neon_vqadds_u32 NYI ");
case NEON::BI__builtin_neon_vqdmulhh_s16:
Expand Down Expand Up @@ -2983,7 +2983,7 @@ static mlir::Value emitCommonNeonSISDBuiltinExpr(
case NEON::BI__builtin_neon_vqsubh_u16:
llvm_unreachable(" neon_vqsubh_u16 NYI ");
case NEON::BI__builtin_neon_vqsubs_s32:
llvm_unreachable(" neon_vqsubs_s32 NYI ");
return builder.createSub(ops[0], ops[1], false, false, true);
case NEON::BI__builtin_neon_vqsubs_u32:
llvm_unreachable(" neon_vqsubs_u32 NYI ");
case NEON::BI__builtin_neon_vrecped_f64:
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3785,6 +3785,7 @@ LogicalResult cir::AtomicFetch::verify() {

LogicalResult cir::BinOp::verify() {
bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
bool saturated = getSaturated();

if (!isa<cir::IntType>(getType()) && noWrap)
return emitError()
Expand All @@ -3794,9 +3795,18 @@ LogicalResult cir::BinOp::verify() {
getKind() == cir::BinOpKind::Sub ||
getKind() == cir::BinOpKind::Mul;

bool saturatedOps =
getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;

if (noWrap && !noWrapOps)
return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
"'sub' and 'mul'";
if (saturated && !saturatedOps)
return emitError() << "The saturated flag is applicable to opcodes: 'add' "
"and 'sub'";
if (noWrap && saturated)
return emitError() << "The nsw/nuw flags and the saturated flag are "
"mutually exclusive";

bool complexOps =
getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;
Expand Down
73 changes: 48 additions & 25 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2452,6 +2452,13 @@ CIRToLLVMBinOpLowering::getIntOverflowFlag(cir::BinOp op) const {
return mlir::LLVM::IntegerOverflowFlags::none;
}

static bool isIntTypeUnsigned(mlir::Type type) {
// TODO: Ideally, we should only need to check cir::IntType here.
return mlir::isa<cir::IntType>(type)
? mlir::cast<cir::IntType>(type).isUnsigned()
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
}

mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
cir::BinOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand All @@ -2464,65 +2471,81 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
"operand type not supported yet");

auto llvmTy = getTypeConverter()->convertType(op.getType());
mlir::Type llvmEltTy =
mlir::isa<mlir::VectorType>(llvmTy)
? mlir::cast<mlir::VectorType>(llvmTy).getElementType()
: llvmTy;
auto rhs = adaptor.getRhs();
auto lhs = adaptor.getLhs();

type = elementTypeIfVector(type);

switch (op.getKind()) {
case cir::BinOpKind::Add:
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
if (op.getSaturated()) {
if (isIntTypeUnsigned(type)) {
rewriter.replaceOpWithNewOp<mlir::LLVM::UAddSat>(op, lhs, rhs);
break;
}
rewriter.replaceOpWithNewOp<mlir::LLVM::SAddSat>(op, lhs, rhs);
break;
}
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs,
getIntOverflowFlag(op));
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmTy, lhs, rhs);
} else
rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, lhs, rhs);
break;
case cir::BinOpKind::Sub:
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
if (op.getSaturated()) {
if (isIntTypeUnsigned(type)) {
rewriter.replaceOpWithNewOp<mlir::LLVM::USubSat>(op, lhs, rhs);
break;
}
rewriter.replaceOpWithNewOp<mlir::LLVM::SSubSat>(op, lhs, rhs);
break;
}
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs,
getIntOverflowFlag(op));
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, llvmTy, lhs, rhs);
} else
rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, lhs, rhs);
break;
case cir::BinOpKind::Mul:
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
if (mlir::isa<mlir::IntegerType>(llvmEltTy))
rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs,
getIntOverflowFlag(op));
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, lhs, rhs);
break;
case cir::BinOpKind::Div:
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
auto isUnsigned = mlir::isa<cir::IntType>(type)
? mlir::cast<cir::IntType>(type).isUnsigned()
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
auto isUnsigned = isIntTypeUnsigned(type);
if (isUnsigned)
rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, lhs, rhs);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, lhs, rhs);
} else
rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, lhs, rhs);
break;
case cir::BinOpKind::Rem:
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
auto isUnsigned = mlir::isa<cir::IntType>(type)
? mlir::cast<cir::IntType>(type).isUnsigned()
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
auto isUnsigned = isIntTypeUnsigned(type);
if (isUnsigned)
rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, lhs, rhs);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, lhs, rhs);
} else
rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, lhs, rhs);
break;
case cir::BinOpKind::And:
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, lhs, rhs);
break;
case cir::BinOpKind::Or:
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, lhs, rhs);
break;
case cir::BinOpKind::Xor:
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, llvmTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, lhs, rhs);
break;
}

Expand Down
29 changes: 20 additions & 9 deletions clang/test/CIR/CodeGen/AArch64/neon.c
Original file line number Diff line number Diff line change
Expand Up @@ -9750,12 +9750,16 @@ poly16x8_t test_vmull_p8(poly8x8_t a, poly8x8_t b) {
// return vqaddh_s16(a, b);
// }

// NYI-LABEL: @test_vqadds_s32(
// NYI: [[VQADDS_S32_I:%.*]] = call i32 @llvm.aarch64.neon.sqadd.i32(i32 %a, i32 %b)
// NYI: ret i32 [[VQADDS_S32_I]]
// int32_t test_vqadds_s32(int32_t a, int32_t b) {
// return vqadds_s32(a, b);
// }
int32_t test_vqadds_s32(int32_t a, int32_t b) {
return vqadds_s32(a, b);

// CIR: vqadds_s32
// CIR: cir.binop(add, {{%.*}}, {{%.*}}) sat : !s32i

// LLVM:{{.*}}test_vqadds_s32(i32{{.*}}[[a:%.*]], i32{{.*}}[[b:%.*]])
// LLVM: [[VQADDS_S32_I:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[a]], i32 [[b]])
// LLVM: ret i32 [[VQADDS_S32_I]]
}

// NYI-LABEL: @test_vqaddd_s64(
// NYI: [[VQADDD_S64_I:%.*]] = call i64 @llvm.aarch64.neon.sqadd.i64(i64 %a, i64 %b)
Expand Down Expand Up @@ -9821,9 +9825,16 @@ poly16x8_t test_vmull_p8(poly8x8_t a, poly8x8_t b) {
// NYI-LABEL: @test_vqsubs_s32(
// NYI: [[VQSUBS_S32_I:%.*]] = call i32 @llvm.aarch64.neon.sqsub.i32(i32 %a, i32 %b)
// NYI: ret i32 [[VQSUBS_S32_I]]
// int32_t test_vqsubs_s32(int32_t a, int32_t b) {
// return vqsubs_s32(a, b);
// }
int32_t test_vqsubs_s32(int32_t a, int32_t b) {
return vqsubs_s32(a, b);

// CIR: vqsubs_s32
// CIR: cir.binop(sub, {{%.*}}, {{%.*}}) sat : !s32i

// LLVM:{{.*}}test_vqsubs_s32(i32{{.*}}[[a:%.*]], i32{{.*}}[[b:%.*]])
// LLVM: [[VQSUBS_S32_I:%.*]] = call i32 @llvm.ssub.sat.i32(i32 [[a]], i32 [[b]])
// LLVM: ret i32 [[VQSUBS_S32_I]]
}

// NYI-LABEL: @test_vqsubd_s64(
// NYI: [[VQSUBD_S64_I:%.*]] = call i64 @llvm.aarch64.neon.sqsub.i64(i64 %a, i64 %b)
Expand Down
27 changes: 27 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,33 @@ cir.func @bad_binop_for_nowrap(%x: !u32i, %y: !u32i) {

// -----

!u32i = !cir.int<u, 32>

cir.func @bad_binop_for_saturated(%x: !u32i, %y: !u32i) {
// expected-error@+1 {{The saturated flag is applicable to opcodes: 'add' and 'sub'}}
%0 = cir.binop(div, %x, %y) sat : !u32i
}

// -----

!s32i = !cir.int<s, 32>

cir.func @no_nsw_for_saturated(%x: !s32i, %y: !s32i) {
// expected-error@+1 {{The nsw/nuw flags and the saturated flag are mutually exclusive}}
%0 = cir.binop(add, %x, %y) nsw sat : !s32i
}

// -----

!s32i = !cir.int<s, 32>

cir.func @no_nuw_for_saturated(%x: !s32i, %y: !s32i) {
// expected-error@+1 {{The nsw/nuw flags and the saturated flag are mutually exclusive}}
%0 = cir.binop(add, %x, %y) nuw sat : !s32i
}

// -----

!s32i = !cir.int<s, 32>

module {
Expand Down
4 changes: 4 additions & 0 deletions clang/test/CIR/Lowering/binop-signed-int.cir
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ module {
%33 = cir.load %1 : !cir.ptr<!s32i>, !s32i
%34 = cir.binop(or, %32, %33) : !s32i
// CHECK: = llvm.or
%35 = cir.binop(add, %32, %33) sat: !s32i
// CHECK: = llvm.intr.sadd.sat{{.*}}(i32, i32) -> i32
%36 = cir.binop(sub, %32, %33) sat: !s32i
// CHECK: = llvm.intr.ssub.sat{{.*}}(i32, i32) -> i32
cir.store %34, %2 : !s32i, !cir.ptr<!s32i>
cir.return
}
Expand Down
7 changes: 6 additions & 1 deletion clang/test/CIR/Lowering/binop-unsigned-int.cir
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ module {
%33 = cir.load %1 : !cir.ptr<!u32i>, !u32i
%34 = cir.binop(or, %32, %33) : !u32i
cir.store %34, %2 : !u32i, !cir.ptr<!u32i>
%35 = cir.binop(add, %32, %33) sat: !u32i
%36 = cir.binop(sub, %32, %33) sat: !u32i
cir.return
}
}
Expand All @@ -62,7 +64,8 @@ module {
// MLIR: = llvm.shl
// MLIR: = llvm.and
// MLIR: = llvm.xor
// MLIR: = llvm.or
// MLIR: = llvm.intr.uadd.sat{{.*}}(i32, i32) -> i32
// MLIR: = llvm.intr.usub.sat{{.*}}(i32, i32) -> i32

// LLVM: = mul i32
// LLVM: = udiv i32
Expand All @@ -74,3 +77,5 @@ module {
// LLVM: = and i32
// LLVM: = xor i32
// LLVM: = or i32
// LLVM: = call i32 @llvm.uadd.sat.i32
// LLVM: = call i32 @llvm.usub.sat.i32

0 comments on commit 21e8647

Please sign in to comment.