Skip to content

Commit

Permalink
[CIR] introduce CIR floating-point types (llvm#385)
Browse files Browse the repository at this point in the history
This PR adds a dedicated `cir.float` type for representing
floating-point types. There are several issues linked to this PR: llvm#5,
llvm#78, and llvm#90.
  • Loading branch information
Lancern authored and lanza committed Apr 17, 2024
1 parent f5419d2 commit 0ba447b
Show file tree
Hide file tree
Showing 55 changed files with 937 additions and 617 deletions.
27 changes: 27 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,33 @@ def IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// FPAttr
//===----------------------------------------------------------------------===//

def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
let summary = "An attribute containing a floating-point value";
let description = [{
An fp attribute is a literal attribute that represents a floating-point
value of the specified floating-point type.
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APFloat":$value);
let builders = [
AttrBuilderWithInferredContext<(ins "Type":$type,
"const APFloat &":$value), [{
return $_get(type.getContext(), type, value);
}]>,
];
let extraClassDeclaration = [{
static FPAttr getZero(mlir::Type type);
}];
let genVerifyDecl = 1;

let assemblyFormat = [{
`<` custom<FloatLiteral>($value, ref($type)) `>`
}];
}

//===----------------------------------------------------------------------===//
// ConstPointerAttr
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2615,8 +2615,8 @@ def IterEndOp : CIR_Op<"iterator_end"> {

class UnaryFPToFPBuiltinOp<string mnemonic>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let arguments = (ins AnyFloat:$src);
let results = (outs AnyFloat:$result);
let arguments = (ins CIR_AnyFloat:$src);
let results = (outs CIR_AnyFloat:$result);
let summary = "libc builtin equivalent ignoring "
"floating point exceptions and errno";
let assemblyFormat = "$src `:` type($src) attr-dict";
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h"

#include "clang/CIR/Interfaces/ASTAttrInterfaces.h"

Expand Down
40 changes: 37 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@

include "clang/CIR/Dialect/IR/CIRDialect.td"
include "clang/CIR/Interfaces/ASTAttrInterfaces.td"
include "clang/CIR/Interfaces/CIRFPTypeInterface.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

//===----------------------------------------------------------------------===//
// CIR Types
//===----------------------------------------------------------------------===//

class CIR_Type<string name, string typeMnemonic, list<Trait> traits = []> :
TypeDef<CIR_Dialect, name, traits> {
class CIR_Type<string name, string typeMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: TypeDef<CIR_Dialect, name, traits, baseCppClass> {
let mnemonic = typeMnemonic;
}

Expand Down Expand Up @@ -94,6 +97,37 @@ def SInt16 : SInt<16>;
def SInt32 : SInt<32>;
def SInt64 : SInt<64>;

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//

class CIR_FloatType<string name, string mnemonic>
: CIR_Type<name, mnemonic,
[
DeclareTypeInterfaceMethods<DataLayoutTypeInterface>,
DeclareTypeInterfaceMethods<CIRFPTypeInterface>,
]> {}

def CIR_Single : CIR_FloatType<"Single", "float"> {
let summary = "CIR single-precision float type";
let description = [{
Floating-point type that represents the `float` type in C/C++. Its
underlying floating-point format is the IEEE-754 binary32 format.
}];
}

def CIR_Double : CIR_FloatType<"Double", "double"> {
let summary = "CIR double-precision float type";
let description = [{
Floating-point type that represents the `double` type in C/C++. Its
underlying floating-point format is the IEEE-754 binar64 format.
}];
}

// Constraints

def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double]>;

//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -318,7 +352,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,

def CIR_AnyType : AnyTypeOf<[
CIR_IntType, CIR_PointerType, CIR_BoolType, CIR_ArrayType, CIR_VectorType,
CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo, AnyFloat,
CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo, CIR_AnyFloat,
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
22 changes: 22 additions & 0 deletions clang/include/clang/CIR/Interfaces/CIRFPTypeInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- CIRFPTypeInterface.h - Interface for CIR FP types -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
//
// Defines the interface to generically handle CIR floating-point types.
//
//===----------------------------------------------------------------------===//

#ifndef CLANG_INTERFACES_CIR_CIR_FPTYPEINTERFACE_H
#define CLANG_INTERFACES_CIR_CIR_FPTYPEINTERFACE_H

#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"

/// Include the tablegen'd interface declarations.
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h.inc"

#endif // CLANG_INTERFACES_CIR_CIR_FPTYPEINTERFACE_H
52 changes: 52 additions & 0 deletions clang/include/clang/CIR/Interfaces/CIRFPTypeInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===- CIRFPTypeInterface.td - CIR FP Interface Definitions -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CIR_INTERFACES_CIR_FP_TYPE_INTERFACE
#define MLIR_CIR_INTERFACES_CIR_FP_TYPE_INTERFACE

include "mlir/IR/OpBase.td"

def CIRFPTypeInterface : TypeInterface<"CIRFPTypeInterface"> {
let description = [{
Contains helper functions to query properties about a floating-point type.
}];
let cppNamespace = "::mlir::cir";

let methods = [
InterfaceMethod<[{
Returns the bit width of this floating-point type.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getWidth",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::APFloat::semanticsSizeInBits($_type.getFloatSemantics());
}]
>,
InterfaceMethod<[{
Return the mantissa width.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getFPMantissaWidth",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::APFloat::semanticsPrecision($_type.getFloatSemantics());
}]
>,
InterfaceMethod<[{
Return the float semantics of this floating-point type.
}],
/*retTy=*/"const llvm::fltSemantics &",
/*methodName=*/"getFloatSemantics"
>,
];
}

#endif // MLIR_CIR_INTERFACES_CIR_FP_TYPE_INTERFACE
9 changes: 9 additions & 0 deletions clang/include/clang/CIR/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ function(add_clang_mlir_op_interface interface)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

function(add_clang_mlir_type_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIR${interface}IncGen)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

add_clang_mlir_attr_interface(ASTAttrInterfaces)
add_clang_mlir_op_interface(CIROpInterfaces)
add_clang_mlir_op_interface(CIRLoopOpInterface)
add_clang_mlir_type_interface(CIRFPTypeInterface)
33 changes: 17 additions & 16 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,10 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
if (ty.isa<mlir::cir::IntType>())
return mlir::cir::IntAttr::get(ty, 0);
if (ty.isa<mlir::FloatType>())
return mlir::FloatAttr::get(ty, 0.0);
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
return getZeroAttr(arrTy);
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
Expand Down Expand Up @@ -256,12 +258,13 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
if (const auto boolVal = attr.dyn_cast<mlir::cir::BoolAttr>())
return !boolVal.getValue();

if (const auto fpVal = attr.dyn_cast<mlir::FloatAttr>()) {
if (auto fpAttr = attr.dyn_cast<mlir::cir::FPAttr>()) {
auto fpVal = fpAttr.getValue();
bool ignored;
llvm::APFloat FV(+0.0);
FV.convert(fpVal.getValue().getSemantics(),
llvm::APFloat::rmNearestTiesToEven, &ignored);
return FV.bitwiseIsEqual(fpVal.getValue());
FV.convert(fpVal.getSemantics(), llvm::APFloat::rmNearestTiesToEven,
&ignored);
return FV.bitwiseIsEqual(fpVal);
}

if (const auto structVal = attr.dyn_cast<mlir::cir::ConstStructAttr>()) {
Expand Down Expand Up @@ -348,23 +351,21 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
}
bool isInt(mlir::Type i) { return i.isa<mlir::cir::IntType>(); }

mlir::FloatType getLongDouble80BitsTy() const {
return typeCache.LongDouble80BitsTy;
}
mlir::Type getLongDouble80BitsTy() const { llvm_unreachable("NYI"); }

/// Get the proper floating point type for the given semantics.
mlir::FloatType getFloatTyForFormat(const llvm::fltSemantics &format,
bool useNativeHalf) const {
mlir::Type getFloatTyForFormat(const llvm::fltSemantics &format,
bool useNativeHalf) const {
if (&format == &llvm::APFloat::IEEEhalf()) {
llvm_unreachable("IEEEhalf float format is NYI");
}

if (&format == &llvm::APFloat::BFloat())
llvm_unreachable("BFloat float format is NYI");
if (&format == &llvm::APFloat::IEEEsingle())
llvm_unreachable("IEEEsingle float format is NYI");
return typeCache.FloatTy;
if (&format == &llvm::APFloat::IEEEdouble())
llvm_unreachable("IEEEdouble float format is NYI");
return typeCache.DoubleTy;
if (&format == &llvm::APFloat::IEEEquad())
llvm_unreachable("IEEEquad float format is NYI");
if (&format == &llvm::APFloat::PPCDoubleDouble())
Expand Down Expand Up @@ -491,9 +492,9 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
}

bool isSized(mlir::Type ty) {
if (ty.isIntOrFloat() ||
ty.isa<mlir::cir::PointerType, mlir::cir::StructType,
mlir::cir::ArrayType, mlir::cir::BoolType, mlir::cir::IntType>())
if (ty.isa<mlir::cir::PointerType, mlir::cir::StructType,
mlir::cir::ArrayType, mlir::cir::BoolType, mlir::cir::IntType,
mlir::cir::CIRFPTypeInterface>())
return true;
assert(0 && "Unimplemented size for type");
return false;
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1706,7 +1706,9 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
assert(0 && "not implemented");
else {
mlir::Type ty = CGM.getCIRType(DestType);
return builder.getFloatAttr(ty, Init);
assert(ty.isa<mlir::cir::CIRFPTypeInterface>() &&
"expected floating-point type");
return CGM.getBuilder().getAttr<mlir::cir::FPAttr>(ty, Init);
}
}
case APValue::Array: {
Expand Down
14 changes: 8 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
}
mlir::Value VisitFloatingLiteral(const FloatingLiteral *E) {
mlir::Type Ty = CGF.getCIRType(E->getType());
assert(Ty.isa<mlir::cir::CIRFPTypeInterface>() &&
"expect floating-point type");
return Builder.create<mlir::cir::ConstantOp>(
CGF.getLoc(E->getExprLoc()), Ty,
Builder.getFloatAttr(Ty, E->getValue()));
Builder.getAttr<mlir::cir::FPAttr>(Ty, E->getValue()));
}
mlir::Value VisitCharacterLiteral(const CharacterLiteral *E) {
mlir::Type Ty = CGF.getCIRType(E->getType());
Expand Down Expand Up @@ -1227,7 +1229,7 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
llvm_unreachable("NYI");

assert(!UnimplementedFeature::cirVectorType());
if (Ops.LHS.getType().isa<mlir::FloatType>()) {
if (Ops.LHS.getType().isa<mlir::cir::CIRFPTypeInterface>()) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFSub(Ops.LHS, Ops.RHS);
}
Expand Down Expand Up @@ -1701,20 +1703,20 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
llvm_unreachable("NYI: signed bool");
if (CGF.getBuilder().isInt(DstTy)) {
CastKind = mlir::cir::CastKind::bool_to_int;
} else if (DstTy.isa<mlir::FloatType>()) {
} else if (DstTy.isa<mlir::cir::CIRFPTypeInterface>()) {
CastKind = mlir::cir::CastKind::bool_to_float;
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
} else if (CGF.getBuilder().isInt(SrcTy)) {
if (CGF.getBuilder().isInt(DstTy)) {
CastKind = mlir::cir::CastKind::integral;
} else if (DstTy.isa<mlir::FloatType>()) {
} else if (DstTy.isa<mlir::cir::CIRFPTypeInterface>()) {
CastKind = mlir::cir::CastKind::int_to_float;
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
} else if (SrcTy.isa<mlir::FloatType>()) {
} else if (SrcTy.isa<mlir::cir::CIRFPTypeInterface>()) {
if (CGF.getBuilder().isInt(DstTy)) {
// If we can't recognize overflow as undefined behavior, assume that
// overflow saturates. This protects against normal optimizations if we
Expand All @@ -1724,7 +1726,7 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
if (Builder.getIsFPConstrained())
llvm_unreachable("NYI");
CastKind = mlir::cir::CastKind::float_to_int;
} else if (DstTy.isa<mlir::FloatType>()) {
} else if (DstTy.isa<mlir::cir::CIRFPTypeInterface>()) {
// TODO: split this to createFPExt/createFPTrunc
return Builder.createFloatingCast(Src, DstTy);
} else {
Expand Down
5 changes: 2 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,10 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,

// TODO: HalfTy
// TODO: BFloatTy
FloatTy = builder.getF32Type();
DoubleTy = builder.getF64Type();
FloatTy = ::mlir::cir::SingleType::get(builder.getContext());
DoubleTy = ::mlir::cir::DoubleType::get(builder.getContext());
// TODO(cir): perhaps we should abstract long double variations into a custom
// cir.long_double type. Said type would also hold the semantics for lowering.
LongDouble80BitsTy = builder.getF80Type();

// TODO: PointerWidthInBits
PointerAlignInBytes =
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenTypeCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ struct CIRGenTypeCache {
// mlir::Type HalfTy, BFloatTy;
// TODO(cir): perhaps we should abstract long double variations into a custom
// cir.long_double type. Said type would also hold the semantics for lowering.
mlir::FloatType FloatTy, DoubleTy, LongDouble80BitsTy;
mlir::cir::SingleType FloatTy;
mlir::cir::DoubleType DoubleTy;

/// int
mlir::Type UIntTy;
Expand Down
Loading

0 comments on commit 0ba447b

Please sign in to comment.