Skip to content

Commit

Permalink
Replace Zygote.@adjoint rules for ChainRules.rrules
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 10, 2023
1 parent f040859 commit 42c4e58
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
3 changes: 0 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@ ValSplit = "0625e100-946b-11ec-09cd-6328dd093154"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
TenetChainRulesCoreExt = "ChainRulesCore"
TenetMakieExt = "Makie"
TenetQuacExt = "Quac"
TenetZygoteExt = "Zygote"

[compat]
Bijections = "0.1"
Expand All @@ -44,5 +42,4 @@ OMEinsum = "0.7"
Permutations = "0.4"
Quac = "0.2"
ValSplit = "0.1"
Zygote = "0.6"
julia = "1.9"
21 changes: 21 additions & 0 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,25 @@ function ChainRulesCore.rrule(::Type{Tensor{T,N,A}}, data, inds; meta...) where
end
end

function ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...)
Tensor_pullback(Δ) = (NoTangent(), Δ.data, NoTangent())
return T(data, inds; meta...), Tensor_pullback
end

# WARN type-piracy
function ChainRulesCore.rrule(::typeof(setdiff), s, itrs...)
setdiff_pullback(_) = fill(NoTangent(), 2 + length(itrs))
return setdiff(s, itrs...), setdiff_pullback
end

function ChainRulesCore.rrule(::typeof(union), s, itrs...)
union_pullback(_) = fill(NoTangent(), 2 + length(itrs))
return union(s, itrs...), union_pullback
end

function ChainRulesCore.rrule(::typeof(intersect), s, itrs...)
intersect_pullback(_) = fill(NoTangent(), 2 + length(itrs))
return intersect(s, itrs...), intersect_pullback
end

end
16 changes: 0 additions & 16 deletions ext/TenetZygoteExt.jl

This file was deleted.

0 comments on commit 42c4e58

Please sign in to comment.