-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gradient of cat
which introduce new dims do not match the dims of input
#1061
Comments
Fixed in ChainRules, but I think Zygote still uses its own older versions:
|
Seems to be fixed already. Tested with Zygote 0.6.34: julia> Zygote.gradient(randn(3,3)) do x
sum(sin.(cat(x; dims=4)))
end[1]
3×3 Matrix{Float64}:
0.598398 0.988603 0.999602
0.84835 -0.101217 0.391286
0.9785 -0.717415 0.87662
|
FWIW, this is because julia> Zygote.pullback(randn(3,3)) do x
sum(sin.(cat(x; dims=4)))
end[2](1.0)[1]
3×3×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
0.392695 -0.145669 0.450509
0.334361 0.316647 0.980656
0.987435 0.843904 0.985057
julia> gradient(x -> sum(abs2, cat(x * x'; dims=4)), [1 2; 3 4])
ERROR: MethodError: no method matching *(::Array{Int64, 4}, ::Matrix{Int64}) |
IIRC the hurdle to simply deleting all of these is JuliaGPU/GPUArrays.jl#362 . |
This is solved now right? |
MWE:
The text was updated successfully, but these errors were encountered: