Skip to content

Commit

Permalink
remove factor field in AbstractGraph and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
houpc committed Jan 10, 2024
1 parent 63bbd9e commit 690964d
Show file tree
Hide file tree
Showing 14 changed files with 75 additions and 260 deletions.
8 changes: 4 additions & 4 deletions example/to_dot_parquetV2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using FeynmanDiagram
function recursive_print(diag)
if typeof(diag) <: FeynmanDiagram.ComputationalGraphs.Graph
if !isempty(diag.subgraphs)
print("$(diag.id) $(diag.factor) $(diag.subgraph_factors)\n")
print("$(diag.id) $(diag.subgraph_factors)\n")
for subdiag in diag.subgraphs
recursive_print(subdiag)
end
Expand All @@ -24,7 +24,7 @@ function main()
KinL, KoutL, KinR = zeros(16), zeros(16), zeros(16)
KinL[1], KoutL[2], KinR[3] = 1.0, 1.0, 1.0
# para = GV.diagPara(SigmaDiag, false, spin, order, [NoHartree], KinL - KoutL)
para = DiagParaF64(type=SigmaDiag, innerLoopNum=order, interaction=[Interaction(UpUp, [Instant,])], hasTau=true)
para = DiagPara(type=SigmaDiag, innerLoopNum=order, interaction=[Interaction(UpUp, [Instant,])], hasTau=true)
# para = DiagParaF64(type=SigmaDiag, innerLoopNum=2, interaction=[Interaction(ChargeCharge, [Instant,])], hasTau=true)
parquet_builder = Parquet.build(para)
diag = parquet_builder.diagram
Expand All @@ -38,8 +38,8 @@ function main()
# print("new diag2\n")
# recursive_print(eachd)
# end
G = FrontEnds.Graph!(d[1])
G = [eldest(G)] # drop extraneous Add node at root
# G = FrontEnds.Graph!(d[1])
G = [eldest(d[1])] # drop extraneous Add node at root
# for d in G
# print("graph1\n")
# recursive_print(d)
Expand Down
2 changes: 1 addition & 1 deletion src/backend/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Compilers
using PyCall
using ..ComputationalGraphs
import ..ComputationalGraphs: id, name, set_name!, operator, subgraphs, subgraph_factors, factor, FeynmanProperties
import ..ComputationalGraphs: id, name, set_name!, operator, subgraphs, subgraph_factors, FeynmanProperties

using ..Parquet
using ..Parquet: PropagatorId, BareGreenId, BareInteractionId
Expand Down
14 changes: 6 additions & 8 deletions src/backend/compiler_python.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,13 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo
end
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = leaf[$(leafidx)]$factor_str\n"
body *= " $target = leaf[$(leafidx)]\n"
gid_to_leafid[target] = leafidx
leafidx += 1
push!(inds_visitedleaf, g_id)
else
g_id in inds_visitednode && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n"
body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))\n"
push!(inds_visitednode, g_id)
end
if isroot
Expand All @@ -115,18 +113,18 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbo
end
head *= "def graphfunc(leaf):\n"
output = ["root$(i)" for i in 0:rootidx-1]
output = join(output,",")
output = join(output, ",")
tail = " return $output\n\n"

if framework == :jax
tail *="graphfunc_jit = jit(graphfunc)"
tail *= "graphfunc_jit = jit(graphfunc)"
end
expr = head * body * tail

return expr, leafidx , gid_to_leafid
return expr, leafidx, gid_to_leafid
end
function compile_python(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax, filename::String="GraphFunc.py")
py_string, leafnum, leafmap = to_python_str(graphs,framework)
py_string, leafnum, leafmap = to_python_str(graphs, framework)
println("The number of leaves: $leafnum")
open(filename, "w") do f
write(f, py_string)
Expand Down
12 changes: 4 additions & 8 deletions src/backend/static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,13 @@ function to_julia_str(graphs::AbstractVector{<:AbstractGraph}; root::AbstractVec
end
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = leafVal[$idx_leafVal]$factor_str\n"
body *= " $target = leafVal[$idx_leafVal]\n"
map_validx_leaf[idx_leafVal] = g
idx_leafVal += 1
push!(inds_visitedleaf, g_id)
else
g_id in inds_visitednode && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = $(to_static(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n"
body *= " $target = $(to_static(operator(g), subgraphs(g), subgraph_factors(g)))\n"
push!(inds_visitednode, g_id)
end
if isroot
Expand Down Expand Up @@ -160,16 +158,14 @@ function to_Cstr(graphs::AbstractVector{<:AbstractGraph}; root::AbstractVector{I
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
declare *= " g$g_id,"
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = leafVal[$idx_leafVal]$factor_str;\n"
body *= " $target = leafVal[$idx_leafVal];\n"
idx_leafVal += 1
map_validx_leaf[idx_leafVal] = g
push!(inds_visitedleaf, g_id)
else
g_id in inds_visitednode && continue
declare *= " g$g_id,"
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = $(to_static(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str;\n"
body *= " $target = $(to_static(operator(g), subgraphs(g), subgraph_factors(g)));\n"
push!(inds_visitednode, g_id)
end
if isroot
Expand Down
Loading

0 comments on commit 690964d

Please sign in to comment.