Skip to content

Commit

Permalink
add burn_from_targetleaves! optimization and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
houpc committed Oct 3, 2023
1 parent 41e7aae commit 66e7940
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 9 deletions.
30 changes: 30 additions & 0 deletions src/computational_graph/operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,21 @@ function insert_dualDict!(dict_kv::Dict{Tk,Tv}, dict_vk::Dict{Tv,Tk}, key::Tk, v
push!(dict_vk[value], key)
end


"""
function forwardAD_root!(graphs::AbstractVector{G}, idx::Int=1,
dual::Dict{Tuple{Int,NTuple{N,Bool}},G}=Dict{Tuple{Int,Tuple{Bool}},G}()) where {G<:Graph,N}
Computes the forward automatic differentiation (AD) of the given graphs beginning from the roots.
# Arguments:
- `graphs`: A vector of graphs.
- `idx`: Index for differentiation (default: 1).
- `dual`: A dictionary that holds the result of differentiation.
# Returns:
- The dual dictionary populated with all differentiated graphs, including the intermediate AD.
"""
function forwardAD_root!(graphs::AbstractVector{G}, idx::Int=1,
dual::Dict{Tuple{Int,NTuple{N,Bool}},G}=Dict{Tuple{Int,Tuple{Bool}},G}()) where {G<:Graph,N}
# dualinv::Dict{G,Tuple{Int,NTuple{N,Int}}}=Dict{G,Tuple{Int,Tuple{Int}}}()) where {G<:Graph,N}
Expand Down Expand Up @@ -380,6 +395,21 @@ end
end
end

"""
function build_derivative_graph(graphs::AbstractVector{G}, orders::NTuple{N,Int};
nodes_id=nothing) where {G<:Graph,N}
Constructs a derivative graph using forward automatic differentiation with given graphs and derivative orders.
# Arguments:
- `graphs`: A vector of graphs.
- `orders::NTuple{N,Int}`: A tuple indicating the orders of differentiation. `N` represents the number of independent variables to be differentiated.
- `nodes_id`: Optional node IDs to indicate saving their derivative graph.
# Returns:
- A dictionary containing the dual derivative graphs for all indicated nodes.
If `isnothing(nodes_id)`, indicated nodes include all leaf and root nodes. Otherwise, indicated nodes include all root nodes and other nodes from `nodes_id`.
"""
function build_derivative_graph(graphs::AbstractVector{G}, orders::NTuple{N,Int};
nodes_id=nothing) where {G<:Graph,N}

Expand Down
235 changes: 228 additions & 7 deletions src/computational_graph/optimize.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
"""
function optimize!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing) where {G<:AbstractGraph}
In-place optimization of given `graphs`. Removes duplicated leaves, merges chains, and merges linear combinations.
# Arguments:
- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs.
- `verbose`: Level of verbosity (default: 0).
- `normalize`: Optional function to normalize the graphs (default: nothing).
# Returns:
- A mapping dictionary from the id of each unique leaf node to its index in collect(1:length(leafs)).
"""
function optimize!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing) where {G<:AbstractGraph}
if isempty(graphs)
return nothing
Expand All @@ -10,12 +23,39 @@ function optimize!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=
end
end

"""
function optimize!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing) where {G<:AbstractGraph}
Optimize a copy of given `graphs`. Removes duplicated leaves, merges chains, and merges linear combinations.
# Arguments:
- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs.
- `verbose`: Level of verbosity (default: 0).
- `normalize`: Optional function to normalize the graphs (default: nothing).
# Returns:
- A tuple/vector of optimized graphs.
- A mapping dictionary from the id of each unique leaf node to its index in collect(1:length(leafs)).
"""
function optimize(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing) where {G<:AbstractGraph}
graphs_new = deepcopy(graphs)
leaf_mapping = optimize!(graphs_new)
leaf_mapping = optimize!(graphs_new, verbose=verbose, normalize=normalize)
return graphs_new, leaf_mapping
end

"""
function merge_all_chain_prefactors!(g::G; verbose=0) where {G<:AbstractGraph}
In-place merge prefactors of all nodes representing trivial unary chains towards the root level for a single graph.
# Arguments:
- `g::G`: An AbstractGraph.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- Optimized graph.
#
"""
function merge_all_chain_prefactors!(g::G; verbose=0) where {G<:AbstractGraph}
verbose > 0 && println("merge prefactors of all nodes representing trivial unary chains toward root level.")
# Post-order DFS
Expand All @@ -27,7 +67,20 @@ function merge_all_chain_prefactors!(g::G; verbose=0) where {G<:AbstractGraph}
return g
end

function merge_all_chain_prefactors!(graphs::AbstractVector{G}; verbose=0) where {G<:AbstractGraph}
"""
function merge_all_chain_prefactors!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph}
In-place merge prefactors of all nodes representing trivial unary chains towards the root level for given graphs.
# Arguments:
- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- Optimized graphs.
#
"""
function merge_all_chain_prefactors!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph}
verbose > 0 && println("merge prefactors of all nodes representing trivial unary chains toward root level.")
# Post-order DFS
for g in graphs
Expand All @@ -37,6 +90,19 @@ function merge_all_chain_prefactors!(graphs::AbstractVector{G}; verbose=0) where
return graphs
end

"""
function merge_all_factorless_chains!(g::G; verbose=0) where {G<:AbstractGraph}
In-place merge all nodes representing factorless trivial unary chains within a single graph.
# Arguments:
- `g::G`: An AbstractGraph.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- Optimized graph.
#
"""
function merge_all_factorless_chains!(g::G; verbose=0) where {G<:AbstractGraph}
verbose > 0 && println("merge all nodes representing factorless trivial unary chains.")
# Post-order DFS
Expand All @@ -48,7 +114,20 @@ function merge_all_factorless_chains!(g::G; verbose=0) where {G<:AbstractGraph}
return g
end

function merge_all_factorless_chains!(graphs::AbstractVector{G}; verbose=0) where {G<:AbstractGraph}
"""
function merge_all_factorless_chains!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph}
In-place merge all nodes representing factorless trivial unary chains within given graphs.
# Arguments:
- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- Optimized graphs.
#
"""
function merge_all_factorless_chains!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph}
verbose > 0 && println("merge all nodes representing factorless trivial unary chains.")
# Post-order DFS
for g in graphs
Expand All @@ -58,20 +137,61 @@ function merge_all_factorless_chains!(graphs::AbstractVector{G}; verbose=0) wher
return graphs
end

"""
function merge_all_chains!(g::G; verbose=0) where {G<:AbstractGraph}
In-place merge all nodes representing trivial unary chains within a single graph.
This function consolidates both chain prefactors and factorless chains.
# Arguments:
- `g::G`: An AbstractGraph.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- Optimized graph.
#
"""
function merge_all_chains!(g::G; verbose=0) where {G<:AbstractGraph}
verbose > 0 && println("merge all nodes representing trivial unary chains.")
merge_all_chain_prefactors!(g, verbose=verbose)
merge_all_factorless_chains!(g, verbose=verbose)
return g
end

function merge_all_chains!(graphs::AbstractVector{G}; verbose=0) where {G<:AbstractGraph}
"""
function merge_all_chains!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph}
In-place merge all nodes representing trivial unary chains in given graphs.
This function consolidates both chain prefactors and factorless chains.
# Arguments:
- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- Optimized graphs.
#
"""
function merge_all_chains!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph}
verbose > 0 && println("merge all nodes representing trivial unary chains.")
merge_all_chain_prefactors!(graphs, verbose=verbose)
merge_all_factorless_chains!(graphs, verbose=verbose)
return graphs
end

"""
function merge_all_linear_combinations!(g::G; verbose=0) where {G<:AbstractGraph}
In-place merge all nodes representing a linear combination of a non-unique list of subgraphs within a single graph.
# Arguments:
- `g::G`: An AbstractGraph.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- Optimized graph.
#
"""
function merge_all_linear_combinations!(g::G; verbose=0) where {G<:AbstractGraph}
verbose > 0 && println("merge nodes representing a linear combination of a non-unique list of graphs.")
# Post-order DFS
Expand All @@ -83,7 +203,20 @@ function merge_all_linear_combinations!(g::G; verbose=0) where {G<:AbstractGraph
return g
end

function merge_all_linear_combinations!(graphs::AbstractVector{G}; verbose=0) where {G<:AbstractGraph}
"""
function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph}
In-place merge all nodes representing a linear combination of a non-unique list of subgraphs in given graphs.
# Arguments:
- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- Optimized graphs.
#
"""
function merge_all_linear_combinations!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0) where {G<:AbstractGraph}
verbose > 0 && println("merge nodes representing a linear combination of a non-unique list of graphs.")
# Post-order DFS
for g in graphs
Expand All @@ -93,7 +226,19 @@ function merge_all_linear_combinations!(graphs::AbstractVector{G}; verbose=0) wh
return graphs
end

function unique_leaves(_graphs::AbstractVector{G}) where {G<:AbstractGraph}
"""
function unique_leaves(_graphs::Union{Tuple,AbstractVector{G}};) where {G<:AbstractGraph}
Identify and retrieve unique leaf nodes from a set of graphs.
# Arguments:
- `_graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs.
# Returns:
- The vector of unique leaf nodes.
- A mapping dictionary from the id of each unique leaf node to its index in collect(1:length(leafs)).
"""
function unique_leaves(_graphs::Union{Tuple,AbstractVector{G}};) where {G<:AbstractGraph}
############### find the unique Leaves #####################
uniqueGraph = []
mapping = Dict{Int,Int}()
Expand All @@ -117,7 +262,20 @@ function unique_leaves(_graphs::AbstractVector{G}) where {G<:AbstractGraph}
return uniqueGraph, mapping
end

function remove_duplicated_leaves!(graphs::AbstractVector{G}; verbose=0, normalize=nothing, kwargs...) where {G<:AbstractGraph}
"""
function remove_duplicated_leaves!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing, kwargs...) where {G<:AbstractGraph}
In-place remove duplicated leaf nodes from a collection of graphs. It also provides optional normalization for these leaves.
# Arguments:
- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs.
- `verbose`: Level of verbosity (default: 0).
- `normalize`: Optional function to normalize the graphs (default: nothing).
# Returns:
- A mapping dictionary from the id of each unique leaf node to its index in collect(1:length(leafs)).
"""
function remove_duplicated_leaves!(graphs::Union{Tuple,AbstractVector{G}}; verbose=0, normalize=nothing, kwargs...) where {G<:AbstractGraph}
verbose > 0 && println("remove duplicated leaves.")
leaves = Vector{G}()
for g in graphs
Expand Down Expand Up @@ -147,3 +305,66 @@ function remove_duplicated_leaves!(graphs::AbstractVector{G}; verbose=0, normali

return leafMap
end

"""
function burn_from_targetleaves!(graphs::AbstractVector{G}, targetleaves_id::AbstractVector{Int}; verbose=0) where {G<:AbstractGraph}
In-place remove all nodes connected to the target leaves via "Prod" operators.
# Arguments:
- `graphs::Union{Tuple,AbstractVector{G}}`: A tuple or vector of graphs.
- `targetleaves_id::AbstractVector{Int}`: Vector of target leafs' id.
- `verbose`: Level of verbosity (default: 0).
# Returns:
- The id of a constant graph with a zero factor if any graph in `graphs` was completely burnt; otherwise, `nothing`.
"""
function burn_from_targetleaves!(graphs::AbstractVector{G}, targetleaves_id::AbstractVector{Int}; verbose=0) where {G<:AbstractGraph}
verbose > 0 && println("remove all nodes connected to the target leaves via Prod operators.")

graphs_sum = linear_combination(graphs, one.(eachindex(graphs)))
ftype = typeof(graphs[1].factor)

for leaf in Leaves(graphs_sum)
if !isdisjoint(leaf.id, targetleaves_id)
leaf.name = "Burn"
end
end

for node in PostOrderDFS(graphs_sum)
if any(x -> x.name == "Burn", node.subgraphs)
if node.operator == Prod
node.subgraphs = G[]
node.subgraph_factors = ftype[]
node.name = "Burn"
else
subgraphs = G[]
subgraph_factors = ftype[]
for (i, subg) in enumerate(node.subgraphs)
if subg.name != "Burn"
push!(subgraphs, subg)
push!(subgraph_factors, node.subgraph_factors[i])
end
end
node.subgraphs = subgraphs
node.subgraph_factors = subgraph_factors
if isempty(subgraph_factors)
node.name = "Burn"
end
end
end
end

g_c0 = constant_graph(ftype(0))
has_c0 = false
for g in graphs
if g.name == "Burn"
has_c0 = true
g.id = g_c0.id
g.operator = Constant
g.factor = ftype(0)
end
end

has_c0 ? (return g_c0.id) : (return nothing)
end
Loading

0 comments on commit 66e7940

Please sign in to comment.