diff --git a/ext/TenetChainRulesCoreExt.jl b/ext/TenetChainRulesCoreExt.jl index b013d0cc..449c9702 100644 --- a/ext/TenetChainRulesCoreExt.jl +++ b/ext/TenetChainRulesCoreExt.jl @@ -13,6 +13,7 @@ ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds; meta...) = T(dat function ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...) Tensor_pullback(Δ) = (NoTangent(), Δ.data, NoTangent()) + Tensor_pullback(Δ::Thunk) = (NoTangent(), unthunk(Δ).data, NoTangent()) return T(data, inds; meta...), Tensor_pullback end