Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel is incompatible with Zygote nested gradient #1199

Closed
Wu-Chenyang opened this issue Jan 13, 2025 · 1 comment
Closed

Parallel is incompatible with Zygote nested gradient #1199

Wu-Chenyang opened this issue Jan 13, 2025 · 1 comment

Comments

@Wu-Chenyang
Copy link

Parallel is found to be incompatible with Zygote nested gradient. A MWE is as follows.

using Lux, Zygote, Random, Optimisers
neural_net = Parallel(
           +,
           Dense(2=>32, relu),
           Dense(1=>32, relu)
       )
ps, st = Lux.setup(Random.default_rng(), neural_net)
train_state = Training.TrainState(neural_net, ps, st, Adam(1f-3))
function grad_mseloss(model, ps, st, ((u, t), targets))
           stateful_net = StatefulLuxLayer{true}(model, ps, st)
           grad = Zygote.gradient(sumstateful_netBase.Fix2(tuple, t), u) |> first
           loss = MSELoss()(grad, targets)
           return loss, st, (; )
end
Training.single_train_step!(AutoZygote(), grad_mseloss, ((rand(Float32,2,100), rand(Float32,1,100)), rand(Float32,2,100)), train_state)
ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/lib/array.jl:70
  [3] (::Zygote.var"#544#545"{Matrix{Float32}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/lib/array.jl:82
  [4] (::Zygote.var"#2623#back#546"{Zygote.var"#544#545"{Matrix{Float32}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] _mapreducedim!
    @ ./reducedim.jl:289 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
  [7] mapreducedim!
    @ ./reducedim.jl:296 [inlined]
  [8] #sum!#957
    @ ./reducedim.jl:1006 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [10] sum!
    @ ./reducedim.jl:1006 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [12] #sum!#958
    @ ./reducedim.jl:1008 [inlined]
 [13] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [14] sum!
    @ ./reducedim.jl:1008 [inlined]
 [15] reduce_sum
    @ ~/.julia/packages/LuxLib/0MWSE/src/impl/common_ops.jl:33 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [17] ∇bias_add
    @ ~/.julia/packages/LuxLib/0MWSE/src/impl/common_ops.jl:27 [inlined]
 [18] ∇matmul_bias
    @ ~/.julia/packages/LuxLib/0MWSE/src/impl/dense.jl:217 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Matrix{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [20] #78
    @ ~/.julia/packages/LuxLib/0MWSE/src/impl/dense.jl:52 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Matrix{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [22] ZBack
    @ ~/.julia/packages/Zygote/TWpme/src/compiler/chainrules.jl:212 [inlined]
 [23] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Matrix{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [24] Pullback
    @ ~/.julia/packages/LuxLib/0MWSE/src/impl/dense.jl:11 [inlined]
 [25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Matrix{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [26] Pullback
    @ ~/.julia/packages/LuxLib/0MWSE/src/api/dense.jl:35 [inlined]
 [27] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Matrix{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [28] Pullback
    @ ~/.julia/packages/Lux/9hFIj/src/layers/basic.jl:343 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Matrix{…}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [30] Pullback
    @ ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:155 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Matrix{…}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [32] Pullback
    @ ~/.julia/packages/Lux/9hFIj/src/layers/containers.jl:0 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Tuple{…}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [34] Pullback
    @ ~/.julia/packages/Lux/9hFIj/src/layers/containers.jl:173 [inlined]
 [35] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Tuple{…}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [36] Pullback
    @ ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:155 [inlined]
 [37] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Tuple{…}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [38] Pullback
    @ ~/.julia/packages/Lux/9hFIj/src/helpers/stateful.jl:119 [inlined]
--- the above 2 lines are repeated 1 more time ---
 [41] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [42] Pullback
    @ ./operators.jl:1053 [inlined]
--- the above 2 lines are repeated 1 more time ---
 [45] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Tuple{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [46] Pullback
    @ ./operators.jl:1050 [inlined]
 [47] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Matrix{…}})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [48] #295
    @ ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:205 [inlined]
 [49] (::Zygote.Pullback{Tuple{Zygote.var"#295#296"{…}, Float32}, Any})(Δ::Tuple{Nothing, Nothing, Nothing, Tuple{Matrix{…}}})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [50] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [51] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [52] Pullback
    @ ./operators.jl:1050 [inlined]
 [53] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Matrix{…}})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [54] #78
    @ ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:91 [inlined]
 [55] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Matrix{…}})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [56] gradient
    @ ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:148 [inlined]
 [57] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Matrix{…}})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [58] grad_mseloss
    @ ./REPL[7]:3 [inlined]
 [59] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [60] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Tuple{Float32, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:91
 [61] compute_gradients_impl(::AutoZygote, objective_function::typeof(grad_mseloss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ LuxZygoteExt ~/.julia/packages/Lux/9hFIj/ext/LuxZygoteExt/training.jl:5
 [62] compute_gradients
    @ ~/.julia/packages/Lux/9hFIj/src/helpers/training.jl:198 [inlined]
 [63] single_train_step_impl!(backend::AutoZygote, obj_fn::typeof(grad_mseloss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/9hFIj/src/helpers/training.jl:301
 [64] single_train_step!(backend::AutoZygote, obj_fn::typeof(grad_mseloss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/9hFIj/src/helpers/training.jl:276
Some type information was truncated. Use `show(err)` to see complete types.
@Wu-Chenyang
Copy link
Author

I just realized that the problem is not with the Parallel layer. A compat layer as follows won't work as well.

net1 = Dense(2=>32, relu)
net2 = Dense(1=>32, relu)
neural_net = @compact(net1=net1, net2=net2) do (u, t)
    @return net1(u)+net2(t)
end

Therefore, it's more likely that the problem is related to the use of tuple as input. Although the specific reason remains unknown, it seems not an issue with Lux.jl.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant