Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Fix indexing + add regression test (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt authored Mar 20, 2019
1 parent 8baa157 commit e3640fc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/sensitivities/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function ∇(Ā, ::typeof(getindex), ::Type{Arg{1}}, p, y, ȳ, A, inds...)
return
end
function (Ā, ::typeof(getindex), ::Type{Arg{1}}, p, y::AbstractArray, ȳ::AbstractArray, A, inds...)
Ā[inds...] .+= reshape(ȳ, size(y)...)
@views Ā[inds...] .+= reshape(ȳ, size(y)...)
return
end
function (::typeof(getindex), ::Type{Arg{1}}, p, y, ȳ, A, inds...)
Expand Down
11 changes: 9 additions & 2 deletions test/sensitivities/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
@testset "Indexing" begin
let
@testset "Int" begin
leaf = Leaf(Tape(), 5 * [1, 1, 1, 1, 1])
y = getindex(leaf, 1)
@test unbox(y) == 5
@test (y, one(unbox(y)))[leaf] == [1, 0, 0, 0, 0]
end

let
@testset "Vector" begin
x = Leaf(Tape(), 10 * [1, 1, 1])
y = x[2:3]
@test unbox(y) == [10, 10]
@test (y, oneslike(unbox(y)))[x] == [0, 1, 1]
end

@testset "Overlapping indices (#139)" begin
x = Leaf(Tape(), 10 * [1, 1, 1])
y = x[[2, 3, 3]]
@test unbox(y) == [10, 10, 10]
@test (y, oneslike(unbox(y)))[x] == [0, 1, 2]
end
end

0 comments on commit e3640fc

Please sign in to comment.