Skip to content

Commit

Permalink
Fix tangent projection of TensorNetwork
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 14, 2023
1 parent 24ebdf1 commit ef96be2
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ function ChainRulesCore.ProjectTo(tn::T) where {T<:TensorNetwork}
ProjectTo{T}(; tensors = ProjectTo(tn.tensors), metadata = tn.metadata)
end

function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {A<:Ansatz,T<:TensorNetwork{A}}
TensorNetwork{A}(projector.tensors(dx.tensors); projector.metadata...)
function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:TensorNetwork}
dx.tensors isa NoTangent && return NoTangent()
Tangent{TensorNetwork}(tensors = projector.tensors(dx.tensors))
end

function Base.:+(x::TensorNetwork{A}, Δ::Tangent{TensorNetwork}) where {A<:Ansatz}
# TODO match tensors by indices
tensors = map(+, x.tensors, Δ.tensors)
TensorNetwork{A}(tensors; x.metadata...)
end
Expand Down

0 comments on commit ef96be2

Please sign in to comment.