Skip to content

Commit

Permalink
misc: use experimental TypeForm feature of Pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Nov 14, 2024
1 parent b7ca27c commit 6707289
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 13 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ build-backend = "setuptools.build_meta"
reportImportCycles = false
reportMissingModuleSource = false
enableTypeIgnoreComments = false
enableExperimentalFeatures = true
typeCheckingMode = "strict"
"include" = ["docs", "xdsl", "tests", "bench"]
"exclude" = [
Expand Down
5 changes: 3 additions & 2 deletions tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from xdsl.dialects import test
from xdsl.dialects.arith import Addi, Arith, Constant, Muli
from xdsl.dialects.builtin import (
AnyIntegerAttr,
Builtin,
IndexType,
IntegerAttr,
Expand Down Expand Up @@ -245,7 +246,7 @@ def test_recursive_rewriter():
class Rewrite(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Constant, rewriter: PatternRewriter):
if not isa(op_val := op.value, IntegerAttr):
if not isa(op_val := op.value, AnyIntegerAttr):
return
val = op_val.value.data
if val == 0 or val == 1:
Expand Down Expand Up @@ -288,7 +289,7 @@ def test_recursive_rewriter_reversed():
class Rewrite(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Constant, rewriter: PatternRewriter):
if not isa(op_val := op.value, IntegerAttr):
if not isa(op_val := op.value, AnyIntegerAttr):
return
val = op_val.value.data
if val == 0 or val == 1:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_is_satisfying_hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class MyParamAttr(Generic[_T], ParametrizedAttribute):
def test_parametrized_attribute():
attr = MyParamAttr[IntAttr]([IntAttr(0)])

assert isa(attr, MyParamAttr)
# `assert isa(attr, MyParamAttr)` not supported: use isinstance instead
assert isa(attr, MyParamAttr[IntAttr])
assert isa(attr, MyParamAttr[IntAttr | FloatData])
assert not isa(attr, MyParamAttr[FloatData])
Expand Down
4 changes: 2 additions & 2 deletions xdsl/transforms/csl_stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
zip(op.receive_chunk.block.args, buf_apply_op.receive_chunk.block.args)
):
# arg0 has special meaning and does not need a `to_tensor` op
if isattr(old_arg.type, TensorType) and idx != 0:
if isattr(old_arg.type, AnyTensorTypeConstr) and idx != 0:
rewriter.insert_op(
# ensure iter_arg is writable
t := to_tensor_op(arg, writable=idx == 2),
Expand All @@ -139,7 +139,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
for idx, (old_arg, arg) in enumerate(
zip(op.done_exchange.block.args, buf_apply_op.done_exchange.block.args)
):
if isattr(old_arg.type, TensorType):
if isattr(old_arg.type, AnyTensorTypeConstr):
rewriter.insert_op(
# ensure iter_arg is writable
t := to_tensor_op(arg, writable=idx == 1),
Expand Down
3 changes: 2 additions & 1 deletion xdsl/transforms/lower_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from xdsl.dialects import arith, func, memref, stencil
from xdsl.dialects.builtin import (
AnyFloatAttr,
AnyMemRefTypeConstr,
ArrayAttr,
DenseIntOrFPElementsAttr,
Float32Type,
Expand Down Expand Up @@ -440,7 +441,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
return

if (
not isattr(accumulator.type, memref.MemRefType)
not isattr(accumulator.type, AnyMemRefTypeConstr)
or not isinstance(op.accumulator, OpResult)
or not isinstance(alloc := op.accumulator.op, memref.Alloc)
):
Expand Down
2 changes: 1 addition & 1 deletion xdsl/transforms/shape_inference_patterns/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def update_result_size(
apply = value.owner
res_types = (cast(TempType[Attribute], r.type) for r in apply.res)
newsize = reduce(
lambda l, r: l | r,
StencilBoundsAttr.union,
(
size,
*(
Expand Down
5 changes: 2 additions & 3 deletions xdsl/transforms/varith_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.utils.hints import isa

# map the arith operation to the right varith op:
ARITH_TO_VARITH_TYPE_MAP: dict[
Expand Down Expand Up @@ -138,14 +137,14 @@ def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /):
# iterate over operands of the varith op:
for inp in op.operands:
# if the input happens to be the right arith op:
if isa(inp.owner, target_arith_type):
if isinstance(inp.owner, target_arith_type):
# fuse the operands of the arith op into the new varith op
new_operands.append(inp.owner.lhs)
new_operands.append(inp.owner.rhs)
# check if the old arith op can be erased
possibly_erased_ops.append(inp.owner)
# if the parent op is a varith op of the same type as us
elif isa(inp.owner, type(op)):
elif isinstance(inp.owner, type(op)):
# include all operands of that
new_operands.extend(inp.owner.operands)
# check the old varith op for usages
Expand Down
6 changes: 5 additions & 1 deletion xdsl/utils/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Iterable, Sequence
from inspect import isclass
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Generic,
Expand All @@ -17,10 +18,13 @@
from xdsl.ir import ParametrizedAttribute
from xdsl.utils.exceptions import VerifyException

if TYPE_CHECKING:
from typing_extensions import TypeForm

_T = TypeVar("_T")


def isa(arg: Any, hint: type[_T]) -> TypeGuard[_T]:
def isa(arg: Any, hint: "TypeForm[_T]") -> TypeGuard[_T]:
from xdsl.irdl import ConstraintContext

"""
Expand Down
7 changes: 5 additions & 2 deletions xdsl/utils/isattr.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Any, TypeGuard
from typing import TYPE_CHECKING, Any, TypeGuard

from xdsl.ir import AttributeInvT
from xdsl.irdl import GenericAttrConstraint
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa

if TYPE_CHECKING:
from typing_extensions import TypeForm


def isattr(
arg: Any, hint: type[AttributeInvT] | GenericAttrConstraint[AttributeInvT]
arg: Any, hint: "TypeForm[AttributeInvT]" | GenericAttrConstraint[AttributeInvT]
) -> TypeGuard[AttributeInvT]:
"""
A helper method to check whether a given attribute has a given type or conforms to a
Expand Down

0 comments on commit 6707289

Please sign in to comment.