diff --git a/examples/gemm.rs b/examples/gemm.rs index ec07f64..77ecf93 100644 --- a/examples/gemm.rs +++ b/examples/gemm.rs @@ -42,10 +42,6 @@ fn main() { ); let gemm_node = Rc::new(ThrillerNode::new(ThrillerNodeInner::Op(Box::new(gemm)))); - // let edge_a = Rc::new(ThrillerEdge::new(node_a.clone(), gemm_node.clone())); - // let edge_b = Rc::new(ThrillerEdge::new(node_b.clone(), gemm_node.clone())); - // let edge_acc = Rc::new(ThrillerEdge::new(gemm_node.clone(), node_acc.clone())); - let gemm_code = gemm_node.emit().unwrap(); println!("{}", gemm_code); diff --git a/thriller-bindings/pythriller/__init__.py b/thriller-bindings/pythriller/__init__.py index 1a398ba..32cad76 100644 --- a/thriller-bindings/pythriller/__init__.py +++ b/thriller-bindings/pythriller/__init__.py @@ -1 +1,2 @@ +from .context import initialize_thriller_flow from .buffer import create_buffer diff --git a/thriller-bindings/pythriller/buffer.py b/thriller-bindings/pythriller/buffer.py index f6c5984..f5760da 100644 --- a/thriller-bindings/pythriller/buffer.py +++ b/thriller-bindings/pythriller/buffer.py @@ -1,5 +1,4 @@ -import thriller_flow -from thriller_flow import PyBuffer +from .context import PyBuffer def create_buffer(name): diff --git a/thriller-bindings/pythriller/context.py b/thriller-bindings/pythriller/context.py new file mode 100644 index 0000000..2e4cf7f --- /dev/null +++ b/thriller-bindings/pythriller/context.py @@ -0,0 +1 @@ +from thriller_flow import * diff --git a/thriller-bindings/src/lib.rs b/thriller-bindings/src/lib.rs index c2f21f7..c37720a 100644 --- a/thriller-bindings/src/lib.rs +++ b/thriller-bindings/src/lib.rs @@ -11,12 +11,6 @@ fn initialize_thriller_flow() -> PyResult<()> { Ok(()) } -// #[pyfunction] -// fn create_buffer(name: String) -> PyResult { -// let buffer = PyBuffer(Buffer::new(name.as_str())); -// Ok(buffer) -// } - /// A Python module implemented in Rust. The name of this function must match /// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to /// import the module. diff --git a/thriller-bindings/tests/test_bindings.py b/thriller-bindings/tests/test_bindings.py index 4952536..1f682bc 100644 --- a/thriller-bindings/tests/test_bindings.py +++ b/thriller-bindings/tests/test_bindings.py @@ -1,13 +1,14 @@ import context -import thriller_flow import pythriller if __name__ == '__main__': - thriller_flow.initialize_thriller_flow() + pythriller.initialize_thriller_flow() g_a = pythriller.create_buffer("g_a") g_b = pythriller.create_buffer("g_b") g_c = pythriller.create_buffer("g_c") - print(g_a, g_b, g_c) + print(g_a) + print(g_b) + print(g_c) diff --git a/thriller-core/src/access.rs b/thriller-core/src/access.rs index 681193f..f9b7ae6 100644 --- a/thriller-core/src/access.rs +++ b/thriller-core/src/access.rs @@ -132,3 +132,9 @@ impl PartialEq for AccessMap { && self.access_dims == other.access_dims } } + +impl Eq for AccessMap { + fn assert_receiver_is_total_eq(&self) { + todo!("Implement this function") + } +} diff --git a/thriller-core/src/dataflow/block.rs b/thriller-core/src/dataflow/block.rs index b0dc840..6ef7d3d 100644 --- a/thriller-core/src/dataflow/block.rs +++ b/thriller-core/src/dataflow/block.rs @@ -32,7 +32,7 @@ pub struct ThrillerBlock { } impl ThrillerBlock { - /// Create a new ThrillerBlock with the given inputs, outputs, memory level, subgraph, and block type. + /// Create a new [`ThrillerBlock`] with the given inputs, outputs, memory level, subgraph, and block type. pub fn new( inputs: Vec>, outputs: Vec>, @@ -51,23 +51,6 @@ impl ThrillerBlock { } } - // pub(crate) fn get_inputs(&self) -> &[Rc] { - // &self.inputs - // } - - // pub(crate) fn get_inner_bufs(&self) -> Vec { - // // Iterate through the inputs and collect all dst buffers. - // let mut bufs = Vec::new(); - // for input in self.inputs.iter() { - // // if let Some(buf) = input.get_dst() { - // // bufs.push(buf.clone()); - // // } - // let dst_buf = input.get_dst_name().clone(); - // bufs.push(dst_buf); - // } - // bufs - // } - /// Get the block type. pub fn get_block_type(&self) -> BlockType { self.block_type @@ -79,9 +62,21 @@ impl ThrillerBlock { // If they are the same, then we can merge them into a unified access map. // TODO: Implement this function. + self.inputs.windows(2).for_each(|window| { + let (first, second) = (&window[0], &window[1]); + assert!( + first.get_access() == second.get_access(), + "Access maps are not the same." + ); + }); + self.unified_access_map = Some(self.inputs[0].get_access().as_ref().unwrap().clone()); } + pub(crate) fn get_inputs(&self) -> &Vec> { + &self.inputs + } + pub(crate) fn gen_loop_load(&self) -> ThrillerResult { let mut code = String::new(); diff --git a/thriller-core/src/dataflow/loop_analysis.rs b/thriller-core/src/dataflow/loop_analysis.rs new file mode 100644 index 0000000..ac8086c --- /dev/null +++ b/thriller-core/src/dataflow/loop_analysis.rs @@ -0,0 +1,44 @@ +use std::collections::HashSet; +use std::hash::Hash; +use std::rc::Rc; + +use crate::AccessMap; + +use super::block::ThrillerBlock; + +pub struct AccessMapPtr<'a>(&'a Rc); + +impl<'a> Hash for AccessMapPtr<'a> { + fn hash(&self, state: &mut H) { + state.write_usize(Rc::as_ptr(self.0) as usize) + } +} + +impl ThrillerBlock { + #[allow(dead_code)] + pub(crate) fn merge_loop(&self) -> HashSet { + let mut sets = HashSet::new(); + + self.get_inputs().iter().for_each(|edge| { + if let Some(access_map) = edge.get_access() { + let ptr = AccessMapPtr(access_map); + sets.insert(ptr); + } + }); + + sets + } +} + +impl<'a> PartialEq for AccessMapPtr<'a> { + fn eq(&self, other: &Self) -> bool { + // TODO: Implement this function.s + Rc::ptr_eq(self.0, other.0) + } +} + +impl<'a> Eq for AccessMapPtr<'a> { + fn assert_receiver_is_total_eq(&self) { + todo!("Implement this function") + } +} diff --git a/thriller-core/src/dataflow/mod.rs b/thriller-core/src/dataflow/mod.rs index e84073a..2276d67 100644 --- a/thriller-core/src/dataflow/mod.rs +++ b/thriller-core/src/dataflow/mod.rs @@ -1,6 +1,7 @@ mod block; mod edge; mod graph; +mod loop_analysis; mod node; pub use block::{BlockType, ThrillerBlock}; diff --git a/thriller-core/src/lib.rs b/thriller-core/src/lib.rs index e126238..a5c88bb 100644 --- a/thriller-core/src/lib.rs +++ b/thriller-core/src/lib.rs @@ -2,6 +2,7 @@ #![deny(warnings)] #![deny(missing_docs)] + mod access; mod buffer; mod dataflow;