From 7ef8905ba0f58d2a13502bdc3faf02d7112ac058 Mon Sep 17 00:00:00 2001 From: Lizhiyi Date: Tue, 19 Dec 2023 22:32:44 +0800 Subject: [PATCH 01/12] modify compile_python for sampling --- src/backend/compiler_python.jl | 96 +++++++++++++++++++++++++++++----- 1 file changed, 82 insertions(+), 14 deletions(-) diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index 66ed79b2..878bd761 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -73,18 +73,18 @@ end """ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax) if framework == :jax - head = "" + head = "from jax import jit" elseif framework == :mindspore head = "import mindspore as ms\n@ms.jit\n" - else + else error("no support for $type framework") end body = "" - leafidx = 1 + leafidx = 0 root = [id(g) for g in graphs] inds_visitedleaf = Int[] inds_visitednode = Int[] - gid_to_leafid = Dict{String, Int64}() + gid_to_leafid = Dict{String,Int64}() rootidx = 1 for graph in graphs for g in PostOrderDFS(graph) #leaf first search @@ -97,7 +97,7 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo if isempty(subgraphs(g)) #leaf g_id in inds_visitedleaf && continue factor_str = factor(g) == 1 ? "" : " * $(factor(g))" - body *= " $target = l$(leafidx)$factor_str\n" + body *= " $target = leaf[$(leafidx)]$factor_str\n" gid_to_leafid[target] = leafidx leafidx += 1 push!(inds_visitedleaf, g_id) @@ -109,24 +109,92 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo end if isroot body *= " out$(rootidx)=$target\n" - rootidx +=1 + rootidx += 1 end end end - input = ["l$(i)" for i in 1:leafidx-1] - input = join(input,",") - output = ["out$(i)" for i in 1:rootidx-1] - output = join(output,",") - head *="def graphfunc($input):\n" - tail = " return $output\n" + head *= "def graphfunc(root,leaf):\n" + tail = "\n" + + # tail = " return $output\n" # tail*= "def to_StaticGraph(leaf):\n" # tail*= " output = graphfunc(leaf)\n" # tail*= " return output" + if framework == :jax + tail *="graphfunc_jit = jit(graphfunc)" + end expr = head * body * tail - println(expr) + # println(expr) # return head * body * tail f = open("GraphFunc.py", "w") write(f, expr) - return expr, leafidx-1, gid_to_leafid + return expr, leafidx , gid_to_leafid end +function compile_python(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax, filename::String="GraphFunc.py") + py_string, leafnum, leafmap = to_python_str(graphs,framework) + println("The number of leaves: $leafnum") + open(filename, "w") do f + write(f, py_string) + end + return leafnum, leafmap +end + +# function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax) +# if framework == :jax +# head = "" +# elseif framework == :mindspore +# head = "import mindspore as ms\n@ms.jit\n" +# else +# error("no support for $type framework") +# end +# body = "" +# leafidx = 1 +# root = [id(g) for g in graphs] +# inds_visitedleaf = Int[] +# inds_visitednode = Int[] +# gid_to_leafid = Dict{String, Int64}() +# rootidx = 1 +# for graph in graphs +# for g in PostOrderDFS(graph) #leaf first search +# g_id = id(g) +# target = "g$(g_id)" +# isroot = false +# if g_id in root +# isroot = true +# end +# if isempty(subgraphs(g)) #leaf +# g_id in inds_visitedleaf && continue +# factor_str = factor(g) == 1 ? "" : " * $(factor(g))" +# body *= " $target = l$(leafidx)$factor_str\n" +# gid_to_leafid[target] = leafidx +# leafidx += 1 +# push!(inds_visitedleaf, g_id) +# else +# g_id in inds_visitednode && continue +# factor_str = factor(g) == 1 ? "" : " * $(factor(g))" +# body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n" +# push!(inds_visitednode, g_id) +# end +# if isroot +# body *= " out$(rootidx)=$target\n" +# rootidx +=1 +# end +# end +# end +# input = ["l$(i)" for i in 1:leafidx-1] +# input = join(input,",") +# output = ["out$(i)" for i in 1:rootidx-1] +# output = join(output,",") +# head *="def graphfunc($input):\n" +# tail = " return $output\n" +# # tail*= "def to_StaticGraph(leaf):\n" +# # tail*= " output = graphfunc(leaf)\n" +# # tail*= " return output" +# expr = head * body * tail +# println(expr) +# # return head * body * tail +# f = open("GraphFunc.py", "w") +# write(f, expr) +# return expr, leafidx-1, gid_to_leafid +# end From 9051a07ee30947ff2bf5fb103530fff0a6062502 Mon Sep 17 00:00:00 2001 From: Lizhiyi Date: Tue, 19 Dec 2023 22:33:39 +0800 Subject: [PATCH 02/12] debug --- src/backend/compiler_python.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index 878bd761..59f9c8d1 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -85,7 +85,7 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo inds_visitedleaf = Int[] inds_visitednode = Int[] gid_to_leafid = Dict{String,Int64}() - rootidx = 1 + rootidx = 0 for graph in graphs for g in PostOrderDFS(graph) #leaf first search g_id = id(g) From 408ccaeef08d85e98a3862c6e1887bbd5aff4a7d Mon Sep 17 00:00:00 2001 From: Lizhiyi Date: Tue, 19 Dec 2023 22:44:28 +0800 Subject: [PATCH 03/12] debug --- src/backend/compiler_python.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index 59f9c8d1..6ebaf763 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -73,7 +73,7 @@ end """ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax) if framework == :jax - head = "from jax import jit" + head = "from jax import jit\n" elseif framework == :mindspore head = "import mindspore as ms\n@ms.jit\n" else From 98ded11bfd01b11820963503ba6b01bee569ce1b Mon Sep 17 00:00:00 2001 From: Lizhiyi Date: Tue, 19 Dec 2023 22:56:47 +0800 Subject: [PATCH 04/12] debug --- src/backend/compiler_python.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index 6ebaf763..13917b5b 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -108,7 +108,7 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo push!(inds_visitednode, g_id) end if isroot - body *= " out$(rootidx)=$target\n" + body *= " out[$(rootidx)]=$target\n" rootidx += 1 end end From f74169a3281468a79cb1456768652e236420ea5f Mon Sep 17 00:00:00 2001 From: Lizhiyi Date: Tue, 19 Dec 2023 22:59:09 +0800 Subject: [PATCH 05/12] debug --- src/backend/compiler_python.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index 13917b5b..a79e177f 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -108,7 +108,7 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo push!(inds_visitednode, g_id) end if isroot - body *= " out[$(rootidx)]=$target\n" + body *= " root[$(rootidx)]=$target\n" rootidx += 1 end end From bf78162b25cc0ae7dcc9eba75c9ea784f194fc85 Mon Sep 17 00:00:00 2001 From: ZhiyiLi Date: Tue, 19 Dec 2023 23:25:29 +0800 Subject: [PATCH 06/12] delete redundancy --- src/backend/compiler_python.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index a79e177f..f6c4354b 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -126,8 +126,6 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo expr = head * body * tail # println(expr) # return head * body * tail - f = open("GraphFunc.py", "w") - write(f, expr) return expr, leafidx , gid_to_leafid end function compile_python(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax, filename::String="GraphFunc.py") From f35f2be75b88fd6075a4faf1150b53a84ea4aa0d Mon Sep 17 00:00:00 2001 From: ZhiyiLi Date: Tue, 19 Dec 2023 23:26:00 +0800 Subject: [PATCH 07/12] delete redundancy --- src/backend/compiler_python.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index f6c4354b..c95167df 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -115,11 +115,7 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo end head *= "def graphfunc(root,leaf):\n" tail = "\n" - - # tail = " return $output\n" - # tail*= "def to_StaticGraph(leaf):\n" - # tail*= " output = graphfunc(leaf)\n" - # tail*= " return output" + if framework == :jax tail *="graphfunc_jit = jit(graphfunc)" end From 5dff47623de7eb6b881acdf430672bb1c7602dbe Mon Sep 17 00:00:00 2001 From: ZhiyiLi Date: Tue, 19 Dec 2023 23:26:05 +0800 Subject: [PATCH 08/12] delete redundancy --- src/backend/compiler_python.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index c95167df..ecd0025c 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -115,13 +115,12 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo end head *= "def graphfunc(root,leaf):\n" tail = "\n" - + if framework == :jax tail *="graphfunc_jit = jit(graphfunc)" end expr = head * body * tail - # println(expr) - # return head * body * tail + return expr, leafidx , gid_to_leafid end function compile_python(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax, filename::String="GraphFunc.py") From e1914febbeac91b69bb03f7fd4dc50e70f30a347 Mon Sep 17 00:00:00 2001 From: ZhiyiLi Date: Wed, 20 Dec 2023 14:59:28 +0800 Subject: [PATCH 09/12] modify python compiler --- src/backend/compiler_python.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index ecd0025c..5c7d2622 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -108,13 +108,15 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo push!(inds_visitednode, g_id) end if isroot - body *= " root[$(rootidx)]=$target\n" + body *= " root$(rootidx) = $target\n" rootidx += 1 end end end - head *= "def graphfunc(root,leaf):\n" - tail = "\n" + head *= "def graphfunc(leaf):\n" + output = ["root$(i)" for i in 0:rootidx-1] + output = join(output,",") + tail = " return $output\n\n" if framework == :jax tail *="graphfunc_jit = jit(graphfunc)" From 1a74871851c4f590b43b4fe3e7dd97655512709f Mon Sep 17 00:00:00 2001 From: Peter Date: Thu, 21 Dec 2023 22:48:08 +0800 Subject: [PATCH 10/12] debug in compiler_python --- src/backend/compiler_python.jl | 390 ++++++++++++++++----------------- 1 file changed, 195 insertions(+), 195 deletions(-) diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index 5c7d2622..1fbea267 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -1,195 +1,195 @@ -# ms = pyimport("mindspore") - -""" - function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector) - -Returns the static representation of a computational graph node `g` with operator `operator`, subgraphs `subgraphs`, and subgraph factors `subgraph_factors` in python. -""" -function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector) - error( - "Static representation for computational graph nodes with operator $(operator) not yet implemented! " * - "Please define a method `to_static(::Type{$(operator)}, subgraphs::$(typeof(subgraphs)), subgraph_factors::$(typeof(subgraph_factors)))`." - ) -end - -function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W} - if length(subgraphs) == 1 - factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" - return "(g$(subgraphs[1].id)$factor_str)" - else - terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] - return "(" * join(terms, " + ") * ")" - end -end - -function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W} - if length(subgraphs) == 1 - factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" - return "(g$(subgraphs[1].id)$factor_str)" - else - terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] - return "(" * join(terms, " * ") * ")" - # return "(" * join(["g$(g.id)" for g in subgraphs], " * ") * ")" - end -end - -function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W} - factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" - return "((g$(subgraphs[1].id))**$N$factor_str)" -end - -function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W} - if length(subgraphs) == 1 - factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" - return "(g$(subgraphs[1].id)$factor_str)" - else - terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] - return "(" * join(terms, " + ") * ")" - end -end - -function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W} - if length(subgraphs) == 1 - factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" - return "(g$(subgraphs[1].id)$factor_str)" - else - terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] - return "(" * join(terms, " * ") * ")" - end -end - -function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W} - factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" - return "((g$(subgraphs[1].id))**$N$factor_str)" -end - -""" - function to_python_str(graphs::AbstractVector{<:AbstractGraph}) - Compile a list of graphs into a string for a python static function and output a python script which support the mindspore and jax framework. - - # Arguments: - - `graphs` vector of computational graphs - - `framework` the type of the python frameworks, including `:jax` and `mindspore`. -""" -function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax) - if framework == :jax - head = "from jax import jit\n" - elseif framework == :mindspore - head = "import mindspore as ms\n@ms.jit\n" - else - error("no support for $type framework") - end - body = "" - leafidx = 0 - root = [id(g) for g in graphs] - inds_visitedleaf = Int[] - inds_visitednode = Int[] - gid_to_leafid = Dict{String,Int64}() - rootidx = 0 - for graph in graphs - for g in PostOrderDFS(graph) #leaf first search - g_id = id(g) - target = "g$(g_id)" - isroot = false - if g_id in root - isroot = true - end - if isempty(subgraphs(g)) #leaf - g_id in inds_visitedleaf && continue - factor_str = factor(g) == 1 ? "" : " * $(factor(g))" - body *= " $target = leaf[$(leafidx)]$factor_str\n" - gid_to_leafid[target] = leafidx - leafidx += 1 - push!(inds_visitedleaf, g_id) - else - g_id in inds_visitednode && continue - factor_str = factor(g) == 1 ? "" : " * $(factor(g))" - body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n" - push!(inds_visitednode, g_id) - end - if isroot - body *= " root$(rootidx) = $target\n" - rootidx += 1 - end - end - end - head *= "def graphfunc(leaf):\n" - output = ["root$(i)" for i in 0:rootidx-1] - output = join(output,",") - tail = " return $output\n\n" - - if framework == :jax - tail *="graphfunc_jit = jit(graphfunc)" - end - expr = head * body * tail - - return expr, leafidx , gid_to_leafid -end -function compile_python(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax, filename::String="GraphFunc.py") - py_string, leafnum, leafmap = to_python_str(graphs,framework) - println("The number of leaves: $leafnum") - open(filename, "w") do f - write(f, py_string) - end - return leafnum, leafmap -end - -# function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax) -# if framework == :jax -# head = "" -# elseif framework == :mindspore -# head = "import mindspore as ms\n@ms.jit\n" -# else -# error("no support for $type framework") -# end -# body = "" -# leafidx = 1 -# root = [id(g) for g in graphs] -# inds_visitedleaf = Int[] -# inds_visitednode = Int[] -# gid_to_leafid = Dict{String, Int64}() -# rootidx = 1 -# for graph in graphs -# for g in PostOrderDFS(graph) #leaf first search -# g_id = id(g) -# target = "g$(g_id)" -# isroot = false -# if g_id in root -# isroot = true -# end -# if isempty(subgraphs(g)) #leaf -# g_id in inds_visitedleaf && continue -# factor_str = factor(g) == 1 ? "" : " * $(factor(g))" -# body *= " $target = l$(leafidx)$factor_str\n" -# gid_to_leafid[target] = leafidx -# leafidx += 1 -# push!(inds_visitedleaf, g_id) -# else -# g_id in inds_visitednode && continue -# factor_str = factor(g) == 1 ? "" : " * $(factor(g))" -# body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n" -# push!(inds_visitednode, g_id) -# end -# if isroot -# body *= " out$(rootidx)=$target\n" -# rootidx +=1 -# end -# end -# end -# input = ["l$(i)" for i in 1:leafidx-1] -# input = join(input,",") -# output = ["out$(i)" for i in 1:rootidx-1] -# output = join(output,",") -# head *="def graphfunc($input):\n" -# tail = " return $output\n" -# # tail*= "def to_StaticGraph(leaf):\n" -# # tail*= " output = graphfunc(leaf)\n" -# # tail*= " return output" -# expr = head * body * tail -# println(expr) -# # return head * body * tail -# f = open("GraphFunc.py", "w") -# write(f, expr) -# return expr, leafidx-1, gid_to_leafid -# end - +# ms = pyimport("mindspore") + +""" + function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector) + +Returns the static representation of a computational graph node `g` with operator `operator`, subgraphs `subgraphs`, and subgraph factors `subgraph_factors` in python. +""" +function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector) + error( + "Static representation for computational graph nodes with operator $(operator) not yet implemented! " * + "Please define a method `to_static(::Type{$(operator)}, subgraphs::$(typeof(subgraphs)), subgraph_factors::$(typeof(subgraph_factors)))`." + ) +end + +function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W} + if length(subgraphs) == 1 + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "(g$(subgraphs[1].id)$factor_str)" + else + terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] + return "(" * join(terms, " + ") * ")" + end +end + +function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W} + if length(subgraphs) == 1 + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "(g$(subgraphs[1].id)$factor_str)" + else + terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] + return "(" * join(terms, " * ") * ")" + # return "(" * join(["g$(g.id)" for g in subgraphs], " * ") * ")" + end +end + +function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W} + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "((g$(subgraphs[1].id))**$N$factor_str)" +end + +function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W} + if length(subgraphs) == 1 + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "(g$(subgraphs[1].id)$factor_str)" + else + terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] + return "(" * join(terms, " + ") * ")" + end +end + +function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W} + if length(subgraphs) == 1 + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "(g$(subgraphs[1].id)$factor_str)" + else + terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] + return "(" * join(terms, " * ") * ")" + end +end + +function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W} + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "((g$(subgraphs[1].id))**$N$factor_str)" +end + +""" + function to_python_str(graphs::AbstractVector{<:AbstractGraph}) + Compile a list of graphs into a string for a python static function and output a python script which support the mindspore and jax framework. + + # Arguments: + - `graphs` vector of computational graphs + - `framework` the type of the python frameworks, including `:jax` and `mindspore`. +""" +function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax) + if framework == :jax + head = "from jax import jit\n" + elseif framework == :mindspore + head = "import mindspore as ms\n@ms.jit\n" + else + error("no support for $type framework") + end + body = "" + leafidx = 0 + root = [id(g) for g in graphs] + inds_visitedleaf = Int[] + inds_visitednode = Int[] + gid_to_leafid = Dict{String,Int64}() + rootidx = 0 + for graph in graphs + for g in PostOrderDFS(graph) #leaf first search + g_id = id(g) + target = "g$(g_id)" + isroot = false + if g_id in root + isroot = true + end + if isempty(subgraphs(g)) #leaf + g_id in inds_visitedleaf && continue + factor_str = factor(g) == 1 ? "" : " * $(factor(g))" + body *= " $target = leaf[$(leafidx)]$factor_str\n" + gid_to_leafid[target] = leafidx + leafidx += 1 + push!(inds_visitedleaf, g_id) + else + g_id in inds_visitednode && continue + factor_str = factor(g) == 1 ? "" : " * $(factor(g))" + body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n" + push!(inds_visitednode, g_id) + end + if isroot + body *= " root$(rootidx) = $target\n" + rootidx += 1 + end + end + end + head *= "def graphfunc(leaf):\n" + output = ["root$(i)" for i in 0:rootidx-1] + output = join(output,",") + tail = " return $output\n\n" + + if framework == :jax + tail *="graphfunc_jit = jit(graphfunc)" + end + expr = head * body * tail + + return expr, leafidx , gid_to_leafid +end +function compile_python(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax, filename::String="GraphFunc.py") + py_string, leafnum, leafmap = to_python_str(graphs,framework) + println("The number of leaves: $leafnum") + open(filename, "w") do f + write(f, py_string) + end + return leafnum, leafmap +end + +# function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax) +# if framework == :jax +# head = "" +# elseif framework == :mindspore +# head = "import mindspore as ms\n@ms.jit\n" +# else +# error("no support for $type framework") +# end +# body = "" +# leafidx = 1 +# root = [id(g) for g in graphs] +# inds_visitedleaf = Int[] +# inds_visitednode = Int[] +# gid_to_leafid = Dict{String, Int64}() +# rootidx = 1 +# for graph in graphs +# for g in PostOrderDFS(graph) #leaf first search +# g_id = id(g) +# target = "g$(g_id)" +# isroot = false +# if g_id in root +# isroot = true +# end +# if isempty(subgraphs(g)) #leaf +# g_id in inds_visitedleaf && continue +# factor_str = factor(g) == 1 ? "" : " * $(factor(g))" +# body *= " $target = l$(leafidx)$factor_str\n" +# gid_to_leafid[target] = leafidx +# leafidx += 1 +# push!(inds_visitedleaf, g_id) +# else +# g_id in inds_visitednode && continue +# factor_str = factor(g) == 1 ? "" : " * $(factor(g))" +# body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n" +# push!(inds_visitednode, g_id) +# end +# if isroot +# body *= " out$(rootidx)=$target\n" +# rootidx +=1 +# end +# end +# end +# input = ["l$(i)" for i in 1:leafidx-1] +# input = join(input,",") +# output = ["out$(i)" for i in 1:rootidx-1] +# output = join(output,",") +# head *="def graphfunc($input):\n" +# tail = " return $output\n" +# # tail*= "def to_StaticGraph(leaf):\n" +# # tail*= " output = graphfunc(leaf)\n" +# # tail*= " return output" +# expr = head * body * tail +# println(expr) +# # return head * body * tail +# f = open("GraphFunc.py", "w") +# write(f, expr) +# return expr, leafidx-1, gid_to_leafid +# end + From dfa5e1b2f47306c9714fbe7b0b8f466b3a2c0c9d Mon Sep 17 00:00:00 2001 From: Lizhiyi Date: Fri, 22 Dec 2023 02:55:58 +0800 Subject: [PATCH 11/12] merge daniel's branch and modified to_dot.jl --- src/backend/to_dot.jl | 269 +++++++++++++++++++++++------------------- 1 file changed, 146 insertions(+), 123 deletions(-) diff --git a/src/backend/to_dot.jl b/src/backend/to_dot.jl index 69e64f2f..f1ddf522 100644 --- a/src/backend/to_dot.jl +++ b/src/backend/to_dot.jl @@ -8,25 +8,29 @@ function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F, subgr node_temp = "" arrow_temp = "" if factor != 1 - opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" - node_temp *= opr_fac * node_str + # opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" + # opr_name = "g$(id)_t" + # node_str = "g$(id)[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + # node_temp *= opr_fac * node_str + opr_node = "g$(id)[shape=box, label = <($factor)*⊕>, style=filled, fillcolor=cyan,fontsize=18]" else - opr_name = "g$id" + opr_node = "g$(id)[shape=box, label = <⊕>, style=filled, fillcolor=cyan,fontsize=18]" + # opr_name = "g$id" end - opr_node = opr_name * "[shape=box, label = \"Add\", style=filled, fillcolor=cyan,]\n" + opr_name = "g$id" + # opr_node = opr_name * "[shape=box, label = <⊕>, style=filled, fillcolor=cyan,]\n" node_temp *= opr_node for (gix, (g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors)) if gfactor != 1 - factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" - subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - node_temp *= factor_str * subg_str - arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" - arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" + # factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" + # subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # node_temp *= factor_str * subg_str + # arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" + # arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" + arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,label=$gfactor]\n" else - arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,]\n" + arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,]\n" end end return node_temp, arrow_temp @@ -36,37 +40,39 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg node_temp = "" arrow_temp = "" if factor != 1 - opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" - node_temp *= opr_fac * node_str + # opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" + # opr_name = "g$(id)_t" + # node_str = "g$(id)[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + # node_temp *= opr_fac * node_str + opr_node = "g$id[shape=box, label = <($factor)⊗>, style=filled, fillcolor=cornsilk,fontsize=18]\n" else - opr_name = "g$id" + opr_node = "g$id[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,fontsize=18]\n" end - opr_node = opr_name * "[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" + # opr_node = opr_name * "[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" node_temp *= opr_node - if length(subgraphs) == 1 - if subgraph_factors[1] == 1 - arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" + # if length(subgraphs) == 1 + # if subgraph_factors[1] == 1 + # arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" + # else + # factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" + # node_temp *= factor_str + # arrow_temp *= "factor$(subgraphs[1].id)_$(id)->$opr_name[arrowhead=vee,]\ng$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" + # end + # else + for (gix, (g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors)) + if gfactor != 1 + # factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" + # subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # node_temp *= factor_str * subg_str + # arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" + # arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" + arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,label=$gfactor]\n" else - factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" - node_temp *= factor_str - arrow_temp *= "factor$(subgraphs[1].id)_$(id)->$opr_name[arrowhead=vee,]\ng$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" - end - else - for (gix, (g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors)) - if gfactor != 1 - factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" - subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - node_temp *= factor_str * subg_str - arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" - arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" - else - arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,]\n" - end + arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,]\n" end end + # end return node_temp, arrow_temp end @@ -74,24 +80,26 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, node_temp = "" arrow_temp = "" if factor != 1 - opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" - node_temp *= opr_fac * node_str + # opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" + # opr_name = "g$(id)_t" + # node_str = "g$(id)[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + # node_temp *= opr_fac * node_str + opr_node = "g$id[shape=box, label = <($factor)*Pow($N)>, style=filled, fillcolor=darkolivegreen,fontsize=18]\n" else - opr_name = "g$id" + opr_node = "g$id[shape=box, label = , style=filled, fillcolor=darkolivegreen,fontsize=18]\n" end - opr_node = opr_name * "[shape=box, label = \"Pow\", style=filled, fillcolor=darkolivegreen,]\n" - order_node = "order$(id)[label=$N, style=filled, fillcolor=lavender]\n" - node_temp *= opr_node * order_node - arrow_temp *= "order$(id)->$opr_name[arrowhead=vee,]\n" + # opr_node = "g$id[shape=box, label = , style=filled, fillcolor=darkolivegreen,]\n" + # order_node = "order$(id)[label=$N, style=filled, fillcolor=lavender]\n" + # node_temp *= opr_node * order_node + node_temp *= opr_node + # arrow_temp *= "order$(id)->$opr_name[arrowhead=vee,]\n" if subgraph_factors[1] != 1 - factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" - subg_str = "g$(subgraphs[1].id)_$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - node_temp *= factor_str * subg_str - arrow_temp *= "factor$(subgraphs[1].id)_$(id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\ng$(subgraphs[1].id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\n" - arrow_temp *= "g$(subgraphs[1].id)_$(id)->$opr_name[arrowhead=vee,]\n" + # factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" + # subg_str = "g$(subgraphs[1].id)_$(id)[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # node_temp *= factor_str * subg_str + # arrow_temp *= "factor$(subgraphs[1].id)_$(id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\ng$(subgraphs[1].id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\n" + arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,label=$gfactor]\n" else arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" end @@ -102,23 +110,27 @@ function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F, subgr node_temp = "" arrow_temp = "" if factor != 1 - opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" - node_temp *= opr_fac * node_str + # opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" + # opr_name = "g$(id)_t" + # node_str = "g$(id)[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + # node_temp *= opr_fac * node_str + opr_node = "g$(id)[shape=box, label = <($factor)*⊕>, style=filled, fillcolor=cyan,fontsize=18]" else - opr_name = "g$id" + opr_node = "g$(id)[shape=box, label = <⊕>, style=filled, fillcolor=cyan,fontsize=18]" + # opr_name = "g$id" end - opr_node = opr_name * "[shape=box, label = \"Add\", style=filled, fillcolor=cyan,]\n" + opr_name = "g$id" + # opr_node = opr_name * "[shape=box, label = <⊕>, style=filled, fillcolor=cyan,]\n" node_temp *= opr_node for (gix, (g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors)) if gfactor != 1 - factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" - subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - node_temp *= factor_str * subg_str - arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" - arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" + # factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" + # subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # node_temp *= factor_str * subg_str + # arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" + # arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" + arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,label=$gfactor]\n" else arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,]\n" end @@ -130,37 +142,39 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg node_temp = "" arrow_temp = "" if factor != 1 - opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" - node_temp *= opr_fac * node_str + # opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" + # opr_name = "g$(id)_t" + # node_str = "g$(id)[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + # node_temp *= opr_fac * node_str + opr_node = "g$id[shape=box, label = <($factor)⊗>, style=filled, fillcolor=cornsilk,fontsize=18]\n" else - opr_name = "g$id" + opr_node = "g$id[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,fontsize=18]\n" end - opr_node = opr_name * "[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" + # opr_node = opr_name * "[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" node_temp *= opr_node - if length(subgraphs) == 1 - if subgraph_factors[1] == 1 - arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" + # if length(subgraphs) == 1 + # if subgraph_factors[1] == 1 + # arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" + # else + # factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" + # node_temp *= factor_str + # arrow_temp *= "factor$(subgraphs[1].id)_$(id)->$opr_name[arrowhead=vee,]\ng$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" + # end + # else + for (gix, (g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors)) + if gfactor != 1 + # factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" + # subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # node_temp *= factor_str * subg_str + # arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" + # arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" + arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,label=$gfactor]\n" else - factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" - node_temp *= factor_str - arrow_temp *= "factor$(subgraphs[1].id)_$(id)->$opr_name[arrowhead=vee,]\ng$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" - end - else - for (gix, (g, gfactor)) in enumerate(zip(subgraphs, subgraph_factors)) - if gfactor != 1 - factor_str = "factor$(g.id)_$(id)_$gix[label=$(gfactor), style=filled, fillcolor=lavender]\n" - subg_str = "g$(g.id)_$(id)_$gix[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - node_temp *= factor_str * subg_str - arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" - arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" - else - arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,]\n" - end + arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,]\n" end end + # end return node_temp, arrow_temp end @@ -168,24 +182,26 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, node_temp = "" arrow_temp = "" if factor != 1 - opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" - opr_name = "g$(id)_t" - node_str = "g$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" - node_temp *= opr_fac * node_str + # opr_fac = "factor$(id)[label=$(factor), style=filled, fillcolor=lavender]\n" + # opr_name = "g$(id)_t" + # node_str = "g$(id)[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # arrow_temp *= "factor$(id)->g$(id)[arrowhead=vee,]\ng$(id)_t->g$(id)[arrowhead=vee,]\n" + # node_temp *= opr_fac * node_str + opr_node = "g$id[shape=box, label = <($factor)*Pow($N)>, style=filled, fillcolor=darkolivegreen,fontsize=18]\n" else - opr_name = "g$id" + opr_node = "g$id[shape=box, label = , style=filled, fillcolor=darkolivegreen,fontsize=18]\n" end - opr_node = opr_name * "[shape=box, label = \"Pow\", style=filled, fillcolor=darkolivegreen,]\n" - order_node = "order$(id)[label=$N, style=filled, fillcolor=lavender]\n" - node_temp *= opr_node * order_node - arrow_temp *= "order$(id)->$opr_name[arrowhead=vee,]\n" + # opr_node = "g$id[shape=box, label = , style=filled, fillcolor=darkolivegreen,]\n" + # order_node = "order$(id)[label=$N, style=filled, fillcolor=lavender]\n" + # node_temp *= opr_node * order_node + node_temp *= opr_node + # arrow_temp *= "order$(id)->$opr_name[arrowhead=vee,]\n" if subgraph_factors[1] != 1 - factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" - subg_str = "g$(subgraphs[1].id)_$(id)[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - node_temp *= factor_str * subg_str - arrow_temp *= "factor$(subgraphs[1].id)_$(id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\ng$(subgraphs[1].id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\n" - arrow_temp *= "g$(subgraphs[1].id)_$(id)->$opr_name[arrowhead=vee,]\n" + # factor_str = "factor$(subgraphs[1].id)_$(id)[label=$(subgraph_factors[1]), style=filled, fillcolor=lavender]\n" + # subg_str = "g$(subgraphs[1].id)_$(id)[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + # node_temp *= factor_str * subg_str + # arrow_temp *= "factor$(subgraphs[1].id)_$(id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\ng$(subgraphs[1].id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\n" + arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,label=$gfactor]\n" else arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" end @@ -203,7 +219,7 @@ end """ function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="", diagram_id_map=nothing) head = "digraph ComputationalGraph { \nlabel=\"$name\"\n" - head *= "ReturnNode[shape=box, label = \"Return\", style=filled, fillcolor=darkorange,]\n" + head *= "ReturnNode[shape=box, label = \"Return\", style=filled, fillcolor=darkorange,fontsize=18]\n" body_node = "" body_arrow = "" leafidx = 1 @@ -222,14 +238,19 @@ function to_dot_str(graphs::AbstractVector{<:AbstractGraph}, name::String="", di g_id in inds_visitedleaf && continue leafname = get_leafname(g, leafidx, diagram_id_map) if factor(g) == 1 - gnode_str = "g$g_id[label=$leafname, style=filled, fillcolor=paleturquoise]\n" + gnode_str = "g$g_id[label=<$leafname>, style=filled, fillcolor=paleturquoise,fontsize=18]\n" + body_node *= gnode_str + elseif factor(g) == -1 + gnode_str = "g$g_id[label=<-$leafname>, style=filled, fillcolor=paleturquoise,fontsize=18]\n" body_node *= gnode_str else - factor_str = "factor$(leafidx)_inp[label=$(factor(g)), style=filled, fillcolor=lavender]\n" - leaf_node = "l$(leafidx)[label=$leafname, style=filled, fillcolor=paleturquoise]\n" - gnode_str = "g$g_id[shape=box, label = \"Mul\", style=filled, fillcolor=cornsilk,]\n" - body_node *= factor_str * leaf_node * gnode_str - body_arrow *= "factor$(leafidx)_inp->g$g_id[arrowhead=vee,]\nl$(leafidx)->g$g_id[arrowhead=vee,]\n" + # factor_str = "factor$(leafidx)_inp[label=$(factor(g)), style=filled, fillcolor=lavender]\n" + # leaf_node = "l$(leafidx)[label=$leafname, style=filled, fillcolor=paleturquoise]\n" + # gnode_str = "g$g_id[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" + gnode_str = "l$(leafidx)[label=<$(factor(g))$leafname>, style=filled, fillcolor=paleturquoise,fontsize=18]\n" + # body_node *= factor_str * leaf_node * gnode_str + body_node *= gnode_str + # body_arrow *= "factor$(leafidx)_inp->g$g_id[arrowhead=vee,]\nl$(leafidx)->g$g_id[arrowhead=vee,]\n" end leafidx += 1 push!(inds_visitedleaf, g_id) @@ -267,27 +288,29 @@ function get_leafname(g, leafidx, diagram_id_map=nothing) elseif g isa Graph if isnothing(diagram_id_map) == false leaftype = typeof(diagram_id_map[g.id]) + else + leaftype = typeof(g.properties) end else error("Unknown graph type: $(typeof(g))") end if leaftype in [BareGreenId, ComputationalGraphs.Propagator] - leafname = "<G$leafidx>" + leafname = "G$leafidx" elseif leaftype in [BareInteractionId, ComputationalGraphs.Interaction] - leafname = "<V$leafidx>" + leafname = "V$leafidx" elseif leaftype == PolarId - leafname = "<Π$leafidx>" + leafname = "Π$leafidx" elseif leaftype == Ver3Id - leafname = "<Γ(3)$leafidx>" + leafname = "Γ(3)$leafidx" elseif leaftype == Ver4Id - leafname = "<Γ(4)$leafidx>" + leafname = "Γ(4)$leafidx" else - leafname = "$leafidx>" + leafname = "L$leafidx" end - println() - println(g) - println(leaftype) - println(leafname) - println() + # println() + # println(g) + # println(leaftype) + # println(leafname) + # println() return leafname end From 77236efce4a253e3b96c48b8301769edbffe4b4b Mon Sep 17 00:00:00 2001 From: Lizhiyi Date: Fri, 22 Dec 2023 14:36:38 +0800 Subject: [PATCH 12/12] modified fontsize in dot file --- src/backend/to_dot.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/backend/to_dot.jl b/src/backend/to_dot.jl index f1ddf522..6fc944dd 100644 --- a/src/backend/to_dot.jl +++ b/src/backend/to_dot.jl @@ -28,7 +28,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F, subgr # node_temp *= factor_str * subg_str # arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" # arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" - arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,label=$gfactor]\n" + arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,label=$gfactor,fontsize=16]\n" else arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,]\n" end @@ -67,7 +67,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg # node_temp *= factor_str * subg_str # arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" # arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" - arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,label=$gfactor]\n" + arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,label=$gfactor,fontsize=16]\n" else arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,]\n" end @@ -99,7 +99,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, # subg_str = "g$(subgraphs[1].id)_$(id)[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" # node_temp *= factor_str * subg_str # arrow_temp *= "factor$(subgraphs[1].id)_$(id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\ng$(subgraphs[1].id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\n" - arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,label=$gfactor]\n" + arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,label=$gfactor,fontsize=16]\n" else arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" end @@ -130,7 +130,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Sum}, id::Int, factor::F, subgr # node_temp *= factor_str * subg_str # arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" # arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" - arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,label=$gfactor]\n" + arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,label=$gfactor,fontsize=16]\n" else arrow_temp *= "g$(g.id)->$opr_name[arrowhead=vee,]\n" end @@ -169,7 +169,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Prod}, id::Int, factor::F, subg # node_temp *= factor_str * subg_str # arrow_temp *= "factor$(g.id)_$(id)_$gix->g$(g.id)_$(id)_$gix[arrowhead=vee,]\ng$(g.id)->g$(g.id)_$(id)_$gix[arrowhead=vee,]\n" # arrow_temp *= "g$(g.id)_$(id)_$gix->$opr_name[arrowhead=vee,]\n" - arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,label=$gfactor]\n" + arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,label=$gfactor,fontsize=16]\n" else arrow_temp *= "g$(g.id)->g$(id)[arrowhead=vee,]\n" end @@ -201,7 +201,7 @@ function to_dotstatic(::Type{ComputationalGraphs.Power{N}}, id::Int, factor::F, # subg_str = "g$(subgraphs[1].id)_$(id)[shape=box, label = <⊗>, style=filled, fillcolor=cornsilk,]\n" # node_temp *= factor_str * subg_str # arrow_temp *= "factor$(subgraphs[1].id)_$(id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\ng$(subgraphs[1].id)->g$(subgraphs[1].id)_$(id)[arrowhead=vee,]\n" - arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,label=$gfactor]\n" + arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,label=$gfactor,fontsize=16]\n" else arrow_temp *= "g$(subgraphs[1].id)->$opr_name[arrowhead=vee,]\n" end @@ -281,7 +281,7 @@ function compile_dot(graphs::AbstractVector{<:AbstractGraph}, filename::String; end function get_leafname(g, leafidx, diagram_id_map=nothing) - println(typeof(g)) + # println(typeof(g)) leaftype = Nothing if g isa FeynmanGraph leaftype = g.properties.diagtype