Skip to content

Commit

Permalink
gen_msb: Tidying using modern eDSL (#2039)
Browse files Browse the repository at this point in the history
* Tidy some cell generation

* Take a port by handle

* Use new lsh_use

* Use lsh_use more.

* More tidying using builder helpers

* No names for registers

* Tidy imports
  • Loading branch information
anshumanmohan authored May 16, 2024
1 parent 1be68ce commit 3db8a0c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 85 deletions.
28 changes: 28 additions & 0 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,10 @@ def rsh(self, size: int, name: str = None, signed: bool = False) -> CellBuilder:
"""Generate a StdRsh cell."""
return self.binary("rsh", size, name, signed)

def lsh(self, size: int, name: str = None, signed: bool = False) -> CellBuilder:
"""Generate a StdLsh cell."""
return self.binary("lsh", size, name, signed)

def logic(self, operation, size: int, name: str = None) -> CellBuilder:
"""Generate a logical operator cell, of the flavor specified in `operation`."""
name = name or self.generate_name(operation)
Expand Down Expand Up @@ -597,6 +601,30 @@ def decr(self, reg, val=1, signed=False, cellname=None):
decr_group.done = reg.done
return decr_group

def lsh_use(self, input, ans, val=1):
"""Inserts wiring into `self` to perform `ans := input << val`."""
width = ans.infer_width_reg()
cell = self.lsh(width)
with self.group(f"{cell.name}_group") as lsh_group:
cell.left = input
cell.right = const(width, val)
ans.write_en = 1
ans.in_ = cell.out
lsh_group.done = ans.done
return lsh_group

def rsh_use(self, input, ans, val=1):
"""Inserts wiring into `self` to perform `ans := input >> val`."""
width = ans.infer_width_reg()
cell = self.rsh(width)
with self.group(f"{cell.name}_group") as rsh_group:
cell.left = input
cell.right = const(width, val)
ans.write_en = 1
ans.in_ = cell.out
rsh_group.done = ans.done
return rsh_group

def reg_store(self, reg, val, groupname=None):
"""Inserts wiring into `self` to perform `reg := val`."""
groupname = groupname or f"{reg.name}_store_to_reg"
Expand Down
106 changes: 21 additions & 85 deletions calyx-py/calyx/gen_msb.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
from typing import List
from calyx.py_ast import (
Stdlib,
Component,
)
from calyx.builder import (
Builder,
CellAndGroup,
const,
HI,
while_with,
)
from calyx.py_ast import Component
from calyx.builder import Builder, while_with


def gen_msb_calc(width: int, int_width: int) -> List[Component]:
Expand All @@ -19,88 +10,33 @@ def gen_msb_calc(width: int, int_width: int) -> List[Component]:
that 2^n <= x. Note that this is essentially finding the index of the most
significant bit in x. count_ans is the value for `n`, and `value_ans` is the
value for `2^n`.
Essentially, the component uses a while loop, a counter register, and shifts the input
The component uses a while loop, a counter register, and shifts the input
1 bit to the right at each iteration until it equals 0.
Important note: this component doesn't work when the input is 0.
"""
builder = Builder()
comp = builder.component("msb_calc")
comp.input("in", width)
in_ = comp.input("in", width)
comp.output("count", width)
comp.output("value", width)

rsh = comp.cell("rsh", Stdlib.op("rsh", width, signed=False))
counter = comp.reg(width, "counter")
cur_val = comp.reg(width, "cur_val")
add = comp.cell("add", Stdlib.op("add", width, signed=False))
sub = comp.cell("sub", Stdlib.op("sub", width, signed=False))
neq = comp.cell("neq", Stdlib.op("neq", width, signed=False))
lsh = comp.cell("lsh", Stdlib.op("lsh", width, signed=False))
count_ans = comp.reg(width, "count_ans")
val_ans = comp.reg(width, "val_ans")
val_build = comp.reg(width, "val_build")
counter = comp.reg(width)
cur_val = comp.reg(width)
count_ans = comp.reg(width)
val_ans = comp.reg(width)
val_build = comp.reg(width)

with comp.group("wr_cur_val") as wr_cur_val:
rsh.left = comp.this().in_
rsh.right = const(width, int_width)
cur_val.in_ = rsh.out
cur_val.write_en = HI
wr_cur_val.done = cur_val.done
wr_cur_val = comp.rsh_use(in_, cur_val, int_width)
wr_val_build = comp.reg_store(val_build, 1)
cur_val_cond = comp.neq_use(0, cur_val.out)
count_cond = comp.neq_use(0, counter.out)
incr_count = comp.incr(counter)
decr_count = comp.decr(counter)

with comp.group("wr_val_build") as wr_val_build:
val_build.in_ = const(32, 1)
val_build.write_en = HI
wr_val_build.done = val_build.done

with comp.comb_group("cur_val_cond") as cur_val_cond:
neq.left = const(width, 0)
neq.right = cur_val.out

with comp.comb_group("count_cond") as count_cond:
neq.left = const(width, 0)
neq.right = counter.out

with comp.group("incr_count") as incr_count:
add.left = counter.out
add.right = const(width, 1)
counter.in_ = add.out
counter.write_en = HI
incr_count.done = counter.done

with comp.group("shift_cur_val") as shift_cur_val:
rsh.left = cur_val.out
rsh.right = const(width, 1)
cur_val.in_ = rsh.out
cur_val.write_en = HI
shift_cur_val.done = cur_val.done

with comp.group("shift_val_build") as shift_val_build:
lsh.left = val_build.out
lsh.right = const(width, 1)
val_build.in_ = lsh.out
val_build.write_en = HI
shift_val_build.done = val_build.done

with comp.group("decr_count") as decr_count:
sub.left = counter.out
sub.right = const(width, 1)
counter.in_ = sub.out
counter.write_en = HI
decr_count.done = counter.done

with comp.group("wr_count") as wr_count:
lsh.left = counter.out
lsh.right = const(width, width - int_width)
count_ans.in_ = lsh.out
count_ans.write_en = HI
wr_count.done = count_ans.done

with comp.group("wr_val") as wr_val:
lsh.left = val_build.out
lsh.right = const(width, width - int_width)
val_ans.in_ = lsh.out
val_ans.write_en = HI
wr_val.done = val_ans.done
shift_cur_val = comp.rsh_use(cur_val.out, cur_val)
shift_val_build = comp.lsh_use(val_build.out, val_build)
wr_count = comp.lsh_use(counter.out, count_ans, width - int_width)
wr_val = comp.lsh_use(val_build.out, val_ans, width - int_width)

with comp.continuous:
comp.this().count = count_ans.out
Expand All @@ -109,14 +45,14 @@ def gen_msb_calc(width: int, int_width: int) -> List[Component]:
comp.control += [
wr_cur_val,
while_with(
CellAndGroup(neq, cur_val_cond),
cur_val_cond,
[incr_count, shift_cur_val],
),
decr_count,
wr_count,
wr_val_build,
while_with(
CellAndGroup(neq, count_cond),
count_cond,
[decr_count, shift_val_build],
),
wr_val,
Expand Down

0 comments on commit 3db8a0c

Please sign in to comment.