Skip to content

Commit

Permalink
Revert "Catch up gen_exp"
Browse files Browse the repository at this point in the history
This reverts commit 4997e9c.
  • Loading branch information
anshumanmohan committed Aug 20, 2023
1 parent 4997e9c commit 69deb30
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 69 deletions.
62 changes: 26 additions & 36 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,6 @@ def le(self, size: int, name: str = None, signed: bool = False) -> CellBuilder:
"""Generate a StdLe cell."""
return self.binary("le", size, name, signed)

def rsh(self, size: int, name: str = None, signed: bool = False) -> CellBuilder:
"""Generate a StdRsh cell."""
return self.binary("rsh", size, name, signed)

def logic(self, operation, size: int, name: str = None) -> CellBuilder:
"""Generate a logical operator cell, of the flavor specified in `operation`."""
name = name or self.generate_name(operation)
Expand All @@ -357,12 +353,10 @@ def logic(self, operation, size: int, name: str = None) -> CellBuilder:

def and_(self, size: int, name: str = None) -> CellBuilder:
"""Generate a StdAnd cell."""
name = name or self.generate_name("and")
return self.logic("and", size, name)

def not_(self, size: int, name: str = None) -> CellBuilder:
"""Generate a StdNot cell."""
name = name or self.generate_name("not")
return self.logic("not", size, name)

def pipelined_mult(self, name: str) -> CellBuilder:
Expand Down Expand Up @@ -427,55 +421,55 @@ def binary_use(self, left, right, cell, groupname=None):
cell.right = right
return CellAndGroup(cell, comb_group)

def eq_use(self, left, right, width, signed=False, cellname=None):
def eq_use(self, left, right, width, cellname=None):
"""Inserts wiring into `self` to check if `left` == `right`."""
return self.binary_use(left, right, self.eq(width, cellname, signed))
return self.binary_use(left, right, self.eq(width, cellname))

def neq_use(self, left, right, width, signed=False, cellname=None):
def neq_use(self, left, right, width, cellname=None):
"""Inserts wiring into `self` to check if `left` != `right`."""
return self.binary_use(left, right, self.neq(width, cellname, signed))
return self.binary_use(left, right, self.neq(width, cellname))

def lt_use(self, left, right, width, signed=False, cellname=None):
def lt_use(self, left, right, width, cellname=None):
"""Inserts wiring into `self` to check if `left` < `right`."""
return self.binary_use(left, right, self.lt(width, cellname, signed))
return self.binary_use(left, right, self.lt(width, cellname))

def le_use(self, left, right, width, signed=False, cellname=None):
def le_use(self, left, right, width, cellname=None):
"""Inserts wiring into `self` to check if `left` <= `right`."""
return self.binary_use(left, right, self.le(width, cellname, signed))
return self.binary_use(left, right, self.le(width, cellname))

def ge_use(self, left, right, width, signed=False, cellname=None):
def ge_use(self, left, right, width, cellname=None):
"""Inserts wiring into `self` to check if `left` >= `right`."""
return self.binary_use(left, right, self.ge(width, cellname, signed))
return self.binary_use(left, right, self.ge(width, cellname))

def gt_use(self, left, right, width, signed=False, cellname=None):
def gt_use(self, left, right, width, cellname=None):
"""Inserts wiring into `self` to check if `left` > `right`."""
return self.binary_use(left, right, self.gt(width, cellname, signed))
return self.binary_use(left, right, self.gt(width, cellname))

def add_use(self, left, right, width, signed=False, cellname=None):
def add_use(self, left, right, width, cellname=None):
"""Inserts wiring into `self` to compute `left` + `right`."""
return self.binary_use(left, right, self.add(width, cellname, signed))
return self.binary_use(left, right, self.add(width, cellname))

def sub_use(self, left, right, width, signed=False, cellname=None):
def sub_use(self, left, right, width, cellname=None):
"""Inserts wiring into `self` to compute `left` - `right`."""
return self.binary_use(left, right, self.sub(width, cellname, signed))
return self.binary_use(left, right, self.sub(width, cellname))

def bitwise_flip_reg(self, reg, width, signed=False, cellname=None):
def bitwise_flip_reg(self, reg, width, cellname=None):
"""Inserts wiring into `self` to bitwise-flip the contents of `reg`
and put the result back into `reg`.
"""
cellname = cellname or f"{reg.name}_not"
not_cell = self.not_(width, cellname, signed)
not_cell = self.not_(width, cellname)
with self.group(f"{cellname}_group") as not_group:
not_cell.in_ = reg.out
reg.write_en = 1
reg.in_ = not_cell.out
not_group.done = reg.done
return not_group

def incr(self, reg, width, val=1, signed=False, cellname=None):
def incr(self, reg, width, val=1, cellname=None):
"""Inserts wiring into `self` to perform `reg := reg + val`."""
cellname = cellname or f"{reg.name}_incr"
add_cell = self.add(width, cellname, signed)
add_cell = self.add(width, cellname)
with self.group(f"{cellname}_group") as incr_group:
add_cell.left = reg.out
add_cell.right = const(width, val)
Expand All @@ -484,10 +478,10 @@ def incr(self, reg, width, val=1, signed=False, cellname=None):
incr_group.done = reg.done
return incr_group

def decr(self, reg, width, val=1, signed=False, cellname=None):
def decr(self, reg, width, val=1, cellname=None):
"""Inserts wiring into `self` to perform `reg := reg - val`."""
cellname = cellname or f"{reg.name}_decr"
sub_cell = self.sub(width, cellname, signed)
sub_cell = self.sub(width, cellname)
with self.group(f"{cellname}_group") as decr_group:
sub_cell.left = reg.out
sub_cell.right = const(width, val)
Expand Down Expand Up @@ -582,11 +576,9 @@ def mem_load_to_mem(self, mem, i, ans, j, groupname=None):
load_grp.done = ans.done
return load_grp

def add_store_in_reg(
self, left, right, cellname, width, ans_reg=None, signed=False
):
def add_store_in_reg(self, left, right, cellname, width, ans_reg=None):
"""Inserts wiring into `self` to perform `reg := left + right`."""
add_cell = self.add(width, cellname, signed)
add_cell = self.add(width, cellname)
ans_reg = ans_reg or self.reg(f"reg_{cellname}", width)
with self.group(f"{cellname}_group") as adder_group:
add_cell.left = left
Expand All @@ -596,11 +588,9 @@ def add_store_in_reg(
adder_group.done = ans_reg.done
return adder_group, ans_reg

def sub_store_in_reg(
self, left, right, cellname, width, ans_reg=None, signed=False
):
def sub_store_in_reg(self, left, right, cellname, width, ans_reg=None):
"""Inserts wiring into `self` to perform `reg := left - right`."""
sub_cell = self.sub(width, cellname, signed)
sub_cell = self.sub(width, cellname)
ans_reg = ans_reg or self.reg(f"reg_{cellname}", width)
with self.group(f"{cellname}_group") as sub_group:
sub_cell.left = left
Expand Down
124 changes: 91 additions & 33 deletions calyx-py/calyx/gen_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def generate_fp_pow_component(
"mult_pipe", width, int_width, frac_width, signed=is_signed
),
)
lt = comp.cell("lt", Stdlib.op("lt", width, signed=is_signed))
incr = comp.cell("incr", Stdlib.op("add", width, signed=is_signed))

# groups
Expand All @@ -69,14 +70,27 @@ def generate_fp_pow_component(
pow.in_ = mul.out
execute_mul.done = pow.done

incr_count = comp.incr(count, width)
with comp.group("incr_count") as incr_count:
incr.left = const(width, 1)
incr.right = count.out
count.in_ = incr.out
count.write_en = 1
incr_count.done = count.done

cond = comp.lt_use(count.out, comp.this().integer_exp, width, is_signed)
with comp.comb_group("cond") as cond:
lt.left = count.out
lt.right = comp.this().integer_exp

with comp.continuous:
comp.this().out = pow.out

comp.control += [init, while_with(cond, par(execute_mul, incr_count))]
comp.control += [
init,
while_with(
CellAndGroup(lt, cond),
par(execute_mul, incr_count),
),
]

return comp.component

Expand All @@ -91,12 +105,11 @@ def generate_cells(
comp.reg("int_x", width)
comp.reg("frac_x", width)
comp.reg("m", width)

comp.and_(width, "and0")
comp.and_(width, "and1")
comp.rsh(width, "rsh")
comp.cell("and0", Stdlib.op("and", width, signed=False))
comp.cell("and1", Stdlib.op("and", width, signed=False))
comp.cell("rsh", Stdlib.op("rsh", width, signed=False))
if is_signed:
comp.lt(width, "lt", is_signed)
comp.cell("lt", Stdlib.op("lt", width, signed=is_signed))

# constants
for i in range(2, degree + 1):
Expand Down Expand Up @@ -528,6 +541,21 @@ def gen_reverse_sign(
group.done = base_cell.done


def gen_comb_lt(
comp: ComponentBuilder,
name: str,
lhs: ExprBuilder,
lt: CellBuilder,
const_cell: CellBuilder,
):
"""
Generates lhs < const_cell
"""
with comp.comb_group(name):
lt.left = lhs
lt.right = const_cell.out


# This appears to be unused. Brilliant.
# TODO (griffin): Double check that this is unused and, if so, remove it.
def gen_constant_cell(
Expand Down Expand Up @@ -562,8 +590,7 @@ def generate_fp_pow_full(
comp.input("base", width)
comp.input("exp_value", width)
comp.output("out", width)
lt = comp.lt(width, "lt", is_signed)

lt = comp.cell("lt", Stdlib.op("lt", width, is_signed))
div = comp.cell(
"div",
Stdlib.fixed_point_op(
Expand Down Expand Up @@ -605,13 +632,7 @@ def generate_fp_pow_full(
)
gen_reverse_sign(comp, "rev_base_sign", new_base_reg, mult, const_neg_one),
gen_reverse_sign(comp, "rev_res_sign", res, mult, const_neg_one),

base_lt_zero = comp.lt_use(
comp.this().base,
const_zero.out,
width,
is_signed,
)
gen_comb_lt(comp, "base_lt_zero", comp.this().base, lt, const_zero),

new_exp_val = comp.reg("new_exp_val", width)
e = comp.comp_instance("e", "exp", check_undeclared=False)
Expand All @@ -628,29 +649,48 @@ def generate_fp_pow_full(
with comp.continuous:
comp.this().out = res.out

write_to_base_reg = comp.reg_store(
new_base_reg, comp.this().base, "write_to_base_reg"
)
store_old_reg_val = comp.reg_store(
stored_base_reg, new_base_reg.out, "store_old_reg_val"
)
write_e_to_res = comp.reg_store(res, e.out, "write_e_to_res")
with comp.group("write_to_base_reg") as write_to_base_reg:
new_base_reg.write_en = 1
new_base_reg.in_ = comp.this().base
write_to_base_reg.done = new_base_reg.done

with comp.group("store_old_reg_val") as store_old_reg_val:
stored_base_reg.write_en = 1
stored_base_reg.in_ = new_base_reg.out
store_old_reg_val.done = stored_base_reg.done

with comp.group("write_e_to_res") as write_e_to_res:
res.write_en = 1
res.in_ = e.out
write_e_to_res.done = res.done

gen_reciprocal(comp, "set_base_reciprocal", new_base_reg, div, const_one)
gen_reciprocal(comp, "set_res_reciprocal", res, div, const_one),
base_lt_one = comp.lt_use(stored_base_reg.out, const_one.out, width, is_signed)
gen_comb_lt(
comp,
"base_lt_one",
stored_base_reg.out,
lt,
const_one,
)

base_reciprocal = if_with(base_lt_one, comp.get_group("set_base_reciprocal"))
base_reciprocal = if_with(
CellAndGroup(lt, comp.get_group("base_lt_one")),
comp.get_group("set_base_reciprocal"),
)

res_reciprocal = if_with(base_lt_one, comp.get_group("set_res_reciprocal"))
res_reciprocal = if_with(
CellAndGroup(lt, comp.get_group("base_lt_one")),
comp.get_group("set_res_reciprocal"),
)

if is_signed:
base_rev = if_with(
base_lt_zero,
CellAndGroup(lt, comp.get_group("base_lt_zero")),
comp.get_group("rev_base_sign"),
)
res_rev = if_with(
base_lt_zero,
CellAndGroup(lt, comp.get_group("base_lt_zero")),
comp.get_group("rev_res_sign"),
)
pre_process = [base_rev, store_old_reg_val, base_reciprocal]
Expand Down Expand Up @@ -701,9 +741,23 @@ def build_base_not_e(degree, width, int_width, is_signed) -> Program:
ret = main.mem_d1("ret", width, 1, 1, is_external=True)
f = main.comp_instance("f", "fp_pow_full")

read_base = main.mem_load_std_d1(b, 0, base_reg, "read_base")
read_exp = main.mem_load_std_d1(x, 0, exp_reg, "read_exp")
write_to_memory = main.mem_store_std_d1(ret, 0, f.out, "write_to_memory")
with main.group("read_base") as read_base:
b.addr0 = 0
base_reg.in_ = b.read_data
base_reg.write_en = 1
read_base.done = base_reg.done

with main.group("read_exp") as read_exp:
x.addr0 = 0
exp_reg.in_ = x.read_data
exp_reg.write_en = 1
read_exp.done = exp_reg.done

with main.group("write_to_memory") as write_to_memory:
ret.addr0 = 0
ret.write_en = 1
ret.write_data = f.out
write_to_memory.done = ret.done

main.control += [
read_base,
Expand Down Expand Up @@ -742,7 +796,11 @@ def build_base_is_e(degree, width, int_width, is_signed) -> Program:
t.write_en = 1
init.done = t.done

write_to_memory = main.mem_store_std_d1(ret, 0, e.out, "write_to_memory")
with main.group("write_to_memory") as write_to_memory:
ret.addr0 = 0
ret.write_en = 1
ret.write_data = e.out
write_to_memory.done = ret.done

main.control += [
init,
Expand Down

0 comments on commit 69deb30

Please sign in to comment.