Skip to content

Commit

Permalink
dialects: (builtin) remove AnyMemRefTypeConstr (#3832)
Browse files Browse the repository at this point in the history
Can be replaced with `MemRefType.constr()`
  • Loading branch information
alexarice authored Feb 4, 2025
1 parent e6f5649 commit 65f957b
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 33 deletions.
5 changes: 2 additions & 3 deletions tests/dialects/test_bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
ToTensorOp,
)
from xdsl.dialects.builtin import (
AnyMemRefTypeConstr,
AnyUnrankedMemrefTypeConstr,
IndexType,
IntegerType,
Expand All @@ -34,7 +33,7 @@


def test_tensor_from_memref_inference():
constr = TensorFromMemrefConstraint(AnyMemRefTypeConstr)
constr = TensorFromMemrefConstraint(MemRefType.constr())
assert not constr.can_infer(set())

constr2 = TensorFromMemrefConstraint(
Expand All @@ -53,7 +52,7 @@ def test_tensor_from_memref_inference():
@irdl_op_definition
class TensorFromMemrefOp(IRDLOperation):
name = "test.tensor_from_memref"
T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)

in_tensor = operand_def(
TensorFromMemrefConstraint(
Expand Down
5 changes: 2 additions & 3 deletions tests/test_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from xdsl.dialects.builtin import (
DYNAMIC_INDEX,
AnyIntegerAttr,
AnyMemRefTypeConstr,
AnyTensorTypeConstr,
AnyUnrankedMemrefTypeConstr,
AnyUnrankedTensorTypeConstr,
Expand Down Expand Up @@ -596,14 +595,14 @@ class SameOperandsAndResultTypeOp(IRDLOperation):
name = "test.same_operand_and_result_type"

ops = var_operand_def(
AnyMemRefTypeConstr
MemRefType.constr()
| AnyUnrankedMemrefTypeConstr
| AnyUnrankedTensorTypeConstr
| AnyTensorTypeConstr
)

res = var_result_def(
AnyMemRefTypeConstr
MemRefType.constr()
| AnyUnrankedMemrefTypeConstr
| AnyUnrankedTensorTypeConstr
| AnyTensorTypeConstr
Expand Down
9 changes: 4 additions & 5 deletions xdsl/dialects/bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, ClassVar

from xdsl.dialects.builtin import (
AnyMemRefTypeConstr,
AnyTensorTypeConstr,
AnyUnrankedMemrefTypeConstr,
AnyUnrankedTensorTypeConstr,
Expand Down Expand Up @@ -140,7 +139,7 @@ def __init__(
class CloneOp(IRDLOperation):
name = "bufferization.clone"

T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)

input = operand_def(T)
output = result_def(T)
Expand All @@ -156,7 +155,7 @@ def __init__(self, input: SSAValue | Operation):
class ToTensorOp(IRDLOperation):
name = "bufferization.to_tensor"

T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)

memref = operand_def(T)
tensor = result_def(TensorFromMemrefConstraint(T))
Expand Down Expand Up @@ -196,7 +195,7 @@ def __init__(
class ToMemrefOp(IRDLOperation):
name = "bufferization.to_memref"

T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
tensor = operand_def(TensorFromMemrefConstraint(T))
memref = result_def(T)

Expand All @@ -209,7 +208,7 @@ class ToMemrefOp(IRDLOperation):
class MaterializeInDestinationOp(IRDLOperation):
name = "bufferization.materialize_in_destination"

T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
source = operand_def(TensorFromMemrefConstraint(T))
dest = operand_def(T | TensorFromMemrefConstraint(T))
result = opt_result_def(TensorFromMemrefConstraint(T))
Expand Down
3 changes: 0 additions & 3 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1975,9 +1975,6 @@ def constr(
)


AnyMemRefTypeConstr = BaseAttr[MemRefType[Attribute]](MemRefType)


@dataclass(frozen=True, init=False)
class TensorOrMemrefOf(
GenericAttrConstraint[TensorType[AttributeCovT] | MemRefType[AttributeCovT]]
Expand Down
19 changes: 9 additions & 10 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from xdsl.dialects.builtin import (
AnyFloat,
AnyIntegerAttr,
AnyMemRefTypeConstr,
AnyTensorTypeConstr,
Float16Type,
Float32Type,
Expand Down Expand Up @@ -132,7 +131,7 @@ class PrefetchOp(IRDLOperation):
name = "csl_stencil.prefetch"

input_stencil = operand_def(
stencil.StencilTypeConstr | AnyMemRefTypeConstr | AnyTensorTypeConstr
stencil.StencilTypeConstr | MemRefType.constr() | AnyTensorTypeConstr
)

swaps = prop_def(builtin.ArrayAttr[ExchangeDeclarationAttr])
Expand All @@ -141,7 +140,7 @@ class PrefetchOp(IRDLOperation):

num_chunks = prop_def(AnyIntegerAttr)

result = result_def(AnyMemRefTypeConstr | AnyTensorTypeConstr)
result = result_def(MemRefType.constr() | AnyTensorTypeConstr)

def __init__(
self,
Expand Down Expand Up @@ -227,13 +226,13 @@ class ApplyOp(IRDLOperation):

name = "csl_stencil.apply"

field = operand_def(stencil.StencilTypeConstr | AnyMemRefTypeConstr)
field = operand_def(stencil.StencilTypeConstr | MemRefType.constr())

accumulator = operand_def(AnyTensorTypeConstr | AnyMemRefTypeConstr)
accumulator = operand_def(AnyTensorTypeConstr | MemRefType.constr())

args_rchunk = var_operand_def(Attribute)
args_dexchng = var_operand_def(Attribute)
dest = var_operand_def(stencil.FieldTypeConstr | AnyMemRefTypeConstr)
dest = var_operand_def(stencil.FieldTypeConstr | MemRefType.constr())

receive_chunk = region_def()
done_exchange = region_def()
Expand Down Expand Up @@ -364,7 +363,7 @@ def verify_(self) -> None:
# typecheck required (only) block arguments
assert isattr(
self.accumulator.type,
AnyTensorTypeConstr | AnyMemRefTypeConstr,
AnyTensorTypeConstr | MemRefType.constr(),
)
chunk_region_req_types = [
type(self.accumulator.type)(
Expand Down Expand Up @@ -460,11 +459,11 @@ class AccessOp(IRDLOperation):

name = "csl_stencil.access"
op = operand_def(
AnyMemRefTypeConstr | stencil.StencilTypeConstr | AnyTensorTypeConstr
MemRefType.constr() | stencil.StencilTypeConstr | AnyTensorTypeConstr
)
offset = prop_def(stencil.IndexAttr)
offset_mapping = opt_prop_def(stencil.IndexAttr)
result = result_def(AnyTensorTypeConstr | AnyMemRefTypeConstr)
result = result_def(AnyTensorTypeConstr | MemRefType.constr())

traits = traits_def(HasAncestor(stencil.ApplyOp, ApplyOp), Pure())

Expand Down Expand Up @@ -582,7 +581,7 @@ def verify_(self) -> None:
f"{type(self)} access to own data requires type stencil.StencilType or memref.MemRefType but found {self.op.type}"
)
else:
if not isattr(self.op.type, AnyTensorTypeConstr | AnyMemRefTypeConstr):
if not isattr(self.op.type, AnyTensorTypeConstr | MemRefType.constr()):
raise VerifyException(
f"{type(self)} access to neighbor data requires type memref.MemRefType or TensorType but found {self.op.type}"
)
Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from xdsl.dialects import memref
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyMemRefTypeConstr,
ArrayAttr,
ContainerType,
IndexType,
IntAttr,
IntegerAttr,
IntegerType,
MemRefType,
StringAttr,
)
from xdsl.dialects.utils import AbstractYieldOperation
Expand Down Expand Up @@ -463,7 +463,7 @@ class GenericOp(IRDLOperation):
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be read.
"""
outputs = var_operand_def(AnyMemRefTypeConstr | WritableStreamType.constr())
outputs = var_operand_def(MemRefType.constr() | WritableStreamType.constr())
"""
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be written
Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

from xdsl.dialects import builtin, memref
from xdsl.dialects.builtin import (
AnyMemRefTypeConstr,
ArrayAttr,
IndexType,
IntAttr,
IntegerAttr,
MemRefType,
TensorType,
)
from xdsl.ir import (
Expand Down Expand Up @@ -910,7 +910,7 @@ class ExternalLoadOp(IRDLOperation):

name = "stencil.external_load"
field = operand_def(Attribute)
result = result_def(base(FieldType[Attribute]) | AnyMemRefTypeConstr)
result = result_def(base(FieldType[Attribute]) | MemRefType.constr())

assembly_format = (
"$field attr-dict-with-keyword `:` type($field) `->` type($result)"
Expand Down
4 changes: 2 additions & 2 deletions xdsl/transforms/convert_stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from xdsl.dialects import arith, builtin, memref, stencil, tensor, varith
from xdsl.dialects.builtin import (
AnyFloatAttr,
AnyMemRefTypeConstr,
AnyTensorType,
DenseIntOrFPElementsAttr,
IndexType,
IntegerAttr,
IntegerType,
MemRefType,
ModuleOp,
TensorType,
)
Expand Down Expand Up @@ -177,7 +177,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):

assert isattr(
op.input_stencil.type,
AnyMemRefTypeConstr | stencil.StencilTypeConstr,
MemRefType.constr() | stencil.StencilTypeConstr,
)
assert isa(
t_type := op.input_stencil.type.get_element_type(), TensorType[Attribute]
Expand Down
5 changes: 2 additions & 3 deletions xdsl/transforms/csl_stencil_to_csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, func, llvm, memref, stencil
from xdsl.dialects.builtin import (
AnyMemRefTypeConstr,
AnyTensorTypeConstr,
ArrayAttr,
DictionaryAttr,
Expand Down Expand Up @@ -113,7 +112,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
stencil.StencilTypeConstr,
) and isattr(
el_type := field_t.element_type,
AnyTensorTypeConstr | AnyMemRefTypeConstr,
AnyTensorTypeConstr | MemRefType.constr(),
):
# unbufferized csl_stencil
z_dim = max(z_dim, el_type.get_shape()[-1])
Expand All @@ -124,7 +123,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
num_chunks = max(num_chunks, apply_op.num_chunks.value.data)
if isattr(
buf_t := apply_op.receive_chunk.block.args[0].type,
AnyTensorTypeConstr | AnyMemRefTypeConstr,
AnyTensorTypeConstr | MemRefType.constr(),
):
chunk_size = max(chunk_size, buf_t.get_shape()[-1])

Expand Down

0 comments on commit 65f957b

Please sign in to comment.