From d81662db24161f2c02ad5af9e21e8bc52d34e46d Mon Sep 17 00:00:00 2001 From: Ayaka Yorihiro <36107281+ayakayorihiro@users.noreply.github.com> Date: Wed, 23 Oct 2024 11:19:47 -0400 Subject: [PATCH] Prevent dead-assignment-removal for protected cells (#2312) --- calyx-ir/src/structure.rs | 12 +++- .../src/passes/dead_assignment_removal.rs | 5 +- .../dead-assign-removal/protected.expect | 47 ++++++++++++++++ .../dead-assign-removal/protected.futil | 55 +++++++++++++++++++ 4 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 tests/passes/dead-assign-removal/protected.expect create mode 100644 tests/passes/dead-assign-removal/protected.futil diff --git a/calyx-ir/src/structure.rs b/calyx-ir/src/structure.rs index cbf0d1d1b5..2876f2236b 100644 --- a/calyx-ir/src/structure.rs +++ b/calyx-ir/src/structure.rs @@ -6,7 +6,7 @@ use crate::Nothing; use super::{ Attributes, Direction, GetAttributes, Guard, Id, PortDef, RRC, WRC, }; -use calyx_frontend::Attribute; +use calyx_frontend::{Attribute, BoolAttr}; use calyx_utils::{CalyxResult, Error, GetName}; use itertools::Itertools; use smallvec::{smallvec, SmallVec}; @@ -103,6 +103,16 @@ impl Port { } } + /// Checks if parent is a protected cell + pub fn parent_is_protected(&self) -> bool { + match &self.parent { + PortParent::Cell(cell) => { + cell.upgrade().borrow().attributes.has(BoolAttr::Protected) + } + _ => false, + } + } + /// Get the canonical representation for this Port. pub fn canonical(&self) -> Canonical { Canonical { diff --git a/calyx-opt/src/passes/dead_assignment_removal.rs b/calyx-opt/src/passes/dead_assignment_removal.rs index 51953ba072..b23c6843ea 100644 --- a/calyx-opt/src/passes/dead_assignment_removal.rs +++ b/calyx-opt/src/passes/dead_assignment_removal.rs @@ -130,8 +130,9 @@ impl Visitor for DeadAssignmentRemoval { gr.borrow_mut().assignments.retain(|assign| { let dst = assign.dst.borrow(); - // if dst is a combinational component, must be used - if dst.parent_is_comb() { + // if dst is a combinational component that is not protected, + // the assignment is removed if it is not used + if dst.parent_is_comb() && !dst.parent_is_protected() { return used_combs.contains(&dst.get_parent_name()); } // Make sure that the assignment's guard it not false diff --git a/tests/passes/dead-assign-removal/protected.expect b/tests/passes/dead-assign-removal/protected.expect new file mode 100644 index 0000000000..5d9bd6ac74 --- /dev/null +++ b/tests/passes/dead-assign-removal/protected.expect @@ -0,0 +1,47 @@ +import "primitives/core.futil"; +import "primitives/memories/comb.futil"; +primitive std_protected_wire[WIDTH](in: WIDTH) -> (out: WIDTH) { + assign out = in; +} +component main(@go go: 1, @clk clk: 1, @reset reset: 1) -> (@done done: 1) { + cells { + @external i = comb_mem_d1(32, 1, 1); + @protected cond_inst = std_wire(1); + @protected incr_inst = std_wire(1); + lt = std_lt(32); + lt_reg = std_reg(1); + add = std_add(32); + } + wires { + group cond { + cond_inst.in = 1'd1; + lt_reg.write_en = 1'd1; + lt.right = 32'd8; + i.addr0 = 1'd0; + lt.left = i.read_data; + lt_reg.in = lt.out; + cond[done] = lt_reg.done; + } + group incr { + incr_inst.in = 1'd1; + i.write_en = 1'd1; + i.addr0 = 1'd0; + add.left = 32'd1; + add.right = i.read_data; + i.write_data = add.out; + incr[done] = i.done; + } + } + control { + seq { + cond; + while lt_reg.out { + seq { + incr; + incr; + cond; + } + } + } + } +} diff --git a/tests/passes/dead-assign-removal/protected.futil b/tests/passes/dead-assign-removal/protected.futil new file mode 100644 index 0000000000..0f7650805c --- /dev/null +++ b/tests/passes/dead-assign-removal/protected.futil @@ -0,0 +1,55 @@ +// -p validate -p dead-assign-removal + +import "primitives/core.futil"; +import "primitives/memories/comb.futil"; + +/// Wire for instrumentation +primitive std_protected_wire[WIDTH](in: WIDTH) -> (out: WIDTH) { + assign out = in; +} + +component main() -> () { + cells { + @external(1) i = comb_mem_d1(32, 1, 1); + @protected cond_inst = std_wire(1); + @protected incr_inst = std_wire(1); + lt = std_lt(32); + lt_reg = std_reg(1); + add = std_add(32); + } + + wires { + group cond { + i.addr0 = 1'd0; + lt.left = i.read_data; + lt.right = 32'd8; + lt_reg.in = lt.out; + lt_reg.write_en = 1'b1; + cond_inst.in = 1'b1; + cond[done] = lt_reg.done; + } + + group incr { + add.right = i.read_data; + add.left = 32'd1; + i.write_data = add.out; + i.addr0 = 1'd0; + i.write_en = 1'b1; + incr_inst.in = 1'b1; + incr[done] = i.done; + } + } + + control { + seq { + cond; + while lt_reg.out { + seq { + incr; + incr; + cond; + } + } + } + } +}