Skip to content

Commit

Permalink
Merge branch 'computgraph' into computgraph_pchou
Browse files Browse the repository at this point in the history
  • Loading branch information
houpc committed Oct 3, 2023
2 parents 66e7940 + 69cf3d8 commit b063a68
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions src/computational_graph/operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ end
Given a graph G and an id of graph g, calculate the derivative d G / d g by forward propagation algorithm.
# Arguments:
- `diag::Graph{F,W}`: Graph to be differentiated
- `ID::Int`: ID of the graph that is
- `diag::Graph{F,W}`: Graph G to be differentiated
- `ID::Int`: ID of the graph g
"""
function forwardAD(diag::Graph{F,W}, ID::Int) where {F,W}
# use a dictionary to host the dual diagram of a diagram for a given hash number
Expand Down Expand Up @@ -111,6 +111,14 @@ function forwardAD(diag::Graph{F,W}, ID::Int) where {F,W}
return dual[rootid]
end

"""
function all_parent(diag::Graph{F,W}) where {F,W}
Given a graph, find all parents node of each node, and return them in a dictionary.
# Arguments:
- `diag::Graph{F,W}`: Target graph
"""

function all_parent(diag::Graph{F,W}) where {F,W}
result = Dict{Int,Vector{Graph{F,W}}}()
for d in PostOrderDFS(diag)
Expand All @@ -129,13 +137,22 @@ function all_parent(diag::Graph{F,W}) where {F,W}
return result
end

"""
function node_derivative(g1::Graph{F,W}, g2::Graph{F,W}) where {F,W}
Return the local derivative d g1/ dg2 at node g1. The local derivative only considers the subgraph of node g1, and ignores g2 that appears in deeper layers.
Example: For g1 = G *g2, and G = g3*g2, return d g1/ dg2 = G = g3*g2 instead of 2 g2*g3.
# Arguments:
- `diag::Graph{F,W}`: Target graph
"""

function node_derivative(g1::Graph{F,W}, g2::Graph{F,W}) where {F,W} #return d g1/ d g2
if isleaf(g1)
return nothing
elseif g1.operator == Sum
sum_factor = 0.0
exist = false #track if g2 exist in g1 subgraphs.
for i in 1:length(g1.subgraphs)
for i in eachindex(g1.subgraphs)
if g1.subgraphs[i].id == g2.id
exist = true
sum_factor += g1.subgraph_factors[i]
Expand All @@ -152,7 +169,7 @@ function node_derivative(g1::Graph{F,W}, g2::Graph{F,W}) where {F,W} #return d g
subgraphfactors = []
factor = nothing
first_time = true #Track if its the first time we find g2 in g1 subgraph.
for i in 1:length(g1.subgraphs)
for i in eachindex(g1.subgraphs)
if g1.subgraphs[i].id == g2.id
if first_time # We should remove the first g2 in g1
first_time = false
Expand Down Expand Up @@ -284,6 +301,14 @@ function build_all_leaf_derivative(diag::Graph{F,W}, max_order::Int) where {F,W}
return result
end

# function build_all_variable_derivative(diag::Graph{F,W}, max_order::Int, variable_number::Int) where {F,W}
# leaf_derivative, leafmap = build_all_leaf_derivative(diag, max_order)
# for (id, idx) in leafmap
# for order in max_order
# for
# end
# end
# end

function insert_dualDict!(dict_kv::Dict{Tk,Tv}, dict_vk::Dict{Tv,Tk}, key::Tk, value::Tv) where {Tk,Tv}
dict_kv[key] = value
Expand Down

0 comments on commit b063a68

Please sign in to comment.