Skip to content

Commit

Permalink
add toMindspore.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
peter0627ustc committed Nov 21, 2023
1 parent 6034184 commit 5e5917d
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/backend/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module Compilers
using PyCall
using ..ComputationalGraphs
import ..ComputationalGraphs: id, name, set_name!, operator, subgraphs, subgraph_factors, factor

Expand All @@ -8,5 +9,6 @@ using ..RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(Compilers)

include("static.jl")
include("toMindspore.jl")

end
122 changes: 122 additions & 0 deletions src/backend/toMindspore.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# ms = pyimport("mindspore")

"""
function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)
Returns the static representation of a computational graph node `g` with operator `operator`, subgraphs `subgraphs`, and subgraph factors `subgraph_factors` in python.
"""
function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector)
error(
"Static representation for computational graph nodes with operator $(operator) not yet implemented! " *
"Please define a method `to_static(::Type{$(operator)}, subgraphs::$(typeof(subgraphs)), subgraph_factors::$(typeof(subgraph_factors)))`."
)
end

function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " + ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " * ") * ")"
# return "(" * join(["g$(g.id)" for g in subgraphs], " * ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "((g$(subgraphs[1].id))**$N$factor_str)"
end

function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " + ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W}
if length(subgraphs) == 1
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "(g$(subgraphs[1].id)$factor_str)"
else
terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)]
return "(" * join(terms, " * ") * ")"
end
end

function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W}
factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])"
return "((g$(subgraphs[1].id))**$N$factor_str)"
end

"""
function to_julia_str(graphs::AbstractVector{<:AbstractGraph})
Compile a list of graphs into a string for a python static function and output a python script which support the static graph representation in mindspore framework.
"""

function to_python_str_ms(graphs::AbstractVector{<:AbstractGraph})
head = "import mindspore as ms\n@ms.jit\n"
head *= "def graphfunc():\n"
body = " graph_list = []\n"
leafidx = 1
root = [id(g) for g in graphs]
inds_visitedleaf = Int[]
inds_visitednode = Int[]
for graph in graphs
for g in PostOrderDFS(graph) #leaf first search
g_id = id(g)
target = "g$(g_id)"
isroot = false
if g_id in root
isroot = true
end
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = ms.Tensor(1.0)$factor_str\n"
leafidx += 1
push!(inds_visitedleaf, g_id)
else
g_id in inds_visitednode && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n"
push!(inds_visitednode, g_id)
end
if isroot
body *= " graph_list.append($target)\n"
end
end
end
tail = " return graph_list\n"
tail *= "output = graphfunc()"
expr = head * body * tail
# return head * body * tail
f = open("GraphFunc.py", "w")
write(f, expr)
end

# function to_mindspore_graph(graphs::AbstractVector{<:AbstractGraph})
# pyexpr = to_python_str_ms(graphs)
# py"""
# import mindspore as ms
# exec($pyexpr)
# ms_graph = jit(fn=graphfunc)
# out = ms_graph()
# """
# return py"out"
# end

0 comments on commit 5e5917d

Please sign in to comment.