Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Add sensitivies for copy (#138)
Browse files Browse the repository at this point in the history
`copy` materializes `Adjoint` and `Transpose` wrappers, which can be
useful, as those are immutable. It is also safe to call on any value.
  • Loading branch information
ararslan authored Mar 15, 2019
1 parent a830fd4 commit 95844bd
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/sensitivities/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,8 @@ end
(::typeof(LinearAlgebra.dot), ::Type{Arg{2}}, p, z, z̄, x::A, y::A) =.* x
(x̄, ::typeof(LinearAlgebra.dot), ::Type{Arg{1}}, p, z, z̄, x::A, y::A) = (x̄ .=.+.* y)
(ȳ, ::typeof(LinearAlgebra.dot), ::Type{Arg{2}}, p, z, z̄, x::A, y::A) = (ȳ .=.+.* x)

# `copy` materializes `Adjoint` and `Transpose` wrappers but can be called on anything
import Base: copy
@explicit_intercepts copy Tuple{Any}
(::typeof(copy), ::Type{Arg{1}}, p, Y, Ȳ, A) = copy(Ȳ)
22 changes: 22 additions & 0 deletions test/sensitivities/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,26 @@
@test check_errs(LinearAlgebra.dot, LinearAlgebra.dot(x, y), (x, y), (vx, vy))
end
end

@testset "copy" begin
rng = MersenneTwister(12345)

# Scalars (no-op)
x = randn(rng)
y = randn(rng)
@test check_errs(copy, x, x, y)
x_ = Leaf(Tape(), x)
c = copy(x_)
@test c isa Branch{Float64}
@test getfield(c, :f) === Base.copy

# Unwrapping adjoint/transposes
X = randn(rng, 6, 6)'
Y = randn(rng, 6, 6)
@test check_errs(copy, X, copy(X), Y)
X_ = Leaf(Tape(), X)
C = copy(X_)
@test C isa Branch{Matrix{Float64}}
@test getfield(c, :f) === Base.copy
end
end

0 comments on commit 95844bd

Please sign in to comment.