Skip to content

Commit

Permalink
update toMindspore.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
peter0627ustc committed Nov 24, 2023
1 parent 5e5917d commit 46566ec
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/backend/toMindspore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Compile a list of graphs into a string for a python static function and output a

function to_python_str_ms(graphs::AbstractVector{<:AbstractGraph})
head = "import mindspore as ms\n@ms.jit\n"
head *= "def graphfunc():\n"
head *= "def graphfunc(leaf):\n"
body = " graph_list = []\n"
leafidx = 1
root = [id(g) for g in graphs]
Expand All @@ -88,7 +88,7 @@ function to_python_str_ms(graphs::AbstractVector{<:AbstractGraph})
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = ms.Tensor(1.0)$factor_str\n"
body *= " $target = ms.Tensor(leaf[$(leafidx-1)])$factor_str\n"
leafidx += 1
push!(inds_visitedleaf, g_id)
else
Expand All @@ -103,11 +103,15 @@ function to_python_str_ms(graphs::AbstractVector{<:AbstractGraph})
end
end
tail = " return graph_list\n"
tail *= "output = graphfunc()"
tail*= "def to_StaticGraph(leaf)\n"
tail*= " output = graphfunc(leaf)\n"
tail*= " return output"
expr = head * body * tail
println(expr)
# return head * body * tail
f = open("GraphFunc.py", "w")
write(f, expr)
return expr
end

# function to_mindspore_graph(graphs::AbstractVector{<:AbstractGraph})
Expand Down

0 comments on commit 46566ec

Please sign in to comment.