From dff4a4e7854d026abe9565110b5f968d47ff544c Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Fri, 15 Nov 2024 08:10:50 +0000 Subject: [PATCH] use assembly format --- .../bufferization/bufferization_ops.mlir | 4 +- xdsl/dialects/bufferization.py | 57 +++---------------- 2 files changed, 9 insertions(+), 52 deletions(-) diff --git a/tests/filecheck/dialects/bufferization/bufferization_ops.mlir b/tests/filecheck/dialects/bufferization/bufferization_ops.mlir index 60b0e98b31..c41a3bf21c 100644 --- a/tests/filecheck/dialects/bufferization/bufferization_ops.mlir +++ b/tests/filecheck/dialects/bufferization/bufferization_ops.mlir @@ -10,10 +10,10 @@ // CHECK-NEXT: %t1 = bufferization.alloc_tensor(%i0, %i1) {"hello" = "world"} : tensor<10x20x?x?xf64> // CHECK-NEXT: %t2 = bufferization.alloc_tensor() copy(%t0) : tensor<10x20x30xf64> -// CHECK-NEXT: %t3 = bufferization.alloc_tensor(%i0, %i1) size_hint=%i1 : tensor<10x20x?x?xf64> +// CHECK-NEXT: %t3 = bufferization.alloc_tensor(%i0, %i1) size_hint = %i1 : tensor<10x20x?x?xf64> %t1 = bufferization.alloc_tensor(%i0, %i1) {"hello"="world"}: tensor<10x20x?x?xf64> %t2 = bufferization.alloc_tensor() copy(%t0) : tensor<10x20x30xf64> -%t3 = bufferization.alloc_tensor(%i0, %i1) size_hint=%i1: tensor<10x20x?x?xf64> +%t3 = bufferization.alloc_tensor(%i0, %i1) size_hint = %i1: tensor<10x20x?x?xf64> // CHECK-NEXT: } diff --git a/xdsl/dialects/bufferization.py b/xdsl/dialects/bufferization.py index 18b76339cc..2d66ba95e5 100644 --- a/xdsl/dialects/bufferization.py +++ b/xdsl/dialects/bufferization.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, Self +from typing import Any, ClassVar from xdsl.dialects.builtin import ( AnyMemRefTypeConstr, @@ -29,8 +29,6 @@ result_def, var_operand_def, ) -from xdsl.parser import Parser -from xdsl.printer import Printer from xdsl.utils.exceptions import VerifyException from xdsl.utils.hints import isa @@ -120,14 +118,18 @@ class AllocTensorOp(IRDLOperation): name = "bufferization.alloc_tensor" + T: ClassVar = VarConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr) + dynamic_sizes = var_operand_def(IndexType()) - copy = opt_operand_def(AnyOf((AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr))) + copy = opt_operand_def(T) size_hint = opt_operand_def(IndexType()) - tensor = result_def(AnyOf((AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr))) + tensor = result_def(T) irdl_options = [AttrSizedOperandSegments(as_property=True)] + assembly_format = "`(` $dynamic_sizes `)` ( `copy` `(` $copy^ `)`)? (`size_hint` `=` $size_hint^)? attr-dict `:` type($tensor)" + def __init__( self, result_type: Attribute, @@ -140,51 +142,6 @@ def __init__( result_types=(result_type,), ) - @classmethod - def parse(cls, parser: Parser) -> Self: - dynamic_sizes = parser.parse_comma_separated_list( - parser.Delimiter.PAREN, parser.parse_operand - ) - if parser.parse_optional_keyword("copy") is not None: - parser.parse_punctuation("(") - copy = parser.parse_operand() - parser.parse_punctuation(")") - else: - copy = None - - if parser.parse_optional_keyword("size_hint") is not None: - parser.parse_punctuation("=") - size_hint = parser.parse_operand() - else: - size_hint = None - - attr_dict = parser.parse_optional_attr_dict() - - parser.parse_punctuation(":") - result_type = parser.parse_type() - - result = cls(result_type, dynamic_sizes, copy, size_hint) - - result.attributes |= attr_dict - - return result - - def print(self, printer: Printer): - printer.print_string("(", indent=0) - printer.print_list(self.dynamic_sizes, printer.print_ssa_value) - printer.print_string(")", indent=0) - if self.copy is not None: - printer.print_string(" copy(", indent=0) - printer.print_ssa_value(self.copy) - printer.print_string(")", indent=0) - if self.size_hint is not None: - printer.print_string(" size_hint=", indent=0) - printer.print_ssa_value(self.size_hint) - - printer.print_op_attributes(self.attributes) - printer.print_string(" : ") - printer.print_attribute(self.tensor.type) - @irdl_op_definition class ToTensorOp(IRDLOperation):