Skip to content

Commit

Permalink
resolve reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Dec 12, 2024
1 parent 5121eff commit e061a16
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 12 deletions.
16 changes: 13 additions & 3 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_IntegerType_packing():


def test_DenseIntOrFPElementsAttr_fp_type_conversion():
check1 = DenseIntOrFPElementsAttr.tensor_from_list([4, 5], f32, [])
check1 = DenseIntOrFPElementsAttr.tensor_from_list([4, 5], f32, [2])

value1 = check1.get_attrs()[0].value.data
value2 = check1.get_attrs()[1].value.data
Expand All @@ -231,7 +231,7 @@ def test_DenseIntOrFPElementsAttr_fp_type_conversion():
t1 = FloatAttr(4.0, f32)
t2 = FloatAttr(5.0, f32)

check2 = DenseIntOrFPElementsAttr.tensor_from_list([t1, t2], f32, [])
check2 = DenseIntOrFPElementsAttr.tensor_from_list([t1, t2], f32, [2])

value3 = check2.get_attrs()[0].value.data
value4 = check2.get_attrs()[1].value.data
Expand All @@ -244,9 +244,19 @@ def test_DenseIntOrFPElementsAttr_fp_type_conversion():


def test_DenseIntOrFPElementsAttr_from_list():
# legal zero-rank tensor
attr = DenseIntOrFPElementsAttr.tensor_from_list([5.5], f32, [])
assert attr.type == AnyTensorType(f32, [])

assert attr.type == AnyTensorType(f32, [1])
# illegal zero-rank tensor
with pytest.raises(
ValueError, match="A zero-rank tensor can only hold 1 value but 2 were given."
):
DenseIntOrFPElementsAttr.tensor_from_list([5.5, 5.6], f32, [])

# legal normal tensor
attr = DenseIntOrFPElementsAttr.tensor_from_list([5.5, 5.6], f32, [2])
assert attr.type == AnyTensorType(f32, [2])


@pytest.mark.parametrize(
Expand Down
19 changes: 14 additions & 5 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,6 +1927,10 @@ def get_shape(self) -> tuple[int, ...]:
def get_element_type(self) -> IntegerType | IndexType | AnyFloat:
return self.type.get_element_type()

@property
def nb_elements(self) -> int:
return prod(self.get_shape())

@property
def shape_is_complete(self) -> bool:
shape = self.get_shape()
Expand All @@ -1939,7 +1943,7 @@ def shape_is_complete(self) -> bool:
n *= dim

# Product of dimensions needs to equal length
return n == len(self.get_values())
return n == self.nb_elements

@staticmethod
def create_dense_index(
Expand Down Expand Up @@ -2036,6 +2040,12 @@ def from_list(
),
data: Sequence[int | float] | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr:
# zero rank type should only hold 1 value
if not type.get_shape() and len(data) != 1:
raise ValueError(
f"A zero-rank {type.name} can only hold 1 value but {len(data)} were given."
)

# splat value given
if len(data) == 1 and prod(type.get_shape()) != 1:
new_data = (data[0],) * prod(type.get_shape())
Expand Down Expand Up @@ -2063,8 +2073,9 @@ def from_list(
def vector_from_list(
data: Sequence[int] | Sequence[float],
data_type: IntegerType | IndexType | AnyFloat,
shape: Sequence[int],
) -> DenseIntOrFPElementsAttr:
t = VectorType(data_type, [len(data)])
t = VectorType(data_type, shape)
return DenseIntOrFPElementsAttr.from_list(t, data)

@staticmethod
Expand All @@ -2079,16 +2090,14 @@ def tensor_from_list(
data_type: IntegerType | IndexType | AnyFloat,
shape: Sequence[int],
) -> DenseIntOrFPElementsAttr:
if not len(shape):
shape = [len(data)]
t = TensorType(data_type, shape)
return DenseIntOrFPElementsAttr.from_list(t, data)

def get_values(self) -> Sequence[int] | Sequence[float]:
"""
Return all the values of the elements in this DenseIntOrFPElementsAttr
"""
return self.type.element_type.unpack(self.data.data, prod(self.get_shape()))
return self.type.element_type.unpack(self.data.data, self.nb_elements)

def get_attrs(self) -> Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr]:
"""
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def parse(cls, parser: Parser) -> Self:
)
assert isinstance(flag_type, IntegerType | IndexType)
case_values = DenseIntOrFPElementsAttr.vector_from_list(
[x for (x, _, _) in cases], flag_type
[x for (x, _, _) in cases], flag_type, [len(cases)]
)
case_blocks = tuple(x for (_, x, _) in cases)
case_operands = tuple(tuple(x) for (_, _, x) in cases)
Expand Down
4 changes: 3 additions & 1 deletion xdsl/dialects/varith.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def parse(cls, parser: Parser) -> Self:
parser.parse_punctuation("]")
attr_dict = parser.parse_optional_attr_dict()

case_values = DenseIntOrFPElementsAttr.vector_from_list(values, flag_type)
case_values = DenseIntOrFPElementsAttr.vector_from_list(
values, flag_type, [len(values)]
)

return cls(
flag,
Expand Down
4 changes: 3 additions & 1 deletion xdsl/transforms/canonicalization_patterns/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ def drop_case_helper(
op.default_block,
op.default_operands,
DenseIntOrFPElementsAttr.vector_from_list(
new_case_values, case_values.get_element_type()
new_case_values,
case_values.get_element_type(),
[len(new_case_values)],
),
new_case_blocks,
new_case_operands,
Expand Down
4 changes: 3 additions & 1 deletion xdsl/transforms/convert_scf_to_cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def match_and_rewrite(self, op: IndexSwitchOp, rewriter: PatternRewriter):
case_value,
default_block,
(),
DenseIntOrFPElementsAttr.vector_from_list(case_values, i32),
DenseIntOrFPElementsAttr.vector_from_list(
case_values, i32, [len(case_values)]
),
case_successors,
case_operands,
),
Expand Down

0 comments on commit e061a16

Please sign in to comment.