Skip to content

Commit

Permalink
update python compiler with vectorization and the in-place argument
Browse files Browse the repository at this point in the history
  • Loading branch information
houpc committed Nov 1, 2024
1 parent 87071a2 commit 7e1b736
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions src/backend/compiler_python.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,54 @@
# Arguments:
- `graphs` vector of computational graphs
"""
function to_python_str(graphs::AbstractVector{<:AbstractGraph})
function to_python_str(graphs::AbstractVector{<:AbstractGraph};
root::AbstractVector{Int}=[id(g) for g in graphs], name::String="eval_graph", in_place::Bool=false)
head = ""
body = ""
leafidx = 0
root = [id(g) for g in graphs]
inds_visitedleaf = Int[]
inds_visitednode = Int[]
gid_to_leafid = Dict{String,Int64}()
rootidx = 0
map_validx_leaf = Dict{Int,eltype(graphs)}() # mapping from the index of the leafVal to the leaf graph
for graph in graphs
for g in PostOrderDFS(graph) #leaf first search
g_id = id(g)
target = "g$(g_id)"
isroot = false
if g_id in root
target_root = "root[:, $(findfirst(x -> x == g_id, root)-1)]"
isroot = true
end
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
body *= " $target = leaf[$(leafidx)]\n"
gid_to_leafid[target] = leafidx
body *= " $target = leafVal[:, $(leafidx)]\n"
leafidx += 1
map_validx_leaf[leafidx] = g
push!(inds_visitedleaf, g_id)
else
g_id in inds_visitednode && continue
body *= " $target = $(to_static(operator(g), subgraphs(g), subgraph_factors(g), lang=:python))\n"
push!(inds_visitednode, g_id)
end
if isroot
body *= " root$(rootidx) = $target\n"
rootidx += 1
body *= " $target_root = $target\n"
end
end
end
head *= "def graphfunc(leaf):\n"
output = ["root$(i)" for i in 0:rootidx-1]
output = join(output, ",")
tail = " return $output\n\n"

expr = head * body * tail
if in_place
head *= "def $name(root, leafVal):\n"
else
head *= "import torch\n"
head *= "def $name(leafVal):\n"
head *= " root = torch.empty(leafVal.shape[0], $(length(graphs)), dtype=leafVal.dtype, device=leafVal.device)\n"
end
tail = " return root\n\n"

return expr, gid_to_leafid
return head * body * tail, map_validx_leaf
end
function compile_Python(graphs::AbstractVector{<:AbstractGraph}, filename::String="GraphFunc.py")
py_string, leafmap = to_python_str(graphs)
open(filename, "w") do f
function compile_Python(graphs::AbstractVector{<:AbstractGraph}, filename::String;
root::AbstractVector{Int}=[id(g) for g in graphs], func_name="eval_graph")
py_string, leafmap = to_python_str(graphs, root=root, name=func_name)
open(filename, "a") do f
write(f, py_string)
end
return leafmap
Expand Down

0 comments on commit 7e1b736

Please sign in to comment.