Skip to content

Commit

Permalink
dialects: (llvm) Add a bunch of float methods
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonLydike committed Jan 23, 2025
1 parent 9cdc462 commit 7c83cf7
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 41 deletions.
22 changes: 21 additions & 1 deletion tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: XDSL_ROUNDTRIP

%arg0, %arg1 = "test.op"() : () -> (i32, i32)
%arg0, %arg1, %f1 = "test.op"() : () -> (i32, i32, f32)

%add_both = llvm.add %arg0, %arg1 {"overflowFlags" = #llvm.overflow<nsw, nuw>} : i32
// CHECK: %add_both = llvm.add %arg0, %arg1 {overflowFlags = #llvm.overflow<nsw,nuw>} : i32
Expand Down Expand Up @@ -121,3 +121,23 @@

%icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32
// CHECK: %icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32

// float arith:

%fmul = llvm.fmul %f1, %f1 : f32
// CHECK: %fmul = llvm.fmul %f1, %f1 : f32

%fmul_fast = llvm.fmul %f1, %f1 {fastmathFlags = #llvm.fastmath<fast>} : f32
// CHECK: %fmul_fast = llvm.fmul %f1, %f1 {fastmathFlags = #llvm.fastmath<fast>} : f32

%fdiv = llvm.fdiv %f1, %f1 : f32
// CHECK: %fdiv = llvm.fdiv %f1, %f1 : f32

%fadd = llvm.fadd %f1, %f1 : f32
// CHECK: %fadd = llvm.fadd %f1, %f1 : f32

%fsub = llvm.fsub %f1, %f1 : f32
// CHECK: %fsub = llvm.fsub %f1, %f1 : f32

%frem = llvm.frem %f1, %f1 : f32
// CHECK: %frem = llvm.frem %f1, %f1 : f32
8 changes: 8 additions & 0 deletions tests/filecheck/dialects/llvm/example.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,12 @@ builtin.module {

// CHECK: %val = "test.op"() : () -> i32
// CHECK-NEXT: %fval = llvm.bitcast %val : i32 to f32

%fval2 = llvm.sitofp %val : i32 to f32

// CHECK-NEXT: %fval2 = llvm.sitofp %val : i32 to f32

%fval3 = llvm.fpext %fval : f32 to f64

// CHECK-NEXT: %fval3 = llvm.fpext %fval : f32 to f64
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s --mlir-print-op-generic | xdsl-opt | filecheck %s

builtin.module {
%arg0, %arg1 = "test.op"() : () -> (i32, i32)
%arg0, %arg1, %f1 = "test.op"() : () -> (i32, i32, f32)

%add = llvm.add %arg0, %arg1 : i32
// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : i32
Expand Down Expand Up @@ -44,4 +44,24 @@ builtin.module {

%ashr = llvm.ashr %arg0, %arg1 : i32
// CHECK: %{{.*}} = llvm.ashr %{{.*}}, %{{.*}} : i32

// float arith:

%fmul = llvm.fmul %f1, %f1 : f32
// CHECK: %fmul = llvm.fmul %f1, %f1 : f32

%fmul_fast = llvm.fmul %f1, %f1 {fastmathFlags = #llvm.fastmath<fast>} : f32
// CHECK: %fmul_fast = llvm.fmul %f1, %f1 {fastmathFlags = #llvm.fastmath<fast>} : f32

%fdiv = llvm.fdiv %f1, %f1 : f32
// CHECK: %fdiv = llvm.fdiv %f1, %f1 : f32

%fadd = llvm.fadd %f1, %f1 : f32
// CHECK: %fadd = llvm.fadd %f1, %f1 : f32

%fsub = llvm.fsub %f1, %f1 : f32
// CHECK: %fsub = llvm.fsub %f1, %f1 : f32

%frem = llvm.frem %f1, %f1 : f32
// CHECK: %frem = llvm.frem %f1, %f1 : f32
}
182 changes: 143 additions & 39 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
ContainerType,
DenseArrayBase,
DenseI64ArrayConstr,
Float16Type,
Float32Type,
Float64Type,
IndexType,
IntAttr,
IntegerAttr,
Expand All @@ -25,7 +28,7 @@
i32,
i64,
)
from xdsl.dialects.utils import FastMathAttrBase
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.ir import (
Attribute,
BitEnumAttribute,
Expand Down Expand Up @@ -54,10 +57,11 @@
result_def,
traits_def,
var_operand_def,
AnyOf,
)
from xdsl.parser import AttrParser, Parser
from xdsl.printer import Printer
from xdsl.traits import IsTerminator, NoMemoryEffect, SymbolOpInterface
from xdsl.traits import IsTerminator, NoMemoryEffect, SymbolOpInterface, Pure
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr
Expand Down Expand Up @@ -1710,10 +1714,7 @@ class ZeroOp(IRDLOperation):
res = result_def(LLVMTypeConstr)


@irdl_op_definition
class BitcastOp(IRDLOperation):
name = "llvm.bitcast"

class GenericCastOp(IRDLOperation, ABC):
arg = operand_def(Attribute)
"""
LLVM-compatible non-aggregate type
Expand All @@ -1735,56 +1736,159 @@ def __init__(self, val: Operation | SSAValue, res_type: Attribute):
)


floatingPointLike = AnyOf([Float16Type, Float32Type, Float64Type])


class AbstractFloatArithOp(IRDLOperation, ABC):
T: ClassVar = VarConstraint("T", floatingPointLike)

lhs = operand_def(T)
rhs = operand_def(T)
res = result_def(T)

fastmathFlags = prop_def(FastMathAttr, default_value=FastMathAttr(None))

traits = traits_def(Pure())

def __init__(
self,
lhs: SSAValue | Operation,
rhs: SSAValue | Operation,
fast_math: FastMathAttr | FastMathFlag | None = None,
attrs: dict[str, Attribute] | None = None,
):
if isinstance(fast_math, FastMathFlag | str | None):
fast_math = FastMathAttr(fast_math)

super().__init__(
operands=[lhs, rhs],
result_types=[SSAValue.get(lhs).type],
properties={"fastmathFlags": fast_math},
attributes=attrs if attrs is not None else {},
)

def print(self, printer: Printer):
printer.print_operand(self.lhs)
printer.print_string(", ")
printer.print_operand(self.rhs)
printer.print_string(" ")
printer.print_attr_dict(
{"fastmathFlags": self.fastmathFlags, **self.attributes}
)
printer.print_string(" : ")
printer.print_attribute(self.res.type)

@classmethod
def parse(cls, parser: Parser) -> AbstractFloatArithOp:
lhs = parser.parse_operand()
parser.parse_punctuation(",")
rhs = parser.parse_operand()
attrs = parser.parse_optional_attr_dict()
fastmath = attrs.pop("fastmathFlags", None)
parser.parse_punctuation(":")
type = parser.parse_type()
if not (lhs.type == type and rhs.type == type):
parser.raise_error(f"Expected arguments to be of type {type}")
return cls(lhs, rhs, fastmath, attrs)

Check failure on line 1792 in xdsl/dialects/llvm.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Argument of type "Attribute | None" cannot be assigned to parameter "fast_math" of type "FastMathAttr | FastMathFlag | None" in function "__init__"   Type "Attribute | None" is not assignable to type "FastMathAttr | FastMathFlag | None"     Type "Attribute" is not assignable to type "FastMathAttr | FastMathFlag | None"       "Attribute" is not assignable to "FastMathAttr"       "Attribute" is not assignable to "FastMathFlag"       "Attribute" is not assignable to "None" (reportArgumentType)

Check failure on line 1792 in xdsl/dialects/llvm.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Argument of type "Attribute | None" cannot be assigned to parameter "fast_math" of type "FastMathAttr | FastMathFlag | None" in function "__init__"   Type "Attribute | None" is not assignable to type "FastMathAttr | FastMathFlag | None"     Type "Attribute" is not assignable to type "FastMathAttr | FastMathFlag | None"       "Attribute" is not assignable to "FastMathAttr"       "Attribute" is not assignable to "FastMathFlag"       "Attribute" is not assignable to "None" (reportArgumentType)


@irdl_op_definition
class FAddOp(AbstractFloatArithOp):
name = "llvm.fadd"


@irdl_op_definition
class FMulOp(AbstractFloatArithOp):
name = "llvm.fmul"


@irdl_op_definition
class FDivOp(AbstractFloatArithOp):
name = "llvm.fdiv"


@irdl_op_definition
class FSubOp(AbstractFloatArithOp):
name = "llvm.fsub"


@irdl_op_definition
class FRemOp(AbstractFloatArithOp):
name = "llvm.frem"


@irdl_op_definition
class BitcastOp(GenericCastOp):
name = "llvm.bitcast"


@irdl_op_definition
class SIToFPOp(GenericCastOp):
name = "llvm.sitofp"


@irdl_op_definition
class FPExtOp(GenericCastOp):
name = "llvm.fpext"


LLVM = Dialect(
"llvm",
[
AShrOp,
AddOp,
AddressOfOp,
AllocaOp,
AndOp,
BitcastOp,
SubOp,
CallIntrinsicOp,
CallOp,
ConstantOp,
ExtractValueOp,
FAddOp,
FDivOp,
FMulOp,
FPExtOp,
FRemOp,
FSubOp,
FuncOp,
GEPOp,
GlobalOp,
ICmpOp,
InlineAsmOp,
InsertValueOp,
IntToPtrOp,
LShrOp,
LoadOp,
MulOp,
UDivOp,
NullOp,
OrOp,
ReturnOp,
SDivOp,
URemOp,
SExtOp,
SIToFPOp,
SRemOp,
AndOp,
OrOp,
XOrOp,
ShlOp,
LShrOp,
AShrOp,
StoreOp,
SubOp,
TruncOp,
ZExtOp,
SExtOp,
ICmpOp,
ExtractValueOp,
InsertValueOp,
InlineAsmOp,
UDivOp,
URemOp,
UndefOp,
AllocaOp,
GEPOp,
IntToPtrOp,
NullOp,
LoadOp,
StoreOp,
GlobalOp,
AddressOfOp,
FuncOp,
CallOp,
ReturnOp,
ConstantOp,
CallIntrinsicOp,
XOrOp,
ZExtOp,
ZeroOp,
],
[
LLVMStructType,
LLVMPointerType,
CallingConventionAttr,
FastMathAttr,
LLVMArrayType,
LLVMVoidType,
LLVMFunctionType,
LLVMPointerType,
LLVMStructType,
LLVMVoidType,
LinkageAttr,
CallingConventionAttr,
TailCallKindAttr,
FastMathAttr,
OverflowAttr,
TailCallKindAttr,
],
)

0 comments on commit 7c83cf7

Please sign in to comment.