From e061a1662e73363101d25a8be54263505fa57f91 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Thu, 12 Dec 2024 22:10:49 +0100 Subject: [PATCH] resolve reviewer comments --- tests/dialects/test_builtin.py | 16 +++++++++++++--- xdsl/dialects/builtin.py | 19 ++++++++++++++----- xdsl/dialects/cf.py | 2 +- xdsl/dialects/varith.py | 4 +++- .../canonicalization_patterns/cf.py | 4 +++- xdsl/transforms/convert_scf_to_cf.py | 4 +++- 6 files changed, 37 insertions(+), 12 deletions(-) diff --git a/tests/dialects/test_builtin.py b/tests/dialects/test_builtin.py index 2409978230..94035c1862 100644 --- a/tests/dialects/test_builtin.py +++ b/tests/dialects/test_builtin.py @@ -217,7 +217,7 @@ def test_IntegerType_packing(): def test_DenseIntOrFPElementsAttr_fp_type_conversion(): - check1 = DenseIntOrFPElementsAttr.tensor_from_list([4, 5], f32, []) + check1 = DenseIntOrFPElementsAttr.tensor_from_list([4, 5], f32, [2]) value1 = check1.get_attrs()[0].value.data value2 = check1.get_attrs()[1].value.data @@ -231,7 +231,7 @@ def test_DenseIntOrFPElementsAttr_fp_type_conversion(): t1 = FloatAttr(4.0, f32) t2 = FloatAttr(5.0, f32) - check2 = DenseIntOrFPElementsAttr.tensor_from_list([t1, t2], f32, []) + check2 = DenseIntOrFPElementsAttr.tensor_from_list([t1, t2], f32, [2]) value3 = check2.get_attrs()[0].value.data value4 = check2.get_attrs()[1].value.data @@ -244,9 +244,19 @@ def test_DenseIntOrFPElementsAttr_fp_type_conversion(): def test_DenseIntOrFPElementsAttr_from_list(): + # legal zero-rank tensor attr = DenseIntOrFPElementsAttr.tensor_from_list([5.5], f32, []) + assert attr.type == AnyTensorType(f32, []) - assert attr.type == AnyTensorType(f32, [1]) + # illegal zero-rank tensor + with pytest.raises( + ValueError, match="A zero-rank tensor can only hold 1 value but 2 were given." + ): + DenseIntOrFPElementsAttr.tensor_from_list([5.5, 5.6], f32, []) + + # legal normal tensor + attr = DenseIntOrFPElementsAttr.tensor_from_list([5.5, 5.6], f32, [2]) + assert attr.type == AnyTensorType(f32, [2]) @pytest.mark.parametrize( diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 5bf44f4a72..c25580fdc0 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -1927,6 +1927,10 @@ def get_shape(self) -> tuple[int, ...]: def get_element_type(self) -> IntegerType | IndexType | AnyFloat: return self.type.get_element_type() + @property + def nb_elements(self) -> int: + return prod(self.get_shape()) + @property def shape_is_complete(self) -> bool: shape = self.get_shape() @@ -1939,7 +1943,7 @@ def shape_is_complete(self) -> bool: n *= dim # Product of dimensions needs to equal length - return n == len(self.get_values()) + return n == self.nb_elements @staticmethod def create_dense_index( @@ -2036,6 +2040,12 @@ def from_list( ), data: Sequence[int | float] | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr], ) -> DenseIntOrFPElementsAttr: + # zero rank type should only hold 1 value + if not type.get_shape() and len(data) != 1: + raise ValueError( + f"A zero-rank {type.name} can only hold 1 value but {len(data)} were given." + ) + # splat value given if len(data) == 1 and prod(type.get_shape()) != 1: new_data = (data[0],) * prod(type.get_shape()) @@ -2063,8 +2073,9 @@ def from_list( def vector_from_list( data: Sequence[int] | Sequence[float], data_type: IntegerType | IndexType | AnyFloat, + shape: Sequence[int], ) -> DenseIntOrFPElementsAttr: - t = VectorType(data_type, [len(data)]) + t = VectorType(data_type, shape) return DenseIntOrFPElementsAttr.from_list(t, data) @staticmethod @@ -2079,8 +2090,6 @@ def tensor_from_list( data_type: IntegerType | IndexType | AnyFloat, shape: Sequence[int], ) -> DenseIntOrFPElementsAttr: - if not len(shape): - shape = [len(data)] t = TensorType(data_type, shape) return DenseIntOrFPElementsAttr.from_list(t, data) @@ -2088,7 +2097,7 @@ def get_values(self) -> Sequence[int] | Sequence[float]: """ Return all the values of the elements in this DenseIntOrFPElementsAttr """ - return self.type.element_type.unpack(self.data.data, prod(self.get_shape())) + return self.type.element_type.unpack(self.data.data, self.nb_elements) def get_attrs(self) -> Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr]: """ diff --git a/xdsl/dialects/cf.py b/xdsl/dialects/cf.py index d067286c57..f49d7131a3 100644 --- a/xdsl/dialects/cf.py +++ b/xdsl/dialects/cf.py @@ -368,7 +368,7 @@ def parse(cls, parser: Parser) -> Self: ) assert isinstance(flag_type, IntegerType | IndexType) case_values = DenseIntOrFPElementsAttr.vector_from_list( - [x for (x, _, _) in cases], flag_type + [x for (x, _, _) in cases], flag_type, [len(cases)] ) case_blocks = tuple(x for (_, x, _) in cases) case_operands = tuple(tuple(x) for (_, _, x) in cases) diff --git a/xdsl/dialects/varith.py b/xdsl/dialects/varith.py index 26855f4abf..1fdae0a785 100644 --- a/xdsl/dialects/varith.py +++ b/xdsl/dialects/varith.py @@ -151,7 +151,9 @@ def parse(cls, parser: Parser) -> Self: parser.parse_punctuation("]") attr_dict = parser.parse_optional_attr_dict() - case_values = DenseIntOrFPElementsAttr.vector_from_list(values, flag_type) + case_values = DenseIntOrFPElementsAttr.vector_from_list( + values, flag_type, [len(values)] + ) return cls( flag, diff --git a/xdsl/transforms/canonicalization_patterns/cf.py b/xdsl/transforms/canonicalization_patterns/cf.py index eea4fb276e..10c73e2c79 100644 --- a/xdsl/transforms/canonicalization_patterns/cf.py +++ b/xdsl/transforms/canonicalization_patterns/cf.py @@ -319,7 +319,9 @@ def drop_case_helper( op.default_block, op.default_operands, DenseIntOrFPElementsAttr.vector_from_list( - new_case_values, case_values.get_element_type() + new_case_values, + case_values.get_element_type(), + [len(new_case_values)], ), new_case_blocks, new_case_operands, diff --git a/xdsl/transforms/convert_scf_to_cf.py b/xdsl/transforms/convert_scf_to_cf.py index d1c07e683b..69d1898ea5 100644 --- a/xdsl/transforms/convert_scf_to_cf.py +++ b/xdsl/transforms/convert_scf_to_cf.py @@ -209,7 +209,9 @@ def match_and_rewrite(self, op: IndexSwitchOp, rewriter: PatternRewriter): case_value, default_block, (), - DenseIntOrFPElementsAttr.vector_from_list(case_values, i32), + DenseIntOrFPElementsAttr.vector_from_list( + case_values, i32, [len(case_values)] + ), case_successors, case_operands, ),