Skip to content

Commit

Permalink
[CIR] Vector types, comparison operators (#432)
Browse files Browse the repository at this point in the history
This is part 3 of implementing vector types and vector operations in
ClangIR, issue #284.

Create new operation `cir.vec.cmp` which implements the relational
comparison operators (`== != < > <= >=`) on vector types. A new
operation was created rather than reusing `cir.cmp` because the result
is a vector of a signed intergral type, not a `bool`.

Add CodeGen and Lowering tests for vector comparisons.

Fix the floating-point comparison predicate when lowering to LLVM. To
handle NaN values correctly, the comparisons need to be ordered rather
than unordered. (Except for `!=`, which needs to be unordered.) For
example, "ueq" was changed to "oeq".
  • Loading branch information
dkolsen-pgi authored and lanza committed Jan 31, 2024
1 parent fae636d commit 1da7e33
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 85 deletions.
25 changes: 25 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,31 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// VecCmp
//===----------------------------------------------------------------------===//

def VecCmpOp : CIR_Op<"vec.cmp", [Pure, SameTypeOperands]> {

let summary = "Compare two vectors";
let description = [{
The `cir.vec.cmp` operation does an element-wise comparison of two vectors
of the same type. The result is a vector of the same size as the operands
whose element type is the signed integral type that is the same size as the
element type of the operands. The values in the result are 0 or -1.
}];

let arguments = (ins Arg<CmpOpKind, "cmp kind">:$kind, CIR_VectorType:$lhs,
CIR_VectorType:$rhs);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
`(` $kind `,` $lhs `,` $rhs `)` `:` type($lhs) `,` type($result) attr-dict
}];

let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
// BaseClassAddr
//===----------------------------------------------------------------------===//
Expand Down
63 changes: 33 additions & 30 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,26 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
QualType LHSTy = E->getLHS()->getType();
QualType RHSTy = E->getRHS()->getType();

auto ClangCmpToCIRCmp = [](auto ClangCmp) -> mlir::cir::CmpOpKind {
switch (ClangCmp) {
case BO_LT:
return mlir::cir::CmpOpKind::lt;
case BO_GT:
return mlir::cir::CmpOpKind::gt;
case BO_LE:
return mlir::cir::CmpOpKind::le;
case BO_GE:
return mlir::cir::CmpOpKind::ge;
case BO_EQ:
return mlir::cir::CmpOpKind::eq;
case BO_NE:
return mlir::cir::CmpOpKind::ne;
default:
llvm_unreachable("unsupported comparison kind");
return mlir::cir::CmpOpKind(-1);
}
};

if (const MemberPointerType *MPT = LHSTy->getAs<MemberPointerType>()) {
assert(0 && "not implemented");
} else if (!LHSTy->isAnyComplexType() && !RHSTy->isAnyComplexType()) {
Expand All @@ -773,12 +793,18 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
mlir::Value RHS = BOInfo.RHS;

if (LHSTy->isVectorType()) {
// Cannot handle any vector just yet.
assert(0 && "not implemented");
// If AltiVec, the comparison results in a numeric type, so we use
// intrinsics comparing vectors and giving 0 or 1 as a result
if (!E->getType()->isVectorType())
assert(0 && "not implemented");
if (!E->getType()->isVectorType()) {
// If AltiVec, the comparison results in a numeric type, so we use
// intrinsics comparing vectors and giving 0 or 1 as a result
llvm_unreachable("NYI: AltiVec comparison");
} else {
// Other kinds of vectors. Element-wise comparison returning
// a vector.
mlir::cir::CmpOpKind Kind = ClangCmpToCIRCmp(E->getOpcode());
return Builder.create<mlir::cir::VecCmpOp>(
CGF.getLoc(BOInfo.Loc), CGF.getCIRType(BOInfo.Ty), Kind,
BOInfo.LHS, BOInfo.RHS);
}
}
if (BOInfo.isFixedPointOp()) {
assert(0 && "not implemented");
Expand All @@ -793,30 +819,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
llvm_unreachable("NYI");
}

mlir::cir::CmpOpKind Kind;
switch (E->getOpcode()) {
case BO_LT:
Kind = mlir::cir::CmpOpKind::lt;
break;
case BO_GT:
Kind = mlir::cir::CmpOpKind::gt;
break;
case BO_LE:
Kind = mlir::cir::CmpOpKind::le;
break;
case BO_GE:
Kind = mlir::cir::CmpOpKind::ge;
break;
case BO_EQ:
Kind = mlir::cir::CmpOpKind::eq;
break;
case BO_NE:
Kind = mlir::cir::CmpOpKind::ne;
break;
default:
llvm_unreachable("unsupported");
}

mlir::cir::CmpOpKind Kind = ClangCmpToCIRCmp(E->getOpcode());
return Builder.create<mlir::cir::CmpOp>(CGF.getLoc(BOInfo.Loc),
CGF.getCIRType(BOInfo.Ty), Kind,
BOInfo.LHS, BOInfo.RHS);
Expand Down
137 changes: 88 additions & 49 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,51 @@ void walkRegionSkipping(mlir::Region &region,
});
}

/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind.
mlir::LLVM::ICmpPredicate
convertCmpKindToICmpPredicate(mlir::cir::CmpOpKind kind, bool isSigned) {
using CIR = mlir::cir::CmpOpKind;
using LLVMICmp = mlir::LLVM::ICmpPredicate;
switch (kind) {
case CIR::eq:
return LLVMICmp::eq;
case CIR::ne:
return LLVMICmp::ne;
case CIR::lt:
return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
case CIR::le:
return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
case CIR::gt:
return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
case CIR::ge:
return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
}
llvm_unreachable("Unknown CmpOpKind");
}

/// Convert from a CIR comparison kind to an LLVM IR floating-point comparison
/// kind.
mlir::LLVM::FCmpPredicate
convertCmpKindToFCmpPredicate(mlir::cir::CmpOpKind kind) {
using CIR = mlir::cir::CmpOpKind;
using LLVMFCmp = mlir::LLVM::FCmpPredicate;
switch (kind) {
case CIR::eq:
return LLVMFCmp::oeq;
case CIR::ne:
return LLVMFCmp::une;
case CIR::lt:
return LLVMFCmp::olt;
case CIR::le:
return LLVMFCmp::ole;
case CIR::gt:
return LLVMFCmp::ogt;
case CIR::ge:
return LLVMFCmp::oge;
}
llvm_unreachable("Unknown CmpOpKind");
}

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1133,6 +1178,41 @@ class CIRVectorExtractLowering
}
};

class CIRVectorCmpOpLowering
: public mlir::OpConversionPattern<mlir::cir::VecCmpOp> {
public:
using OpConversionPattern<mlir::cir::VecCmpOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VecCmpOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
assert(op.getType().isa<mlir::cir::VectorType>() &&
op.getLhs().getType().isa<mlir::cir::VectorType>() &&
op.getRhs().getType().isa<mlir::cir::VectorType>() &&
"Vector compare with non-vector type");
// LLVM IR vector comparison returns a vector of i1. This one-bit vector
// must be sign-extended to the correct result type.
auto elementType =
op.getLhs().getType().dyn_cast<mlir::cir::VectorType>().getEltType();
mlir::Value bitResult;
if (auto intType = elementType.dyn_cast<mlir::cir::IntType>()) {
bitResult = rewriter.create<mlir::LLVM::ICmpOp>(
op.getLoc(),
convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()),
adaptor.getLhs(), adaptor.getRhs());
} else if (elementType.isa<mlir::FloatType>()) {
bitResult = rewriter.create<mlir::LLVM::FCmpOp>(
op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()),
adaptor.getLhs(), adaptor.getRhs());
} else {
return op.emitError() << "unsupported type for VecCmpOp: " << elementType;
}
rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(
op, typeConverter->convertType(op.getType()), bitResult);
return mlir::success();
}
};

class CIRVAStartLowering
: public mlir::OpConversionPattern<mlir::cir::VAStartOp> {
public:
Expand Down Expand Up @@ -1835,50 +1915,6 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
public:
using OpConversionPattern<mlir::cir::CmpOp>::OpConversionPattern;

mlir::LLVM::ICmpPredicate convertToICmpPredicate(mlir::cir::CmpOpKind kind,
bool isSigned) const {
using CIR = mlir::cir::CmpOpKind;
using LLVMICmp = mlir::LLVM::ICmpPredicate;

switch (kind) {
case CIR::eq:
return LLVMICmp::eq;
case CIR::ne:
return LLVMICmp::ne;
case CIR::lt:
return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
case CIR::le:
return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
case CIR::gt:
return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
case CIR::ge:
return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
}
llvm_unreachable("Unknown CmpOpKind");
}

mlir::LLVM::FCmpPredicate
convertToFCmpPredicate(mlir::cir::CmpOpKind kind) const {
using CIR = mlir::cir::CmpOpKind;
using LLVMFCmp = mlir::LLVM::FCmpPredicate;

switch (kind) {
case CIR::eq:
return LLVMFCmp::ueq;
case CIR::ne:
return LLVMFCmp::une;
case CIR::lt:
return LLVMFCmp::ult;
case CIR::le:
return LLVMFCmp::ule;
case CIR::gt:
return LLVMFCmp::ugt;
case CIR::ge:
return LLVMFCmp::uge;
}
llvm_unreachable("Unknown CmpOpKind");
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::CmpOp cmpOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
Expand All @@ -1887,15 +1923,17 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {

// Lower to LLVM comparison op.
if (auto intTy = type.dyn_cast<mlir::cir::IntType>()) {
auto kind = convertToICmpPredicate(cmpOp.getKind(), intTy.isSigned());
auto kind =
convertCmpKindToICmpPredicate(cmpOp.getKind(), intTy.isSigned());
llResult = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else if (auto ptrTy = type.dyn_cast<mlir::cir::PointerType>()) {
auto kind = convertToICmpPredicate(cmpOp.getKind(), /* isSigned=*/false);
auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(),
/* isSigned=*/false);
llResult = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else if (type.isa<mlir::FloatType>()) {
auto kind = convertToFCmpPredicate(cmpOp.getKind());
auto kind = convertCmpKindToFCmpPredicate(cmpOp.getKind());
llResult = rewriter.create<mlir::LLVM::FCmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else {
Expand Down Expand Up @@ -2090,8 +2128,9 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRTernaryOpLowering, CIRGetMemberOpLowering, CIRSwitchOpLowering,
CIRPtrDiffOpLowering, CIRCopyOpLowering, CIRMemCpyOpLowering,
CIRFAbsOpLowering, CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRStackSaveLowering,
CIRStackRestoreLowering>(converter, patterns.getContext());
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
CIRStackSaveLowering, CIRStackRestoreLowering>(converter,
patterns.getContext());
}

namespace {
Expand Down
29 changes: 29 additions & 0 deletions clang/test/CIR/CodeGen/vectype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

typedef int vi4 __attribute__((vector_size(16)));
typedef double vd2 __attribute__((vector_size(16)));
typedef long long vll2 __attribute__((vector_size(16)));

void vector_int_test(int x) {

Expand Down Expand Up @@ -49,6 +50,20 @@ void vector_int_test(int x) {
// CHECK: %{{[0-9]+}} = cir.unary(minus, %{{[0-9]+}}) : !cir.vector<!s32i x 4>, !cir.vector<!s32i x 4>
vi4 n = ~a;
// CHECK: %{{[0-9]+}} = cir.unary(not, %{{[0-9]+}}) : !cir.vector<!s32i x 4>, !cir.vector<!s32i x 4>

// Comparisons
vi4 o = a == b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(eq, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
vi4 p = a != b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(ne, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
vi4 q = a < b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(lt, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
vi4 r = a > b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(gt, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
vi4 s = a <= b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(le, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
vi4 t = a >= b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(ge, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
}

void vector_double_test(int x, double y) {
Expand Down Expand Up @@ -86,4 +101,18 @@ void vector_double_test(int x, double y) {
// CHECK: %{{[0-9]+}} = cir.unary(plus, %{{[0-9]+}}) : !cir.vector<f64 x 2>, !cir.vector<f64 x 2>
vd2 m = -a;
// CHECK: %{{[0-9]+}} = cir.unary(minus, %{{[0-9]+}}) : !cir.vector<f64 x 2>, !cir.vector<f64 x 2>

// Comparisons
vll2 o = a == b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(eq, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
vll2 p = a != b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(ne, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
vll2 q = a < b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(lt, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
vll2 r = a > b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(gt, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
vll2 s = a <= b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(le, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
vll2 t = a >= b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(ge, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
}
10 changes: 5 additions & 5 deletions clang/test/CIR/Lowering/cmp.cir
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,27 @@ module {
%23 = cir.load %2 : cir.ptr <f32>, f32
%24 = cir.load %3 : cir.ptr <f32>, f32
%25 = cir.cmp(gt, %23, %24) : f32, !cir.bool
// CHECK: llvm.fcmp "ugt"
// CHECK: llvm.fcmp "ogt"
%26 = cir.load %2 : cir.ptr <f32>, f32
%27 = cir.load %3 : cir.ptr <f32>, f32
%28 = cir.cmp(eq, %26, %27) : f32, !cir.bool
// CHECK: llvm.fcmp "ueq"
// CHECK: llvm.fcmp "oeq"
%29 = cir.load %2 : cir.ptr <f32>, f32
%30 = cir.load %3 : cir.ptr <f32>, f32
%31 = cir.cmp(lt, %29, %30) : f32, !cir.bool
// CHECK: llvm.fcmp "ult"
// CHECK: llvm.fcmp "olt"
%32 = cir.load %2 : cir.ptr <f32>, f32
%33 = cir.load %3 : cir.ptr <f32>, f32
%34 = cir.cmp(ge, %32, %33) : f32, !cir.bool
// CHECK: llvm.fcmp "uge"
// CHECK: llvm.fcmp "oge"
%35 = cir.load %2 : cir.ptr <f32>, f32
%36 = cir.load %3 : cir.ptr <f32>, f32
%37 = cir.cmp(ne, %35, %36) : f32, !cir.bool
// CHECK: llvm.fcmp "une"
%38 = cir.load %2 : cir.ptr <f32>, f32
%39 = cir.load %3 : cir.ptr <f32>, f32
%40 = cir.cmp(le, %38, %39) : f32, !cir.bool
// CHECK: llvm.fcmp "ule"
// CHECK: llvm.fcmp "ole"

// Pointer comparisons.
%41 = cir.cmp(ne, %0, %1) : !cir.ptr<!s32i>, !cir.bool
Expand Down
Loading

0 comments on commit 1da7e33

Please sign in to comment.