Skip to content

Commit

Permalink
fix(tapac): keep different ast nodes separated in generated modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaok committed Feb 13, 2024
1 parent 5f72d09 commit bc092e4
Showing 1 changed file with 48 additions and 40 deletions.
88 changes: 48 additions & 40 deletions backend/python/tapa/verilog/xilinx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,24 @@
class Module:
"""AST and helpers for a verilog module.
_last_*_idx is the array bound if the type of item is not present.
`_next_*_idx` is the index to module_def.items where the next type of item
should be inserted.
Attributes:
ast: The ast.Source node.
directives: Tuple of Directives.
_handshake_output_ports: A mapping from ap_done, ap_idle, ap_ready signal
names to their ast.Assign nodes.
_last_io_port_idx: Last index of an IOPort in module_def.items.
_last_signal_idx: Last index of ast.Wire or ast.Reg in module_def.items.
_last_param_idx: Last index of ast.Parameter in module_def.items.
_last_instance_idx: Last index of ast.InstanceList in module_def.items.
_last_logic_idx: Last index of ast.Assign or ast.Always in module_def.items.
_next_io_port_idx: Next index of an IOPort in module_def.items.
_next_signal_idx: Next index of ast.Wire or ast.Reg in module_def.items.
_next_param_idx: Next index of ast.Parameter in module_def.items.
_next_instance_idx: Next index of ast.InstanceList in module_def.items.
_next_logic_idx: Next index of ast.Assign or ast.Always in module_def.items.
"""

# module_def.items should contain the following attributes, in that order.
_ATTRS = 'param', 'io_port', 'signal', 'logic', 'instance'

def __init__(self, files: Iterable[str], is_trimming_enabled: bool = False):
"""Construct a Module from files. """
if not files:
Expand Down Expand Up @@ -81,13 +85,13 @@ def _calculate_indices(self) -> None:
if any(
isinstance(x, (ast.Input, ast.Output, ast.Input))
for x in item.list):
self._last_io_port_idx = idx
self._next_io_port_idx = idx + 1
elif any(isinstance(x, (ast.Wire, ast.Reg)) for x in item.list):
self._last_signal_idx = idx
self._next_signal_idx = idx + 1
elif any(isinstance(x, ast.Parameter) for x in item.list):
self._last_param_idx = idx
self._next_param_idx = idx + 1
elif isinstance(item, (ast.Assign, ast.Always)):
self._last_logic_idx = idx
self._next_logic_idx = idx + 1
if isinstance(item, ast.Assign):
if isinstance(item.left, ast.Lvalue):
name = item.left.var.name
Expand All @@ -96,12 +100,12 @@ def _calculate_indices(self) -> None:
if name in HANDSHAKE_OUTPUT_PORTS:
self._handshake_output_ports[name] = item
elif isinstance(item, ast.InstanceList):
self._last_instance_idx = idx
self._next_instance_idx = idx + 1

# if the item type is not present, set idx to the array bound
for attr in 'io_port', 'signal', 'param', 'logic', 'instance':
if not hasattr(self, '_last_%s_idx' % attr):
setattr(self, '_last_%s_idx' % attr, len(self._module_def.items))
if not hasattr(self, f'_next_{attr}_idx'):
setattr(self, f'_next_{attr}_idx', 0)

@property
def _module_def(self) -> ast.ModuleDef:
Expand Down Expand Up @@ -260,18 +264,22 @@ def code(self) -> str:
) + codegen.ASTCodeGenerator().visit(self.ast)

def _increment_idx(self, length: int, target: str) -> None:
attrs = 'io_port', 'signal', 'param', 'logic', 'instance'
if target not in attrs:
raise ValueError('target must be one of %s' % str(attrs))

for attr in attrs:
if attr == target:
continue
attr = '_last_%s_idx' % attr
idx = '_last_%s_idx' % target
if getattr(self, attr) > getattr(self, idx):
setattr(self, attr, getattr(self, attr) + length)
setattr(self, idx, getattr(self, idx) + length)
attr_map = {attr: priority for priority, attr in enumerate(self._ATTRS)}
target_priority = attr_map.get(target)
if target_priority is None:
raise ValueError(f'target must be one of {self._ATTRS}')

# Get the index of the target once, since it could change in the loop.
target_idx = getattr(self, f'_next_{target}_idx')

# Increment `_next_*_idx` if it is after `_next_{target}_idx`. If
# `_next_*_idx` == `_next_{target}_idx`, increment only if `priority` is
# larger, i.e., `attr` should show up after `target` in `module_def.items`.
for priority, attr in enumerate(self._ATTRS):
attr_name = f'_next_{attr}_idx'
idx = getattr(self, attr_name)
if (idx, priority) >= (target_idx, target_priority):
setattr(self, attr_name, idx + length)

def _filter(self, func: Callable[[ast.Node], bool], target: str) -> None:
self._module_def.items = tuple(filter(func, self._module_def.items))
Expand All @@ -282,17 +290,17 @@ def add_ports(self, ports: Iterable[IOPort]) -> 'Module':
self._module_def.portlist.ports += tuple(
ast.Port(name=port.name, width=None, dimensions=None, type=None)
for port in port_tuple)
self._module_def.items = (
self._module_def.items[:self._last_io_port_idx + 1] + port_tuple +
self._module_def.items[self._last_io_port_idx + 1:])
self._module_def.items = (self._module_def.items[:self._next_io_port_idx] +
port_tuple +
self._module_def.items[self._next_io_port_idx:])
self._increment_idx(len(port_tuple), 'io_port')
return self

def add_signals(self, signals: Iterable[Signal]) -> 'Module':
signal_tuple = tuple(signals)
self._module_def.items = (
self._module_def.items[:self._last_signal_idx + 1] + signal_tuple +
self._module_def.items[self._last_signal_idx + 1:])
self._module_def.items = (self._module_def.items[:self._next_signal_idx] +
signal_tuple +
self._module_def.items[self._next_signal_idx:])
self._increment_idx(len(signal_tuple), 'signal')
return self

Expand Down Expand Up @@ -330,9 +338,9 @@ def func(item: ast.Node) -> bool:

def add_params(self, params: Iterable[ast.Parameter]) -> 'Module':
param_tuple = tuple(params)
self._module_def.items = (
self._module_def.items[:self._last_param_idx + 1] + param_tuple +
self._module_def.items[self._last_param_idx + 1:])
self._module_def.items = (self._module_def.items[:self._next_param_idx] +
param_tuple +
self._module_def.items[self._next_param_idx:])
self._increment_idx(len(param_tuple), 'param')
return self

Expand All @@ -350,9 +358,9 @@ def func(item: ast.Node) -> bool:
self._filter(func, 'param')

def add_instancelist(self, item: ast.InstanceList) -> 'Module':
self._module_def.items = (
self._module_def.items[:self._last_instance_idx + 1] + (item,) +
self._module_def.items[self._last_instance_idx + 1:])
self._module_def.items = (self._module_def.items[:self._next_instance_idx] +
(item,) +
self._module_def.items[self._next_instance_idx:])
self._increment_idx(1, 'instance')
return self

Expand All @@ -375,9 +383,9 @@ def add_instance(

def add_logics(self, logics: Iterable[Logic]) -> 'Module':
logic_tuple = tuple(logics)
self._module_def.items = (
self._module_def.items[:self._last_logic_idx + 1] + logic_tuple +
self._module_def.items[self._last_logic_idx + 1:])
self._module_def.items = (self._module_def.items[:self._next_logic_idx] +
logic_tuple +
self._module_def.items[self._next_logic_idx:])
self._increment_idx(len(logic_tuple), 'logic')
return self

Expand Down

0 comments on commit bc092e4

Please sign in to comment.