Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: (constraints) add IntAttrConstraint #3797

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Float64Type,
FloatAttr,
IndexType,
IntAttrConstraint,
IntegerAttr,
IntegerType,
MemRefType,
Expand Down Expand Up @@ -3155,3 +3156,125 @@ def test_multiple_operand_extraction_fails():
"Possible values are: {i32, index}",
):
parser.parse_operation()


################################################################################
# IntAttr #
################################################################################


@irdl_op_definition
class IntAttrExtractOp(IRDLOperation):
name = "test.int_attr_extract"

_I: ClassVar = IntVarConstraint("I", AnyInt())

prop = prop_def(
IntegerAttr.constr(value=IntAttrConstraint(_I), type=eq(IndexType()))
)

outs = var_result_def(RangeOf(eq(IndexType()), length=_I))

assembly_format = "$prop attr-dict"


@pytest.mark.parametrize(
"program",
["%0 = test.int_attr_extract 1", "%0, %1 = test.int_attr_extract 2"],
)
def test_int_attr_extraction(program: str):
ctx = MLContext()
ctx.load_op(IntAttrExtractOp)

check_roundtrip(program, ctx)


@pytest.mark.parametrize(
"program, error",
[
(
"%0 = test.int_attr_extract 2",
"Operation has 2 results, but was given 1 to bind",
),
(
"%0, %1 = test.int_attr_extract 1",
"Operation has 1 results, but was given 2 to bind",
),
],
)
def test_int_attr_extraction_errors(program: str, error: str):
ctx = MLContext()
ctx.load_op(IntAttrExtractOp)
parser = Parser(ctx, program)
with pytest.raises(ParseError, match=error):
parser.parse_optional_operation()


@irdl_op_definition
class IntAttrVerifyOp(IRDLOperation):
name = "test.int_attr_verify"

_I: ClassVar = IntVarConstraint("I", AnyInt())

prop = prop_def(
IntegerAttr.constr(value=IntAttrConstraint(_I), type=eq(IndexType()))
)

prop2 = opt_prop_def(
IntegerAttr.constr(value=IntAttrConstraint(_I), type=eq(IndexType()))
)

ins = var_operand_def(RangeOf(eq(IndexType()), length=_I))

assembly_format = "$prop (`and` $prop2^)? `,` $ins attr-dict"


@pytest.mark.parametrize(
"program",
[
"test.int_attr_verify 1, %0",
"test.int_attr_verify 2, %0, %1",
"test.int_attr_verify 1 and 1, %0",
"test.int_attr_verify 2 and 2, %0, %1",
],
)
def test_int_attr_verify(program: str):
ctx = MLContext()
ctx.load_op(IntAttrVerifyOp)

check_roundtrip(program, ctx)


@pytest.mark.parametrize(
"program, error_type, error",
[
(
"test.int_attr_verify 1, %0, %1",
ValueError,
"Value of variable I could not be uniquely extracted",
),
(
"test.int_attr_verify 1 and 2, %0",
VerifyException,
"integer 1 expected from int variable 'I', but got 2",
),
(
"test.int_attr_verify 2, %0",
ValueError,
"Value of variable I could not be uniquely extracted",
),
(
"test.int_attr_verify 2 and 1, %0, %1",
VerifyException,
"integer 2 expected from int variable 'I', but got 1",
),
],
)
def test_int_attr_verify_errors(program: str, error_type: type[Exception], error: str):
ctx = MLContext()
ctx.load_op(IntAttrVerifyOp)

parser = Parser(ctx, program)
with pytest.raises(error_type, match=error):
op = parser.parse_operation()
op.verify()
40 changes: 39 additions & 1 deletion xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import struct
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence, Set
from dataclasses import dataclass
from enum import Enum
from math import prod
Expand Down Expand Up @@ -52,6 +52,8 @@
ConstraintVariableType,
GenericAttrConstraint,
GenericData,
InferenceContext,
IntConstraint,
IRDLOperation,
MessageConstraint,
ParamAttrConstraint,
Expand Down Expand Up @@ -296,6 +298,42 @@ def print_parameter(self, printer: Printer) -> None:
printer.print_string(f"{self.data}")


@dataclass(frozen=True)
class IntAttrConstraint(GenericAttrConstraint[IntAttr]):
"""
Constrains the value of an IntAttr.
"""

int_constraint: IntConstraint

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if not isinstance(attr, IntAttr):
raise VerifyException(f"attribute {attr} expected to be an IntAttr")
self.int_constraint.verify(attr.data, constraint_context)

@dataclass(frozen=True)
class _Extractor(VarExtractor[IntAttr]):
inner: VarExtractor[int]

def extract_var(self, a: IntAttr) -> ConstraintVariableType:
return self.inner.extract_var(a.data)

def get_variable_extractors(self) -> dict[str, VarExtractor[IntAttr]]:
return {
k: self._Extractor(v)
for k, v in self.int_constraint.get_length_extractors().items()
}

def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.int_constraint.can_infer(var_constraint_names)

def infer(self, context: InferenceContext) -> IntAttr:
return IntAttr(self.int_constraint.infer(context))

def get_unique_base(self) -> type[Attribute] | None:
return IntAttr


class Signedness(Enum):
"Signedness semantics for integer"

Expand Down