Skip to content

Commit

Permalink
fix ad type instabilities in new zygote (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho authored Jan 11, 2025
1 parent c9aaf45 commit 2adad27
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,19 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
function pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
dA = @thunk let
ipA = invperm(linearize(pA))
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
return projectA(_dA)
end
= @thunk begin
= @thunk let
_dα = tensorscalar(tensorcontract(A, ((), linearize(pA)), !conjA,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...))
return projectα(_dα)
end
= @thunk begin
= @thunk let
# TODO: consider using `inner`
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pA))), true,
ΔC, (trivtuple(numind(pA)), ()), false,
Expand Down Expand Up @@ -165,7 +165,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
pΔC = (TupleTools.getindices(ipAB, trivtuple(numout(pA))),
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))))
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
dA = @thunk let
ipA = (invperm(linearize(pA)), ())
conjΔC = conjA
conjB′ = conjA ? conjB : !conjB
Expand All @@ -177,7 +177,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
conjA ? α : conj(α), Zero(), ba...)
return projectA(_dA)
end
dB = @thunk begin
dB = @thunk let
ipB = (invperm(linearize(pB)), ())
conjΔC = conjB
conjA′ = conjB ? conjA : !conjA
Expand All @@ -189,15 +189,15 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
conjB ? α : conj(α), Zero(), ba...)
return projectB(_dB)
end
= @thunk begin
= @thunk let
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
# TODO: consider using `inner`
_dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...))
return projectα(_dα)
end
= @thunk begin
= @thunk let
# TODO: consider using `inner`
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
Expand Down Expand Up @@ -249,7 +249,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
function pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
dA = @thunk let
ip = invperm((linearize(p)..., q[1]..., q[2]...))
Es = map(q[1], q[2]) do i1, i2
return one(TensorOperations.tensoralloc_add(scalartype(A), A,
Expand All @@ -263,15 +263,15 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
conjA ? α : conj(α), Zero(), ba...)
return projectA(_dA)
end
= @thunk begin
= @thunk let
C_αβ = tensortrace(A, p, q, false, One(), ba...)
_dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(p))),
!conjA,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...))
return projectα(_dα)
end
= @thunk begin
= @thunk let
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(p))), true,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...))
Expand Down

0 comments on commit 2adad27

Please sign in to comment.