Skip to content

Commit

Permalink
add Locks and Buffers to tile array (Xilinx#1046)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Feb 24, 2024
1 parent d3d2c04 commit f4f09a7
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 97 deletions.
2 changes: 1 addition & 1 deletion python/aie-python-extras-req.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# this is actually the aie branch of mlir-python-extras (see https://github.com/makslevental/mlir-python-extras/pull/33)
aie-python-extras @ https://github.com/makslevental/mlir-python-extras/archive/a0f48f0b67affc1c88d79fe94ef3fe5c27c73f80.zip
aie-python-extras @ https://github.com/makslevental/mlir-python-extras/archive/c7cf7ca8587b24e1f169199ddec5d4424f24c42e.zip
126 changes: 124 additions & 2 deletions python/dialects/aie.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import inspect
from collections import namedtuple
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

from ._aie_enum_gen import *
Expand All @@ -20,6 +22,8 @@
_get_sym_name,
get_user_code_loc,
region_adder,
find_parent_of_type,
find_ops,
)
from ..ir import (
ArrayAttr,
Expand All @@ -31,6 +35,8 @@
IntegerAttr,
IntegerType,
TypeAttr,
DictAttr,
UnitAttr,
_i32ArrayAttr,
)

Expand Down Expand Up @@ -374,8 +380,10 @@ def buffer(buffer, tile, *, sym_name=None, address=None, loc=None, ip=None):
_lock = lock


def lock(tile, *, lock_id=None, init=None, sym_name=None, loc=None, ip=None):
return _lock(
def lock(
tile, *, lock_id=None, init=None, sym_name=None, annot=None, loc=None, ip=None
):
l = _lock(
tile,
lock_id=lock_id,
init=init,
Expand All @@ -384,6 +392,9 @@ def lock(tile, *, lock_id=None, init=None, sym_name=None, loc=None, ip=None):
loc=loc,
ip=ip,
)
if annot is not None:
l.owner.attributes["annot"] = DictAttr.get({annot: UnitAttr.get()})
return l


@_cext.register_operation(_Dialect, replace=True)
Expand Down Expand Up @@ -414,6 +425,101 @@ def flow(
)


def find_matching_flows(
tiles,
filter_source=False,
filter_dest=False,
source_annot=None,
dest_annot=None,
device=None,
):
assert not (filter_source and filter_dest), "Can only filter by source XOR dest"
if device is None:
device = find_parent_of_type(lambda op: isinstance(op, DeviceOp))

def _cb(op):
if isinstance(op, FlowOp):
if filter_source and op.source.owner.opview not in tiles:
return False
if filter_dest and op.dest.owner.opview not in tiles:
return False

return (
op.source.owner.opview in tiles
or op.dest.owner.opview in tiles
and (
(
"source_annot" in op.attributes
and source_annot in op.attributes["source_annot"]
)
if source_annot is not None
else True
)
and (
(
"dest_annot" in op.attributes
and dest_annot in op.attributes["dest_annot"]
)
if dest_annot is not None
else True
)
)

return find_ops(device, _cb)


def find_matching_locks(tiles, sym_name=None, annot=None, device=None):
if device is None:
device = find_parent_of_type(lambda op: isinstance(op, DeviceOp))

def _cb(op):
if isinstance(op, LockOp):
return (
op.tile.owner.opview in tiles
and (sym_name == str(op.sym_name) if sym_name is not None else True)
and (
("annot" in op.attributes and annot in op.attributes["annot"])
if annot is not None
else True
)
)

return find_ops(device, _cb)


@dataclass
class Neighbors:
north: TileOp = None
west: TileOp = None
south: TileOp = None


def find_neighbors(tile, device=None):
if device is None:
device = find_parent_of_type(lambda op: isinstance(op, DeviceOp))

assert int(device.device) == int(AIEDevice.ipu), "only ipu supported"

neighbors = {}
col, row = map(int, (tile.col, tile.row))
if col > 0 and row > 0 and not (col, row) == (1, 1):
neighbors[col - 1, row] = "west"
if row > 1:
neighbors[col, row - 1] = "south"
if 0 < row < 5:
neighbors[col, row + 1] = "north"

neighbors_ = {"north": None, "west": None, "south": None}

for n in find_ops(
device,
lambda op: isinstance(op, TileOp) and (int(op.col), int(op.row)) in neighbors,
):
neighbors_[neighbors[int(n.col), int(n.row)]] = n

return Neighbors(**neighbors_)


@_cext.register_operation(_Dialect, replace=True)
class TileOp(TileOp):
def __str__(self):
Expand All @@ -435,6 +541,22 @@ def __eq__(self, other):
def __hash__(self):
return hash((self.col, self.row))

def flows(
self, filter_source=False, filter_dest=False, source_annot=None, dest_annot=None
):
return find_matching_flows(
[self],
filter_source=filter_source,
filter_dest=filter_dest,
source_annot=None,
dest_annot=None,
)

def locks(self, sym_name=None, annot=None, device=None):
return find_matching_locks(
[self], sym_name=sym_name, annot=annot, device=device
)


def tile(col, row, *, loc=None, ip=None):
return TileOp(col=col, row=row, loc=loc, ip=ip)
106 changes: 20 additions & 86 deletions python/dialects/aiex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from ._aiex_ops_gen import *
from .aie import (
DMAChannelDir,
DeviceOp,
FlowOp,
LockAction,
TileOp,
dma,
dma_bd,
find_neighbors,
find_matching_flows,
find_matching_locks,
flow,
lock,
tile,
Expand All @@ -26,11 +27,8 @@
from .transform.structured import MixedValues, _dispatch_mixed_values
from .._mlir_libs import get_dialect_registry
from .._mlir_libs._aie import *
from ..extras.util import _get_previous_frame_idents, find_ops, find_parent_of_type
from ..ir import DictAttr, IntegerAttr, UnitAttr

# Copyright (C) 2023, Advanced Micro Devices, Inc.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Comes from _aie
register_dialect(get_dialect_registry())
Expand Down Expand Up @@ -373,26 +371,21 @@ def hold_lock(acq_lock, rel_lock, *, acq_val=None, rel_val=None):


class TileArray:
def __init__(self, cols=5, rows=6, df=None, flows=None):
def __init__(self, cols=5, rows=6, df=None):
if df is None:
df = np.array(
[[tile(c, r) for r in range(rows)] for c in range(cols)],
)
self.df = df
if flows is None:
flows = Flows(self)
self._flows = flows

def flow(self, other, *args, **kwargs):
return broadcast_flow(self.df, other.df, *args, **kwargs, flows=self.flows)
return broadcast_flow(self.df, other.df, *args, **kwargs)

def __rshift__(self, other):
return broadcast_flow(self.df, other.df, flows=self.flows)
return broadcast_flow(self.df, other.df)

def __lshift__(self, other):
r = np.frompyfunc(partial(broadcast_flow, flows=self.flows), 2, 1).outer(
other.df, self.df
)
r = np.frompyfunc(partial(broadcast_flow), 2, 1).outer(other.df, self.df)
if isinstance(r, np.ndarray):
r = r.flatten().tolist()
if len(r) == 1:
Expand Down Expand Up @@ -429,9 +422,15 @@ def __contains__(self, item):
assert isinstance(self.df, TileOp)
return item == self.df

def flows(self, **kwargs):
return find_matching_flows(self, **kwargs)

def locks(self, **kwargs):
return find_matching_locks(self, **kwargs)

@property
def flows(self):
return self._flows
def neighbors(self):
return np.vectorize(find_neighbors)(self.df)

@property
def shape(self):
Expand All @@ -441,66 +440,6 @@ def __repr__(self):
return f"<{self.__class__.__name__}: {self.df}>"


class Flows:
def __init__(self, tiles: TileArray):
self.tiles = tiles

def find_matching_flows(
self,
tiles,
filter_source=False,
filter_dest=False,
source_annot=None,
dest_annot=None,
):
assert not (filter_source and filter_dest), "Can only filter by source XOR dest"
device = find_parent_of_type(lambda op: isinstance(op, DeviceOp))

def _cb(op):
if isinstance(op, FlowOp):
if filter_source and op.source.owner.opview not in tiles:
return False
if filter_dest and op.dest.owner.opview not in tiles:
return False

return (
op.source.owner.opview in tiles
or op.dest.owner.opview in tiles
and (
(
"source_annot" in op.attributes
and source_annot in op.attributes["source_annot"]
)
if source_annot is not None
else True
)
and (
(
"dest_annot" in op.attributes
and dest_annot in op.attributes["dest_annot"]
)
if dest_annot is not None
else True
)
)

return find_ops(device, _cb)

def __getitem__(self, item):
kwargs = {}
if len(item) > 2:
# not sure how but you don't need two backs here
# (the previous frame is the call site of TileArray().flows...)
previous_frame = inspect.currentframe().f_back
for kwarg in item[2:]:
k = _get_previous_frame_idents(kwarg, previous_frame)
assert len(k) == 1, f"{len(k)=}"
kwargs[k[0]] = kwarg
item = item[:2]
tiles = self.tiles[item]
return self.find_matching_flows(tiles, **kwargs)


def broadcast_flow(
source: Union[np.ndarray, TileOp],
dest: Union[np.ndarray, TileOp],
Expand All @@ -510,14 +449,13 @@ def broadcast_flow(
dest_channel=None,
source_annot=None,
dest_annot=None,
flows: Flows = None,
):
if isinstance(source, TileOp):
source = np.asarray([source])
if isinstance(dest, TileOp):
dest = np.asarray([dest])
for chan in [source_channel, dest_channel]:
assert (chan is None and flows is not None) or np.all(
assert chan is None or np.all(
np.array(chan) != None
), "can't handle mixed auto channel assignment"

Expand All @@ -534,19 +472,15 @@ def _find_next_channel(used_channels):
if source_channel is None or np.all(np.array(source_channel) == None):
source_channel = np.empty_like(source, dtype=None)
for s, indices in zip(*map(list, np.unique(source, return_index=True))):
used_channels = set(
int(f.source_channel)
for f in flows.find_matching_flows([s], filter_source=True)
)
matching_flows = find_matching_flows([s], filter_source=True)
used_channels = set(int(f.source_channel) for f in matching_flows)
source_channel.flat[indices] = _find_next_channel(used_channels)

if dest_channel is None or np.all(np.array(dest_channel) == None):
used_channels = {}
for d in np.unique(dest):
used_channels[d] = set(
int(f.dest_channel)
for f in flows.find_matching_flows([d], filter_dest=True)
)
matching_flows = find_matching_flows([d], filter_dest=True)
used_channels[d] = set(int(f.dest_channel) for f in matching_flows)
dest_channel = np.empty_like(dest, dtype=None)
for idx, dst in np.ndenumerate(dest):
dest_channel[idx] = _find_next_channel(used_channels[dst])
Expand Down
Loading

0 comments on commit f4f09a7

Please sign in to comment.