Skip to content

Commit

Permalink
refactor(tapac): check custom RTL using Module
Browse files Browse the repository at this point in the history
This helps reduce parser-specific details exposed in public APIs,
simplifies code, and makes it possible to change the internals used to
implement `Module`.
  • Loading branch information
Blaok committed Jan 21, 2025
1 parent 0528953 commit 0cdac1d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 189 deletions.
63 changes: 17 additions & 46 deletions tapa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
IntConst,
Land,
Minus,
ModuleDef,
Node,
NonblockingSubstitution,
Output,
Expand All @@ -54,15 +53,13 @@
Width,
Wire,
)
from pyverilog.vparser.parser import VerilogCodeParser

from tapa.backend.xilinx import RunAie, RunHls
from tapa.instance import Instance, Port
from tapa.safety_check import check_mmap_arg_name
from tapa.task import Task
from tapa.util import (
clang_format,
extract_ports_from_module,
get_instance_name,
get_module_name,
get_vendor_include_paths,
Expand Down Expand Up @@ -1499,49 +1496,23 @@ def replace_custom_rtl(self, rtl_paths: tuple[Path, ...]) -> None:

def _check_custom_rtl_format(self, rtl_paths: list[Path]) -> None:
"""Check if the custom RTL files are in the correct format."""
if not rtl_paths:
return
_logger.info("Checking custom RTL files format.")
with tempfile.TemporaryDirectory(prefix="pyverilog-") as output_dir:
codeparser = VerilogCodeParser(
rtl_paths,
preprocess_output=os.path.join(output_dir, "preprocess.output"),
outputdir=output_dir,
debug=False,
)
ast = codeparser.parse()
# Traverse the AST to find module definitions
module_defs: list[ModuleDef] = [
mod_def
for mod_def in ast.description.definitions
if isinstance(mod_def, ModuleDef)
]

for module_def in module_defs:
for task_name, task in self._tasks.items():
if task_name != module_def.name:
continue
_logger.info("Checking custom RTL file format for task %s.", task_name)
task_port_infos = extract_ports_from_module(
task.module.get_module_def()
)
custom_rtl_port_infos = extract_ports_from_module(module_def)
if task_port_infos != custom_rtl_port_infos:
assert set(task_port_infos.keys()) == set(
custom_rtl_port_infos.keys()
)
task_port_infos_str = "\n".join(
f" {port}" for port in task_port_infos.values()
)
custom_rtl_port_infos_str = "\n".join(
f" {port}" for port in custom_rtl_port_infos.values()
)
msg = (
f"Custom RTL file for task {task_name} does not match the "
f"expected ports. \nTask ports: \n{task_port_infos_str}\n"
f"Custom RTL ports:\n{custom_rtl_port_infos_str}"
)
raise ValueError(msg)
if rtl_paths:
_logger.info("checking custom RTL files format")
for rtl_path in rtl_paths:
rtl_module = Module([str(rtl_path)])
if (task := self._tasks.get(rtl_module.name)) is None:
continue # ignore RTL modules that are not tasks
if rtl_module.ports == task.module.ports:
continue # ports match exactly
msg = [
f"Custom RTL file {rtl_path} for task {task.name}"
" does not match the expected ports.",
"Task ports:",
*(f" {port}" for port in task.module.ports.values()),
"Custom RTL ports:",
*(f" {port}" for port in rtl_module.ports.values()),
]
raise ValueError("\n".join(msg))

def get_aie_graph(self, task: Task) -> str:
"""Generates the complete AIE graph code."""
Expand Down
139 changes: 0 additions & 139 deletions tapa/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,12 @@
from typing import TYPE_CHECKING, Literal

import coloredlogs
from pyverilog.ast_code_generator.codegen import ASTCodeGenerator
from pyverilog.vparser.ast import (
Decl,
Inout,
Input,
Ioport,
ModuleDef,
Output,
Port,
Width,
)

if TYPE_CHECKING:
from collections.abc import Iterable

_logger = logging.getLogger().getChild(__name__)

AST_IOPort = Input | Output | Inout


class PortInfo:
"""Port information extracted from a Verilog module definition."""

def __init__(self, name: str, direction: str | None, width: Width | None) -> None:
self.name = name
self.direction = direction
self.width = width

def __str__(self) -> str:
codegen = ASTCodeGenerator()
if self.width:
return f"{self.direction} {codegen.visit(self.width)} {self.name}"
return f"{self.direction} {self.name}"

def __eq__(self, other: PortInfo) -> bool:
codegen = ASTCodeGenerator()
width1 = codegen.visit(self.width) if self.width else None
width2 = codegen.visit(other.width) if other.width else None
return (
self.name == other.name
and self.direction == other.direction
and width1 == width2
)

def __hash__(self) -> int:
return hash((self.name, self.direction, self.width))


def clang_format(code: str, *args: str) -> str:
"""Apply clang-format with given arguments, if possible."""
Expand Down Expand Up @@ -241,101 +200,3 @@ def setup_logging(
logging.getLogger().addHandler(handler)

_logger.info("logging level set to %s", logging.getLevelName(logging_level))


def extract_ports_from_module_header(module_def: ModuleDef) -> dict[str, PortInfo]:
"""Extract ports direction and width from a given Verilog module header."""
port_infos = {}
for port in module_def.portlist.ports:
if isinstance(port, Port):
# When header contains only port names
# the port is an instance of Port
port_infos[port.name] = PortInfo(
name=port.name,
direction=None,
width=None,
)
else:
# When header contains port direction and/or width
# the port is an instance of Ioport
assert isinstance(port, Ioport)

if port.first.children():
width = port.first.children()[0]
assert isinstance(width, Width)
else:
width = None

if isinstance(port.first, Port):
direction = None
else:
direction = port.first.__class__.__name__

port_infos[port.first.name] = PortInfo(
name=port.first.name,
direction=direction,
width=width,
)
return port_infos


def extract_ports_from_module_body(module_def: ModuleDef) -> dict[str, PortInfo]:
"""Extract ports direction and width from a given Verilog module body."""
port_infos = {}
for decl in module_def.children():
if not isinstance(decl, Decl):
continue
port = decl.children()[0]
if not isinstance(port, AST_IOPort):
continue

if port.children():
width = port.children()[0]
assert isinstance(width, Width)
else:
width = None

direction = port.__class__.__name__

port_infos[port.name] = PortInfo(
name=port.name,
direction=direction,
width=width,
)
return port_infos


def extract_ports_from_module(module_def: ModuleDef) -> dict[str, PortInfo]:
"""Extract ports direction and width from a given Verilog module."""
# get port info from header and body separately
port_infos_1 = extract_ports_from_module_header(module_def)
port_infos_2 = extract_ports_from_module_body(module_def)

# merge the two dictionaries
merged_dict = {}

# Get all unique port names
all_ports = set(port_infos_1.keys()).union(port_infos_2.keys())

for port_name in all_ports:
port1 = port_infos_1.get(port_name)
port2 = port_infos_2.get(port_name)

# Merge the port information
if port1 and port2:
if port1.direction and port2.direction:
assert port1.direction == port2.direction
if port1.width and port2.width:
assert port1.width == port2.width
merged_port = PortInfo(
name=port1.name, # Assuming name is always the same
direction=port1.direction or port2.direction,
width=port1.width if port1.width is not None else port2.width,
)
else:
# If the port exists in only one dictionary, use it as is
merged_port = port1 or port2

merged_dict[port_name] = merged_port

return merged_dict
11 changes: 7 additions & 4 deletions tapa/verilog/xilinx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Input,
Instance,
InstanceList,
Ioport,
Lvalue,
ModuleDef,
Node,
Expand Down Expand Up @@ -212,16 +213,18 @@ def _module_def(self) -> ModuleDef:
raise ValueError(msg)
return module_defs[0]

def get_module_def(self) -> ModuleDef:
return self._module_def

@property
def name(self) -> str:
return self._module_def.name

@property
def ports(self) -> dict[str, ioport.IOPort]:
port_lists = (x.list for x in self._module_def.items if isinstance(x, Decl))
port_lists = [
# ANSI style: ports declared in header
(x.first for x in self._module_def.portlist.ports if isinstance(x, Ioport)),
# Non-ANSI style: ports declared in body
*(x.list for x in self._module_def.items if isinstance(x, Decl)),
]
return collections.OrderedDict(
(x.name, ioport.IOPort.create(x))
for x in itertools.chain.from_iterable(port_lists)
Expand Down

0 comments on commit 0cdac1d

Please sign in to comment.