Skip to content

Commit

Permalink
dialects: (varith) Add varith (variadic arithmetic) dialect (#3241)
Browse files Browse the repository at this point in the history
Add a variadic arithmetic dialect.

The goal is to provide a place for nice-to-use arithmetic operations.
Also makes re-writing arithmetic easier in some cases, e.g. when you
want to split a summation over one set of values into two summations by
selecting certain values from the set.

Coming next is two rewrites to canonicalize arith to varith, and then to
de-canonicalize back.
  • Loading branch information
AntonLydike authored Oct 4, 2024
1 parent e5d483b commit 104245e
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/filecheck/dialects/varith/invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: xdsl-opt --parsing-diagnostics --verify-diagnostics --split-input-file


%i, %f, %t1, %t2 = "test.op"() : () -> (i32, f32, tensor<10xf32>, tensor<5xf32>)
varith.add %i, %f : i32
// CHECK: operand is used with type i32, but has been previously used or defined with type f32


// -----
// CHECK: -----


%i, %f, %t1, %t2 = "test.op"() : () -> (i32, f32, tensor<10xf32>, tensor<5xf32>)
varith.add %t1, %t2 : tensor<10xf32>
// CHECK: operand is used with type tensor<10xf32>, but has been previously used or defined with type tensor<5xf32>
31 changes: 31 additions & 0 deletions tests/filecheck/dialects/varith/varith_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: XDSL_ROUNDTRIP
// RUN: XDSL_GENERIC_ROUNDTRIP

%ia, %ib, %ic, %id = "test.op"() : () -> (i32, i32, i32, i32)
%fa, %fb, %fc, %fd = "test.op"() : () -> (f32, f32, f32, f32)
%ta, %tb, %tc, %td = "test.op"() : () -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)

%x1 = "varith.add"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32
// CHECK: %x1 = varith.add %ia, %ib, %ic, %id : i32
// CHECK-GENERIC: %x1 = "varith.add"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32

%x2 = "varith.add"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32
// CHECK: %x2 = varith.add %fa, %fb, %fc, %fd : f32
// CHECK-GENERIC: %x2 = "varith.add"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32

%x3 = "varith.add"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// CHECK: %x3 = varith.add %ta, %tb, %tc, %td : tensor<10xf32>
// CHECK-GENERIC: %x3 = "varith.add"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>


%x4 = "varith.mul"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32
// CHECK: %x4 = varith.mul %ia, %ib, %ic, %id : i32
// CHECK-GENERIC: %x4 = "varith.mul"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32

%x5 = "varith.mul"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32
// CHECK: %x5 = varith.mul %fa, %fb, %fc, %fd : f32
// CHECK-GENERIC: %x5 = "varith.mul"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32

%x6 = "varith.mul"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// CHECK: %x6 = varith.mul %ta, %tb, %tc, %td : tensor<10xf32>
// CHECK-GENERIC: %x6 = "varith.mul"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
6 changes: 6 additions & 0 deletions xdsl/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,11 @@ def get_tosa():

return TOSA

def get_varith():
from xdsl.dialects.varith import Varith

return Varith

def get_vector():
from xdsl.dialects.vector import Vector

Expand Down Expand Up @@ -371,6 +376,7 @@ def get_transform():
"tensor": get_tensor,
"test": get_test,
"tosa": get_tosa,
"varith": get_varith,
"vector": get_vector,
"wasm": get_wasm,
"x86": get_x86,
Expand Down
76 changes: 76 additions & 0 deletions xdsl/dialects/varith.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Annotated

from xdsl.dialects.builtin import (
BFloat16Type,
ContainerOf,
Float16Type,
Float32Type,
Float64Type,
Float80Type,
Float128Type,
IndexType,
IntegerType,
)
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AnyOf,
ConstraintVar,
IRDLOperation,
irdl_op_definition,
result_def,
var_operand_def,
)
from xdsl.traits import Pure

integerOrFloatLike: ContainerOf = ContainerOf(
AnyOf(
[
IntegerType,
IndexType,
BFloat16Type,
Float16Type,
Float32Type,
Float64Type,
Float80Type,
Float128Type,
]
)
)


class VarithOp(IRDLOperation):
"""
Variadic arithmetic operation
"""

T = Annotated[Attribute, ConstraintVar("T"), integerOrFloatLike]

args = var_operand_def(T)
res = result_def(T)

traits = frozenset((Pure(),))

assembly_format = "$args attr-dict `:` type($res)"

def __init__(self, *args: SSAValue | Operation):
assert len(args) > 0
super().__init__(operands=[args], result_types=[SSAValue.get(args[-1]).type])


@irdl_op_definition
class VarithAddOp(VarithOp):
name = "varith.add"


@irdl_op_definition
class VarithMulOp(VarithOp):
name = "varith.mul"


Varith = Dialect(
"varith",
[
VarithAddOp,
VarithMulOp,
],
)

0 comments on commit 104245e

Please sign in to comment.