Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into update-uv
Browse files Browse the repository at this point in the history
superlopuh authored Dec 16, 2024
2 parents cd1f5ce + 7c4cb63 commit cc95819
Showing 9 changed files with 109 additions and 66 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/ci-lockfile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: CI - Lockfile

on:
# Trigger the workflow on push or pull request,
# but only for the master branch
push:
branches:
- main
pull_request:

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.12']

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "uv.lock"

- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}

- name: Install the package locally and check for lockfile mismatch
run: |
# Install all default extras.
# Fail if the lockfile dependencies are out of date with pyproject.toml.
XDSL_VERSION_OVERRIDE="0+dynamic" uv sync --extra gui --extra dev --extra jax --extra riscv --locked
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ uv-installed:
# set up the venv with all dependencies for development
.PHONY: ${VENV_DIR}/
${VENV_DIR}/: uv-installed
uv sync ${VENV_EXTRAS}
XDSL_VERSION_OVERRIDE="0+dynamic" uv sync ${VENV_EXTRAS}

# make sure `make venv` also works correctly
.PHONY: venv
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os
from collections.abc import Mapping
from typing import cast

from setuptools import Command, find_packages, setup

import versioneer

if "XDSL_VERSION_OVERRIDE" in os.environ:
version = os.environ["XDSL_VERSION_OVERRIDE"]
else:
version = versioneer.get_version()


setup(
version=versioneer.get_version(),
version=version,
cmdclass=cast(Mapping[str, type[Command]], versioneer.get_cmdclass()),
packages=find_packages(),
)
2 changes: 1 addition & 1 deletion tests/interactive/test_app.py
Original file line number Diff line number Diff line change
@@ -291,7 +291,7 @@ def callback(x: str):
assert app.condense_mode is True
rewrites = get_all_possible_rewrites(
expected_module,
individual_rewrite.REWRITE_BY_NAMES,
individual_rewrite.INDIVIDUAL_REWRITE_PATTERNS_BY_NAME,
)
assert app.available_pass_list == get_condensed_pass_list(
expected_module, app.all_passes
17 changes: 9 additions & 8 deletions tests/interactive/test_rewrites.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from xdsl.context import MLContext
from xdsl.dialects import get_all_dialects
from xdsl.dialects.builtin import (
Builtin,
StringAttr,
)
from xdsl.dialects.test import TestOp
from xdsl.dialects.test import Test, TestOp
from xdsl.interactive.passes import AvailablePass
from xdsl.interactive.rewrites import (
IndexedIndividualRewrite,
@@ -39,9 +39,10 @@ def test_get_all_possible_rewrite():
}
"""

ctx = MLContext(True)
for dialect_name, dialect_factory in get_all_dialects().items():
ctx.register_dialect(dialect_name, dialect_factory)
ctx = MLContext()
ctx.load_dialect(Builtin)
ctx.load_dialect(Test)

parser = Parser(ctx, prog)
module = parser.parse_module()

@@ -73,9 +74,9 @@ def test_convert_indexed_individual_rewrites_to_available_pass():
}
"""

ctx = MLContext(True)
for dialect_name, dialect_factory in get_all_dialects().items():
ctx.register_dialect(dialect_name, dialect_factory)
ctx = MLContext()
ctx.load_dialect(Builtin)
ctx.load_dialect(Test)
parser = Parser(ctx, prog)
module = parser.parse_module()

2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions xdsl/interactive/app.py
Original file line number Diff line number Diff line change
@@ -266,7 +266,7 @@ def compute_available_pass_list(self) -> tuple[AvailablePass, ...]:
self.input_text_area.text,
self.pass_pipeline,
self.condense_mode,
individual_rewrite.REWRITE_BY_NAMES,
individual_rewrite.INDIVIDUAL_REWRITE_PATTERNS_BY_NAME,
)

def watch_available_pass_list(
@@ -509,7 +509,7 @@ def expand_tree_node(
self.input_text_area.text,
child_pass_pipeline,
self.condense_mode,
individual_rewrite.REWRITE_BY_NAMES,
individual_rewrite.INDIVIDUAL_REWRITE_PATTERNS_BY_NAME,
)

self.expand_node(expanded_node, child_pass_list)
12 changes: 9 additions & 3 deletions xdsl/interactive/rewrites.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
from xdsl.dialects.builtin import ModuleOp
from xdsl.interactive.passes import AvailablePass
from xdsl.pattern_rewriter import PatternRewriter, RewritePattern
from xdsl.traits import HasCanonicalizationPatternsTrait
from xdsl.transforms import individual_rewrite
from xdsl.utils.parse_pipeline import PipelinePassSpec

@@ -64,9 +65,14 @@ def get_all_possible_rewrites(
res: list[IndexedIndividualRewrite] = []

for op_idx, matched_op in enumerate(module.walk()):
if matched_op.name not in rewrite_by_name:
continue
pattern_by_name = rewrite_by_name[matched_op.name]
pattern_by_name = rewrite_by_name.get(matched_op.name, {}).copy()

if (
trait := matched_op.get_trait(HasCanonicalizationPatternsTrait)
) is not None:
for pattern in trait.get_canonicalization_patterns():
pattern_by_name[type(pattern).__name__] = pattern

for pattern_name, pattern in pattern_by_name.items():
cloned_op = tuple(module.clone().walk())[op_idx]
rewriter = PatternRewriter(cloned_op)
89 changes: 40 additions & 49 deletions xdsl/transforms/individual_rewrite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from dataclasses import dataclass, field

from xdsl.context import MLContext
from xdsl.dialects import arith, get_all_dialects
from xdsl.dialects import arith
from xdsl.dialects.builtin import IndexType, IntegerAttr, IntegerType, ModuleOp
from xdsl.ir import Operation
from xdsl.passes import ModulePass
@@ -44,43 +44,28 @@ def match_and_rewrite(self, op: arith.DivUIOp, rewriter: PatternRewriter) -> Non
rewriter.replace_matched_op([], [mul_op.lhs])


INDIVIDUAL_REWRITE_PATTERNS_BY_OP_CLASS: dict[
type[Operation], tuple[RewritePattern, ...]
] = {
arith.AddiOp: (AdditionOfSameVariablesToMultiplyByTwo(),),
arith.DivUIOp: (DivisionOfSameVariableToOne(),),
INDIVIDUAL_REWRITE_PATTERNS_BY_NAME: dict[str, dict[str, RewritePattern]] = {
arith.AddiOp.name: {
AdditionOfSameVariablesToMultiplyByTwo.__name__: AdditionOfSameVariablesToMultiplyByTwo()
},
arith.DivUIOp.name: {
DivisionOfSameVariableToOne.__name__: DivisionOfSameVariableToOne()
},
}
"""
Dictionary where the key is an Operation and the value is a tuple of rewrite pattern(s) associated with that operation. These are rewrite patterns defined in this class.
Extra rewrite patterns available to ApplyIndividualRewritePass
"""

CANONICALIZATION_PATTERNS_BY_OP_CLASS: dict[
type[Operation], tuple[RewritePattern, ...]
] = {
op: trait.get_canonicalization_patterns()
for dialect in get_all_dialects().values()
for op in dialect().operations
if (trait := op.get_trait(HasCanonicalizationPatternsTrait)) is not None
}
"""
Dictionary where the key is an Operation and the value is a tuple of rewrite pattern(s) associated with that operation. These are the xdsl canonicalization patterns.
"""

REWRITE_BY_NAMES: dict[str, dict[str, RewritePattern]] = {
op.name: {
pattern.__class__.__name__: pattern
for pattern in INDIVIDUAL_REWRITE_PATTERNS_BY_OP_CLASS.get(op, ())
+ CANONICALIZATION_PATTERNS_BY_OP_CLASS.get(op, ())
}
for op in set(INDIVIDUAL_REWRITE_PATTERNS_BY_OP_CLASS)
| set(CANONICALIZATION_PATTERNS_BY_OP_CLASS)
}
"""
Returns a dictionary representing all possible rewrites. Keys are operation names, and
values are dictionaries. In the inner dictionary, the keys are names of patterns
associated with each operation, and the values are the corresponding RewritePattern
instances.
"""
def _get_canonicalization_pattern(
op: Operation, pattern_name: str
) -> RewritePattern | None:
if (trait := op.get_trait(HasCanonicalizationPatternsTrait)) is None:
return None

for pattern in trait.get_canonicalization_patterns():
if type(pattern).__name__ == pattern_name:
return pattern


@dataclass(frozen=True)
@@ -94,30 +79,36 @@ class ApplyIndividualRewritePass(ModulePass):

name = "apply-individual-rewrite"

matched_operation_index: int | None = None
operation_name: str | None = None
pattern_name: str | None = None
matched_operation_index: int = field()
operation_name: str = field()
pattern_name: str = field()

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
assert self.matched_operation_index is not None
assert self.operation_name is not None
assert self.pattern_name is not None

matched_operation_list = list(op.walk())
if self.matched_operation_index >= len(matched_operation_list):
all_ops = list(op.walk())
if self.matched_operation_index >= len(all_ops):
raise ValueError("Matched operation index out of range.")

matched_operation = list(op.walk())[self.matched_operation_index]
matched_operation = all_ops[self.matched_operation_index]
rewriter = PatternRewriter(matched_operation)

rewrite_dictionary = REWRITE_BY_NAMES.get(self.operation_name)
if rewrite_dictionary is None:
if matched_operation.name != self.operation_name:
raise ValueError(
f"Operation name {self.operation_name} not found in the rewrite dictionary."
f"Operation {matched_operation.name} at index "
f"{self.matched_operation_index} does not match {self.operation_name}"
)

pattern = rewrite_dictionary.get(self.pattern_name)
if pattern is None:
# Check individual rewrites first
if (
individual_rewrites := INDIVIDUAL_REWRITE_PATTERNS_BY_NAME.get(
self.operation_name
)
) is not None and (p := individual_rewrites.get(self.pattern_name)) is not None:
pattern = p
elif (
p := _get_canonicalization_pattern(matched_operation, self.pattern_name)
) is not None:
pattern = p
else:
raise ValueError(
f"Pattern name {self.pattern_name} not found for the provided operation name."
)

0 comments on commit cc95819

Please sign in to comment.