Skip to content

Commit

Permalink
Refactor Tensor pullbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 12, 2023
1 parent 89dfb77 commit f51f45b
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ ChainRulesCore.ProjectTo(tensor::T) where {T<:Tensor} =

ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; meta...), T(Δ, inds; meta...)

function ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...)
Tensor_pullback(Δ) = (NoTangent(), Δ.data, NoTangent())
Tensor_pullback::Thunk) = (NoTangent(), unthunk(Δ).data, NoTangent())
return T(data, inds; meta...), Tensor_pullback
end
_Tensor_pullback(Δ) = (NoTangent(), Δ.data, NoTangent())
_Tensor_pullback::AbstractThunk) = _Tensor_pullback(unthunk(Δ))
ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; meta...), _Tensor_pullback

@non_differentiable copy(tn::TensorNetwork)

Expand Down

0 comments on commit f51f45b

Please sign in to comment.