Skip to content

Commit

Permalink
Define AD rules for TensorNetwork constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 13, 2023
1 parent f8e118d commit e227526
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
23 changes: 23 additions & 0 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,27 @@ ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; met
@non_differentiable intersect(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)
@non_differentiable symdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)

function ChainRulesCore.ProjectTo(tn::T) where {T<:TensorNetwork}
ProjectTo{T}(; tensors = ProjectTo(tn.tensors), metadata = tn.metadata)
end

function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {A<:Ansatz,T<:TensorNetwork{A}}
TensorNetwork{A}(projector.tensors(dx.tensors); projector.metadata...)
end

function Base.:+(x::TensorNetwork{A}, Δ::Tangent{TensorNetwork}) where {A<:Ansatz}
tensors = map(+, x.tensors, Δ.tensors)
TensorNetwork{A}(tensors; x.metadata...)
end

function ChainRulesCore.frule((_, Δ), T::Type{<:TensorNetwork}, tensors; metadata...)
T(tensors; metadata...), Tangent{TensorNetwork}(tensors = Δ)
end

TensorNetwork_pullback::Tangent{TensorNetwork}) = (NoTangent(), Δ.tensors)
TensorNetwork_pullback::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ))
function ChainRulesCore.rrule(T::Type{<:TensorNetwork}, tensors; metadata...)
T(tensors; metadata...), TensorNetwork_pullback
end

end
11 changes: 11 additions & 0 deletions test/integration/ChainRules_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,15 @@
test_frule(Tensor, fill(1.0, 2, 3), Symbol[:i, :j])
test_rrule(Tensor, fill(1.0, 2, 3), Symbol[:i, :j])
end

# NOTE fixes some problems on testing, not sure why
Base.collect(tn::TensorNetwork) = tensors(tn)

@testset "TensorNetwork" begin
a = Tensor(rand(4, 2), (:i, :j))
b = Tensor(rand(2, 3), (:j, :k))

test_frule(TensorNetwork, Tensor[a, b])
test_rrule(TensorNetwork, Tensor[a, b])
end
end

0 comments on commit e227526

Please sign in to comment.