Skip to content

Commit

Permalink
wrap_chainrules_input for mutable struct
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 15, 2021
1 parent 5887e46 commit 2b35bc0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.27"
version = "0.6.28"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1.5"
ChainRulesCore = "1.6"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR
xp = map(wrap_chainrules_input, xs)
ChainRules.Tangent{Any, typeof(xp)}(xp)
end
# For mutable types, including x=Ref(1), Zygote makes Ref{Any}(::NamedTuple)
@inline wrap_chainrules_input(x::Ref) = wrap_chainrules_input(x[])

"""
_project(x, dx)
Expand Down
14 changes: 14 additions & 0 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,20 @@ let
@test back(1.) == ((1.0,),)
end

@testset "mutable struct, including Ref" begin
# Zygote's representation is Base.RefValue{Any}((value = 7.0,)), but the
# map to ChainRules types and back normalises to (value = 7.0,) same as struct:
@test gradient(x -> x.value^2 + x.value, MyMutable(3)) === ((value = 7.0,),)

# Same for Ref. This doesn't seem to affect `pow_mut` test in this file.
@test gradient(x -> x.x^2 + x.x, Ref(3)) === ((x = 7.0,),)
@test gradient(x -> real(x.x^2 + im * x.x), Ref(4)) === ((x = 8.0,),)

# Broadcasting over Ref is handled specially. Tested elsehwere too.
@test gradient(x -> sum(sum, x .* [1,2,3]), Ref([4,5])) == ((x = [6.0, 6.0],),)
@test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)
end

function type_test()
Complex{<:Real}
end
Expand Down

0 comments on commit 2b35bc0

Please sign in to comment.