generated from minitorch/Module-0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodule_interface.py
47 lines (38 loc) · 1.19 KB
/
module_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import networkx as nx
import streamlit as st
from streamlit_ace import st_ace
import minitorch
MyModule = None
minitorch
def render_module_sandbox():
st.write("## Sandbox for Module Trees")
st.write(
"Visual debugging checks showing the module tree that your code constructs."
)
code = st_ace(
language="python",
height=300,
value="""
class MyModule(minitorch.Module):
def __init__(self):
super().__init__()
self.parameter1 = minitorch.Parameter(15)
""",
)
out = exec(code, globals())
out = MyModule()
st.write(dict(out.named_parameters()))
G = nx.MultiDiGraph()
G.add_node("base")
stack = [(out, "base")]
while stack:
n, name = stack[0]
stack = stack[1:]
for pname, p in n.__dict__["_parameters"].items():
G.add_node(name + "." + pname, shape="rect", penwidth=0.5)
G.add_edge(name, name + "." + pname)
for cname, m in n.__dict__["_modules"].items():
G.add_edge(name, name + "." + cname)
stack.append((m, name + "." + cname))
G.graph["graph"] = {"rankdir": "TB"}
st.graphviz_chart(nx.nx_pydot.to_pydot(G).to_string())