From 438a362b5ac786aac70d18adea0c5c23a80981a7 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Sat, 30 Sep 2023 01:14:55 -0400 Subject: [PATCH] minor change --- src/computational_graph/operation.jl | 33 ++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/src/computational_graph/operation.jl b/src/computational_graph/operation.jl index 307cc93e..bd6de437 100644 --- a/src/computational_graph/operation.jl +++ b/src/computational_graph/operation.jl @@ -47,8 +47,8 @@ end Given a graph G and an id of graph g, calculate the derivative d G / d g by forward propagation algorithm. # Arguments: -- `diag::Graph{F,W}`: Graph to be differentiated -- `ID::Int`: ID of the graph that is +- `diag::Graph{F,W}`: Graph G to be differentiated +- `ID::Int`: ID of the graph g """ function forwardAD(diag::Graph{F,W}, ID::Int) where {F,W} # use a dictionary to host the dual diagram of a diagram for a given hash number @@ -111,6 +111,14 @@ function forwardAD(diag::Graph{F,W}, ID::Int) where {F,W} return dual[rootid] end +""" + function all_parent(diag::Graph{F,W}) where {F,W} + + Given a graph, find all parents node of each node, and return them in a dictionary. +# Arguments: +- `diag::Graph{F,W}`: Target graph +""" + function all_parent(diag::Graph{F,W}) where {F,W} result = Dict{Int,Vector{Graph{F,W}}}() for d in PostOrderDFS(diag) @@ -129,13 +137,22 @@ function all_parent(diag::Graph{F,W}) where {F,W} return result end +""" +function node_derivative(g1::Graph{F,W}, g2::Graph{F,W}) where {F,W} + + Return the local derivative d g1/ dg2 at node g1. The local derivative only considers the subgraph of node g1, and ignores g2 that appears in deeper layers. + Example: For g1 = G *g2, and G = g3*g2, return d g1/ dg2 = G = g3*g2 instead of 2 g2*g3. +# Arguments: +- `diag::Graph{F,W}`: Target graph +""" + function node_derivative(g1::Graph{F,W}, g2::Graph{F,W}) where {F,W} #return d g1/ d g2 if isleaf(g1) return nothing elseif g1.operator == Sum sum_factor = 0.0 exist = false #track if g2 exist in g1 subgraphs. - for i in 1:length(g1.subgraphs) + for i in eachindex(g1.subgraphs) if g1.subgraphs[i].id == g2.id exist = true sum_factor += g1.subgraph_factors[i] @@ -152,7 +169,7 @@ function node_derivative(g1::Graph{F,W}, g2::Graph{F,W}) where {F,W} #return d g subgraphfactors = [] factor = nothing first_time = true #Track if its the first time we find g2 in g1 subgraph. - for i in 1:length(g1.subgraphs) + for i in eachindex(g1.subgraphs) if g1.subgraphs[i].id == g2.id if first_time # We should remove the first g2 in g1 first_time = false @@ -284,6 +301,14 @@ function build_all_leaf_derivative(diag::Graph{F,W}, max_order::Int) where {F,W} return result end +# function build_all_variable_derivative(diag::Graph{F,W}, max_order::Int, variable_number::Int) where {F,W} +# leaf_derivative, leafmap = build_all_leaf_derivative(diag, max_order) +# for (id, idx) in leafmap +# for order in max_order +# for +# end +# end +# end function forwardAD_root(diags::AbstractVector{G}) where {G<:Graph}