-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dialects: (varith) Add varith (variadic arithmetic) dialect (#3241)
Add a variadic arithmetic dialect. The goal is to provide a place for nice-to-use arithmetic operations. Also makes re-writing arithmetic easier in some cases, e.g. when you want to split a summation over one set of values into two summations by selecting certain values from the set. Coming next is two rewrites to canonicalize arith to varith, and then to de-canonicalize back.
- Loading branch information
1 parent
e5d483b
commit 104245e
Showing
4 changed files
with
128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// RUN: xdsl-opt --parsing-diagnostics --verify-diagnostics --split-input-file | ||
|
||
|
||
%i, %f, %t1, %t2 = "test.op"() : () -> (i32, f32, tensor<10xf32>, tensor<5xf32>) | ||
varith.add %i, %f : i32 | ||
// CHECK: operand is used with type i32, but has been previously used or defined with type f32 | ||
|
||
|
||
// ----- | ||
// CHECK: ----- | ||
|
||
|
||
%i, %f, %t1, %t2 = "test.op"() : () -> (i32, f32, tensor<10xf32>, tensor<5xf32>) | ||
varith.add %t1, %t2 : tensor<10xf32> | ||
// CHECK: operand is used with type tensor<10xf32>, but has been previously used or defined with type tensor<5xf32> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// RUN: XDSL_ROUNDTRIP | ||
// RUN: XDSL_GENERIC_ROUNDTRIP | ||
|
||
%ia, %ib, %ic, %id = "test.op"() : () -> (i32, i32, i32, i32) | ||
%fa, %fb, %fc, %fd = "test.op"() : () -> (f32, f32, f32, f32) | ||
%ta, %tb, %tc, %td = "test.op"() : () -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) | ||
|
||
%x1 = "varith.add"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 | ||
// CHECK: %x1 = varith.add %ia, %ib, %ic, %id : i32 | ||
// CHECK-GENERIC: %x1 = "varith.add"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 | ||
|
||
%x2 = "varith.add"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 | ||
// CHECK: %x2 = varith.add %fa, %fb, %fc, %fd : f32 | ||
// CHECK-GENERIC: %x2 = "varith.add"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 | ||
|
||
%x3 = "varith.add"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> | ||
// CHECK: %x3 = varith.add %ta, %tb, %tc, %td : tensor<10xf32> | ||
// CHECK-GENERIC: %x3 = "varith.add"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> | ||
|
||
|
||
%x4 = "varith.mul"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 | ||
// CHECK: %x4 = varith.mul %ia, %ib, %ic, %id : i32 | ||
// CHECK-GENERIC: %x4 = "varith.mul"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 | ||
|
||
%x5 = "varith.mul"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 | ||
// CHECK: %x5 = varith.mul %fa, %fb, %fc, %fd : f32 | ||
// CHECK-GENERIC: %x5 = "varith.mul"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 | ||
|
||
%x6 = "varith.mul"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> | ||
// CHECK: %x6 = varith.mul %ta, %tb, %tc, %td : tensor<10xf32> | ||
// CHECK-GENERIC: %x6 = "varith.mul"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Annotated | ||
|
||
from xdsl.dialects.builtin import ( | ||
BFloat16Type, | ||
ContainerOf, | ||
Float16Type, | ||
Float32Type, | ||
Float64Type, | ||
Float80Type, | ||
Float128Type, | ||
IndexType, | ||
IntegerType, | ||
) | ||
from xdsl.ir import Attribute, Dialect, Operation, SSAValue | ||
from xdsl.irdl import ( | ||
AnyOf, | ||
ConstraintVar, | ||
IRDLOperation, | ||
irdl_op_definition, | ||
result_def, | ||
var_operand_def, | ||
) | ||
from xdsl.traits import Pure | ||
|
||
integerOrFloatLike: ContainerOf = ContainerOf( | ||
AnyOf( | ||
[ | ||
IntegerType, | ||
IndexType, | ||
BFloat16Type, | ||
Float16Type, | ||
Float32Type, | ||
Float64Type, | ||
Float80Type, | ||
Float128Type, | ||
] | ||
) | ||
) | ||
|
||
|
||
class VarithOp(IRDLOperation): | ||
""" | ||
Variadic arithmetic operation | ||
""" | ||
|
||
T = Annotated[Attribute, ConstraintVar("T"), integerOrFloatLike] | ||
|
||
args = var_operand_def(T) | ||
res = result_def(T) | ||
|
||
traits = frozenset((Pure(),)) | ||
|
||
assembly_format = "$args attr-dict `:` type($res)" | ||
|
||
def __init__(self, *args: SSAValue | Operation): | ||
assert len(args) > 0 | ||
super().__init__(operands=[args], result_types=[SSAValue.get(args[-1]).type]) | ||
|
||
|
||
@irdl_op_definition | ||
class VarithAddOp(VarithOp): | ||
name = "varith.add" | ||
|
||
|
||
@irdl_op_definition | ||
class VarithMulOp(VarithOp): | ||
name = "varith.mul" | ||
|
||
|
||
Varith = Dialect( | ||
"varith", | ||
[ | ||
VarithAddOp, | ||
VarithMulOp, | ||
], | ||
) |