Skip to content

Commit

Permalink
feat(example): Add a python example to codegen whole gemm. (#20)
Browse files Browse the repository at this point in the history
* Add Rust crate docs.

* fix Python import.

* Add Shared to Register codegen.

* Add shared to register codegen

* Add a python example to codegen whole gemm.

* fix pynode buffer.

* Add docs.

* Add TODO comments.

* Delete pycache.

* fix gitignore.

* Add empty line.

* chore: fix PyNode tensor method.

* chore: fix emit_loop method.

* chore: follow code reviews.
  • Loading branch information
KuangjuX authored Sep 20, 2024
1 parent 39f95e0 commit 5b0977d
Show file tree
Hide file tree
Showing 13 changed files with 279 additions and 66 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
TiledCUDA/
**/__pycache__/
3 changes: 2 additions & 1 deletion thriller-bindings/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.env/
tests/__pycache__
pythriller/__pycache__
pythriller/__pycache__
*/__pycache__
Binary file not shown.
74 changes: 37 additions & 37 deletions thriller-bindings/examples/gemm/gemm_g2r.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import context

import pythriller
from pythriller import initialize_thriller_flow
from pythriller import Tensor, Layout, TensorType, Graph, Node, Edge
from pythriller import AttachedEdge, Block, IterationVar, AccessMap

if __name__ == '__main__':
pythriller.initialize_thriller_flow()
initialize_thriller_flow()

LayoutA = pythriller.PyLayout.RowMajor
LayoutB = pythriller.PyLayout.RowMajor
LayoutC = pythriller.PyLayout.RowMajor
LayoutA = Layout.RowMajor
LayoutB = Layout.RowMajor
LayoutC = Layout.RowMajor

GlobalLayoutA = pythriller.PyLayout.RowMajor
GlobalLayoutB = pythriller.PyLayout.ColMajor
GlobalLayoutC = pythriller.PyLayout.RowMajor
GlobalLayoutA = Layout.RowMajor
GlobalLayoutB = Layout.ColMajor
GlobalLayoutC = Layout.RowMajor

BufTypeA = pythriller.PyBufType.RegTile
BufTypeB = pythriller.PyBufType.RegTile
BufTypeC = pythriller.PyBufType.RegTile
BufTypeA = TensorType.RegTile
BufTypeB = TensorType.RegTile
BufTypeC = TensorType.RegTile

GlobalTypeA = pythriller.PyBufType.GlobalTile
GlobalTypeB = pythriller.PyBufType.GlobalTile
GlobalTypeC = pythriller.PyBufType.GlobalTile
GlobalTypeA = TensorType.GlobalTile
GlobalTypeB = TensorType.GlobalTile
GlobalTypeC = TensorType.GlobalTile

DimA = [64, 64]
DimB = [64, 64]
Expand All @@ -29,13 +31,13 @@
GlobalDimB = [256, 256]
GlobalDimC = [256, 256]

rA = pythriller.PyBuffer("rA", DimA, LayoutA, BufTypeA)
rB = pythriller.PyBuffer("rB", DimB, LayoutB, BufTypeB)
acc = pythriller.PyBuffer("acc", DimC, LayoutC, BufTypeC)
rA = Tensor("rA", DimA, LayoutA, BufTypeA)
rB = Tensor("rB", DimB, LayoutB, BufTypeB)
acc = Tensor("acc", DimC, LayoutC, BufTypeC)

gA = pythriller.PyBuffer("gA", GlobalDimA, GlobalLayoutA, GlobalTypeA)
gB = pythriller.PyBuffer("gB", GlobalDimB, GlobalLayoutB, GlobalTypeB)
gC = pythriller.PyBuffer("gC", GlobalDimC, GlobalLayoutC, GlobalTypeC)
gA = Tensor("gA", GlobalDimA, GlobalLayoutA, GlobalTypeA)
gB = Tensor("gB", GlobalDimB, GlobalLayoutB, GlobalTypeB)
gC = Tensor("gC", GlobalDimC, GlobalLayoutC, GlobalTypeC)

print(rA)
print(rB)
Expand All @@ -45,37 +47,35 @@
print(gB)
print(gC)

MemoryLevel = pythriller.PyMemoryLevel.Register
RegGraph = pythriller.PyGraph()
RegGraph = Graph()

NodeA = pythriller.PyNode(rA)
NodeB = pythriller.PyNode(rB)
NodeAcc = pythriller.PyNode(acc)
NodeA = Node.tensor(rA)
NodeB = Node.tensor(rB)
NodeAcc = Node.tensor(acc)

GemmNode = pythriller.PyNode.gemm(NodeA, NodeB, NodeAcc)
GemmNode = Node.gemm(NodeA, NodeB, NodeAcc)

LoopIter = pythriller.IterationVar('i', (0, 4))
LoopIter = IterationVar('i', (0, 4))

access_dims = [1]

AccessMap = pythriller.AccessMap(
AccessMap = AccessMap(
access_dims, [[[1]], [[0]]], [[0], [10]], [LoopIter])

EdgeA_Gemm = pythriller.PyEdge(NodeA, GemmNode)
EdgeB_GEMM = pythriller.PyEdge(NodeB, GemmNode)
EdgeGemm_Acc = pythriller.PyEdge(GemmNode, NodeAcc)
EdgeA_Gemm = Edge(NodeA, GemmNode)
EdgeB_Gemm = Edge(NodeB, GemmNode)
EdgeGemm_Acc = Edge(GemmNode, NodeAcc)

RegGraph.add_nodes([NodeA, NodeB, NodeAcc, GemmNode])
RegGraph.add_edges([EdgeA_Gemm, EdgeB_GEMM, EdgeGemm_Acc])
RegGraph.add_edges([EdgeA_Gemm, EdgeB_Gemm, EdgeGemm_Acc])

RegGraph.connect()

LoadGlobalToRegEdgeA = pythriller.AttachedEdge(gA, rA, AccessMap)
LoadGlobalToRegEdgeB = pythriller.AttachedEdge(gB, rB, AccessMap)
StoreRegToGlobalEdgeC = pythriller.AttachedEdge(acc, gC, AccessMap)
G2RBlockMemLevel = pythriller.PyMemoryLevel.Register
LoadGlobalToRegEdgeA = AttachedEdge(gA, rA, AccessMap)
LoadGlobalToRegEdgeB = AttachedEdge(gB, rB, AccessMap)
StoreRegToGlobalEdgeC = AttachedEdge(acc, gC, AccessMap)

GlobalToRegBlock = pythriller.Block(
GlobalToRegBlock = Block(
[LoadGlobalToRegEdgeA, LoadGlobalToRegEdgeB], [StoreRegToGlobalEdgeC], RegGraph, [LoopIter])

code = GlobalToRegBlock.codegen()
Expand Down
163 changes: 163 additions & 0 deletions thriller-bindings/examples/gemm/whole_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
'''
Whole GEMM is an example of GEMM that utilizes all memory hierarchies
of NVIDIA GPU.
'''
import context

from pythriller import initialize_thriller_flow, Layout, Tensor, TensorType
from pythriller import Graph, Node, Edge, AttachedEdge, IterationVar, AccessMap
from pythriller import Block


if __name__ == '__main__':
# Initialize runtime.
initialize_thriller_flow()

# Define reg layout for A, B, C.
RegLayoutA = Layout.RowMajor
RegLayoutB = Layout.RowMajor
RegLayoutC = Layout.RowMajor

# Define shared layout for A, B, C.
SharedLayoutA = Layout.RowMajor
SharedLayoutB = Layout.ColMajor
SharedLayoutC = Layout.RowMajor

# Define global layout for A, B, C.
GlobalLayoutA = Layout.RowMajor
GlobalLayoutB = Layout.ColMajor
GlobalLayoutC = Layout.RowMajor

# Define Reg Dim for A, B, C.
RegDimA = [64, 64]
RegDimB = [64, 64]
RegDimC = [64, 64]

# Define Shared Dim for A, B, C.
SharedDimA = [64, 64]
SharedDimB = [64, 64]
SharedDimC = [64, 64]

# Define Global Dim for A, B, C.
GlobalDimA = [256, 256]
GlobalDimB = [256, 256]
GlobalDimC = [256, 256]

# Define Reg Tensor for A, B, C.
rA = Tensor("rA", RegDimA, RegLayoutA, TensorType.RegTile)
rB = Tensor("rB", RegDimB, RegLayoutB, TensorType.RegTile)
acc = Tensor("acc", RegDimC, RegLayoutC, TensorType.RegTile)

# Define Shared Tensor for A, B, C.
sA = Tensor("sA", SharedDimA, SharedLayoutA, TensorType.SharedTile)
sB = Tensor("sB", SharedDimB, SharedLayoutB, TensorType.SharedTile)
sC = Tensor("sC", SharedDimC, SharedLayoutC, TensorType.SharedTile)

# Define Global Tensor for A, B, C.
gA = Tensor("gA", GlobalDimA, GlobalLayoutA, TensorType.GlobalTile)
gB = Tensor("gB", GlobalDimB, GlobalLayoutB, TensorType.GlobalTile)
gC = Tensor("gC", GlobalDimC, GlobalLayoutC, TensorType.GlobalTile)

# Define Reg Node for A, B, C.
NodeRA = Node.tensor(rA)
NodeRB = Node.tensor(rB)
NodeRC = Node.tensor(acc)

# Define Reg GEMM Node.
RegGemmNode = Node.gemm(NodeRA, NodeRB, NodeRC)

# Define Reg Edge for A, B, C, GEMM.
RegEdgeA = Edge(NodeRA, RegGemmNode)
RegEdgeB = Edge(NodeRB, RegGemmNode)
RegEdgeC = Edge(RegGemmNode, NodeRC)

# Define Shared Node for A, B, C.
NodeSA = Node.tensor(sA)
NodeSB = Node.tensor(sB)
NodeSC = Node.tensor(sC)

# Define Global Node for A, B, C.
NodeGA = Node.tensor(gA)
NodeGB = Node.tensor(gB)
NodeGC = Node.tensor(gC)

# Define loop iter from shared to register
LoopIterS2R = IterationVar('j', (0, 1))

# Define loop iter from global to shared
LoopIterG2S = IterationVar('i', (0, 4))

# Build AccessMap from Shared to Register.
AccessMapSA2RA = AccessMap(
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterS2R])
AccessMapSB2RB = AccessMap(
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterS2R])
AccessMapRC2SC = AccessMap([0], [[[]], [[]]], [[], []], [])

# Build AccessMap from Global to Shared.
AccessMapGA2SA = AccessMap(
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterG2S])
AccessMapGB2SB = AccessMap(
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterG2S])
AccessMapSC2GC = AccessMap([0], [[[]], [[]]], [[], []], [])

# Build Attached Edge from Shared to Register.
AttachedEdgeSA2RA = AttachedEdge(sA, rA, AccessMapSA2RA)
AttachedEdgeSB2RB = AttachedEdge(sB, rB, AccessMapSB2RB)
AttachedEdgeSC2RC = AttachedEdge(acc, sC, AccessMapRC2SC)

# Build Attached Edge from Global to Shared.
AttachedEdgeGA2SA = AttachedEdge(gA, sA, AccessMapGA2SA)
AttachedEdgeGB2SB = AttachedEdge(gB, sB, AccessMapGB2SB)
AttachedEdgeSC2GC = AttachedEdge(sC, gC, AccessMapSC2GC)

# Build Register Level ETDG.
RegGraph = Graph()

# Add Reg Nodes into Reg Graph.
RegGraph.add_nodes([NodeRA, NodeRB, NodeRC, RegGemmNode])
# Add Reg Edges into Reg Graph.
RegGraph.add_edges([RegEdgeA, RegEdgeB, RegEdgeC])
# Connect Reg Graph.
RegGraph.connect()

# Print codegen for Reg Graph.
reg_code = RegGraph.codegen()
print(reg_code)

# Build Block for Shared to Register.
SharedToRegBlock = Block(
[AttachedEdgeSA2RA, AttachedEdgeSB2RB], [AttachedEdgeSC2RC], RegGraph, [LoopIterS2R])

# Print codegen for Shared to Register Block.
shared_to_reg_code = SharedToRegBlock.codegen()
print(shared_to_reg_code)

# Define BlockNode for SharedToRegBlock
SharedBlockNode = Node.block(SharedToRegBlock)

# Define Edge for SA, SB, SC, SharedBlockNode.
EdgeSA2Block = Edge(NodeSA, SharedBlockNode)
EdgeSB2Block = Edge(NodeSB, SharedBlockNode)
EdgeBlock2SC = Edge(SharedBlockNode, NodeSC)

# Build Shared Level ETDG.
SharedGraph = Graph()
# Add Shared Nodes into Shared Graph.
SharedGraph.add_nodes([NodeSA, NodeSB, NodeSC, SharedBlockNode])
# Add Shared Edges into Shared Graph.
SharedGraph.add_edges([EdgeSA2Block, EdgeSB2Block, EdgeBlock2SC])
# Connect Shared Graph.
SharedGraph.connect()

# Print codegen for Shared Graph.
shared_code = SharedGraph.codegen()
print(shared_code)

# Build Block for Global to Shared.
GlobalToSharedBlock = Block(
[AttachedEdgeGA2SA, AttachedEdgeGB2SB], [AttachedEdgeSC2GC], SharedGraph, [LoopIterG2S])

# Print codegen for Global to Shared Block.
global_to_shared_code = GlobalToSharedBlock.codegen()
print(global_to_shared_code)
4 changes: 3 additions & 1 deletion thriller-bindings/pythriller/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .context import initialize_thriller_flow, PyLayout, PyBufType, PyBuffer, PyGraph, PyNode, PyEdge, PyMemoryLevel, Gemm, AttachedEdge, Block, IterationVar, AccessMap
from .context import initialize_thriller_flow, Layout, TensorType
from .context import Graph, Node, Edge, Gemm, AttachedEdge, Tensor
from .context import Block, IterationVar, AccessMap
4 changes: 2 additions & 2 deletions thriller-bindings/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use pyo3::{prelude::*, types::PyList};
use crate::{access::PyAccessMap, buffer::PyBuffer, graph::PyGraph, var::PyIterationVar};

#[pyclass(unsendable, module = "block", name = "Block")]
pub struct PyBlock(pub ThrillerBlock);
pub struct PyBlock(pub Rc<ThrillerBlock>);

#[pyclass(unsendable, module = "block", name = "AttachedEdge")]
pub struct PyAttachedEdge(pub Rc<AttachedEdge>);
Expand Down Expand Up @@ -52,7 +52,7 @@ impl PyBlock {

let block = ThrillerBlock::new(inputs, outputs, subgraph, ivars);

Ok(PyBlock(block))
Ok(PyBlock(Rc::new(block)))
}

fn codegen(&self) -> PyResult<String> {
Expand Down
6 changes: 3 additions & 3 deletions thriller-bindings/src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ use std::rc::Rc;
use pyo3::prelude::*;
use thriller_core::{BufType, Buffer, Dim, Layout};

#[pyclass(unsendable)]
#[pyclass(unsendable, module = "buffer", name = "Tensor")]
pub struct PyBuffer(pub Rc<Buffer>);

#[pyclass]
#[pyclass(module = "buffer", name = "Layout")]
pub enum PyLayout {
RowMajor,
ColMajor,
}

#[pyclass]
#[pyclass(module = "buffer", name = "TensorType")]
pub enum PyBufType {
GlobalTile,
SharedTile,
Expand Down
25 changes: 13 additions & 12 deletions thriller-bindings/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,12 @@ use thriller_core::{
AccessMap, Gemm, Task, ThrillerEdge, ThrillerGraph, ThrillerNode, ThrillerNodeInner,
};

use crate::block::PyBlock;
use crate::buffer::PyBuffer;

use std::{cell::RefCell, rc::Rc};

#[pyclass]
pub enum PyMemoryLevel {
Register,
Shared,
Global,
}

#[pyclass(unsendable)]
#[pyclass(unsendable, module = "graph", name = "Graph")]
pub struct PyGraph(pub Rc<RefCell<ThrillerGraph>>);

#[pymethods]
Expand Down Expand Up @@ -66,17 +60,24 @@ impl PyGraph {
}
}

#[pyclass(unsendable)]
#[pyclass(unsendable, module = "graph", name = "Node")]
pub struct PyNode(pub Rc<RefCell<ThrillerNode>>);

#[pymethods]
impl PyNode {
#[new]
fn buffer(buf: &PyBuffer) -> Self {
#[staticmethod]
fn tensor(buf: PyRef<PyBuffer>) -> Self {
let node = ThrillerNode::new(thriller_core::ThrillerNodeInner::Buffer(Rc::clone(&buf.0)));
PyNode(Rc::new(RefCell::new(node)))
}

#[staticmethod]
fn block(block: PyRef<PyBlock>) -> Self {
let node = ThrillerNode::new(ThrillerNodeInner::Block(Rc::clone(&block.0)));
PyNode(Rc::new(RefCell::new(node)))
}

#[staticmethod]
fn gemm(a: PyRef<PyNode>, b: PyRef<PyNode>, c: PyRef<PyNode>) -> Self {
let access_map = AccessMap::new(0, vec![]);

Expand All @@ -98,7 +99,7 @@ impl PyNode {
}
}

#[pyclass(unsendable)]
#[pyclass(unsendable, module = "graph", name = "Edge")]
pub struct PyEdge(pub Rc<ThrillerEdge>);

#[pymethods]
Expand Down
Loading

0 comments on commit 5b0977d

Please sign in to comment.