Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified the python compiler for jax #167

Merged
merged 20 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 195 additions & 132 deletions src/backend/compiler_python.jl
Original file line number Diff line number Diff line change
@@ -1,132 +1,195 @@
# ms = pyimport("mindspore")

"""
function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)

Returns the static representation of a computational graph node `g` with operator `operator`, subgraphs `subgraphs`, and subgraph factors `subgraph_factors` in python.
"""
function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)
error(
"Static representation for computational graph nodes with operator $(operator) not yet implemented! " *
"Please define a method `to_static(::Type{$(operator)}, subgraphs::$(typeof(subgraphs)), subgraph_factors::$(typeof(subgraph_factors)))`."
)
end

function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " + ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " * ") * ")"
# return "(" * join(["g$(g.id)" for g in subgraphs], " * ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "((g$(subgraphs[1].id))**$N$factor_str)"
end

function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " + ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " * ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "((g$(subgraphs[1].id))**$N$factor_str)"
end

"""
function to_python_str(graphs::AbstractVector{<:AbstractGraph})
Compile a list of graphs into a string for a python static function and output a python script which support the mindspore and jax framework.

# Arguments:
- `graphs` vector of computational graphs
- `framework` the type of the python frameworks, including `:jax` and `mindspore`.
"""
function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax)
if framework == :jax
head = ""
elseif framework == :mindspore
head = "import mindspore as ms\[email protected]\n"
else
error("no support for $type framework")
end
body = ""
leafidx = 1
root = [id(g) for g in graphs]
inds_visitedleaf = Int[]
inds_visitednode = Int[]
gid_to_leafid = Dict{String, Int64}()
rootidx = 1
for graph in graphs
for g in PostOrderDFS(graph) #leaf first search
g_id = id(g)
target = "g$(g_id)"
isroot = false
if g_id in root
isroot = true
end
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = l$(leafidx)$factor_str\n"
gid_to_leafid[target] = leafidx
leafidx += 1
push!(inds_visitedleaf, g_id)
else
g_id in inds_visitednode && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n"
push!(inds_visitednode, g_id)
end
if isroot
body *= " out$(rootidx)=$target\n"
rootidx +=1
end
end
end
input = ["l$(i)" for i in 1:leafidx-1]
input = join(input,",")
output = ["out$(i)" for i in 1:rootidx-1]
output = join(output,",")
head *="def graphfunc($input):\n"
tail = " return $output\n"
# tail*= "def to_StaticGraph(leaf):\n"
# tail*= " output = graphfunc(leaf)\n"
# tail*= " return output"
expr = head * body * tail
println(expr)
# return head * body * tail
f = open("GraphFunc.py", "w")
write(f, expr)
return expr, leafidx-1, gid_to_leafid
end

# ms = pyimport("mindspore")

"""
function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)

Returns the static representation of a computational graph node `g` with operator `operator`, subgraphs `subgraphs`, and subgraph factors `subgraph_factors` in python.
"""
function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)
error(
"Static representation for computational graph nodes with operator $(operator) not yet implemented! " *
"Please define a method `to_static(::Type{$(operator)}, subgraphs::$(typeof(subgraphs)), subgraph_factors::$(typeof(subgraph_factors)))`."
)
end

function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " + ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " * ") * ")"
# return "(" * join(["g$(g.id)" for g in subgraphs], " * ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "((g$(subgraphs[1].id))**$N$factor_str)"
end

function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " + ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " * ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "((g$(subgraphs[1].id))**$N$factor_str)"
end

"""
function to_python_str(graphs::AbstractVector{<:AbstractGraph})
Compile a list of graphs into a string for a python static function and output a python script which support the mindspore and jax framework.

# Arguments:
- `graphs` vector of computational graphs
- `framework` the type of the python frameworks, including `:jax` and `mindspore`.
"""
function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax)
if framework == :jax
head = "from jax import jit\n"
elseif framework == :mindspore
head = "import mindspore as ms\[email protected]\n"
else
error("no support for $type framework")
end
body = ""
leafidx = 0
root = [id(g) for g in graphs]
inds_visitedleaf = Int[]
inds_visitednode = Int[]
gid_to_leafid = Dict{String,Int64}()
rootidx = 0
for graph in graphs
for g in PostOrderDFS(graph) #leaf first search
g_id = id(g)
target = "g$(g_id)"
isroot = false
if g_id in root
isroot = true
end
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = leaf[$(leafidx)]$factor_str\n"
gid_to_leafid[target] = leafidx
leafidx += 1
push!(inds_visitedleaf, g_id)
else
g_id in inds_visitednode && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n"
push!(inds_visitednode, g_id)
end
if isroot
body *= " root$(rootidx) = $target\n"
rootidx += 1
end
end
end
head *= "def graphfunc(leaf):\n"
output = ["root$(i)" for i in 0:rootidx-1]
output = join(output,",")
tail = " return $output\n\n"

if framework == :jax
tail *="graphfunc_jit = jit(graphfunc)"
end
expr = head * body * tail

return expr, leafidx , gid_to_leafid
end
function compile_python(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax, filename::String="GraphFunc.py")
py_string, leafnum, leafmap = to_python_str(graphs,framework)
println("The number of leaves: $leafnum")
open(filename, "w") do f
write(f, py_string)
end
return leafnum, leafmap
end

# function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax)
# if framework == :jax
# head = ""
# elseif framework == :mindspore
# head = "import mindspore as ms\[email protected]\n"
# else
# error("no support for $type framework")
# end
# body = ""
# leafidx = 1
# root = [id(g) for g in graphs]
# inds_visitedleaf = Int[]
# inds_visitednode = Int[]
# gid_to_leafid = Dict{String, Int64}()
# rootidx = 1
# for graph in graphs
# for g in PostOrderDFS(graph) #leaf first search
# g_id = id(g)
# target = "g$(g_id)"
# isroot = false
# if g_id in root
# isroot = true
# end
# if isempty(subgraphs(g)) #leaf
# g_id in inds_visitedleaf && continue
# factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
# body *= " $target = l$(leafidx)$factor_str\n"
# gid_to_leafid[target] = leafidx
# leafidx += 1
# push!(inds_visitedleaf, g_id)
# else
# g_id in inds_visitednode && continue
# factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
# body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n"
# push!(inds_visitednode, g_id)
# end
# if isroot
# body *= " out$(rootidx)=$target\n"
# rootidx +=1
# end
# end
# end
# input = ["l$(i)" for i in 1:leafidx-1]
# input = join(input,",")
# output = ["out$(i)" for i in 1:rootidx-1]
# output = join(output,",")
# head *="def graphfunc($input):\n"
# tail = " return $output\n"
# # tail*= "def to_StaticGraph(leaf):\n"
# # tail*= " output = graphfunc(leaf)\n"
# # tail*= " return output"
# expr = head * body * tail
# println(expr)
# # return head * body * tail
# f = open("GraphFunc.py", "w")
# write(f, expr)
# return expr, leafidx-1, gid_to_leafid
# end

Loading
Loading