Skip to content

Commit

Permalink
use assembly format
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Nov 15, 2024
1 parent 82e7072 commit dff4a4e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 52 deletions.
4 changes: 2 additions & 2 deletions tests/filecheck/dialects/bufferization/bufferization_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

// CHECK-NEXT: %t1 = bufferization.alloc_tensor(%i0, %i1) {"hello" = "world"} : tensor<10x20x?x?xf64>
// CHECK-NEXT: %t2 = bufferization.alloc_tensor() copy(%t0) : tensor<10x20x30xf64>
// CHECK-NEXT: %t3 = bufferization.alloc_tensor(%i0, %i1) size_hint=%i1 : tensor<10x20x?x?xf64>
// CHECK-NEXT: %t3 = bufferization.alloc_tensor(%i0, %i1) size_hint = %i1 : tensor<10x20x?x?xf64>
%t1 = bufferization.alloc_tensor(%i0, %i1) {"hello"="world"}: tensor<10x20x?x?xf64>
%t2 = bufferization.alloc_tensor() copy(%t0) : tensor<10x20x30xf64>
%t3 = bufferization.alloc_tensor(%i0, %i1) size_hint=%i1: tensor<10x20x?x?xf64>
%t3 = bufferization.alloc_tensor(%i0, %i1) size_hint = %i1: tensor<10x20x?x?xf64>

// CHECK-NEXT: }

Expand Down
57 changes: 7 additions & 50 deletions xdsl/dialects/bufferization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any, Self
from typing import Any, ClassVar

from xdsl.dialects.builtin import (
AnyMemRefTypeConstr,
Expand Down Expand Up @@ -29,8 +29,6 @@
result_def,
var_operand_def,
)
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa

Expand Down Expand Up @@ -120,14 +118,18 @@ class AllocTensorOp(IRDLOperation):

name = "bufferization.alloc_tensor"

T: ClassVar = VarConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr)

dynamic_sizes = var_operand_def(IndexType())
copy = opt_operand_def(AnyOf((AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr)))
copy = opt_operand_def(T)
size_hint = opt_operand_def(IndexType())

tensor = result_def(AnyOf((AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr)))
tensor = result_def(T)

irdl_options = [AttrSizedOperandSegments(as_property=True)]

assembly_format = "`(` $dynamic_sizes `)` ( `copy` `(` $copy^ `)`)? (`size_hint` `=` $size_hint^)? attr-dict `:` type($tensor)"

def __init__(
self,
result_type: Attribute,
Expand All @@ -140,51 +142,6 @@ def __init__(
result_types=(result_type,),
)

@classmethod
def parse(cls, parser: Parser) -> Self:
dynamic_sizes = parser.parse_comma_separated_list(
parser.Delimiter.PAREN, parser.parse_operand
)
if parser.parse_optional_keyword("copy") is not None:
parser.parse_punctuation("(")
copy = parser.parse_operand()
parser.parse_punctuation(")")
else:
copy = None

if parser.parse_optional_keyword("size_hint") is not None:
parser.parse_punctuation("=")
size_hint = parser.parse_operand()
else:
size_hint = None

attr_dict = parser.parse_optional_attr_dict()

parser.parse_punctuation(":")
result_type = parser.parse_type()

result = cls(result_type, dynamic_sizes, copy, size_hint)

result.attributes |= attr_dict

return result

def print(self, printer: Printer):
printer.print_string("(", indent=0)
printer.print_list(self.dynamic_sizes, printer.print_ssa_value)
printer.print_string(")", indent=0)
if self.copy is not None:
printer.print_string(" copy(", indent=0)
printer.print_ssa_value(self.copy)
printer.print_string(")", indent=0)
if self.size_hint is not None:
printer.print_string(" size_hint=", indent=0)
printer.print_ssa_value(self.size_hint)

printer.print_op_attributes(self.attributes)
printer.print_string(" : ")
printer.print_attribute(self.tensor.type)


@irdl_op_definition
class ToTensorOp(IRDLOperation):
Expand Down

0 comments on commit dff4a4e

Please sign in to comment.