diff --git a/ext/TenetChainRulesCoreExt.jl b/ext/TenetChainRulesCoreExt.jl index 449c9702..d2b1da3f 100644 --- a/ext/TenetChainRulesCoreExt.jl +++ b/ext/TenetChainRulesCoreExt.jl @@ -11,11 +11,9 @@ ChainRulesCore.ProjectTo(tensor::T) where {T<:Tensor} = ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; meta...), T(Δ, inds; meta...) -function ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...) - Tensor_pullback(Δ) = (NoTangent(), Δ.data, NoTangent()) - Tensor_pullback(Δ::Thunk) = (NoTangent(), unthunk(Δ).data, NoTangent()) - return T(data, inds; meta...), Tensor_pullback -end +_Tensor_pullback(Δ) = (NoTangent(), Δ.data, NoTangent()) +_Tensor_pullback(Δ::AbstractThunk) = _Tensor_pullback(unthunk(Δ)) +ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; meta...), _Tensor_pullback @non_differentiable copy(tn::TensorNetwork)