diff --git a/examples/gemm/gemm_codegen.rs b/examples/gemm/gemm_codegen.rs index 8aa3c76..fd82123 100644 --- a/examples/gemm/gemm_codegen.rs +++ b/examples/gemm/gemm_codegen.rs @@ -25,7 +25,7 @@ fn main() { global_graph.add_nodes(vec![shared_block_node.clone()]); - let global_graph = Rc::new(global_graph); + let global_graph = Rc::new(RefCell::new(global_graph)); let global_block = ThrillerBlock::new( vec![], diff --git a/examples/gemm/global_block.rs b/examples/gemm/global_block.rs index 6b1fd09..20296ac 100644 --- a/examples/gemm/global_block.rs +++ b/examples/gemm/global_block.rs @@ -25,7 +25,7 @@ fn main() { global_graph.add_nodes(vec![shared_block_node.clone()]); - let global_graph = Rc::new(global_graph); + let global_graph = Rc::new(RefCell::new(global_graph)); let global_block = ThrillerBlock::new( vec![], diff --git a/examples/gemm/rf_block.rs b/examples/gemm/rf_block.rs index dd849c0..65a8d33 100644 --- a/examples/gemm/rf_block.rs +++ b/examples/gemm/rf_block.rs @@ -89,7 +89,7 @@ fn main() { vec![Rc::new(in_edge0), Rc::new(in_edge1)], vec![Rc::new(out_edge)], MemoryLevel::Register, - Rc::new(subgraph), + Rc::new(RefCell::new(subgraph)), BlockType::Reduce, ); diff --git a/examples/gemm/shared_block.rs b/examples/gemm/shared_block.rs index ab69099..42f409c 100644 --- a/examples/gemm/shared_block.rs +++ b/examples/gemm/shared_block.rs @@ -87,7 +87,7 @@ fn main() { vec![Rc::new(in_edge0), Rc::new(in_edge1)], vec![Rc::new(out_edge)], MemoryLevel::Shared, - Rc::new(subgraph), + Rc::new(RefCell::new(subgraph)), BlockType::Map, ); diff --git a/examples/loop/multiloops.rs b/examples/loop/multiloops.rs index 0090899..30a0a4e 100644 --- a/examples/loop/multiloops.rs +++ b/examples/loop/multiloops.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::{cell::RefCell, rc::Rc}; use thriller_core::{ initialize, AccessMap, AccessMatrix, AccessOffset, AttachedEdge, BlockType, IterationBound, @@ -57,7 +57,7 @@ fn main() { vec![Rc::new(in_edge0), Rc::new(in_edge1), Rc::new(in_edge2)], vec![], MemoryLevel::Register, - Rc::new(subgraph), + Rc::new(RefCell::new(subgraph)), BlockType::Reduce, ); diff --git a/thriller-core/src/dataflow/block.rs b/thriller-core/src/dataflow/block.rs index 45238ee..52dc673 100644 --- a/thriller-core/src/dataflow/block.rs +++ b/thriller-core/src/dataflow/block.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::rc::Rc; use std::vec::Vec; @@ -26,7 +27,7 @@ pub struct ThrillerBlock { pub(crate) inputs: Vec>, pub(crate) outputs: Vec>, pub(crate) mem_level: MemoryLevel, - pub(crate) subgraph: Rc, + pub(crate) subgraph: Rc>, pub(crate) block_type: BlockType, pub(crate) unified_access_map: Option>, pub(crate) loop_groups: Vec, @@ -38,7 +39,7 @@ impl ThrillerBlock { inputs: Vec>, outputs: Vec>, mem_level: MemoryLevel, - subgraph: Rc, + subgraph: Rc>, block_type: BlockType, ) -> Self { ThrillerBlock { @@ -228,11 +229,11 @@ impl ThrillerBlock { if self.mem_level == MemoryLevel::Shared { inner_code += Sync::emit_copy_async().as_str(); } - inner_code += self.subgraph.emit()?.as_str(); + inner_code += self.subgraph.borrow().emit()?.as_str(); code += access_map.gen_loop_access(inner_code)?.as_str(); code += Sync::emit_sync().as_str(); - if let Some(reduce_outputs) = self.subgraph.reduce_block_outputs() { + if let Some(reduce_outputs) = self.subgraph.borrow().reduce_block_outputs() { // self.outputs.extend(reduce_outputs); for output in reduce_outputs { code += &self.emit_store(&output)?; @@ -244,7 +245,7 @@ impl ThrillerBlock { } else { // TODO: Handle cases without an unified access map. if self.inputs.is_empty() && self.outputs.is_empty() { - let code = self.subgraph.emit()?; + let code = self.subgraph.borrow().emit()?; Ok(code) } else { // unimplemented!(); diff --git a/thriller-core/src/dataflow/node.rs b/thriller-core/src/dataflow/node.rs index f116283..58aafa9 100644 --- a/thriller-core/src/dataflow/node.rs +++ b/thriller-core/src/dataflow/node.rs @@ -68,8 +68,8 @@ impl ThrillerNode { &self.nexts } - #[allow(dead_code)] - pub(crate) fn get_inner(&self) -> &ThrillerNodeInner { + #[doc(hidden)] + pub fn get_inner(&self) -> &ThrillerNodeInner { &self.inner } diff --git a/thriller-utils/src/gemm.rs b/thriller-utils/src/gemm.rs index 50fd072..d76ee23 100644 --- a/thriller-utils/src/gemm.rs +++ b/thriller-utils/src/gemm.rs @@ -91,7 +91,7 @@ impl ThrillerUtils { vec![Rc::new(in_edge0), Rc::new(in_edge1)], vec![Rc::new(out_edge)], MemoryLevel::Register, - Rc::new(subgraph), + Rc::new(RefCell::new(subgraph)), BlockType::Reduce, ); @@ -176,7 +176,7 @@ impl ThrillerUtils { vec![Rc::new(in_edge0), Rc::new(in_edge1)], vec![Rc::new(out_edge)], MemoryLevel::Shared, - Rc::new(subgraph), + Rc::new(RefCell::new(subgraph)), BlockType::Map, );