Skip to content

Commit

Permalink
Merge pull request #167 from numericalEFT/computgraph_zhiyi
Browse files Browse the repository at this point in the history
Modify the python compiler for jax and modify the dot_compiler for visualization.
  • Loading branch information
peter0627ustc authored Dec 22, 2023
2 parents aac010c + 466bddf commit 3317708
Show file tree
Hide file tree
Showing 2 changed files with 342 additions and 256 deletions.
327 changes: 195 additions & 132 deletions src/backend/compiler_python.jl
Original file line number Diff line number Diff line change
@@ -1,132 +1,195 @@
# ms = pyimport("mindspore")

"""
function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)
Returns the static representation of a computational graph node `g` with operator `operator`, subgraphs `subgraphs`, and subgraph factors `subgraph_factors` in python.
"""
function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)
error(
"Static representation for computational graph nodes with operator $(operator) not yet implemented! " *
"Please define a method `to_static(::Type{$(operator)}, subgraphs::$(typeof(subgraphs)), subgraph_factors::$(typeof(subgraph_factors)))`."
)
end

function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " + ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " * ") * ")"
# return "(" * join(["g$(g.id)" for g in subgraphs], " * ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "((g$(subgraphs[1].id))**$N$factor_str)"
end

function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " + ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " * ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "((g$(subgraphs[1].id))**$N$factor_str)"
end

"""
function to_python_str(graphs::AbstractVector{<:AbstractGraph})
Compile a list of graphs into a string for a python static function and output a python script which support the mindspore and jax framework.
# Arguments:
- `graphs` vector of computational graphs
- `framework` the type of the python frameworks, including `:jax` and `mindspore`.
"""
function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax)
if framework == :jax
head = ""
elseif framework == :mindspore
head = "import mindspore as ms\n@ms.jit\n"
else
error("no support for $type framework")
end
body = ""
leafidx = 1
root = [id(g) for g in graphs]
inds_visitedleaf = Int[]
inds_visitednode = Int[]
gid_to_leafid = Dict{String, Int64}()
rootidx = 1
for graph in graphs
for g in PostOrderDFS(graph) #leaf first search
g_id = id(g)
target = "g$(g_id)"
isroot = false
if g_id in root
isroot = true
end
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = l$(leafidx)$factor_str\n"
gid_to_leafid[target] = leafidx
leafidx += 1
push!(inds_visitedleaf, g_id)
else
g_id in inds_visitednode && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n"
push!(inds_visitednode, g_id)
end
if isroot
body *= " out$(rootidx)=$target\n"
rootidx +=1
end
end
end
input = ["l$(i)" for i in 1:leafidx-1]
input = join(input,",")
output = ["out$(i)" for i in 1:rootidx-1]
output = join(output,",")
head *="def graphfunc($input):\n"
tail = " return $output\n"
# tail*= "def to_StaticGraph(leaf):\n"
# tail*= " output = graphfunc(leaf)\n"
# tail*= " return output"
expr = head * body * tail
println(expr)
# return head * body * tail
f = open("GraphFunc.py", "w")
write(f, expr)
return expr, leafidx-1, gid_to_leafid
end

# ms = pyimport("mindspore")

"""
function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)
Returns the static representation of a computational graph node `g` with operator `operator`, subgraphs `subgraphs`, and subgraph factors `subgraph_factors` in python.
"""
function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)
error(
"Static representation for computational graph nodes with operator $(operator) not yet implemented! " *
"Please define a method `to_static(::Type{$(operator)}, subgraphs::$(typeof(subgraphs)), subgraph_factors::$(typeof(subgraph_factors)))`."
)
end

function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " + ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " * ") * ")"
# return "(" * join(["g$(g.id)" for g in subgraphs], " * ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "((g$(subgraphs[1].id))**$N$factor_str)"
end

function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " + ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " * ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "((g$(subgraphs[1].id))**$N$factor_str)"
end

"""
function to_python_str(graphs::AbstractVector{<:AbstractGraph})
Compile a list of graphs into a string for a python static function and output a python script which support the mindspore and jax framework.
# Arguments:
- `graphs` vector of computational graphs
- `framework` the type of the python frameworks, including `:jax` and `mindspore`.
"""
function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax)
if framework == :jax
head = "from jax import jit\n"
elseif framework == :mindspore
head = "import mindspore as ms\n@ms.jit\n"
else
error("no support for $type framework")
end
body = ""
leafidx = 0
root = [id(g) for g in graphs]
inds_visitedleaf = Int[]
inds_visitednode = Int[]
gid_to_leafid = Dict{String,Int64}()
rootidx = 0
for graph in graphs
for g in PostOrderDFS(graph) #leaf first search
g_id = id(g)
target = "g$(g_id)"
isroot = false
if g_id in root
isroot = true
end
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = leaf[$(leafidx)]$factor_str\n"
gid_to_leafid[target] = leafidx
leafidx += 1
push!(inds_visitedleaf, g_id)
else
g_id in inds_visitednode && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n"
push!(inds_visitednode, g_id)
end
if isroot
body *= " root$(rootidx) = $target\n"
rootidx += 1
end
end
end
head *= "def graphfunc(leaf):\n"
output = ["root$(i)" for i in 0:rootidx-1]
output = join(output,",")
tail = " return $output\n\n"

if framework == :jax
tail *="graphfunc_jit = jit(graphfunc)"
end
expr = head * body * tail

return expr, leafidx , gid_to_leafid
end
function compile_python(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax, filename::String="GraphFunc.py")
py_string, leafnum, leafmap = to_python_str(graphs,framework)
println("The number of leaves: $leafnum")
open(filename, "w") do f
write(f, py_string)
end
return leafnum, leafmap
end

# function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax)
# if framework == :jax
# head = ""
# elseif framework == :mindspore
# head = "import mindspore as ms\[email protected]\n"
# else
# error("no support for $type framework")
# end
# body = ""
# leafidx = 1
# root = [id(g) for g in graphs]
# inds_visitedleaf = Int[]
# inds_visitednode = Int[]
# gid_to_leafid = Dict{String, Int64}()
# rootidx = 1
# for graph in graphs
# for g in PostOrderDFS(graph) #leaf first search
# g_id = id(g)
# target = "g$(g_id)"
# isroot = false
# if g_id in root
# isroot = true
# end
# if isempty(subgraphs(g)) #leaf
# g_id in inds_visitedleaf && continue
# factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
# body *= " $target = l$(leafidx)$factor_str\n"
# gid_to_leafid[target] = leafidx
# leafidx += 1
# push!(inds_visitedleaf, g_id)
# else
# g_id in inds_visitednode && continue
# factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
# body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n"
# push!(inds_visitednode, g_id)
# end
# if isroot
# body *= " out$(rootidx)=$target\n"
# rootidx +=1
# end
# end
# end
# input = ["l$(i)" for i in 1:leafidx-1]
# input = join(input,",")
# output = ["out$(i)" for i in 1:rootidx-1]
# output = join(output,",")
# head *="def graphfunc($input):\n"
# tail = " return $output\n"
# # tail*= "def to_StaticGraph(leaf):\n"
# # tail*= " output = graphfunc(leaf)\n"
# # tail*= " return output"
# expr = head * body * tail
# println(expr)
# # return head * body * tail
# f = open("GraphFunc.py", "w")
# write(f, expr)
# return expr, leafidx-1, gid_to_leafid
# end

Loading

0 comments on commit 3317708

Please sign in to comment.