Skip to content

Commit 5b0977d

Browse files
authored
feat(example): Add a python example to codegen whole gemm. (#20)
* Add Rust crate docs. * fix Python import. * Add Shared to Register codegen. * Add shared to register codegen * Add a python example to codegen whole gemm. * fix pynode buffer. * Add docs. * Add TODO comments. * Delete pycache. * fix gitignore. * Add empty line. * chore: fix PyNode tensor method. * chore: fix emit_loop method. * chore: follow code reviews.
1 parent 39f95e0 commit 5b0977d

File tree

13 files changed

+279
-66
lines changed

13 files changed

+279
-66
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
/target
22
TiledCUDA/
3+
**/__pycache__/

thriller-bindings/.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.env/
22
tests/__pycache__
3-
pythriller/__pycache__
3+
pythriller/__pycache__
4+
*/__pycache__
Binary file not shown.

thriller-bindings/examples/gemm/gemm_g2r.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
import context
22

3-
import pythriller
3+
from pythriller import initialize_thriller_flow
4+
from pythriller import Tensor, Layout, TensorType, Graph, Node, Edge
5+
from pythriller import AttachedEdge, Block, IterationVar, AccessMap
46

57
if __name__ == '__main__':
6-
pythriller.initialize_thriller_flow()
8+
initialize_thriller_flow()
79

8-
LayoutA = pythriller.PyLayout.RowMajor
9-
LayoutB = pythriller.PyLayout.RowMajor
10-
LayoutC = pythriller.PyLayout.RowMajor
10+
LayoutA = Layout.RowMajor
11+
LayoutB = Layout.RowMajor
12+
LayoutC = Layout.RowMajor
1113

12-
GlobalLayoutA = pythriller.PyLayout.RowMajor
13-
GlobalLayoutB = pythriller.PyLayout.ColMajor
14-
GlobalLayoutC = pythriller.PyLayout.RowMajor
14+
GlobalLayoutA = Layout.RowMajor
15+
GlobalLayoutB = Layout.ColMajor
16+
GlobalLayoutC = Layout.RowMajor
1517

16-
BufTypeA = pythriller.PyBufType.RegTile
17-
BufTypeB = pythriller.PyBufType.RegTile
18-
BufTypeC = pythriller.PyBufType.RegTile
18+
BufTypeA = TensorType.RegTile
19+
BufTypeB = TensorType.RegTile
20+
BufTypeC = TensorType.RegTile
1921

20-
GlobalTypeA = pythriller.PyBufType.GlobalTile
21-
GlobalTypeB = pythriller.PyBufType.GlobalTile
22-
GlobalTypeC = pythriller.PyBufType.GlobalTile
22+
GlobalTypeA = TensorType.GlobalTile
23+
GlobalTypeB = TensorType.GlobalTile
24+
GlobalTypeC = TensorType.GlobalTile
2325

2426
DimA = [64, 64]
2527
DimB = [64, 64]
@@ -29,13 +31,13 @@
2931
GlobalDimB = [256, 256]
3032
GlobalDimC = [256, 256]
3133

32-
rA = pythriller.PyBuffer("rA", DimA, LayoutA, BufTypeA)
33-
rB = pythriller.PyBuffer("rB", DimB, LayoutB, BufTypeB)
34-
acc = pythriller.PyBuffer("acc", DimC, LayoutC, BufTypeC)
34+
rA = Tensor("rA", DimA, LayoutA, BufTypeA)
35+
rB = Tensor("rB", DimB, LayoutB, BufTypeB)
36+
acc = Tensor("acc", DimC, LayoutC, BufTypeC)
3537

36-
gA = pythriller.PyBuffer("gA", GlobalDimA, GlobalLayoutA, GlobalTypeA)
37-
gB = pythriller.PyBuffer("gB", GlobalDimB, GlobalLayoutB, GlobalTypeB)
38-
gC = pythriller.PyBuffer("gC", GlobalDimC, GlobalLayoutC, GlobalTypeC)
38+
gA = Tensor("gA", GlobalDimA, GlobalLayoutA, GlobalTypeA)
39+
gB = Tensor("gB", GlobalDimB, GlobalLayoutB, GlobalTypeB)
40+
gC = Tensor("gC", GlobalDimC, GlobalLayoutC, GlobalTypeC)
3941

4042
print(rA)
4143
print(rB)
@@ -45,37 +47,35 @@
4547
print(gB)
4648
print(gC)
4749

48-
MemoryLevel = pythriller.PyMemoryLevel.Register
49-
RegGraph = pythriller.PyGraph()
50+
RegGraph = Graph()
5051

51-
NodeA = pythriller.PyNode(rA)
52-
NodeB = pythriller.PyNode(rB)
53-
NodeAcc = pythriller.PyNode(acc)
52+
NodeA = Node.tensor(rA)
53+
NodeB = Node.tensor(rB)
54+
NodeAcc = Node.tensor(acc)
5455

55-
GemmNode = pythriller.PyNode.gemm(NodeA, NodeB, NodeAcc)
56+
GemmNode = Node.gemm(NodeA, NodeB, NodeAcc)
5657

57-
LoopIter = pythriller.IterationVar('i', (0, 4))
58+
LoopIter = IterationVar('i', (0, 4))
5859

5960
access_dims = [1]
6061

61-
AccessMap = pythriller.AccessMap(
62+
AccessMap = AccessMap(
6263
access_dims, [[[1]], [[0]]], [[0], [10]], [LoopIter])
6364

64-
EdgeA_Gemm = pythriller.PyEdge(NodeA, GemmNode)
65-
EdgeB_GEMM = pythriller.PyEdge(NodeB, GemmNode)
66-
EdgeGemm_Acc = pythriller.PyEdge(GemmNode, NodeAcc)
65+
EdgeA_Gemm = Edge(NodeA, GemmNode)
66+
EdgeB_Gemm = Edge(NodeB, GemmNode)
67+
EdgeGemm_Acc = Edge(GemmNode, NodeAcc)
6768

6869
RegGraph.add_nodes([NodeA, NodeB, NodeAcc, GemmNode])
69-
RegGraph.add_edges([EdgeA_Gemm, EdgeB_GEMM, EdgeGemm_Acc])
70+
RegGraph.add_edges([EdgeA_Gemm, EdgeB_Gemm, EdgeGemm_Acc])
7071

7172
RegGraph.connect()
7273

73-
LoadGlobalToRegEdgeA = pythriller.AttachedEdge(gA, rA, AccessMap)
74-
LoadGlobalToRegEdgeB = pythriller.AttachedEdge(gB, rB, AccessMap)
75-
StoreRegToGlobalEdgeC = pythriller.AttachedEdge(acc, gC, AccessMap)
76-
G2RBlockMemLevel = pythriller.PyMemoryLevel.Register
74+
LoadGlobalToRegEdgeA = AttachedEdge(gA, rA, AccessMap)
75+
LoadGlobalToRegEdgeB = AttachedEdge(gB, rB, AccessMap)
76+
StoreRegToGlobalEdgeC = AttachedEdge(acc, gC, AccessMap)
7777

78-
GlobalToRegBlock = pythriller.Block(
78+
GlobalToRegBlock = Block(
7979
[LoadGlobalToRegEdgeA, LoadGlobalToRegEdgeB], [StoreRegToGlobalEdgeC], RegGraph, [LoopIter])
8080

8181
code = GlobalToRegBlock.codegen()
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
'''
2+
Whole GEMM is an example of GEMM that utilizes all memory hierarchies
3+
of NVIDIA GPU.
4+
'''
5+
import context
6+
7+
from pythriller import initialize_thriller_flow, Layout, Tensor, TensorType
8+
from pythriller import Graph, Node, Edge, AttachedEdge, IterationVar, AccessMap
9+
from pythriller import Block
10+
11+
12+
if __name__ == '__main__':
13+
# Initialize runtime.
14+
initialize_thriller_flow()
15+
16+
# Define reg layout for A, B, C.
17+
RegLayoutA = Layout.RowMajor
18+
RegLayoutB = Layout.RowMajor
19+
RegLayoutC = Layout.RowMajor
20+
21+
# Define shared layout for A, B, C.
22+
SharedLayoutA = Layout.RowMajor
23+
SharedLayoutB = Layout.ColMajor
24+
SharedLayoutC = Layout.RowMajor
25+
26+
# Define global layout for A, B, C.
27+
GlobalLayoutA = Layout.RowMajor
28+
GlobalLayoutB = Layout.ColMajor
29+
GlobalLayoutC = Layout.RowMajor
30+
31+
# Define Reg Dim for A, B, C.
32+
RegDimA = [64, 64]
33+
RegDimB = [64, 64]
34+
RegDimC = [64, 64]
35+
36+
# Define Shared Dim for A, B, C.
37+
SharedDimA = [64, 64]
38+
SharedDimB = [64, 64]
39+
SharedDimC = [64, 64]
40+
41+
# Define Global Dim for A, B, C.
42+
GlobalDimA = [256, 256]
43+
GlobalDimB = [256, 256]
44+
GlobalDimC = [256, 256]
45+
46+
# Define Reg Tensor for A, B, C.
47+
rA = Tensor("rA", RegDimA, RegLayoutA, TensorType.RegTile)
48+
rB = Tensor("rB", RegDimB, RegLayoutB, TensorType.RegTile)
49+
acc = Tensor("acc", RegDimC, RegLayoutC, TensorType.RegTile)
50+
51+
# Define Shared Tensor for A, B, C.
52+
sA = Tensor("sA", SharedDimA, SharedLayoutA, TensorType.SharedTile)
53+
sB = Tensor("sB", SharedDimB, SharedLayoutB, TensorType.SharedTile)
54+
sC = Tensor("sC", SharedDimC, SharedLayoutC, TensorType.SharedTile)
55+
56+
# Define Global Tensor for A, B, C.
57+
gA = Tensor("gA", GlobalDimA, GlobalLayoutA, TensorType.GlobalTile)
58+
gB = Tensor("gB", GlobalDimB, GlobalLayoutB, TensorType.GlobalTile)
59+
gC = Tensor("gC", GlobalDimC, GlobalLayoutC, TensorType.GlobalTile)
60+
61+
# Define Reg Node for A, B, C.
62+
NodeRA = Node.tensor(rA)
63+
NodeRB = Node.tensor(rB)
64+
NodeRC = Node.tensor(acc)
65+
66+
# Define Reg GEMM Node.
67+
RegGemmNode = Node.gemm(NodeRA, NodeRB, NodeRC)
68+
69+
# Define Reg Edge for A, B, C, GEMM.
70+
RegEdgeA = Edge(NodeRA, RegGemmNode)
71+
RegEdgeB = Edge(NodeRB, RegGemmNode)
72+
RegEdgeC = Edge(RegGemmNode, NodeRC)
73+
74+
# Define Shared Node for A, B, C.
75+
NodeSA = Node.tensor(sA)
76+
NodeSB = Node.tensor(sB)
77+
NodeSC = Node.tensor(sC)
78+
79+
# Define Global Node for A, B, C.
80+
NodeGA = Node.tensor(gA)
81+
NodeGB = Node.tensor(gB)
82+
NodeGC = Node.tensor(gC)
83+
84+
# Define loop iter from shared to register
85+
LoopIterS2R = IterationVar('j', (0, 1))
86+
87+
# Define loop iter from global to shared
88+
LoopIterG2S = IterationVar('i', (0, 4))
89+
90+
# Build AccessMap from Shared to Register.
91+
AccessMapSA2RA = AccessMap(
92+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterS2R])
93+
AccessMapSB2RB = AccessMap(
94+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterS2R])
95+
AccessMapRC2SC = AccessMap([0], [[[]], [[]]], [[], []], [])
96+
97+
# Build AccessMap from Global to Shared.
98+
AccessMapGA2SA = AccessMap(
99+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterG2S])
100+
AccessMapGB2SB = AccessMap(
101+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterG2S])
102+
AccessMapSC2GC = AccessMap([0], [[[]], [[]]], [[], []], [])
103+
104+
# Build Attached Edge from Shared to Register.
105+
AttachedEdgeSA2RA = AttachedEdge(sA, rA, AccessMapSA2RA)
106+
AttachedEdgeSB2RB = AttachedEdge(sB, rB, AccessMapSB2RB)
107+
AttachedEdgeSC2RC = AttachedEdge(acc, sC, AccessMapRC2SC)
108+
109+
# Build Attached Edge from Global to Shared.
110+
AttachedEdgeGA2SA = AttachedEdge(gA, sA, AccessMapGA2SA)
111+
AttachedEdgeGB2SB = AttachedEdge(gB, sB, AccessMapGB2SB)
112+
AttachedEdgeSC2GC = AttachedEdge(sC, gC, AccessMapSC2GC)
113+
114+
# Build Register Level ETDG.
115+
RegGraph = Graph()
116+
117+
# Add Reg Nodes into Reg Graph.
118+
RegGraph.add_nodes([NodeRA, NodeRB, NodeRC, RegGemmNode])
119+
# Add Reg Edges into Reg Graph.
120+
RegGraph.add_edges([RegEdgeA, RegEdgeB, RegEdgeC])
121+
# Connect Reg Graph.
122+
RegGraph.connect()
123+
124+
# Print codegen for Reg Graph.
125+
reg_code = RegGraph.codegen()
126+
print(reg_code)
127+
128+
# Build Block for Shared to Register.
129+
SharedToRegBlock = Block(
130+
[AttachedEdgeSA2RA, AttachedEdgeSB2RB], [AttachedEdgeSC2RC], RegGraph, [LoopIterS2R])
131+
132+
# Print codegen for Shared to Register Block.
133+
shared_to_reg_code = SharedToRegBlock.codegen()
134+
print(shared_to_reg_code)
135+
136+
# Define BlockNode for SharedToRegBlock
137+
SharedBlockNode = Node.block(SharedToRegBlock)
138+
139+
# Define Edge for SA, SB, SC, SharedBlockNode.
140+
EdgeSA2Block = Edge(NodeSA, SharedBlockNode)
141+
EdgeSB2Block = Edge(NodeSB, SharedBlockNode)
142+
EdgeBlock2SC = Edge(SharedBlockNode, NodeSC)
143+
144+
# Build Shared Level ETDG.
145+
SharedGraph = Graph()
146+
# Add Shared Nodes into Shared Graph.
147+
SharedGraph.add_nodes([NodeSA, NodeSB, NodeSC, SharedBlockNode])
148+
# Add Shared Edges into Shared Graph.
149+
SharedGraph.add_edges([EdgeSA2Block, EdgeSB2Block, EdgeBlock2SC])
150+
# Connect Shared Graph.
151+
SharedGraph.connect()
152+
153+
# Print codegen for Shared Graph.
154+
shared_code = SharedGraph.codegen()
155+
print(shared_code)
156+
157+
# Build Block for Global to Shared.
158+
GlobalToSharedBlock = Block(
159+
[AttachedEdgeGA2SA, AttachedEdgeGB2SB], [AttachedEdgeSC2GC], SharedGraph, [LoopIterG2S])
160+
161+
# Print codegen for Global to Shared Block.
162+
global_to_shared_code = GlobalToSharedBlock.codegen()
163+
print(global_to_shared_code)
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
from .context import initialize_thriller_flow, PyLayout, PyBufType, PyBuffer, PyGraph, PyNode, PyEdge, PyMemoryLevel, Gemm, AttachedEdge, Block, IterationVar, AccessMap
1+
from .context import initialize_thriller_flow, Layout, TensorType
2+
from .context import Graph, Node, Edge, Gemm, AttachedEdge, Tensor
3+
from .context import Block, IterationVar, AccessMap

thriller-bindings/src/block.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use pyo3::{prelude::*, types::PyList};
77
use crate::{access::PyAccessMap, buffer::PyBuffer, graph::PyGraph, var::PyIterationVar};
88

99
#[pyclass(unsendable, module = "block", name = "Block")]
10-
pub struct PyBlock(pub ThrillerBlock);
10+
pub struct PyBlock(pub Rc<ThrillerBlock>);
1111

1212
#[pyclass(unsendable, module = "block", name = "AttachedEdge")]
1313
pub struct PyAttachedEdge(pub Rc<AttachedEdge>);
@@ -52,7 +52,7 @@ impl PyBlock {
5252

5353
let block = ThrillerBlock::new(inputs, outputs, subgraph, ivars);
5454

55-
Ok(PyBlock(block))
55+
Ok(PyBlock(Rc::new(block)))
5656
}
5757

5858
fn codegen(&self) -> PyResult<String> {

thriller-bindings/src/buffer.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ use std::rc::Rc;
33
use pyo3::prelude::*;
44
use thriller_core::{BufType, Buffer, Dim, Layout};
55

6-
#[pyclass(unsendable)]
6+
#[pyclass(unsendable, module = "buffer", name = "Tensor")]
77
pub struct PyBuffer(pub Rc<Buffer>);
88

9-
#[pyclass]
9+
#[pyclass(module = "buffer", name = "Layout")]
1010
pub enum PyLayout {
1111
RowMajor,
1212
ColMajor,
1313
}
1414

15-
#[pyclass]
15+
#[pyclass(module = "buffer", name = "TensorType")]
1616
pub enum PyBufType {
1717
GlobalTile,
1818
SharedTile,

thriller-bindings/src/graph.rs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,12 @@ use thriller_core::{
55
AccessMap, Gemm, Task, ThrillerEdge, ThrillerGraph, ThrillerNode, ThrillerNodeInner,
66
};
77

8+
use crate::block::PyBlock;
89
use crate::buffer::PyBuffer;
910

1011
use std::{cell::RefCell, rc::Rc};
1112

12-
#[pyclass]
13-
pub enum PyMemoryLevel {
14-
Register,
15-
Shared,
16-
Global,
17-
}
18-
19-
#[pyclass(unsendable)]
13+
#[pyclass(unsendable, module = "graph", name = "Graph")]
2014
pub struct PyGraph(pub Rc<RefCell<ThrillerGraph>>);
2115

2216
#[pymethods]
@@ -66,17 +60,24 @@ impl PyGraph {
6660
}
6761
}
6862

69-
#[pyclass(unsendable)]
63+
#[pyclass(unsendable, module = "graph", name = "Node")]
7064
pub struct PyNode(pub Rc<RefCell<ThrillerNode>>);
7165

7266
#[pymethods]
7367
impl PyNode {
74-
#[new]
75-
fn buffer(buf: &PyBuffer) -> Self {
68+
#[staticmethod]
69+
fn tensor(buf: PyRef<PyBuffer>) -> Self {
7670
let node = ThrillerNode::new(thriller_core::ThrillerNodeInner::Buffer(Rc::clone(&buf.0)));
7771
PyNode(Rc::new(RefCell::new(node)))
7872
}
7973

74+
#[staticmethod]
75+
fn block(block: PyRef<PyBlock>) -> Self {
76+
let node = ThrillerNode::new(ThrillerNodeInner::Block(Rc::clone(&block.0)));
77+
PyNode(Rc::new(RefCell::new(node)))
78+
}
79+
80+
#[staticmethod]
8081
fn gemm(a: PyRef<PyNode>, b: PyRef<PyNode>, c: PyRef<PyNode>) -> Self {
8182
let access_map = AccessMap::new(0, vec![]);
8283

@@ -98,7 +99,7 @@ impl PyNode {
9899
}
99100
}
100101

101-
#[pyclass(unsendable)]
102+
#[pyclass(unsendable, module = "graph", name = "Edge")]
102103
pub struct PyEdge(pub Rc<ThrillerEdge>);
103104

104105
#[pymethods]

0 commit comments

Comments
 (0)