diff --git a/src/computational_graph/optimize.jl b/src/computational_graph/optimize.jl index eb856e78..206fe576 100644 --- a/src/computational_graph/optimize.jl +++ b/src/computational_graph/optimize.jl @@ -17,7 +17,6 @@ function optimize!(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose else graphs = collect(graphs) leaf_mapping = remove_duplicated_leaves!(graphs, verbose=verbose, normalize=normalize) - # merge_all_chains!(graphs, verbose=verbose) flatten_all_chains!(graphs, verbose=verbose) merge_all_linear_combinations!(graphs, verbose=verbose) return leaf_mapping @@ -44,146 +43,10 @@ function optimize(graphs::Union{Tuple,AbstractVector{<:AbstractGraph}}; verbose= return graphs_new, leaf_mapping end -""" - function merge_all_chain_prefactors!(g::AbstractGraph; verbose=0) - - In-place merge prefactors of all nodes representing trivial unary chains towards the root level for a single graph. - -# Arguments: -- `g`: An AbstractGraph. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graph. -# -""" -function merge_all_chain_prefactors!(g::AbstractGraph; verbose=0) - verbose > 0 && println("merge prefactors of all nodes representing trivial unary chains toward root level.") - # Post-order DFS - for sub_g in g.subgraphs - merge_all_chain_prefactors!(sub_g) - merge_chain_prefactors!(sub_g) - end - merge_chain_prefactors!(g) - return g -end - -""" - function merge_all_chain_prefactors!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) - - In-place merge prefactors of all nodes representing trivial unary chains towards the root level for given graphs. - -# Arguments: -- `graphs`: An AbstractVector of graphs. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graphs. -# -""" -function merge_all_chain_prefactors!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) - verbose > 0 && println("merge prefactors of all nodes representing trivial unary chains toward root level.") - # Post-order DFS - for g in graphs - merge_all_chain_prefactors!(g.subgraphs) - merge_chain_prefactors!(g) - end - return graphs -end - -""" - function merge_all_factorless_chains!(g::AbstractGraph; verbose=0) - - In-place merge all nodes representing factorless trivial unary chains within a single graph. - -# Arguments: -- `g`: An AbstractGraph. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graph. -# -""" -function merge_all_factorless_chains!(g::AbstractGraph; verbose=0) - verbose > 0 && println("merge all nodes representing factorless trivial unary chains.") - # Post-order DFS - for sub_g in g.subgraphs - merge_all_factorless_chains!(sub_g) - merge_factorless_chain!(sub_g) - end - merge_factorless_chain!(g) - return g -end - -""" - function merge_all_factorless_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) - - In-place merge all nodes representing factorless trivial unary chains within given graphs. - -# Arguments: -- `graphs`: An AbstractVector of graphs. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graphs. -# -""" -function merge_all_factorless_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) - verbose > 0 && println("merge all nodes representing factorless trivial unary chains.") - # Post-order DFS - for g in graphs - merge_all_factorless_chains!(g.subgraphs) - merge_factorless_chain!(g) - end - return graphs -end - -""" - function merge_all_chains!(g::AbstractGraph; verbose=0) - - In-place merge all nodes representing trivial unary chains within a single graph. - This function consolidates both chain prefactors and factorless chains. - -# Arguments: -- `g`: An AbstractGraph. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graph. -# -""" -function merge_all_chains!(g::AbstractGraph; verbose=0) - verbose > 0 && println("merge all nodes representing trivial unary chains.") - merge_all_chain_prefactors!(g, verbose=verbose) - merge_all_factorless_chains!(g, verbose=verbose) - return g -end - -""" - function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) where {G<:AbstractGraph} - - In-place merge all nodes representing trivial unary chains in given graphs. - This function consolidates both chain prefactors and factorless chains. - -# Arguments: -- `graphs`: An AbstractVector of graphs. -- `verbose`: Level of verbosity (default: 0). - -# Returns: -- Optimized graphs. -# -""" -function merge_all_chains!(graphs::AbstractVector{<:AbstractGraph}; verbose=0) - verbose > 0 && println("merge all nodes representing trivial unary chains.") - merge_all_chain_prefactors!(graphs, verbose=verbose) - merge_all_factorless_chains!(graphs, verbose=verbose) - return graphs -end - """ function flatten_all_chains!(g::AbstractGraph; verbose=0) - - In-place flattens all nodes representing trivial unary chains in the given graph `g`. +F + Flattens all nodes representing trivial unary chains in-place in the given graph `g`. # Arguments: - `graphs`: The graph to be processed. diff --git a/src/computational_graph/transform.jl b/src/computational_graph/transform.jl index 7c291ca4..a0683380 100644 --- a/src/computational_graph/transform.jl +++ b/src/computational_graph/transform.jl @@ -157,134 +157,6 @@ function replace_subgraph(g::AbstractGraph, w::AbstractGraph, m::AbstractGraph) return g_new end -""" - function merge_factorless_chain!(g::AbstractGraph) - - Simplifies `g` in-place if it represents a factorless trivial unary chain. For example, +(+(+g)) ↦ g. - - Does nothing unless g has the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, - a node with non-unity multiplicative prefactor, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -function merge_factorless_chain!(g::AbstractGraph) - if unary_istrivial(g.operator) && onechild(g) && isfactorless(g) - child = eldest(g) - for field in fieldnames(typeof(g)) - value = getproperty(child, field) - setproperty!(g, field, value) - end - end - while unary_istrivial(g.operator) && onechild(g) && isfactorless(g) - child = eldest(g) - for field in fieldnames(typeof(g)) - value = getproperty(child, field) - setproperty!(g, field, value) - end - end - return g -end - -""" - function merge_factorless_chain(g::AbstractGraph) - - Returns a simplified copy of `g` if it represents a factorless trivial unary chain. - Otherwise, returns the original graph. For example, +(+(+g)) ↦ g. - - Does nothing unless g has the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, - a node with non-unity multiplicative prefactor, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -function merge_factorless_chain(g::AbstractGraph) - while unary_istrivial(g.operator) && onechild(g) && isfactorless(g) - g = eldest(g) - end - return g -end - -""" - function merge_chain_prefactors!(g::AbstractGraph) - - Simplifies subgraphs of g representing trivial unary chains by merging their - subgraph factors toward root level, e.g., 2*(3*(5*g)) + 7*(9*(h)) ↦ 30*(*(*g)) + 63*(*h). - - Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -function merge_chain_prefactors!(g::AbstractGraph) - for (i, child) in enumerate(g.subgraphs) - total_chain_factor = 1 - while onechild(child) - # Break case: end of trivial unary chain - unary_istrivial(child.operator) == false && break - # Move this subfactor to running total - total_chain_factor *= child.subgraph_factors[1] - child.subgraph_factors[1] = 1 - # Descend one level - child = eldest(child) - end - # Update g subfactor with total factors from children - g.subgraph_factors[i] *= total_chain_factor - end - return g -end - -""" - function merge_chain_prefactors(g::AbstractGraph) - - Returns a copy of g with subgraphs representing trivial unary chains simplified by merging - their subgraph factors toward root level, e.g., 2*(3*(5*g)) + 7*(9*(h)) ↦ 30*(*(*g)) + 63*(*h). - - Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -merge_chain_prefactors(g::AbstractGraph) = merge_chain_prefactors!(deepcopy(g)) - -""" - function merge_chains!(g::AbstractGraph) - - Converts subgraphs of g representing trivial unary chains - to in-place form, e.g., 2*(3*(5*g)) + 7*(9*(h)) ↦ 30*g + 63*h. - - Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -function merge_chains!(g::AbstractGraph) - merge_chain_prefactors!(g) # shift chain subgraph factors towards root level - for sub_g in g.subgraphs # prune factorless chain subgraphs - merge_factorless_chain!(sub_g) - end - return g -end - -""" - function merge_chains(g::AbstractGraph) - - Returns a copy of a graph g with subgraphs representing trivial unary chain - simplified to in-place form, e.g., 2*(3*(5*g)) + 7*(9*(h)) ↦ 30*g + 63*h. - - Acts only on subgraphs of g with the following structure: 𝓞 --- 𝓞' --- ⋯ --- 𝓞'' ⋯ (!), - where the stop-case (!) represents a leaf, a non-trivial unary operator 𝓞'''(g) != g, or a non-unary operation. - -# Arguments: -- `g::AbstractGraph`: graph to be modified -""" -merge_chains(g::AbstractGraph) = merge_chains!(deepcopy(g)) - """ function flatten_chains!(g::AbstractGraph) diff --git a/test/computational_graph.jl b/test/computational_graph.jl index e56643a3..37ab49bf 100644 --- a/test/computational_graph.jl +++ b/test/computational_graph.jl @@ -111,14 +111,8 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true g4p = Graph([g3p,]; operator=Graphs.Sum()) @test Graphs.unary_istrivial(Graphs.Prod) @test Graphs.unary_istrivial(Graphs.Sum) - @test Graphs.merge_factorless_chain(g2) == g1 - @test Graphs.merge_factorless_chain(g3) == g1 - @test Graphs.merge_factorless_chain(g4) == g1 - @test Graphs.merge_factorless_chain(g3p) == g3p - @test Graphs.merge_factorless_chain(g4p) == g3p g5 = Graph([g1,]; operator=O()) @test Graphs.unary_istrivial(O) == false - @test Graphs.merge_factorless_chain(g5) == g5 end g1 = Graph([]) g2 = Graph([g1,]; subgraph_factors=[5,], operator=Graphs.Prod()) @@ -130,43 +124,6 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true g2p = Graph([g1, g2]; operator=Graphs.Sum()) g3p = Graph([g2p,]; subgraph_factors=[3,], operator=Graphs.Prod()) gp = Graph([g3p,]; subgraph_factors=[2,], operator=Graphs.Prod()) - @testset "Merge chains" begin - # g ↦ 30*(*(*g1)) - g_merged = Graphs.merge_chain_prefactors(g) - @test g_merged.subgraph_factors == [30,] - @test all(isfactorless(node) for node in PreOrderDFS(eldest(g_merged))) - # in-place form - gc = deepcopy(g) - Graphs.merge_chain_prefactors!(gc) - @test isequiv(gc, g_merged, :id) - # gp ↦ 6*(*(g1 + 5*g1)) - gp_merged = Graphs.merge_chain_prefactors(gp) - @test gp_merged.subgraph_factors == [6,] - @test isfactorless(eldest(gp)) == false - @test isfactorless(eldest(gp_merged)) - @test eldest(eldest(gp_merged)) == g2p - # g ↦ 30*g1 - g_merged = merge_chains(g) - @test isequiv(g_merged, 30 * g1, :id) - # in-place form - merge_chains!(g) - @test isequiv(g, 30 * g1, :id) - # gp ↦ 6*(g1 + 5*g1) - gp_merged = merge_chains(gp) - @test isequiv(gp_merged, 6 * g2p, :id) - # Test a generic trivial unary chain - # *(O3(5 * O2(3 * O1(2 * h)))) ↦ 30 * h - h = Graph([]) - h1 = Graph([h,]; subgraph_factors=[2,], operator=O1()) - h2 = Graph([h1,]; subgraph_factors=[3,], operator=O2()) - h3 = Graph([h2,]; subgraph_factors=[5,], operator=O3()) - h4 = Graph([h3,]; operator=Graphs.Prod()) - h4_merged = merge_chains(h4) - @test isequiv(h4_merged, 30 * h, :id) - # in-place form - merge_chains!(h4) - @test isequiv(h4, 30 * h, :id) - end @testset "Merge prefactors" begin g1 = propagator(𝑓⁺(1)𝑓⁻(2)) h1 = FeynmanGraph([g1, g1], drop_topology(g1.properties); subgraph_factors=[1, 2], operator=Graphs.Sum()) @@ -239,40 +196,6 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true end end @testset verbose = true "Optimizations" begin - @testset "Remove one-child parents" begin - # h = O(7 * (5 * (3 * (2 * g)))) ↦ O(210 * g) - g1 = Graph([]) - g2 = 2 * g1 - g3 = Graph([g2,]; subgraph_factors=[3,], operator=Graphs.Prod()) - g4 = Graph([g3,]; subgraph_factors=[5,], operator=Graphs.Prod()) - h = Graph([g4,]; subgraph_factors=[7,], operator=O()) - hvec = repeat([deepcopy(h)], 3) - # Test on a single graph - Graphs.merge_all_chains!(h) - @test h.operator == O - @test h.subgraph_factors == [210,] - @test eldest(h) == g1 - # Test on a vector of graphs - Graphs.merge_all_chains!(hvec) - @test all(h.operator == O for h in hvec) - @test all(h.subgraph_factors == [210,] for h in hvec) - @test all(eldest(h) == g1 for h in hvec) - - g2 = 2 * g1 - g3 = Graph([g2,]; subgraph_factors=[3,], operator=Graphs.Prod()) - g4 = Graph([g3,]; subgraph_factors=[5,], operator=Graphs.Prod()) - h0 = Graph([g1, g4]; subgraph_factors=[2, 7], operator=O()) - Graphs.merge_all_chains!(h0) - @test h0.subgraph_factors == [2, 210] - @test h0.subgraphs[2] == g1 - - h1 = Graph([h0]; subgraph_factors=[3,], operator=Graphs.Prod()) - h2 = Graph([h1]; subgraph_factors=[5,], operator=Graphs.Prod()) - h = Graph([h2]; subgraph_factors=[7,], operator=O()) - Graphs.merge_all_chains!(h) - @test h.subgraph_factors == [105] - @test eldest(h) == h0 - end @testset "Flatten all chains" begin l0 = Graph([]) l1 = Graph([l0]; subgraph_factors=[2]) @@ -298,14 +221,6 @@ Graphs.unary_istrivial(::Type{O}) where {O<:Union{O1,O2,O3}} = true @test isequiv(r3, Graph([g1, g1,]; subgraph_factors=[12, 210], operator=O()), :id) Graphs.flatten_all_chains!(rvec) @test rvec == [r1, r2, r3] - - Graphs.merge_all_chains!(rvec1) - @test rvec1[1].subgraph_factors == [210] - @test eldest(rvec1[1]) == g1 - @test rvec1[2].subgraph_factors == [-1] - @test eldest(rvec1[2]) == g1 # BUG! - @test rvec1[3].subgraph_factors == [2, 7] - @test rvec1[3].subgraphs == [g1, g1] # BUG! end @testset "merge all linear combinations" begin g1 = Graph([]) @@ -540,14 +455,8 @@ end g4p = FeynmanGraph([g3p,], drop_topology(g3p.properties); operator=Graphs.Sum()) @test Graphs.unary_istrivial(Graphs.Prod) @test Graphs.unary_istrivial(Graphs.Sum) - @test Graphs.merge_factorless_chain(g2) == g1 - @test Graphs.merge_factorless_chain(g3) == g1 - @test Graphs.merge_factorless_chain(g4) == g1 - @test Graphs.merge_factorless_chain(g3p) == g3p - @test Graphs.merge_factorless_chain(g4p) == g3p g5 = FeynmanGraph([g1,], drop_topology(g1.properties); operator=O()) @test Graphs.unary_istrivial(O) == false - @test Graphs.merge_factorless_chain(g5) == g5 end g1 = propagator(𝑓⁻(1)𝑓⁺(2)) g2 = FeynmanGraph([g1,], g1.properties; subgraph_factors=[5,], operator=Graphs.Prod()) @@ -559,43 +468,6 @@ end g2p = FeynmanGraph([g1, g2], drop_topology(g1.properties)) g3p = FeynmanGraph([g2p,], g2p.properties; subgraph_factors=[3,], operator=Graphs.Prod()) gp = FeynmanGraph([g3p,], g3p.properties; subgraph_factors=[2,], operator=Graphs.Prod()) - @testset "Merge chains" begin - # g ↦ 30*(*(*g1)) - g_merged = Graphs.merge_chain_prefactors(g) - @test g_merged.subgraph_factors == [30,] - @test all(isfactorless(node) for node in PreOrderDFS(eldest(g_merged))) - # in-place form - gc = deepcopy(g) - Graphs.merge_chain_prefactors!(gc) - @test isequiv(gc, g_merged, :id) - # gp ↦ 6*(*(g1 + 5*g1)) - gp_merged = Graphs.merge_chain_prefactors(gp) - @test gp_merged.subgraph_factors == [6,] - @test isfactorless(eldest(gp)) == false - @test isfactorless(eldest(gp_merged)) - @test isequiv(eldest(eldest(gp_merged)), g2p, :id) - # g ↦ 30*g1 - g_merged = merge_chains(g) - @test isequiv(g_merged, 30 * g1, :id) - # in-place form - merge_chains!(g) - @test isequiv(g, 30 * g1, :id) - # gp ↦ 6*(g1 + 5*g1) - gp_merged = merge_chains(gp) - @test isequiv(gp_merged, 6 * g2p, :id) - # Test a generic trivial unary chain - # *(O3(5 * O2(3 * O1(2 * h)))) ↦ 30 * h - h = propagator(𝑓⁻(1)𝑓⁺(2)) - h1 = FeynmanGraph([h,], h.properties; subgraph_factors=[2,], operator=O1()) - h2 = FeynmanGraph([h1,], h1.properties; subgraph_factors=[3,], operator=O2()) - h3 = FeynmanGraph([h2,], h2.properties; subgraph_factors=[5,], operator=O3()) - h4 = FeynmanGraph([h3,], h3.properties; operator=Graphs.Prod()) - h4_merged = merge_chains(h4) - @test isequiv(h4_merged, 30 * h, :id) - # in-place form - merge_chains!(h4) - @test isequiv(h4, 30 * h, :id) - end @testset "Merge prefactors" begin g1 = propagator(𝑓⁺(1)𝑓⁻(2)) h1 = FeynmanGraph([g1, g1], drop_topology(g1.properties), subgraph_factors=[1, 2]) @@ -630,32 +502,6 @@ end end @testset verbose = true "Optimizations" begin - @testset "Remove one-child parents" begin - g1 = propagator(𝑓⁻(1)𝑓⁺(2)) - g2 = 2 * g1 - # h = O(7 * (5 * (3 * (2 * g)))) ↦ O(210 * g) - g3 = FeynmanGraph([g2,], g2.properties; subgraph_factors=[3,], operator=Graphs.Prod()) - g4 = FeynmanGraph([g3,], g3.properties; subgraph_factors=[5,], operator=Graphs.Prod()) - h = FeynmanGraph([g4,], drop_topology(g4.properties); subgraph_factors=[7,], operator=O()) - hvec = repeat([h], 3) - # Test on a single graph - Graphs.merge_all_chains!(h) - @test h.operator == O - @test h.subgraph_factors == [210,] - @test isequiv(eldest(h), g1, :id) - # Test on a vector of graphs - Graphs.merge_all_chains!(hvec) - @test all(h.operator == O for h in hvec) - @test all(h.subgraph_factors == [210,] for h in hvec) - @test all(isequiv(eldest(h), g1, :id) for h in hvec) - - g2 = 2 * g1 - g3 = FeynmanGraph([g2,], g2.properties; subgraph_factors=[3,], operator=Graphs.Prod()) - g4 = FeynmanGraph([g3,], g3.properties; subgraph_factors=[5,], operator=Graphs.Prod()) - h = FeynmanGraph([g1, g4], drop_topology(g4.properties); subgraph_factors=[2, 7], operator=O()) - Graphs.merge_all_chains!(h) - @test h.subgraph_factors == [2, 210] - end @testset "optimize" begin g1 = propagator(𝑓⁻(1)𝑓⁺(2)) g2 = 2 * g1