Skip to content

Commit

Permalink
Add loop analysis module.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed May 26, 2024
1 parent b36aa26 commit c93b524
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 33 deletions.
4 changes: 0 additions & 4 deletions examples/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ fn main() {
);
let gemm_node = Rc::new(ThrillerNode::new(ThrillerNodeInner::Op(Box::new(gemm))));

// let edge_a = Rc::new(ThrillerEdge::new(node_a.clone(), gemm_node.clone()));
// let edge_b = Rc::new(ThrillerEdge::new(node_b.clone(), gemm_node.clone()));
// let edge_acc = Rc::new(ThrillerEdge::new(gemm_node.clone(), node_acc.clone()));

let gemm_code = gemm_node.emit().unwrap();

println!("{}", gemm_code);
Expand Down
1 change: 1 addition & 0 deletions thriller-bindings/pythriller/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .context import initialize_thriller_flow
from .buffer import create_buffer
3 changes: 1 addition & 2 deletions thriller-bindings/pythriller/buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import thriller_flow
from thriller_flow import PyBuffer
from .context import PyBuffer


def create_buffer(name):
Expand Down
1 change: 1 addition & 0 deletions thriller-bindings/pythriller/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from thriller_flow import *
6 changes: 0 additions & 6 deletions thriller-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@ fn initialize_thriller_flow() -> PyResult<()> {
Ok(())
}

// #[pyfunction]
// fn create_buffer(name: String) -> PyResult<PyBuffer> {
// let buffer = PyBuffer(Buffer::new(name.as_str()));
// Ok(buffer)
// }

/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
Expand Down
7 changes: 4 additions & 3 deletions thriller-bindings/tests/test_bindings.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@

import context
import thriller_flow

import pythriller

if __name__ == '__main__':
thriller_flow.initialize_thriller_flow()
pythriller.initialize_thriller_flow()
g_a = pythriller.create_buffer("g_a")
g_b = pythriller.create_buffer("g_b")
g_c = pythriller.create_buffer("g_c")

print(g_a, g_b, g_c)
print(g_a)
print(g_b)
print(g_c)
6 changes: 6 additions & 0 deletions thriller-core/src/access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,9 @@ impl PartialEq for AccessMap {
&& self.access_dims == other.access_dims
}
}

impl Eq for AccessMap {
fn assert_receiver_is_total_eq(&self) {
todo!("Implement this function")
}
}
31 changes: 13 additions & 18 deletions thriller-core/src/dataflow/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub struct ThrillerBlock {
}

impl ThrillerBlock {
/// Create a new ThrillerBlock with the given inputs, outputs, memory level, subgraph, and block type.
/// Create a new [`ThrillerBlock`] with the given inputs, outputs, memory level, subgraph, and block type.
pub fn new(
inputs: Vec<Rc<AttachedEdge>>,
outputs: Vec<Rc<AttachedEdge>>,
Expand All @@ -51,23 +51,6 @@ impl ThrillerBlock {
}
}

// pub(crate) fn get_inputs(&self) -> &[Rc<AttachedEdge>] {
// &self.inputs
// }

// pub(crate) fn get_inner_bufs(&self) -> Vec<String> {
// // Iterate through the inputs and collect all dst buffers.
// let mut bufs = Vec::new();
// for input in self.inputs.iter() {
// // if let Some(buf) = input.get_dst() {
// // bufs.push(buf.clone());
// // }
// let dst_buf = input.get_dst_name().clone();
// bufs.push(dst_buf);
// }
// bufs
// }

/// Get the block type.
pub fn get_block_type(&self) -> BlockType {
self.block_type
Expand All @@ -79,9 +62,21 @@ impl ThrillerBlock {
// If they are the same, then we can merge them into a unified access map.

// TODO: Implement this function.
self.inputs.windows(2).for_each(|window| {
let (first, second) = (&window[0], &window[1]);
assert!(
first.get_access() == second.get_access(),
"Access maps are not the same."
);
});

self.unified_access_map = Some(self.inputs[0].get_access().as_ref().unwrap().clone());
}

pub(crate) fn get_inputs(&self) -> &Vec<Rc<AttachedEdge>> {
&self.inputs
}

pub(crate) fn gen_loop_load(&self) -> ThrillerResult<String> {
let mut code = String::new();

Expand Down
44 changes: 44 additions & 0 deletions thriller-core/src/dataflow/loop_analysis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use std::collections::HashSet;
use std::hash::Hash;
use std::rc::Rc;

use crate::AccessMap;

use super::block::ThrillerBlock;

pub struct AccessMapPtr<'a>(&'a Rc<AccessMap>);

impl<'a> Hash for AccessMapPtr<'a> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
state.write_usize(Rc::as_ptr(self.0) as usize)
}
}

impl ThrillerBlock {
#[allow(dead_code)]
pub(crate) fn merge_loop(&self) -> HashSet<AccessMapPtr> {
let mut sets = HashSet::new();

self.get_inputs().iter().for_each(|edge| {
if let Some(access_map) = edge.get_access() {
let ptr = AccessMapPtr(access_map);
sets.insert(ptr);
}
});

sets
}
}

impl<'a> PartialEq for AccessMapPtr<'a> {
fn eq(&self, other: &Self) -> bool {
// TODO: Implement this function.s
Rc::ptr_eq(self.0, other.0)
}
}

impl<'a> Eq for AccessMapPtr<'a> {
fn assert_receiver_is_total_eq(&self) {
todo!("Implement this function")
}
}
1 change: 1 addition & 0 deletions thriller-core/src/dataflow/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod block;
mod edge;
mod graph;
mod loop_analysis;
mod node;

pub use block::{BlockType, ThrillerBlock};
Expand Down
1 change: 1 addition & 0 deletions thriller-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#![deny(warnings)]
#![deny(missing_docs)]

mod access;
mod buffer;
mod dataflow;
Expand Down

0 comments on commit c93b524

Please sign in to comment.