From 3db8a0c3d6d49f24cfa198ccef7b378c48ee1417 Mon Sep 17 00:00:00 2001 From: Anshuman Mohan <10830208+anshumanmohan@users.noreply.github.com> Date: Thu, 16 May 2024 10:58:14 -0400 Subject: [PATCH] `gen_msb`: Tidying using modern eDSL (#2039) * Tidy some cell generation * Take a port by handle * Use new lsh_use * Use lsh_use more. * More tidying using builder helpers * No names for registers * Tidy imports --- calyx-py/calyx/builder.py | 28 ++++++++++ calyx-py/calyx/gen_msb.py | 106 ++++++++------------------------------ 2 files changed, 49 insertions(+), 85 deletions(-) diff --git a/calyx-py/calyx/builder.py b/calyx-py/calyx/builder.py index a695807ef6..025e4e2654 100644 --- a/calyx-py/calyx/builder.py +++ b/calyx-py/calyx/builder.py @@ -378,6 +378,10 @@ def rsh(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: """Generate a StdRsh cell.""" return self.binary("rsh", size, name, signed) + def lsh(self, size: int, name: str = None, signed: bool = False) -> CellBuilder: + """Generate a StdLsh cell.""" + return self.binary("lsh", size, name, signed) + def logic(self, operation, size: int, name: str = None) -> CellBuilder: """Generate a logical operator cell, of the flavor specified in `operation`.""" name = name or self.generate_name(operation) @@ -597,6 +601,30 @@ def decr(self, reg, val=1, signed=False, cellname=None): decr_group.done = reg.done return decr_group + def lsh_use(self, input, ans, val=1): + """Inserts wiring into `self` to perform `ans := input << val`.""" + width = ans.infer_width_reg() + cell = self.lsh(width) + with self.group(f"{cell.name}_group") as lsh_group: + cell.left = input + cell.right = const(width, val) + ans.write_en = 1 + ans.in_ = cell.out + lsh_group.done = ans.done + return lsh_group + + def rsh_use(self, input, ans, val=1): + """Inserts wiring into `self` to perform `ans := input >> val`.""" + width = ans.infer_width_reg() + cell = self.rsh(width) + with self.group(f"{cell.name}_group") as rsh_group: + cell.left = input + cell.right = const(width, val) + ans.write_en = 1 + ans.in_ = cell.out + rsh_group.done = ans.done + return rsh_group + def reg_store(self, reg, val, groupname=None): """Inserts wiring into `self` to perform `reg := val`.""" groupname = groupname or f"{reg.name}_store_to_reg" diff --git a/calyx-py/calyx/gen_msb.py b/calyx-py/calyx/gen_msb.py index 6ba1bc53ae..700927a2ab 100644 --- a/calyx-py/calyx/gen_msb.py +++ b/calyx-py/calyx/gen_msb.py @@ -1,15 +1,6 @@ from typing import List -from calyx.py_ast import ( - Stdlib, - Component, -) -from calyx.builder import ( - Builder, - CellAndGroup, - const, - HI, - while_with, -) +from calyx.py_ast import Component +from calyx.builder import Builder, while_with def gen_msb_calc(width: int, int_width: int) -> List[Component]: @@ -19,88 +10,33 @@ def gen_msb_calc(width: int, int_width: int) -> List[Component]: that 2^n <= x. Note that this is essentially finding the index of the most significant bit in x. count_ans is the value for `n`, and `value_ans` is the value for `2^n`. - Essentially, the component uses a while loop, a counter register, and shifts the input + The component uses a while loop, a counter register, and shifts the input 1 bit to the right at each iteration until it equals 0. Important note: this component doesn't work when the input is 0. """ builder = Builder() comp = builder.component("msb_calc") - comp.input("in", width) + in_ = comp.input("in", width) comp.output("count", width) comp.output("value", width) - rsh = comp.cell("rsh", Stdlib.op("rsh", width, signed=False)) - counter = comp.reg(width, "counter") - cur_val = comp.reg(width, "cur_val") - add = comp.cell("add", Stdlib.op("add", width, signed=False)) - sub = comp.cell("sub", Stdlib.op("sub", width, signed=False)) - neq = comp.cell("neq", Stdlib.op("neq", width, signed=False)) - lsh = comp.cell("lsh", Stdlib.op("lsh", width, signed=False)) - count_ans = comp.reg(width, "count_ans") - val_ans = comp.reg(width, "val_ans") - val_build = comp.reg(width, "val_build") + counter = comp.reg(width) + cur_val = comp.reg(width) + count_ans = comp.reg(width) + val_ans = comp.reg(width) + val_build = comp.reg(width) - with comp.group("wr_cur_val") as wr_cur_val: - rsh.left = comp.this().in_ - rsh.right = const(width, int_width) - cur_val.in_ = rsh.out - cur_val.write_en = HI - wr_cur_val.done = cur_val.done + wr_cur_val = comp.rsh_use(in_, cur_val, int_width) + wr_val_build = comp.reg_store(val_build, 1) + cur_val_cond = comp.neq_use(0, cur_val.out) + count_cond = comp.neq_use(0, counter.out) + incr_count = comp.incr(counter) + decr_count = comp.decr(counter) - with comp.group("wr_val_build") as wr_val_build: - val_build.in_ = const(32, 1) - val_build.write_en = HI - wr_val_build.done = val_build.done - - with comp.comb_group("cur_val_cond") as cur_val_cond: - neq.left = const(width, 0) - neq.right = cur_val.out - - with comp.comb_group("count_cond") as count_cond: - neq.left = const(width, 0) - neq.right = counter.out - - with comp.group("incr_count") as incr_count: - add.left = counter.out - add.right = const(width, 1) - counter.in_ = add.out - counter.write_en = HI - incr_count.done = counter.done - - with comp.group("shift_cur_val") as shift_cur_val: - rsh.left = cur_val.out - rsh.right = const(width, 1) - cur_val.in_ = rsh.out - cur_val.write_en = HI - shift_cur_val.done = cur_val.done - - with comp.group("shift_val_build") as shift_val_build: - lsh.left = val_build.out - lsh.right = const(width, 1) - val_build.in_ = lsh.out - val_build.write_en = HI - shift_val_build.done = val_build.done - - with comp.group("decr_count") as decr_count: - sub.left = counter.out - sub.right = const(width, 1) - counter.in_ = sub.out - counter.write_en = HI - decr_count.done = counter.done - - with comp.group("wr_count") as wr_count: - lsh.left = counter.out - lsh.right = const(width, width - int_width) - count_ans.in_ = lsh.out - count_ans.write_en = HI - wr_count.done = count_ans.done - - with comp.group("wr_val") as wr_val: - lsh.left = val_build.out - lsh.right = const(width, width - int_width) - val_ans.in_ = lsh.out - val_ans.write_en = HI - wr_val.done = val_ans.done + shift_cur_val = comp.rsh_use(cur_val.out, cur_val) + shift_val_build = comp.lsh_use(val_build.out, val_build) + wr_count = comp.lsh_use(counter.out, count_ans, width - int_width) + wr_val = comp.lsh_use(val_build.out, val_ans, width - int_width) with comp.continuous: comp.this().count = count_ans.out @@ -109,14 +45,14 @@ def gen_msb_calc(width: int, int_width: int) -> List[Component]: comp.control += [ wr_cur_val, while_with( - CellAndGroup(neq, cur_val_cond), + cur_val_cond, [incr_count, shift_cur_val], ), decr_count, wr_count, wr_val_build, while_with( - CellAndGroup(neq, count_cond), + count_cond, [decr_count, shift_val_build], ), wr_val,