Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tests for stateful sum(f,x) #1011

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

add tests for stateful sum(f,x) #1011

wants to merge 1 commit into from

Conversation

CarloLucibello
Copy link
Member

looks like we have a problem induced by JuliaDiff/ChainRules.jl#441.
On ChainRules 0.8.10 the test introduced in this PR passes, while on 0.8.15 the zygote gradient errors out with

sum(f, x): Error During Test at /home/carlo/.julia/dev/Zygote/test/lib/array.jl:11
  Got exception outside of a @test
  MethodError: no method matching +(::Base.RefValue{Any}, ::Base.RefValue{Any})
  Closest candidates are:
    +(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
    +(::ChainRulesCore.AbstractThunk, ::Any) at /home/carlo/.julia/packages/ChainRulesCore/e5hAX/src/differential_arithmetic.jl:138
    +(::ChainRulesCore.Tangent{P, T} where T, ::P) where P at /home/carlo/.julia/packages/ChainRulesCore/e5hAX/src/differential_arithmetic.jl:162
    ...
  Stacktrace:
    [1] add_sum(x::Base.RefValue{Any}, y::Base.RefValue{Any})
      @ Base ./reduce.jl:24
    [2] _mapreduce(f::typeof(first), op::typeof(Base.add_sum), #unused#::IndexLinear, A::Vector{Tuple{Base.RefValue{Any}, Float64}})
      @ Base ./reduce.jl:408
    [3] _mapreduce_dim(f::Function, op::Function, #unused#::Base._InitialValue, A::Vector{Tuple{Base.RefValue{Any}, Float64}}, #unused#::Colon)
      @ Base ./reducedim.jl:318
    [4] #mapreduce#672
      @ ./reducedim.jl:310 [inlined]
    [5] mapreduce
      @ ./reducedim.jl:310 [inlined]
    [6] #_sum#682
      @ ./reducedim.jl:878 [inlined]
    [7] _sum
      @ ./reducedim.jl:878 [inlined]
    [8] #sum#680
      @ ./reducedim.jl:874 [inlined]
    [9] sum(f::Function, a::Vector{Tuple{Base.RefValue{Any}, Float64}})
      @ Base ./reducedim.jl:874
   [10] (::ChainRules.var"#sum_pullback#1262"{F, Vector{Zygote.var"#ad_pullback#41"{F, Tuple{Float64}, typeof((λ))}}})(ȳ::Float64)
      @ ChainRules ~/.julia/packages/ChainRules/VDG9a/src/rulesets/Base/mapreduce.jl:58
   [11] ZBack
      @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:91 [inlined]
   [12] Pullback
      @ ~/.julia/dev/Zygote/test/lib/array.jl:29 [inlined]
   [13] (::typeof((#6)))(Δ::Float64)
      @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
   [14] (::Zygote.var"#46#47"{typeof((#6))})(Δ::Float64)
      @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:41
   [15] gradient(f::Function, args::Vector{Float64})
      @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:59
   [16] macro expansion

Besides the problem in the ChainRules' rule, I wonder why we don't hit the Zygote adjoint instead

@adjoint function sum(f, xs::AbstractArray{<:AbstractArray}; kws...)

I thought zygote's adjoints had priority over ChainRules' rules

@mcabbott
Copy link
Member

The adjoint you link looks to be only for arrays of arrays, hence not called by this example. Why it survived, I don't know. But it does explain why the sum(f,x) worked like broadcasting in my gist of examples here:
https://gist.github.com/mcabbott/c6cdc73d45ed3e35c3fd8966863993f8

Maybe an error is a good outcome, though. The docstring for sum says "as it is unspecified whether init is used for non-empty collections." which is a pretty strong hint that the order can't be relied on. So the forward pass with a stateful f is already, I think, producing garbage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants