Skip to content

Commit

Permalink
Fix autodiff on tensor contraction
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 10, 2023
1 parent f6e4248 commit f040859
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
ValSplit = "0625e100-946b-11ec-09cd-6328dd093154"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
TenetChainRulesCoreExt = "ChainRulesCore"
TenetMakieExt = "Makie"
TenetQuacExt = "Quac"
TenetZygoteExt = "Zygote"

[compat]
Bijections = "0.1"
ChainRulesCore = "1.0"
Combinatorics = "1.0"
DeltaArrays = "0.1.1"
EinExprs = "0.5.2"
Expand All @@ -39,4 +44,5 @@ OMEinsum = "0.7"
Permutations = "0.4"
Quac = "0.2"
ValSplit = "0.1"
Zygote = "0.6"
julia = "1.9"
18 changes: 18 additions & 0 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module TenetChainRulesCoreExt

using Tenet
using ChainRulesCore

ChainRulesCore.ProjectTo(tensor::T) where {T<:Tensor} =
ProjectTo{T}(; data = ProjectTo(tensor.data), inds = tensor.inds, meta = tensor.meta)

(projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:Tensor} =
T(projector.data(dx.data), projector.inds; projector.meta...)

function ChainRulesCore.rrule(::Type{Tensor{T,N,A}}, data, inds; meta...) where {T,N,A}
return Tensor(data, inds; meta...), function Tensor_pullback(_)
(NoTangent(), data, NoTangent())
end
end

end
16 changes: 16 additions & 0 deletions ext/TenetZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module TenetZygoteExt

using Tenet
using Zygote

Zygote.@adjoint (T::Type{<:Tensor})(data, inds; meta...) = T(data, inds; meta...), y -> (nothing, y.data, nothing)

# WARN type-piracy
Zygote.@adjoint Base.setdiff(s, itrs...) =
setdiff(s, itrs...), _ -> (nothing, nothing, [nothing for _ in 1:length(itrs)]...)
Zygote.@adjoint Base.union(s, itrs...) =
union(s, itrs...), _ -> (nothing, nothing, [nothing for _ in 1:length(itrs)]...)
Zygote.@adjoint Base.intersect(s, itrs...) =
intersect(s, itrs...), _ -> (nothing, nothing, [nothing for _ in 1:length(itrs)]...)

end

0 comments on commit f040859

Please sign in to comment.