From f51f45bfd20f89f4ebdaabeeb3af01a872072081 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 12 Sep 2023 16:38:06 +0200 Subject: [PATCH] Refactor `Tensor` pullbacks --- ext/TenetChainRulesCoreExt.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ext/TenetChainRulesCoreExt.jl b/ext/TenetChainRulesCoreExt.jl index 449c9702..d2b1da3f 100644 --- a/ext/TenetChainRulesCoreExt.jl +++ b/ext/TenetChainRulesCoreExt.jl @@ -11,11 +11,9 @@ ChainRulesCore.ProjectTo(tensor::T) where {T<:Tensor} = ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; meta...), T(Δ, inds; meta...) -function ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...) - Tensor_pullback(Δ) = (NoTangent(), Δ.data, NoTangent()) - Tensor_pullback(Δ::Thunk) = (NoTangent(), unthunk(Δ).data, NoTangent()) - return T(data, inds; meta...), Tensor_pullback -end +_Tensor_pullback(Δ) = (NoTangent(), Δ.data, NoTangent()) +_Tensor_pullback(Δ::AbstractThunk) = _Tensor_pullback(unthunk(Δ)) +ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; meta...), _Tensor_pullback @non_differentiable copy(tn::TensorNetwork)