diff --git a/src/computational_graph/operation.jl b/src/computational_graph/operation.jl index 40c09a83..fd51e428 100644 --- a/src/computational_graph/operation.jl +++ b/src/computational_graph/operation.jl @@ -318,6 +318,21 @@ function insert_dualDict!(dict_kv::Dict{Tk,Tv}, dict_vk::Dict{Tv,Tk}, key::Tk, v push!(dict_vk[value], key) end + +""" + function forwardAD_root!(graphs::AbstractVector{G}, idx::Int=1, + dual::Dict{Tuple{Int,NTuple{N,Bool}},G}=Dict{Tuple{Int,Tuple{Bool}},G}()) where {G<:Graph,N} + + Computes the forward automatic differentiation (AD) of the given graphs beginning from the roots. + +# Arguments: +- `graphs`: A vector of graphs. +- `idx`: Index for differentiation (default: 1). +- `dual`: A dictionary that holds the result of differentiation. + +# Returns: +- The dual dictionary populated with all differentiated graphs, including the intermediate AD. +""" function forwardAD_root!(graphs::AbstractVector{G}, idx::Int=1, dual::Dict{Tuple{Int,NTuple{N,Bool}},G}=Dict{Tuple{Int,Tuple{Bool}},G}()) where {G<:Graph,N} # dualinv::Dict{G,Tuple{Int,NTuple{N,Int}}}=Dict{G,Tuple{Int,Tuple{Int}}}()) where {G<:Graph,N} @@ -405,6 +420,21 @@ end end end +""" + function build_derivative_graph(graphs::AbstractVector{G}, orders::NTuple{N,Int}; + nodes_id=nothing) where {G<:Graph,N} + + Constructs a derivative graph using forward automatic differentiation with given graphs and derivative orders. + +# Arguments: +- `graphs`: A vector of graphs. +- `orders::NTuple{N,Int}`: A tuple indicating the orders of differentiation. `N` represents the number of independent variables to be differentiated. +- `nodes_id`: Optional node IDs to indicate saving their derivative graph. + +# Returns: +- A dictionary containing the dual derivative graphs for all indicated nodes. +If `isnothing(nodes_id)`, indicated nodes include all leaf and root nodes. Otherwise, indicated nodes include all root nodes and other nodes from `nodes_id`. +""" function build_derivative_graph(graphs::AbstractVector{G}, orders::NTuple{N,Int}; nodes_id=nothing) where {G<:Graph,N} diff --git a/src/computational_graph/optimize.jl b/src/computational_graph/optimize.jl index aa8e6143..77f7bc2b 100644 --- a/src/computational_graph/optimize.jl +++ b/src/computational_graph/optimize.jl @@ -1,3 +1,16 @@ +""" + function optimize!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing) where {G<:AbstractGraph} + + In-place optimization of given `graphs`. Removes duplicated leaves, merges chains, and merges linear combinations. + +# Arguments: +- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs. +- `verbose`: Level of verbosity (default: 0). +- `normalize`: Optional function to normalize the graphs (default: nothing). + +# Returns: +- A mapping dictionary from the id of each unique leaf node to its index in collect(1:length(leafs)). +""" function optimize!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing) where {G<:AbstractGraph} if isempty(graphs) return nothing @@ -10,12 +23,39 @@ function optimize!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize= end end +""" + function optimize!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing) where {G<:AbstractGraph} + + Optimize a copy of given `graphs`. Removes duplicated leaves, merges chains, and merges linear combinations. + +# Arguments: +- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs. +- `verbose`: Level of verbosity (default: 0). +- `normalize`: Optional function to normalize the graphs (default: nothing). + +# Returns: +- A tuple/vector of optimized graphs. +- A mapping dictionary from the id of each unique leaf node to its index in collect(1:length(leafs)). +""" function optimize(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing) where {G<:AbstractGraph} graphs_new = deepcopy(graphs) - leaf_mapping = optimize!(graphs_new) + leaf_mapping = optimize!(graphs_new, verbose=verbose, normalize=normalize) return graphs_new, leaf_mapping end +""" + function merge_all_chain_prefactors!(g::G; verbose=0) where {G<:AbstractGraph} + + In-place merge prefactors of all nodes representing trivial unary chains towards the root level for a single graph. + +# Arguments: +- `g::G`: An AbstractGraph. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graph. +# +""" function merge_all_chain_prefactors!(g::G; verbose=0) where {G<:AbstractGraph} verbose > 0 && println("merge prefactors of all nodes representing trivial unary chains toward root level.") # Post-order DFS @@ -27,7 +67,20 @@ function merge_all_chain_prefactors!(g::G; verbose=0) where {G<:AbstractGraph} return g end -function merge_all_chain_prefactors!(graphs::AbstractVector{G}; verbose=0) where {G<:AbstractGraph} +""" + function merge_all_chain_prefactors!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph} + + In-place merge prefactors of all nodes representing trivial unary chains towards the root level for given graphs. + +# Arguments: +- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graphs. +# +""" +function merge_all_chain_prefactors!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph} verbose > 0 && println("merge prefactors of all nodes representing trivial unary chains toward root level.") # Post-order DFS for g in graphs @@ -37,6 +90,19 @@ function merge_all_chain_prefactors!(graphs::AbstractVector{G}; verbose=0) where return graphs end +""" + function merge_all_factorless_chains!(g::G; verbose=0) where {G<:AbstractGraph} + + In-place merge all nodes representing factorless trivial unary chains within a single graph. + +# Arguments: +- `g::G`: An AbstractGraph. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graph. +# +""" function merge_all_factorless_chains!(g::G; verbose=0) where {G<:AbstractGraph} verbose > 0 && println("merge all nodes representing factorless trivial unary chains.") # Post-order DFS @@ -48,7 +114,20 @@ function merge_all_factorless_chains!(g::G; verbose=0) where {G<:AbstractGraph} return g end -function merge_all_factorless_chains!(graphs::AbstractVector{G}; verbose=0) where {G<:AbstractGraph} +""" + function merge_all_factorless_chains!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph} + + In-place merge all nodes representing factorless trivial unary chains within given graphs. + +# Arguments: +- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graphs. +# +""" +function merge_all_factorless_chains!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph} verbose > 0 && println("merge all nodes representing factorless trivial unary chains.") # Post-order DFS for g in graphs @@ -58,6 +137,20 @@ function merge_all_factorless_chains!(graphs::AbstractVector{G}; verbose=0) wher return graphs end +""" + function merge_all_chains!(g::G; verbose=0) where {G<:AbstractGraph} + + In-place merge all nodes representing trivial unary chains within a single graph. + This function consolidates both chain prefactors and factorless chains. + +# Arguments: +- `g::G`: An AbstractGraph. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graph. +# +""" function merge_all_chains!(g::G; verbose=0) where {G<:AbstractGraph} verbose > 0 && println("merge all nodes representing trivial unary chains.") merge_all_chain_prefactors!(g, verbose=verbose) @@ -65,13 +158,40 @@ function merge_all_chains!(g::G; verbose=0) where {G<:AbstractGraph} return g end -function merge_all_chains!(graphs::AbstractVector{G}; verbose=0) where {G<:AbstractGraph} +""" + function merge_all_chains!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph} + + In-place merge all nodes representing trivial unary chains in given graphs. + This function consolidates both chain prefactors and factorless chains. + +# Arguments: +- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graphs. +# +""" +function merge_all_chains!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph} verbose > 0 && println("merge all nodes representing trivial unary chains.") merge_all_chain_prefactors!(graphs, verbose=verbose) merge_all_factorless_chains!(graphs, verbose=verbose) return graphs end +""" + function merge_all_linear_combinations!(g::G; verbose=0) where {G<:AbstractGraph} + + In-place merge all nodes representing a linear combination of a non-unique list of subgraphs within a single graph. + +# Arguments: +- `g::G`: An AbstractGraph. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graph. +# +""" function merge_all_linear_combinations!(g::G; verbose=0) where {G<:AbstractGraph} verbose > 0 && println("merge nodes representing a linear combination of a non-unique list of graphs.") # Post-order DFS @@ -83,7 +203,20 @@ function merge_all_linear_combinations!(g::G; verbose=0) where {G<:AbstractGraph return g end -function merge_all_linear_combinations!(graphs::AbstractVector{G}; verbose=0) where {G<:AbstractGraph} +""" + function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph} + + In-place merge all nodes representing a linear combination of a non-unique list of subgraphs in given graphs. + +# Arguments: +- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- Optimized graphs. +# +""" +function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph} verbose > 0 && println("merge nodes representing a linear combination of a non-unique list of graphs.") # Post-order DFS for g in graphs @@ -93,7 +226,19 @@ function merge_all_linear_combinations!(graphs::AbstractVector{G}; verbose=0) wh return graphs end -function unique_leaves(_graphs::AbstractVector{G}) where {G<:AbstractGraph} +""" + function unique_leaves(_graphs::Union{Tuple,AbstractVector{G}};) where {G<:AbstractGraph} + + Identify and retrieve unique leaf nodes from a set of graphs. + +# Arguments: +- `_graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs. + +# Returns: +- The vector of unique leaf nodes. +- A mapping dictionary from the id of each unique leaf node to its index in collect(1:length(leafs)). +""" +function unique_leaves(_graphs::Union{Tuple,AbstractVector{G}};) where {G<:AbstractGraph} ############### find the unique Leaves ##################### uniqueGraph = [] mapping = Dict{Int,Int}() @@ -117,7 +262,20 @@ function unique_leaves(_graphs::AbstractVector{G}) where {G<:AbstractGraph} return uniqueGraph, mapping end -function remove_duplicated_leaves!(graphs::AbstractVector{G}; verbose=0, normalize=nothing, kwargs...) where {G<:AbstractGraph} +""" + function remove_duplicated_leaves!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing, kwargs...) where {G<:AbstractGraph} + + In-place remove duplicated leaf nodes from a collection of graphs. It also provides optional normalization for these leaves. + +# Arguments: +- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs. +- `verbose`: Level of verbosity (default: 0). +- `normalize`: Optional function to normalize the graphs (default: nothing). + +# Returns: +- A mapping dictionary from the id of each unique leaf node to its index in collect(1:length(leafs)). +""" +function remove_duplicated_leaves!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing, kwargs...) where {G<:AbstractGraph} verbose > 0 && println("remove duplicated leaves.") leaves = Vector{G}() for g in graphs @@ -147,3 +305,66 @@ function remove_duplicated_leaves!(graphs::AbstractVector{G}; verbose=0, normali return leafMap end + +""" + function burn_from_targetleaves!(graphs::AbstractVector{G}, targetleaves_id::AbstractVector{Int}; verbose=0) where {G<:AbstractGraph} + + In-place remove all nodes connected to the target leaves via "Prod" operators. + +# Arguments: +- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs. +- `targetleaves_id::AbstractVector{Int}`: Vector of target leafs' id. +- `verbose`: Level of verbosity (default: 0). + +# Returns: +- The id of a constant graph with a zero factor if any graph in `graphs` was completely burnt; otherwise, `nothing`. +""" +function burn_from_targetleaves!(graphs::AbstractVector{G}, targetleaves_id::AbstractVector{Int}; verbose=0) where {G<:AbstractGraph} + verbose > 0 && println("remove all nodes connected to the target leaves via Prod operators.") + + graphs_sum = linear_combination(graphs, one.(eachindex(graphs))) + ftype = typeof(graphs[1].factor) + + for leaf in Leaves(graphs_sum) + if !isdisjoint(leaf.id, targetleaves_id) + leaf.name = "Burn" + end + end + + for node in PostOrderDFS(graphs_sum) + if any(x -> x.name == "Burn", node.subgraphs) + if node.operator == Prod + node.subgraphs = G[] + node.subgraph_factors = ftype[] + node.name = "Burn" + else + subgraphs = G[] + subgraph_factors = ftype[] + for (i, subg) in enumerate(node.subgraphs) + if subg.name != "Burn" + push!(subgraphs, subg) + push!(subgraph_factors, node.subgraph_factors[i]) + end + end + node.subgraphs = subgraphs + node.subgraph_factors = subgraph_factors + if isempty(subgraph_factors) + node.name = "Burn" + end + end + end + end + + g_c0 = constant_graph(ftype(0)) + has_c0 = false + for g in graphs + if g.name == "Burn" + has_c0 = true + g.id = g_c0.id + g.operator = Constant + g.factor = ftype(0) + end + end + + has_c0 ? (return g_c0.id) : (return nothing) +end \ No newline at end of file diff --git a/test/computational_graph.jl b/test/computational_graph.jl index 5a6f4f17..b7e4f445 100644 --- a/test/computational_graph.jl +++ b/test/computational_graph.jl @@ -722,7 +722,7 @@ end @testset verbose = true "Auto Differentiation" begin using FeynmanDiagram.ComputationalGraphs: - eval!, forwardAD, node_derivative, backAD, forwardAD_root!, build_all_leaf_derivative, build_derivative_graph + eval!, forwardAD, node_derivative, backAD, forwardAD_root!, build_all_leaf_derivative, build_derivative_graph, burn_from_targetleaves! g1 = Graph([]) g2 = Graph([]) g3 = Graph([], factor=2.0) @@ -857,15 +857,17 @@ end leafmap = Dict{Int,Int}() leafmap[g1.id], leafmap[g2.id], leafmap[g3.id] = 1, 2, 3 orders = (3, 2, 2) - dual = build_derivative_graph(F1, orders) + dual = Graphs.build_derivative_graph(F1, orders) leafmap[dual[(g1.id, (1, 0, 0))].id], leafmap[dual[(g2.id, (0, 1, 0))].id], leafmap[dual[(g3.id, (0, 0, 1))].id] = 4, 5, 6 + burnleafs_id = Int[] for order in Iterators.product((0:x for x in orders)...) order == (0, 0, 0) && continue for g in [g1, g2, g3] if !haskey(leafmap, dual[(g.id, order)].id) leafmap[dual[(g.id, order)].id] = 7 + push!(burnleafs_id, dual[(g.id, order)].id) end end end @@ -874,6 +876,59 @@ end @test eval!(dual[(F1.id, (2, 0, 0))], leafmap, leaf) == 426 @test eval!(dual[(F1.id, (3, 0, 0))], leafmap, leaf) == 90 @test eval!(dual[(F1.id, (3, 1, 0))], leafmap, leaf) == 0 + + # optimize the derivative graph + c0_id = burn_from_targetleaves!([dual[(F1.id, (1, 0, 0))], dual[(F1.id, (2, 0, 0))], dual[(F1.id, (3, 0, 0))], dual[(F1.id, (3, 1, 0))]], burnleafs_id) + if !isnothing(c0_id) + leafmap[c0_id] = 7 + end + @test eval!(dual[(F1.id, (1, 0, 0))], leafmap, leaf) == 1002 + @test eval!(dual[(F1.id, (2, 0, 0))], leafmap, leaf) == 426 + @test eval!(dual[(F1.id, (3, 0, 0))], leafmap, leaf) == 90 + @test eval!(dual[(F1.id, (3, 1, 0))], leafmap, leaf) == 0 + + # Test on a vector of graphs + F0 = F1 * F3 + F0_r1 = F1 + F3 + dual = Graphs.build_derivative_graph([F0, F0_r1], orders) + + leafmap = Dict{Int,Int}() + leafmap[g1.id], leafmap[g2.id], leafmap[g3.id] = 1, 2, 3 + leafmap[dual[(g1.id, (1, 0, 0))].id], leafmap[dual[(g2.id, (0, 1, 0))].id], leafmap[dual[(g3.id, (0, 0, 1))].id] = 4, 5, 6 + burnleafs_id = Int[] + for order in Iterators.product((0:x for x in orders)...) + order == (0, 0, 0) && continue + for g in [g1, g2, g3] + if !haskey(leafmap, dual[(g.id, order)].id) + leafmap[dual[(g.id, order)].id] = 7 + push!(burnleafs_id, dual[(g.id, order)].id) + end + end + end + @test eval!(dual[(F0.id, (1, 0, 0))], leafmap, leaf) == 5568 + @test eval!(dual[(F0_r1.id, (1, 0, 0))], leafmap, leaf) == 1003 + @test eval!(dual[(F0.id, (2, 0, 0))], leafmap, leaf) == 3708 + @test eval!(dual[(F0_r1.id, (2, 0, 0))], leafmap, leaf) == 426 + @test eval!(dual[(F0.id, (3, 0, 0))], leafmap, leaf) == 1638 + @test eval!(dual[(F0_r1.id, (3, 0, 0))], leafmap, leaf) == 90 + @test eval!(dual[(F0.id, (3, 1, 0))], leafmap, leaf) == 234 + @test eval!(dual[(F0_r1.id, (3, 1, 0))], leafmap, leaf) == 0 + @test eval!(dual[(F0.id, (3, 2, 0))], leafmap, leaf) == eval!(dual[(F0_r1.id, (3, 2, 0))], leafmap, leaf) == 0 + + c0_id = burn_from_targetleaves!([dual[(F0.id, (1, 0, 0))], dual[(F0.id, (2, 0, 0))], dual[(F0.id, (3, 0, 0))], dual[(F0.id, (3, 1, 0))], dual[(F0.id, (3, 2, 0))], + dual[(F0_r1.id, (1, 0, 0))], dual[(F0_r1.id, (2, 0, 0))], dual[(F0_r1.id, (3, 0, 0))], dual[(F0_r1.id, (3, 1, 0))], dual[(F0_r1.id, (3, 2, 0))]], burnleafs_id) + if !isnothing(c0_id) + leafmap[c0_id] = 7 + end + @test eval!(dual[(F0.id, (1, 0, 0))], leafmap, leaf) == 5568 + @test eval!(dual[(F0_r1.id, (1, 0, 0))], leafmap, leaf) == 1003 + @test eval!(dual[(F0.id, (2, 0, 0))], leafmap, leaf) == 3708 + @test eval!(dual[(F0_r1.id, (2, 0, 0))], leafmap, leaf) == 426 + @test eval!(dual[(F0.id, (3, 0, 0))], leafmap, leaf) == 1638 + @test eval!(dual[(F0_r1.id, (3, 0, 0))], leafmap, leaf) == 90 + @test eval!(dual[(F0.id, (3, 1, 0))], leafmap, leaf) == 234 + @test eval!(dual[(F0_r1.id, (3, 1, 0))], leafmap, leaf) == 0 + @test eval!(dual[(F0.id, (3, 2, 0))], leafmap, leaf) == eval!(dual[(F0_r1.id, (3, 2, 0))], leafmap, leaf) == 0 end end