Skip to content

Commit

Permalink
Catch up gen_exp
Browse files Browse the repository at this point in the history
  • Loading branch information
anshumanmohan committed Aug 20, 2023
1 parent f7d83b4 commit 4997e9c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 117 deletions.
62 changes: 36 additions & 26 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ def le(self, size: int, name: str = None, signed: bool = False) -> CellBuilder:
"""Generate a StdLe cell."""
return self.binary("le", size, name, signed)

def rsh(self, size: int, name: str = None, signed: bool = False) -> CellBuilder:
"""Generate a StdRsh cell."""
return self.binary("rsh", size, name, signed)

def logic(self, operation, size: int, name: str = None) -> CellBuilder:
"""Generate a logical operator cell, of the flavor specified in `operation`."""
name = name or self.generate_name(operation)
Expand All @@ -353,10 +357,12 @@ def logic(self, operation, size: int, name: str = None) -> CellBuilder:

def and_(self, size: int, name: str = None) -> CellBuilder:
"""Generate a StdAnd cell."""
name = name or self.generate_name("and")
return self.logic("and", size, name)

def not_(self, size: int, name: str = None) -> CellBuilder:
"""Generate a StdNot cell."""
name = name or self.generate_name("not")
return self.logic("not", size, name)

def pipelined_mult(self, name: str) -> CellBuilder:
Expand Down Expand Up @@ -421,55 +427,55 @@ def binary_use(self, left, right, cell, groupname=None):
cell.right = right
return CellAndGroup(cell, comb_group)

def eq_use(self, left, right, width, cellname=None):
def eq_use(self, left, right, width, signed=False, cellname=None):
"""Inserts wiring into `self` to check if `left` == `right`."""
return self.binary_use(left, right, self.eq(width, cellname))
return self.binary_use(left, right, self.eq(width, cellname, signed))

def neq_use(self, left, right, width, cellname=None):
def neq_use(self, left, right, width, signed=False, cellname=None):
"""Inserts wiring into `self` to check if `left` != `right`."""
return self.binary_use(left, right, self.neq(width, cellname))
return self.binary_use(left, right, self.neq(width, cellname, signed))

def lt_use(self, left, right, width, cellname=None):
def lt_use(self, left, right, width, signed=False, cellname=None):
"""Inserts wiring into `self` to check if `left` < `right`."""
return self.binary_use(left, right, self.lt(width, cellname))
return self.binary_use(left, right, self.lt(width, cellname, signed))

def le_use(self, left, right, width, cellname=None):
def le_use(self, left, right, width, signed=False, cellname=None):
"""Inserts wiring into `self` to check if `left` <= `right`."""
return self.binary_use(left, right, self.le(width, cellname))
return self.binary_use(left, right, self.le(width, cellname, signed))

def ge_use(self, left, right, width, cellname=None):
def ge_use(self, left, right, width, signed=False, cellname=None):
"""Inserts wiring into `self` to check if `left` >= `right`."""
return self.binary_use(left, right, self.ge(width, cellname))
return self.binary_use(left, right, self.ge(width, cellname, signed))

def gt_use(self, left, right, width, cellname=None):
def gt_use(self, left, right, width, signed=False, cellname=None):
"""Inserts wiring into `self` to check if `left` > `right`."""
return self.binary_use(left, right, self.gt(width, cellname))
return self.binary_use(left, right, self.gt(width, cellname, signed))

def add_use(self, left, right, width, cellname=None):
def add_use(self, left, right, width, signed=False, cellname=None):
"""Inserts wiring into `self` to compute `left` + `right`."""
return self.binary_use(left, right, self.add(width, cellname))
return self.binary_use(left, right, self.add(width, cellname, signed))

def sub_use(self, left, right, width, cellname=None):
def sub_use(self, left, right, width, signed=False, cellname=None):
"""Inserts wiring into `self` to compute `left` - `right`."""
return self.binary_use(left, right, self.sub(width, cellname))
return self.binary_use(left, right, self.sub(width, cellname, signed))

def bitwise_flip_reg(self, reg, width, cellname=None):
def bitwise_flip_reg(self, reg, width, signed=False, cellname=None):
"""Inserts wiring into `self` to bitwise-flip the contents of `reg`
and put the result back into `reg`.
"""
cellname = cellname or f"{reg.name}_not"
not_cell = self.not_(width, cellname)
not_cell = self.not_(width, cellname, signed)
with self.group(f"{cellname}_group") as not_group:
not_cell.in_ = reg.out
reg.write_en = 1
reg.in_ = not_cell.out
not_group.done = reg.done
return not_group

def incr(self, reg, width, val=1, cellname=None):
def incr(self, reg, width, val=1, signed=False, cellname=None):
"""Inserts wiring into `self` to perform `reg := reg + val`."""
cellname = cellname or f"{reg.name}_incr"
add_cell = self.add(width, cellname)
add_cell = self.add(width, cellname, signed)
with self.group(f"{cellname}_group") as incr_group:
add_cell.left = reg.out
add_cell.right = const(width, val)
Expand All @@ -478,10 +484,10 @@ def incr(self, reg, width, val=1, cellname=None):
incr_group.done = reg.done
return incr_group

def decr(self, reg, width, val=1, cellname=None):
def decr(self, reg, width, val=1, signed=False, cellname=None):
"""Inserts wiring into `self` to perform `reg := reg - val`."""
cellname = cellname or f"{reg.name}_decr"
sub_cell = self.sub(width, cellname)
sub_cell = self.sub(width, cellname, signed)
with self.group(f"{cellname}_group") as decr_group:
sub_cell.left = reg.out
sub_cell.right = const(width, val)
Expand Down Expand Up @@ -576,9 +582,11 @@ def mem_load_to_mem(self, mem, i, ans, j, groupname=None):
load_grp.done = ans.done
return load_grp

def add_store_in_reg(self, left, right, cellname, width, ans_reg=None):
def add_store_in_reg(
self, left, right, cellname, width, ans_reg=None, signed=False
):
"""Inserts wiring into `self` to perform `reg := left + right`."""
add_cell = self.add(width, cellname)
add_cell = self.add(width, cellname, signed)
ans_reg = ans_reg or self.reg(f"reg_{cellname}", width)
with self.group(f"{cellname}_group") as adder_group:
add_cell.left = left
Expand All @@ -588,9 +596,11 @@ def add_store_in_reg(self, left, right, cellname, width, ans_reg=None):
adder_group.done = ans_reg.done
return adder_group, ans_reg

def sub_store_in_reg(self, left, right, cellname, width, ans_reg=None):
def sub_store_in_reg(
self, left, right, cellname, width, ans_reg=None, signed=False
):
"""Inserts wiring into `self` to perform `reg := left - right`."""
sub_cell = self.sub(width, cellname)
sub_cell = self.sub(width, cellname, signed)
ans_reg = ans_reg or self.reg(f"reg_{cellname}", width)
with self.group(f"{cellname}_group") as sub_group:
sub_cell.left = left
Expand Down
124 changes: 33 additions & 91 deletions calyx-py/calyx/gen_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def generate_fp_pow_component(
"mult_pipe", width, int_width, frac_width, signed=is_signed
),
)
lt = comp.cell("lt", Stdlib.op("lt", width, signed=is_signed))
incr = comp.cell("incr", Stdlib.op("add", width, signed=is_signed))

# groups
Expand All @@ -70,27 +69,14 @@ def generate_fp_pow_component(
pow.in_ = mul.out
execute_mul.done = pow.done

with comp.group("incr_count") as incr_count:
incr.left = const(width, 1)
incr.right = count.out
count.in_ = incr.out
count.write_en = 1
incr_count.done = count.done
incr_count = comp.incr(count, width)

with comp.comb_group("cond") as cond:
lt.left = count.out
lt.right = comp.this().integer_exp
cond = comp.lt_use(count.out, comp.this().integer_exp, width, is_signed)

with comp.continuous:
comp.this().out = pow.out

comp.control += [
init,
while_with(
CellAndGroup(lt, cond),
par(execute_mul, incr_count),
),
]
comp.control += [init, while_with(cond, par(execute_mul, incr_count))]

return comp.component

Expand All @@ -105,11 +91,12 @@ def generate_cells(
comp.reg("int_x", width)
comp.reg("frac_x", width)
comp.reg("m", width)
comp.cell("and0", Stdlib.op("and", width, signed=False))
comp.cell("and1", Stdlib.op("and", width, signed=False))
comp.cell("rsh", Stdlib.op("rsh", width, signed=False))

comp.and_(width, "and0")
comp.and_(width, "and1")
comp.rsh(width, "rsh")
if is_signed:
comp.cell("lt", Stdlib.op("lt", width, signed=is_signed))
comp.lt(width, "lt", is_signed)

# constants
for i in range(2, degree + 1):
Expand Down Expand Up @@ -541,21 +528,6 @@ def gen_reverse_sign(
group.done = base_cell.done


def gen_comb_lt(
comp: ComponentBuilder,
name: str,
lhs: ExprBuilder,
lt: CellBuilder,
const_cell: CellBuilder,
):
"""
Generates lhs < const_cell
"""
with comp.comb_group(name):
lt.left = lhs
lt.right = const_cell.out


# This appears to be unused. Brilliant.
# TODO (griffin): Double check that this is unused and, if so, remove it.
def gen_constant_cell(
Expand Down Expand Up @@ -590,7 +562,8 @@ def generate_fp_pow_full(
comp.input("base", width)
comp.input("exp_value", width)
comp.output("out", width)
lt = comp.cell("lt", Stdlib.op("lt", width, is_signed))
lt = comp.lt(width, "lt", is_signed)

div = comp.cell(
"div",
Stdlib.fixed_point_op(
Expand Down Expand Up @@ -632,7 +605,13 @@ def generate_fp_pow_full(
)
gen_reverse_sign(comp, "rev_base_sign", new_base_reg, mult, const_neg_one),
gen_reverse_sign(comp, "rev_res_sign", res, mult, const_neg_one),
gen_comb_lt(comp, "base_lt_zero", comp.this().base, lt, const_zero),

base_lt_zero = comp.lt_use(
comp.this().base,
const_zero.out,
width,
is_signed,
)

new_exp_val = comp.reg("new_exp_val", width)
e = comp.comp_instance("e", "exp", check_undeclared=False)
Expand All @@ -649,48 +628,29 @@ def generate_fp_pow_full(
with comp.continuous:
comp.this().out = res.out

with comp.group("write_to_base_reg") as write_to_base_reg:
new_base_reg.write_en = 1
new_base_reg.in_ = comp.this().base
write_to_base_reg.done = new_base_reg.done

with comp.group("store_old_reg_val") as store_old_reg_val:
stored_base_reg.write_en = 1
stored_base_reg.in_ = new_base_reg.out
store_old_reg_val.done = stored_base_reg.done

with comp.group("write_e_to_res") as write_e_to_res:
res.write_en = 1
res.in_ = e.out
write_e_to_res.done = res.done
write_to_base_reg = comp.reg_store(
new_base_reg, comp.this().base, "write_to_base_reg"
)
store_old_reg_val = comp.reg_store(
stored_base_reg, new_base_reg.out, "store_old_reg_val"
)
write_e_to_res = comp.reg_store(res, e.out, "write_e_to_res")

gen_reciprocal(comp, "set_base_reciprocal", new_base_reg, div, const_one)
gen_reciprocal(comp, "set_res_reciprocal", res, div, const_one),
gen_comb_lt(
comp,
"base_lt_one",
stored_base_reg.out,
lt,
const_one,
)
base_lt_one = comp.lt_use(stored_base_reg.out, const_one.out, width, is_signed)

base_reciprocal = if_with(
CellAndGroup(lt, comp.get_group("base_lt_one")),
comp.get_group("set_base_reciprocal"),
)
base_reciprocal = if_with(base_lt_one, comp.get_group("set_base_reciprocal"))

res_reciprocal = if_with(
CellAndGroup(lt, comp.get_group("base_lt_one")),
comp.get_group("set_res_reciprocal"),
)
res_reciprocal = if_with(base_lt_one, comp.get_group("set_res_reciprocal"))

if is_signed:
base_rev = if_with(
CellAndGroup(lt, comp.get_group("base_lt_zero")),
base_lt_zero,
comp.get_group("rev_base_sign"),
)
res_rev = if_with(
CellAndGroup(lt, comp.get_group("base_lt_zero")),
base_lt_zero,
comp.get_group("rev_res_sign"),
)
pre_process = [base_rev, store_old_reg_val, base_reciprocal]
Expand Down Expand Up @@ -741,23 +701,9 @@ def build_base_not_e(degree, width, int_width, is_signed) -> Program:
ret = main.mem_d1("ret", width, 1, 1, is_external=True)
f = main.comp_instance("f", "fp_pow_full")

with main.group("read_base") as read_base:
b.addr0 = 0
base_reg.in_ = b.read_data
base_reg.write_en = 1
read_base.done = base_reg.done

with main.group("read_exp") as read_exp:
x.addr0 = 0
exp_reg.in_ = x.read_data
exp_reg.write_en = 1
read_exp.done = exp_reg.done

with main.group("write_to_memory") as write_to_memory:
ret.addr0 = 0
ret.write_en = 1
ret.write_data = f.out
write_to_memory.done = ret.done
read_base = main.mem_load_std_d1(b, 0, base_reg, "read_base")
read_exp = main.mem_load_std_d1(x, 0, exp_reg, "read_exp")
write_to_memory = main.mem_store_std_d1(ret, 0, f.out, "write_to_memory")

main.control += [
read_base,
Expand Down Expand Up @@ -796,11 +742,7 @@ def build_base_is_e(degree, width, int_width, is_signed) -> Program:
t.write_en = 1
init.done = t.done

with main.group("write_to_memory") as write_to_memory:
ret.addr0 = 0
ret.write_en = 1
ret.write_data = e.out
write_to_memory.done = ret.done
write_to_memory = main.mem_store_std_d1(ret, 0, e.out, "write_to_memory")

main.control += [
init,
Expand Down

0 comments on commit 4997e9c

Please sign in to comment.