From b118e0797d35d305971dad81ac2c478b6057b26d Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Mon, 1 Oct 2018 17:19:50 -0500 Subject: [PATCH 01/11] Update Nabla to support Julia versions 0.7 and up Summary of changes: * The minimum required Julia version is now 0.7. In particular, this means we no longer support or test on Julia 0.6. * Functions and syntax which are deprecated in 0.7 have been updated. * The required version of the DualNumbers package has been raised. This allows us to remove some pirated methods, as they've been upstreamed to that package. * Macro call expressions now contain line number information, which we don't have access to in our expression manipulation functions, so we insert `nothing` instead. This requires adding some logic in the tests so that we can properly compare expressions containing macro calls by stripping out all line information from both sides. * Type definition expressions now use `:struct` as the expression head rather than `:type`, so we follow suit to ensure correct expressions are emitted. * Binary linear algebra optimizations and BLAS matrix multiplications have been entirely rewritten for Julia 0.7, since the `A_mul_B` family of functions have been deprecated in favor of dispatch on simple operators such as `*` with lazy container types such as `Transpose`. * The overloading of `broadcast` has been rewritten to use the new methodology for broadcast customization provided in Base. We define our own `BroadcastStyle` which allows us to intercept `broadcast` calls to insert `Branch`es containing `broadcasted` calls. `broadcasted` is the lazy function that gets called to materialize a fused broadcast expression; `broadcast` now just eagerly materializes that. * The interface to `diagm` has changed in 0.7 to accept `Pair`s specifying the diagonal and contents. This breaks the machinery we have in place, so `diagm` is handled slightly differently than the rest of the functions in that it dispatches to an internal helper function which has simple methods that can use the existing machinery. --- .travis.yml | 7 +- REQUIRE | 4 +- appveyor.yml | 42 +++-- docs/src/index.md | 4 +- src/Nabla.jl | 7 +- src/code_transformation/differentiable.jl | 20 +-- src/core.jl | 39 +++-- src/finite_differencing.jl | 42 ++--- src/sensitivities/array.jl | 5 +- src/sensitivities/functional/functional.jl | 38 +++-- src/sensitivities/functional/reduce.jl | 23 ++- src/sensitivities/functional/reducedim.jl | 2 +- src/sensitivities/indexing.jl | 2 +- src/sensitivities/linalg/blas.jl | 42 ++--- src/sensitivities/linalg/diagonal.jl | 132 +++++++-------- .../linalg/factorization/cholesky.jl | 8 +- src/sensitivities/linalg/generic.jl | 160 +++++++++--------- src/sensitivities/linalg/strided.jl | 40 ++--- src/sensitivities/linalg/symmetric.jl | 2 +- src/sensitivities/linalg/triangular.jl | 5 +- src/sensitivities/scalar.jl | 8 +- src/sensitivity.jl | 6 +- test/code_transformation/differentiable.jl | 34 ++-- test/core.jl | 10 +- test/finite_differencing.jl | 4 +- test/runtests.jl | 13 +- test/sensitivities/functional/functional.jl | 90 +++++----- test/sensitivities/functional/reduce.jl | 14 +- test/sensitivities/functional/reducedim.jl | 6 +- test/sensitivities/indexing.jl | 2 +- test/sensitivities/linalg/blas.jl | 21 +-- test/sensitivities/linalg/diagonal.jl | 18 +- .../linalg/factorization/cholesky.jl | 8 +- test/sensitivities/linalg/generic.jl | 4 +- test/sensitivities/linalg/strided.jl | 4 +- test/sensitivities/scalar.jl | 37 ++-- test/sensitivity.jl | 2 +- 37 files changed, 463 insertions(+), 442 deletions(-) diff --git a/.travis.yml b/.travis.yml index ca91b670..b9d68c21 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ os: - linux - osx julia: - - 0.6 + - 0.7 # TODO: Add 1.0 once 0.7 works - nightly notifications: email: false @@ -18,7 +18,6 @@ matrix: # - julia -e 'Pkg.clone(pwd()); Pkg.build("Nabla"); Pkg.test("Nabla"; coverage=true)' after_success: # push coverage results to Codecov - - julia -e 'cd(Pkg.dir("Nabla")); Pkg.add("Coverage"); using Coverage; Codecov.submit(Codecov.process_folder())' + - julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(Codecov.process_folder())' # build documentation - - julia -e 'Pkg.add("Documenter")' - - julia -e 'cd(Pkg.dir("Nabla")); include(joinpath("docs", "make.jl"))' + - julia -e 'using Pkg; Pkg.add("Documenter"); include(joinpath("docs", "make.jl"))' diff --git a/REQUIRE b/REQUIRE index b8ea1aa5..31e93ed4 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,5 +1,5 @@ -julia 0.6 -DualNumbers 0.3.0 +julia 0.7 +DualNumbers 0.6.0 DiffRules 0.0.1 FDM 0.1.0 SpecialFunctions 0.3.0 diff --git a/appveyor.yml b/appveyor.yml index 2c7edadb..18e3ab4d 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,14 +1,16 @@ environment: matrix: - - JULIA_URL: "https://julialang-s3.julialang.org/bin/winnt/x86/0.6/julia-0.6-latest-win32.exe" - - JULIA_URL: "https://julialang-s3.julialang.org/bin/winnt/x64/0.6/julia-0.6-latest-win64.exe" - - JULIA_URL: "https://julialangnightlies-s3.julialang.org/bin/winnt/x86/julia-latest-win32.exe" - - JULIA_URL: "https://julialangnightlies-s3.julialang.org/bin/winnt/x64/julia-latest-win64.exe" + - julia_version: 0.7 + #- julia_version: 1 + - julia_version: nightly + +platform: + - x86 # 32-bit + - x64 # 64-bit matrix: allow_failures: - - JULIA_URL: "https://julialangnightlies-s3.julialang.org/bin/winnt/x86/julia-latest-win32.exe" - - JULIA_URL: "https://julialangnightlies-s3.julialang.org/bin/winnt/x64/julia-latest-win64.exe" + - julia_version: nightly branches: only: @@ -22,24 +24,18 @@ notifications: on_build_status_changed: false install: - - ps: "[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.SecurityProtocolType]::Tls12" -# If there's a newer build queued for the same PR, cancel this one - - ps: if ($env:APPVEYOR_PULL_REQUEST_NUMBER -and $env:APPVEYOR_BUILD_NUMBER -ne ((Invoke-RestMethod ` - https://ci.appveyor.com/api/projects/$env:APPVEYOR_ACCOUNT_NAME/$env:APPVEYOR_PROJECT_SLUG/history?recordsNumber=50).builds | ` - Where-Object pullRequestId -eq $env:APPVEYOR_PULL_REQUEST_NUMBER)[0].buildNumber) { ` - throw "There are newer queued builds for this pull request, failing early." } -# Download most recent Julia Windows binary - - ps: (new-object net.webclient).DownloadFile( - $env:JULIA_URL, - "C:\projects\julia-binary.exe") -# Run installer silently, output to C:\projects\julia - - C:\projects\julia-binary.exe /S /D=C:\projects\julia + - ps: iex ((new-object net.webclient).DownloadString("https://raw.githubusercontent.com/JuliaCI/Appveyor.jl/version-1/bin/install.ps1")) build_script: -# Need to convert from shallow to complete for Pkg.clone to work - - IF EXIST .git\shallow (git fetch --unshallow) - - C:\projects\julia\bin\julia -e "versioninfo(); - Pkg.clone(pwd(), \"Nabla\"); Pkg.build(\"Nabla\")" + - echo "%JL_BUILD_SCRIPT%" + - C:\julia\bin\julia -e "%JL_BUILD_SCRIPT%" test_script: - - C:\projects\julia\bin\julia -e "Pkg.test(\"Nabla\")" + - echo "%JL_TEST_SCRIPT%" + - C:\julia\bin\julia -e "%JL_TEST_SCRIPT%" + +# # Uncomment to support code coverage upload. Should only be enabled for packages +# # which would have coverage gaps without running on Windows +# on_success: +# - echo "%JL_CODECOV_SCRIPT%" +# - C:\julia\bin\julia -e "%JL_CODECOV_SCRIPT%" diff --git a/docs/src/index.md b/docs/src/index.md index d0fdd4e4..0a438190 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -22,14 +22,14 @@ x, y = randn.(rng, [N, N]) A = randn(rng, N, N) # Construct a vector-quadratic function in `x` and `y`. -f(x, y) = y.' * (A * x) +f(x, y) = y' * (A * x) f(x, y) ``` Only a small amount of [matrix calculus](https://en.wikipedia.org/wiki/Matrix_calculus) is required to the find the gradient of `f(x, y)` w.r.t. `x` and `y`, which we denote by `∇x` and `∇y` respectively, to be ```@example toy -(∇x, ∇y) = (A.'y, A * x) +(∇x, ∇y) = (A'y, A * x) ``` ## High-Level Interface diff --git a/src/Nabla.jl b/src/Nabla.jl index 7c519090..597b3dc6 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -2,6 +2,9 @@ __precompile__() module Nabla + using SpecialFunctions + using LinearAlgebra + # Some aliases used repeatedly throughout the package. export ∇Scalar, ∇Array, SymOrExpr, ∇ArrayOrScalar const ∇Scalar = Number @@ -34,7 +37,7 @@ module Nabla # Sensitivities for functionals. include("sensitivities/functional/functional.jl") include("sensitivities/functional/reduce.jl") - include("sensitivities/functional/reducedim.jl") + #include("sensitivities/functional/reducedim.jl") # Linear algebra optimisations. include("sensitivities/linalg/generic.jl") @@ -43,6 +46,6 @@ module Nabla include("sensitivities/linalg/blas.jl") include("sensitivities/linalg/diagonal.jl") include("sensitivities/linalg/triangular.jl") - include("sensitivities/linalg/factorization/cholesky.jl") + #include("sensitivities/linalg/factorization/cholesky.jl") end # module Nabla diff --git a/src/code_transformation/differentiable.jl b/src/code_transformation/differentiable.jl index ea9fd565..6d9b89dd 100644 --- a/src/code_transformation/differentiable.jl +++ b/src/code_transformation/differentiable.jl @@ -12,7 +12,7 @@ unionise_arg(arg::Expr) = Expr(Symbol("::"), arg.args[1:end-1]..., unionise_type(arg.args[end])) : arg.head == Symbol("...") ? Expr(Symbol("..."), unionise_arg(arg.args[1])) : - throw(error("Unrecognised argument in Symbol ($arg).")) + throw(ArgumentError("Unrecognised argument in Symbol ($arg).")) """ unionise_subtype(arg::Union{Symbol, Expr}) @@ -23,7 +23,7 @@ unionise_subtype(arg::Symbol) = arg unionise_subtype(arg::Expr) = arg.head == Symbol("<:") ? Expr(Symbol("<:"), arg.args[1:end-1]..., unionise_type(arg.args[end])) : - throw(error("Unrecognised argument in arg ($arg).")) + throw(ArgumentError("Unrecognised argument in arg ($arg).")) """ get_quote_body(code) @@ -42,7 +42,7 @@ Unionise the code inside a call to `eval`, such that when the `eval` call actual the code inside will be unionised. """ function unionise_eval(code::Expr) - body = Expr(:macrocall, Symbol("@unionise"), deepcopy(get_quote_body(code.args[end]))) + body = Expr(:macrocall, Symbol("@unionise"), nothing, deepcopy(get_quote_body(code.args[end]))) return length(code.args) == 3 ? Expr(:call, :eval, deepcopy(code.args[2]), quot(body)) : Expr(:call, :eval, quot(body)) @@ -55,10 +55,10 @@ Unionise the code in a call to @eval, such that when the `eval` call actually oc code inside will be unionised. """ function unionise_macro_eval(code::Expr) - body = Expr(:macrocall, Symbol("@unionise"), deepcopy(code.args[end])) - return length(code.args) == 3 ? - Expr(:macrocall, Symbol("@eval"), deepcopy(code.args[2]), body) : - Expr(:macrocall, Symbol("@eval"), body) + body = Expr(:macrocall, Symbol("@unionise"), nothing, deepcopy(code.args[end])) + return length(code.args) == 4 ? + Expr(:macrocall, Symbol("@eval"), nothing, deepcopy(code.args[3]), body) : + Expr(:macrocall, Symbol("@eval"), nothing, body) end """ @@ -95,13 +95,13 @@ function unionise_struct(code::Expr) curly = Expr(:curly, name.args[1], unionise_subtype.(name.args[2:end])...) if is_subtype_expr return Expr( - :type, + :struct, code.args[1], Expr(Symbol("<:"), curly, tmp.args[2]), code.args[3], ) else - return Expr(:type, code.args[1], curly, code.args[3]) + return Expr(:struct, code.args[1], curly, code.args[3]) end else return code @@ -131,7 +131,7 @@ function unionise(code::Expr) return unionise_eval(code) elseif code.head == :macrocall && code.args[1] == Symbol("@eval") return unionise_macro_eval(code) - elseif code.head == :type + elseif code.head == :struct return unionise_struct(code) else return Expr(code.head, [unionise(arg) for arg in code.args]...) diff --git a/src/core.jl b/src/core.jl index fd013f77..944b2ba4 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1,17 +1,18 @@ using DualNumbers -import Base: push!, length, show, getindex, setindex!, endof, eachindex, isassigned, -isapprox, zero, one +import Base: push!, length, show, getindex, setindex!, eachindex, isassigned, + isapprox, zero, one, lastindex + export Leaf, Tape, Node, Branch, ∇ """ Basic unit on the computational graph.""" abstract type Node{T} end """ A topologically ordered collection of Nodes. """ -immutable Tape +struct Tape tape::Vector{Any} Tape() = new(Vector{Any}()) - Tape(N::Int) = new(Vector{Any}(N)) + Tape(N::Int) = new(Vector{Any}(undef, N)) end function show(io::IO, tape::Tape) if length(tape) == 0 @@ -24,7 +25,7 @@ function show(io::IO, tape::Tape) end @inline getindex(tape::Tape, n::Int) = Base.getindex(tape.tape, n) @inline getindex(tape::Tape, node::Node) = Base.getindex(tape, node.pos) -@inline endof(tape::Tape) = length(tape) +@inline lastindex(tape::Tape) = length(tape) @inline setindex!(tape::Tape, x, n::Int) = (tape.tape[n] = x; tape) @inline eachindex(tape::Tape) = eachindex(tape.tape) @inline length(tape::Tape) = length(tape.tape) @@ -32,6 +33,9 @@ end @inline isassigned(tape::Tape, n::Int) = isassigned(tape.tape, n) @inline isassigned(tape::Tape, node::Node) = isassigned(tape, node.pos) +# Make `Tape`s broadcast as scalars without a warning on 0.7 +Base.Broadcast.broadcastable(tape::Tape) = Ref(tape) + """ An element at the 'bottom' of the computational graph. @@ -63,7 +67,7 @@ args - Values indicating which elements in the tape will require updating by thi tape - The Tape to which this Branch is assigned. pos - the location of this Branch in the tape to which it is assigned. """ -immutable Branch{T} <: Node{T} +struct Branch{T} <: Node{T} val::T f args::Tuple @@ -99,7 +103,7 @@ Get `.val` if `x` is a Node, otherwise is equivalent to `identity`. unbox(x::Node) = x.val unbox(x) = x -isapprox(n::Node, f) = Nabla.unbox(n) ≈ f +isapprox(n::Node, f) = unbox(n) ≈ f isapprox(f, n::Node) = n ≈ f zero(n::Node) = zero(unbox(n)) @@ -177,8 +181,8 @@ function ∇(f, get_output::Bool=false) y isa Node || return zero.(args) ∇f = ∇(y) ∇args = ([isassigned(∇f, arg_) ? ∇f[arg_] : zero(arg) - for (arg_, arg) in zip(args_, args)]...) - return get_output ? (y, ∇args) : ∇args + for (arg_, arg) in zip(args_, args)]...,) + return get_output ? (y, ∇args) : ∇args end end @@ -205,16 +209,16 @@ end # A collection of methods for initialising nested indexable containers to zero. for (f_name, scalar_init, array_init) in zip((:zerod_container, :oned_container, :randned_container), - (:zero, :one, Nullable()), - (:zeros, :ones, Nullable())) - if !isnull(scalar_init) + (:zero, :one, nothing), + (:zeros, :ones, nothing)) + if scalar_init !== nothing @eval @inline $f_name(x::Number) = $scalar_init(x) end - if !isnull(array_init) - @eval @inline $f_name(x::AbstractArray{<:Real}) = $array_init(x) + if array_init !== nothing + @eval @inline $f_name(x::AbstractArray{<:Real}) = $array_init(eltype(x), size(x)) end eval(quote - @inline $f_name(x::Tuple) = ([$f_name(n) for n in x]...) + @inline $f_name(x::Tuple) = map($f_name, x) @inline function $f_name(x) y = Base.copy(x) for n in eachindex(y) @@ -250,8 +254,3 @@ function fmad_expr(f, x::Type{<:Tuple}) return body end @generated fmad(f, x) = fmad_expr(f, x) - -function Base.exp10(x::Dual) - y = exp10(DualNumbers.value(x)) - return Dual(y, y * log(10) * DualNumbers.epsilon(x)) -end diff --git a/src/finite_differencing.jl b/src/finite_differencing.jl index eb2eb074..2befd501 100644 --- a/src/finite_differencing.jl +++ b/src/finite_differencing.jl @@ -38,7 +38,7 @@ approximate_Dv(f, ȳ::∇ArrayOrScalar, x::∇ArrayOrScalar, v::∇ArrayOrScala Compute the directional derivative of `f` at `x` in direction `v` using AD. Use this result to back-propagate the sensitivity ȳ. If ȳ, x and v are column vectors, then this is -equivalent to computing `ȳ.'(J f)(x) v`, where `(J f)(x)` denotes the Jacobian of `f` +equivalent to computing `ȳ'(J f)(x) v`, where `(J f)(x)` denotes the Jacobian of `f` evaluated at `x`. Analogous operations happen for scalars and N-dimensional arrays. """ function compute_Dv( @@ -65,7 +65,7 @@ function compute_Dv_update( rtape = reverse_tape(y, ȳ) # Randomly initialise `Leaf`s. - inits = Vector(length(rtape)) + inits = Vector(undef, length(rtape)) for i = 1:length(rtape) if isleaf(y.tape[i]) inits[i] = randned_container(y.tape[i].val) @@ -127,7 +127,7 @@ Check whether an input `x` is in a scalar, real function `f`'s domain. function in_domain(f::Function, x::Float64...) try y = f(x...) - return issubtype(typeof(y), Real) && !isnan(y) + return isa(y, Real) && !isnan(y) catch err return isa(err, DomainError) ? false : throw(err) end @@ -148,7 +148,7 @@ Attempt to find a domain for a unary, scalar function `f`. - `measure::Function`: Function that measures the size of a set of points for `f`. - `points::Vector{T}`: Ordered set of test points to construct the domain from. """ -function domain1{T}(in_domain::Function, measure::Function, points::Vector{T}) +function domain1(in_domain::Function, measure::Function, points::Vector{T}) where T # Find the connected sets of points that are in f's domain. connected_sets, set = Vector{Vector{T}}(), Vector{T}() for x in points @@ -165,17 +165,17 @@ function domain1{T}(in_domain::Function, measure::Function, points::Vector{T}) # Add the possibly yet unadded set. length(set) > 0 && push!(connected_sets, set) - # Return null if no domain could be found. - length(connected_sets) == 0 && return Nullable{Vector{T}}() + # Return nothing if no domain could be found. + length(connected_sets) == 0 && return # Pick the largest domain. - return Nullable(connected_sets[indmax(measure.(connected_sets))]) + return connected_sets[argmax(measure.(connected_sets))] end function domain1(f::Function) set = domain1(x -> in_domain(f, x), x -> maximum(x) - minimum(x), points) - isnull(set) && return Nullable{NTuple{2, Float64}}() - return Nullable((minimum(get(set)), maximum(get(set)))) + set === nothing && return + return (minimum(set), maximum(set)) end """ @@ -183,9 +183,9 @@ end Slice of a Float64 x Float64 domain. """ -type Slice2 +mutable struct Slice2 x::Float64 - y_range::Nullable{Tuple{Float64, Float64}} + y_range::Union{Tuple{Float64, Float64}, Nothing} end """ @@ -199,29 +199,29 @@ function domain2(f::Function) # Extract a set of in-domain slices. measure = x -> maximum(getfield.(x, :x)) - minimum(getfield.(x, :x)) - in_domain_slices = domain1(x -> !isnull(x.y_range), measure, slices) - isnull(in_domain_slices) && return Nullable{NTuple{2, NTuple{2, Float64}}}() + in_domain_slices = domain1(x -> x.y_range !== nothing, measure, slices) + in_domain_slices === nothing && return # Extract the x range of the domain. - xs = getfield.(get(in_domain_slices), :x) + xs = getfield.(in_domain_slices, :x) x_range = (minimum(xs), maximum(xs)) # Extract the y range of the domain. - y_ranges = get.(getfield.(get(in_domain_slices), :y_range)) + y_ranges = getfield.(in_domain_slices, :y_range) y_lower, y_upper = maximum(getindex.(y_ranges, 1)), minimum(getindex.(y_ranges, 2)) - y_lower >= y_upper && return Nullable{NTuple{2, NTuple{2, Float64}}}() + y_lower >= y_upper && return y_range = (y_lower, y_upper) - return Nullable((x_range, y_range)) + return (x_range, y_range) end # `beta`s domain cannot be determined correctly, since `beta(-.2, -.2)` doesn't throw an # error, strangely enough. -domain2(::typeof(beta)) = Nullable(((minimum(points[points .> 0]), maximum(points)), - (minimum(points[points .> 0]), maximum(points)))) +domain2(::typeof(beta)) = ((minimum(points[points .> 0]), maximum(points)), + (minimum(points[points .> 0]), maximum(points))) # Both of these functions are technically defined on the entire real line, but the left # half is troublesome due to the large number of points at which it isn't defined. As such # we restrict unit testing to the right-half. -domain1(::typeof(gamma)) = Nullable((minimum(points[points .> 0]), maximum(points[points .> 0]))) -domain1(::typeof(trigamma)) = Nullable((minimum(points[points .> 0]), maximum(points[points .> 0]))) +domain1(::typeof(gamma)) = (minimum(points[points .> 0]), maximum(points[points .> 0])) +domain1(::typeof(trigamma)) = (minimum(points[points .> 0]), maximum(points[points .> 0])) diff --git a/src/sensitivities/array.jl b/src/sensitivities/array.jl index 82d569d6..a9e0c01a 100644 --- a/src/sensitivities/array.jl +++ b/src/sensitivities/array.jl @@ -21,7 +21,8 @@ function Nabla.∇( ) where i l = sum([size(A[j], 2) for j in 1:(i - 1)]) u = l + size(A[i], 2) - return u > l + 1 ? slicedim(ȳ, 2, (l+1):u) : slicedim(ȳ, 2, u) + # Using copy materializes the views returned by selectdim + return copy(u > l + 1 ? selectdim(ȳ, 2, (l+1):u) : selectdim(ȳ, 2, u)) end @union_intercepts vcat Tuple{Vararg{∇Array}} Tuple{Vararg{AbstractArray}} @@ -35,5 +36,5 @@ function Nabla.∇( ) where i l = sum([size(A[j], 1) for j in 1:(i - 1)]) u = l + size(A[i], 1) - return slicedim(ȳ, 1, (l+1):u) + return copy(selectdim(ȳ, 1, (l+1):u)) end diff --git a/src/sensitivities/functional/functional.jl b/src/sensitivities/functional/functional.jl index ba7e70a2..8d226ad4 100644 --- a/src/sensitivities/functional/functional.jl +++ b/src/sensitivities/functional/functional.jl @@ -1,5 +1,4 @@ # Implementation of functionals (i.e. higher-order functions). -import Base.Broadcast.broadcast_shape # Implementation of sensitivities w.r.t. `map`. import Base.map @@ -10,20 +9,37 @@ import Base.map ∇(::typeof(map), ::Type{Arg{N}}, p, y, ȳ, f::Function, A::∇Array...) where N = _∇(map, Arg{N-1}, p, y, ȳ, f, A...) _∇(::typeof(map), arg::Type{Arg{N}}, p, y, ȳ, f::Function, A::∇Array...) where N = - method_exists(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ? + hasmethod(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ? map((yn, ȳn, An...)->∇(f, Arg{N}, p, yn, ȳn, An...), y, ȳ, A...) : map((ȳn, An...)->ȳn * fmad(f, An, Val{N}), ȳ, A...) # Deal with ambiguities introduced by `map`. map(f, x::AbstractArray{<:Number}...) = invoke(map, Tuple{Any, Vararg{Any}}, f, x...) map(f, x::AbstractArray{<:Number}) = - invoke(map, Tuple{Any, Union{AbstractArray, AbstractSet, Associative}}, f, x) + invoke(map, Tuple{Any, Union{AbstractArray, AbstractSet, AbstractDict}}, f, x) # Implementation of sensitivities w.r.t. `broadcast`. -import Base.broadcast -@union_intercepts broadcast Tuple{Any, Vararg{∇Scalar}} Tuple{Any, Vararg{Number}} -@explicit_intercepts broadcast Tuple{Any, Any} [false, true] -@union_intercepts broadcast Tuple{Any, Vararg{∇ArrayOrScalar}} Tuple{Any, Any, Vararg} +using Base.Broadcast +using Base.Broadcast: Broadcasted, broadcastable, broadcast_axes, broadcast_shape + +struct NodeStyle{S} <: BroadcastStyle end + +Base.BroadcastStyle(::Type{<:Node{T}}) where {T} = NodeStyle{BroadcastStyle(T)}() + +Base.BroadcastStyle(::NodeStyle{S}, ::NodeStyle{S}) where {S} = NodeStyle{S}() +Base.BroadcastStyle(::NodeStyle{S1}, ::NodeStyle{S2}) where {S1,S2} = + NodeStyle{BroadcastStyle(S1, S2)}() +Base.BroadcastStyle(::NodeStyle{S}, B::BroadcastStyle) where {S} = + NodeStyle{BroadcastStyle(S, B)}() + +Broadcast.broadcast_axes(x::Node) = broadcast_axes(x.val) +Broadcast.broadcastable(x::Node) = x + +function Base.copy(bc::Broadcasted{NodeStyle{S}}) where S + args = bc.args + tape = getfield(args[findfirst(x->x isa Node, args)], :tape) + return Branch(broadcast, (bc.f, args...), tape) +end """ broadcastsum!(f::Function, add::Bool, z, As...) @@ -34,7 +50,7 @@ the current value of z, otherwise it is overwritten. function broadcastsum!(f::Function, add::Bool, z, As...) tmp_shape = broadcast_shape(map(size, As)...) if size(z) != tmp_shape - tmp = Array{eltype(z)}(tmp_shape) + tmp = Array{eltype(z)}(undef, tmp_shape) return sum!(z, broadcast!(f, tmp, As...), init=!add) else return add ? @@ -49,7 +65,7 @@ end Allocating version of broadcastsum! specialised for Arrays. """ broadcastsum(f, add::Bool, z::AbstractArray, As...) = - broadcastsum!(f, add, Array{eltype(z)}(size(z)), As...) + broadcastsum!(f, add, Array{eltype(z)}(undef, size(z)), As...) """ broadcastsum(f, add::Bool, z::Number, As...) @@ -57,7 +73,7 @@ broadcastsum(f, add::Bool, z::AbstractArray, As...) = Specialisation of broadcastsum to Number-sized outputs. """ function broadcastsum(f, add::Bool, z::Number, As...) - tmp = Array{eltype(z)}(broadcast_shape(map(size, As)...)) + tmp = Array{eltype(z)}(undef, broadcast_shape(map(size, As)...)) return sum(broadcast!(f, tmp, As...)) + (add ? z : zero(z)) end @@ -65,7 +81,7 @@ end ∇(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A::∇ArrayOrScalar...) where N = _∇(broadcast, Arg{N-1}, p, y, ȳ, f, A...) _∇(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A...) where N = - method_exists(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ? + hasmethod(∇, Tuple{typeof(f), Type{Arg{N}}, Any, Any, Any, map(eltype, A)...}) ? broadcastsum((yn, ȳn, xn...)->∇(f, Arg{N}, p, yn, ȳn, xn...), false, A[N], y, ȳ, A...) : broadcastsum((ȳn, xn...)->ȳn * fmad(f, xn, Val{N}), false, A[N], ȳ, A...) diff --git a/src/sensitivities/functional/reduce.jl b/src/sensitivities/functional/reduce.jl index 0ba53c0f..41bd630d 100644 --- a/src/sensitivities/functional/reduce.jl +++ b/src/sensitivities/functional/reduce.jl @@ -1,14 +1,13 @@ -export mapreduce, mapfoldl, mapfoldr, mapreducedim - # Intercepts for `mapreduce`, `mapfoldl` and `mapfoldr` under `op` `+`. -type_tuple = :(Tuple{Any, typeof(+), ∇ArrayOrScalar}) -for (f, base_f_name) in ((:mapreduce, :(Base.mapreduce)), - (:mapfoldl, :(Base.mapfoldl)), - (:mapfoldr, :(Base.mapfoldr))) - @eval import Base.$f - @eval @explicit_intercepts $f $type_tuple [false, false, true] - @eval ∇(::typeof($f), ::Type{Arg{3}}, p, y, ȳ, f, ::typeof(+), A::∇ArrayOrScalar) = - method_exists(∇, Tuple{typeof(f), Type{Arg{1}}, Real}) ? - broadcast(An->ȳ * ∇(f, Arg{1}, An), A) : - broadcast(An->ȳ * fmad(f, (An,), Val{1}), A) +const plustype = Union{typeof(+), typeof(Base.add_sum)} +const type_tuple = :(Tuple{Any, $plustype, ∇ArrayOrScalar}) +for f in (:mapreduce, :mapfoldl, :mapfoldr) + @eval begin + import Base: $f + @explicit_intercepts $f $type_tuple [false, false, true] + ∇(::typeof($f), ::Type{Arg{3}}, p, y, ȳ, f, ::$plustype, A::∇ArrayOrScalar) = + hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, Real}) ? + broadcast(An->ȳ * ∇(f, Arg{1}, An), A) : + broadcast(An->ȳ * fmad(f, (An,), Val{1}), A) + end end diff --git a/src/sensitivities/functional/reducedim.jl b/src/sensitivities/functional/reducedim.jl index ed3faac3..6ee81b2e 100644 --- a/src/sensitivities/functional/reducedim.jl +++ b/src/sensitivities/functional/reducedim.jl @@ -12,7 +12,7 @@ accept_w_default = :(Tuple{Function, typeof(+), AbstractArray{<:∇Scalar}, Any, A::AbstractArray{<:∇Scalar}, region, v0=nothing, -) = method_exists(∇, Tuple{typeof(f), Type{Arg{1}}, ∇Scalar}) ? +) = hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, ∇Scalar}) ? broadcast((An, ȳn)->ȳn * ∇(f, Arg{1}, An), A, ȳ) : broadcast((An, ȳn)->ȳn * fmad(f, (An,), Val{1}), A, ȳ) diff --git a/src/sensitivities/indexing.jl b/src/sensitivities/indexing.jl index dc149865..2fa87359 100644 --- a/src/sensitivities/indexing.jl +++ b/src/sensitivities/indexing.jl @@ -6,7 +6,7 @@ for i = 1:7 end function ∇(Ā, ::typeof(getindex), ::Type{Arg{1}}, p, y, ȳ, A, inds...) - Ā[inds...] .+= ȳ + Ā[inds...] += ȳ return Ā end function ∇(Ā, ::typeof(getindex), ::Type{Arg{1}}, p, y::AbstractArray, ȳ::AbstractArray, A, inds...) diff --git a/src/sensitivities/linalg/blas.jl b/src/sensitivities/linalg/blas.jl index 5b9770e8..42b2ed94 100644 --- a/src/sensitivities/linalg/blas.jl +++ b/src/sensitivities/linalg/blas.jl @@ -1,4 +1,4 @@ -import Base.LinAlg.BLAS: asum, dot, blascopy!, nrm2, scal, scal!, gemm, gemm!, gemv, gemv!, +import LinearAlgebra.BLAS: asum, dot, blascopy!, nrm2, scal, scal!, gemm, gemm!, gemv, gemv!, syrk, symm, symm!, symv, symv!, trmm, trsm, trmv, trsv, trsv!, ger! const SA = StridedArray @@ -17,13 +17,13 @@ const SA = StridedArray [false, true, false, true, false], ) ∇(::typeof(dot), ::Type{Arg{2}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - scal!(n, z̄, blascopy!(n, y, iy, zeros(x), ix), ix) + scal!(n, z̄, blascopy!(n, y, iy, zeros(eltype(x), size(x)), ix), ix) ∇(::typeof(dot), ::Type{Arg{4}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - scal!(n, z̄, blascopy!(n, x, ix, zeros(y), iy), iy) + scal!(n, z̄, blascopy!(n, x, ix, zeros(eltype(y), size(y)), iy), iy) ∇(x̄, ::typeof(dot), ::Type{Arg{2}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - (x̄ .= x̄ .+ scal!(n, z̄, blascopy!(n, y, iy, zeros(x), ix), ix)) + (x̄ .= x̄ .+ scal!(n, z̄, blascopy!(n, y, iy, zeros(eltype(x), size(x)), ix), ix)) ∇(ȳ, ::typeof(dot), ::Type{Arg{4}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - (ȳ .= ȳ .+ scal!(n, z̄, blascopy!(n, x, ix, zeros(y), iy), iy)) + (ȳ .= ȳ .+ scal!(n, z̄, blascopy!(n, x, ix, zeros(eltype(y), size(y)), iy), iy)) # Short-form `nrm2`. @explicit_intercepts nrm2 Tuple{Union{StridedVector, Array}} @@ -37,9 +37,9 @@ const SA = StridedArray [false, true, false], ) ∇(::typeof(nrm2), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - scal!(n, ȳ / y, blascopy!(n, x, inc, zeros(x), inc), inc) + scal!(n, ȳ / y, blascopy!(n, x, inc, zeros(eltype(x), size(x)), inc), inc) ∇(x̄, ::typeof(nrm2), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - (x̄ .= x̄ .+ scal!(n, ȳ / y, blascopy!(n, x, inc, zeros(x), inc), inc)) + (x̄ .= x̄ .+ scal!(n, ȳ / y, blascopy!(n, x, inc, zeros(eltype(x), size(x)), inc), inc)) # Short-form `asum`. @explicit_intercepts asum Tuple{Union{StridedVector, Array}} @@ -53,9 +53,9 @@ const SA = StridedArray [false, true, false], ) ∇(::typeof(asum), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - scal!(n, ȳ, blascopy!(n, sign.(x), inc, zeros(x), inc), inc) + scal!(n, ȳ, blascopy!(n, sign.(x), inc, zeros(eltype(x), size(x)), inc), inc) ∇(x̄, ::typeof(asum), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - (x̄ .= x̄ .+ scal!(n, ȳ, blascopy!(n, sign.(x), inc, zeros(x), inc), inc)) + (x̄ .= x̄ .+ scal!(n, ȳ, blascopy!(n, sign.(x), inc, zeros(eltype(x), size(x)), inc), inc)) # Some weird stuff going on that I haven't figured out yet. @@ -187,7 +187,7 @@ const SA = StridedArray α::T, A::StridedMatrix{T}, x::StridedVector{T}, -) where T<:∇Scalar = uppercase(tA) == 'N' ? α * ȳ * x.' : α * x * ȳ.' +) where T<:∇Scalar = uppercase(tA) == 'N' ? α * ȳ * x' : α * x * ȳ' ∇(Ā::StridedMatrix{T}, ::typeof(gemv), ::Type{Arg{3}}, _, y, ȳ, tA::Char, α::T, @@ -256,8 +256,8 @@ const SA = StridedArray # A::StridedVecOrMat{<:∇Scalar}, # ) # triȲ = uppercase(uplo) == 'L' ? tril(Ȳ) : triu(Ȳ) -# out = gemm('N', trans, α, triȲ .+ triȲ.', A) -# return uppercase(trans) == 'N' ? out : out.' +# out = gemm('N', trans, α, triȲ .+ triȲ', A) +# return uppercase(trans) == 'N' ? out : out' # end # function ∇(Ā::StridedVecOrMat{T}, ::typeof(syrk), ::Type{Arg{4}}, p, Y, Ȳ, # uplo::Char, @@ -266,8 +266,8 @@ const SA = StridedArray # A::StridedVecOrMat{T}, # ) where T<:∇Scalar # triȲ = uppercase(uplo) == 'L' ? tril(Ȳ) : triu(Ȳ) -# out = gemm('N', trans, α, triȲ .+ triȲ.', A) -# return broadcast!((ā, δā)->ā+δā, Ā, Ā, uppercase(trans) == 'N' ? out : out.') +# out = gemm('N', trans, α, triȲ .+ triȲ', A) +# return broadcast!((ā, δā)->ā+δā, Ā, Ā, uppercase(trans) == 'N' ? out : out') # end # # `syrk` sensitivity implementations for `α=1`. @@ -307,7 +307,7 @@ function ∇(::typeof(symm), ::Type{Arg{4}}, p, Y, Ȳ, A::StridedMatrix{T}, B::StridedVecOrMat{T}, ) where T<:∇Scalar - tmp = uppercase(side) == 'L' ? Ȳ * B.' : B.'Ȳ + tmp = uppercase(side) == 'L' ? Ȳ * B' : B'Ȳ g! = uppercase(ul) == 'L' ? tril! : triu! return α * g!(tmp + tmp' - Diagonal(tmp)) end @@ -318,7 +318,7 @@ function ∇(Ā::StridedMatrix{T}, ::typeof(symm), ::Type{Arg{4}}, p, Y, Ȳ, A::StridedMatrix{T}, B::StridedVecOrMat{T}, ) where T<:∇Scalar - tmp = uppercase(side) == 'L' ? Ȳ * B.' : B.'Ȳ + tmp = uppercase(side) == 'L' ? Ȳ * B' : B'Ȳ g! = uppercase(ul) == 'L' ? tril! : triu! return broadcast!((ā, δā)->ā + δā, Ā, Ā, α * g!(tmp + tmp' - Diagonal(tmp))) end @@ -485,7 +485,7 @@ function ∇(::typeof(trmv), ::Type{Arg{4}}, p, y, ȳ, A::StridedMatrix{T}, b::StridedVector{T}, ) where T<:∇Scalar - Ā = (uppercase(ul) == 'L' ? tril! : triu!)(uppercase(ta) == 'N' ? ȳ * b.' : b * ȳ.') + Ā = (uppercase(ul) == 'L' ? tril! : triu!)(uppercase(ta) == 'N' ? ȳ * b' : b * ȳ') dA == 'U' && fill!(view(Ā, diagind(Ā)), zero(T)) return Ā end @@ -515,11 +515,11 @@ function ∇(::typeof(trsm), ::Type{Arg{6}}, p, Y, Ȳ, ) where T<:∇Scalar Ā_full = uppercase(side) == 'L' ? uppercase(ta) == 'N' ? - trsm('L', ul, 'T', dA, -1.0, A, Ȳ * Y.') : - trsm('R', ul, 'T', dA, -1.0, A, Y * Ȳ.') : + trsm('L', ul, 'T', dA, -1.0, A, Ȳ * Y') : + trsm('R', ul, 'T', dA, -1.0, A, Y * Ȳ') : uppercase(ta) == 'N' ? - trsm('R', ul, 'T', dA, -1.0, A, Y.'Ȳ) : - trsm('L', ul, 'T', dA, -1.0, A, Ȳ.'Y) + trsm('R', ul, 'T', dA, -1.0, A, Y'Ȳ) : + trsm('L', ul, 'T', dA, -1.0, A, Ȳ'Y) dA == 'U' && fill!(view(Ā_full, diagind(Ā_full)), zero(T)) return (uppercase(ul) == 'L' ? tril! : triu!)(Ā_full) end diff --git a/src/sensitivities/linalg/diagonal.jl b/src/sensitivities/linalg/diagonal.jl index 787a5c2f..0b4ab85e 100644 --- a/src/sensitivities/linalg/diagonal.jl +++ b/src/sensitivities/linalg/diagonal.jl @@ -1,5 +1,4 @@ -import Base: det, logdet, diagm, Diagonal, diag -export diag, diagm, Diagonal +import LinearAlgebra: det, logdet, diagm, Diagonal, diag const ∇ScalarDiag = Diagonal{<:∇Scalar} @@ -12,7 +11,7 @@ function ∇( ȳ::∇AbstractVector, x::∇AbstractMatrix, ) - x̄ = zeros(x) + x̄ = zeros(eltype(x), size(x)) x̄[diagind(x̄)] = ȳ return x̄ end @@ -40,7 +39,7 @@ function ∇( x::∇AbstractMatrix, k::Integer, ) - x̄ = zeros(x) + x̄ = zeros(eltype(x), size(x)) x̄[diagind(x̄, k)] = ȳ return x̄ end @@ -59,67 +58,6 @@ function ∇( return x̄ end -@explicit_intercepts diagm Tuple{∇AbstractVector} -function ∇( - ::typeof(diagm), - ::Type{Arg{1}}, - p, - Y::∇AbstractMatrix, - Ȳ::∇AbstractMatrix, - x::∇AbstractVector, -) - return copy!(similar(x), view(Ȳ, diagind(Ȳ))) -end -function ∇( - x̄::∇AbstractVector, - ::typeof(diagm), - ::Type{Arg{1}}, - p, - Y::∇AbstractMatrix, - Ȳ::∇AbstractMatrix, - x::∇AbstractVector, -) - return broadcast!(+, x̄, x̄, view(Ȳ, diagind(Ȳ))) -end - -@explicit_intercepts diagm Tuple{∇AbstractVector, Integer} [true, false] -function ∇( - ::typeof(diagm), - ::Type{Arg{1}}, - p, - Y::∇AbstractMatrix, - Ȳ::∇AbstractMatrix, - x::∇AbstractVector, - k::Integer, -) - return copy!(similar(x), view(Ȳ, diagind(Ȳ, k))) -end -function ∇( - x̄::∇AbstractVector, - ::typeof(diagm), - ::Type{Arg{1}}, - p, - Y::∇AbstractMatrix, - Ȳ::∇AbstractMatrix, - x::∇AbstractVector, - k::Integer, -) - return broadcast!(+, x̄, x̄, view(Ȳ, diagind(Ȳ, k))) -end - -@explicit_intercepts diagm Tuple{∇Scalar} -function ∇( - ::typeof(diagm), - ::Type{Arg{1}}, - p, - Y::∇AbstractMatrix, - Ȳ::∇AbstractMatrix, - x::∇Scalar, -) - length(Ȳ) != 1 && throw(error("Ȳ isn't a 1x1 matrix.")) - return Ȳ[1] -end - @explicit_intercepts Diagonal Tuple{∇AbstractVector} function ∇( ::Type{Diagonal}, @@ -129,7 +67,7 @@ function ∇( Ȳ::∇ScalarDiag, x::∇AbstractVector, ) - return copy!(similar(x), Ȳ.diag) + return copyto!(similar(x), Ȳ.diag) end function ∇( x̄::∇AbstractVector, @@ -152,8 +90,8 @@ function ∇( Ȳ::∇ScalarDiag, X::∇AbstractMatrix, ) - X̄ = zeros(X) - copy!(view(X̄, diagind(X)), Ȳ.diag) + X̄ = zeros(eltype(X), size(X)) + copyto!(view(X̄, diagind(X)), Ȳ.diag) return X̄ end function ∇( @@ -201,3 +139,61 @@ function ∇( broadcast!((x̄, x, ȳ)->x̄ + ȳ / x, X̄.diag, X̄.diag, X.diag, ȳ) return X̄ end + +# NOTE: diagm can't go through the @explicit_intercepts machinery directly because as of +# Julia 0.7, its methods are not sufficiently straightforward; we need to dispatch on one +# of the parameters in the parametric type of diagm's one argument. However, we can cheat +# a little bit and use an internal helper function _diagm that has simple methods that +# dispatch to diagm when no arguments are Nodes, and we'll extend diagm to dispatch to +# _diagm when it receives arguments that are nodes. _diagm can go through the intercepts +# machinery, so it knows how to deal. + +_diagm(x::∇AbstractVector, k::Integer=0) = diagm(k => x) +LinearAlgebra.diagm(x::Pair{<:Integer, <:Node{<:∇AbstractVector}}) = _diagm(last(x), first(x)) + +@explicit_intercepts _diagm Tuple{∇AbstractVector} +function ∇( + ::typeof(_diagm), + ::Type{Arg{1}}, + p, + Y::∇AbstractMatrix, + Ȳ::∇AbstractMatrix, + x::∇AbstractVector, +) + return copyto!(similar(x), view(Ȳ, diagind(Ȳ))) +end +function ∇( + x̄::∇AbstractVector, + ::typeof(_diagm), + ::Type{Arg{1}}, + p, + Y::∇AbstractMatrix, + Ȳ::∇AbstractMatrix, + x::∇AbstractVector, +) + return broadcast!(+, x̄, x̄, view(Ȳ, diagind(Ȳ))) +end +@explicit_intercepts _diagm Tuple{∇AbstractVector, Integer} [true, false] +function ∇( + ::typeof(_diagm), + ::Type{Arg{1}}, + p, + Y::∇AbstractMatrix, + Ȳ::∇AbstractMatrix, + x::∇AbstractVector, + k::Integer, +) + return copyto!(similar(x), view(Ȳ, diagind(Ȳ, k))) +end +function ∇( + x̄::∇AbstractVector, + ::typeof(_diagm), + ::Type{Arg{1}}, + p, + Y::∇AbstractMatrix, + Ȳ::∇AbstractMatrix, + x::∇AbstractVector, + k::Integer, +) + return broadcast!(+, x̄, x̄, view(Ȳ, diagind(Ȳ, k))) +end diff --git a/src/sensitivities/linalg/factorization/cholesky.jl b/src/sensitivities/linalg/factorization/cholesky.jl index 2ca003b6..1855a77b 100644 --- a/src/sensitivities/linalg/factorization/cholesky.jl +++ b/src/sensitivities/linalg/factorization/cholesky.jl @@ -1,5 +1,5 @@ -import Base.LinAlg.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! -import Base.LinAlg.chol +import LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! +import LinearAlgebra.chol #= See [1] for implementation details: pages 5-9 in particular. The derivations presented in @@ -14,7 +14,7 @@ const AM = AbstractMatrix const UT = UpperTriangular @explicit_intercepts chol Tuple{AbstractMatrix{<:∇Scalar}} ∇(::typeof(chol), ::Type{Arg{1}}, p, U::UT{T}, Ū::AM{T}, Σ::AM{T}) where T<:∇Scalar = - chol_blocked_rev(full(Ū), full(U), 25, true) + chol_blocked_rev(Matrix(Ū), Matrix(U), 25, true) """ level2partition(A::AbstractMatrix, j::Int, upper::Bool) @@ -139,7 +139,7 @@ function chol_blocked_rev!(Σ̄::AM{T}, L::AM{T}, Nb::Int, upper::Bool) where T< M, N = size(Σ̄) M != N && throw(ArgumentError("Σ̄ is not square.")) - tmp = Matrix{T}(Nb, Nb) + tmp = Matrix{T}(undef, Nb, Nb) # Compute the reverse-mode diff. k = N diff --git a/src/sensitivities/linalg/generic.jl b/src/sensitivities/linalg/generic.jl index 6c7f7ce6..30b3d478 100644 --- a/src/sensitivities/linalg/generic.jl +++ b/src/sensitivities/linalg/generic.jl @@ -2,17 +2,21 @@ _ϵ, lb, ub = 3e-2, -3.0, 3.0 unary_linalg_optimisations = [ (:-, ∇Array, ∇Array, :(map(-, Ȳ)), (lb, ub)), - (:trace, ∇Array, ∇Scalar, :(Diagonal(fill!(similar(X), Ȳ))), (lb, ub)), + (:tr, ∇Array, ∇Scalar, :(Diagonal(fill!(similar(X), Ȳ))), (lb, ub)), (:inv, ∇Array, ∇Array, :(-transpose(Y) * Ȳ * transpose(Y)), (lb, ub)), (:det, ∇Array, ∇Scalar, :(Y * Ȳ * transpose(inv(X))), (_ϵ, ub)), (:logdet, ∇Array, ∇Scalar, :(Ȳ * transpose(inv(X))), (_ϵ, ub)), (:transpose, ∇Array, ∇Array, :(transpose(Ȳ)), (lb, ub)), - (:ctranspose, ∇Array, ∇Array, :(ctranspose(Ȳ)), (lb, ub)), - (:vecnorm, ∇Array, ∇Scalar, :(Ȳ ./ Y .* abs2.(X) ./ X), (lb, ub)), - (:vecnorm, ∇Scalar, ∇Scalar, :(Ȳ * sign(X)), (lb, ub)) + (:adjoint, ∇Array, ∇Array, :(adjoint(Ȳ)), (lb, ub)), + (:norm, ∇Array, ∇Scalar, :(Ȳ ./ Y .* abs2.(X) ./ X), (lb, ub)), + (:norm, ∇Scalar, ∇Scalar, :(Ȳ * sign(X)), (lb, ub)) ] for (f, T_In, T_Out, X̄, bounds) in unary_linalg_optimisations - @eval import Base.$f + if f === :- + @eval import Base: - + else + @eval import LinearAlgebra: $f + end @eval @explicit_intercepts $f Tuple{$T_In} @eval ∇(::typeof($f), ::Type{Arg{1}}, p, Y::$T_Out, Ȳ::$T_Out, X::$T_In) = $X̄ end @@ -21,88 +25,92 @@ end const A = ∇Array const S = ∇Scalar const AS = Union{∇Scalar, ∇Array} +const AT = Transpose{<:∇Scalar, ∇Array} +const AH = Adjoint{<:∇Scalar, ∇Array} δ = 1e-5 binary_linalg_optimisations = [ - (:*, A, A, AS, - :(A_mul_Bc(Ȳ, B)), - :(Ac_mul_B(A, Ȳ))), - (:At_mul_B, A, A, AS, - :(A_mul_Bt(B, Ȳ)), - :(getfield(Base, :*)(A, Ȳ))), - (:A_mul_Bt, A, A, AS, - :(getfield(Base, :*)(Ȳ, B)), - :(At_mul_B(Ȳ, A))), - (:At_mul_Bt, A, A, AS, - :(At_mul_Bt(B, Ȳ)), - :(At_mul_Bt(Ȳ, A))), - (:Ac_mul_B, A, A, AS, - :(A_mul_Bt(B, Ȳ)), - :(getfield(Base, :*)(A, Ȳ))), - (:A_mul_Bc, A, A, AS, - :(getfield(Base, :*)(Ȳ, B)), - :(Ac_mul_B(Ȳ, A))), - (:Ac_mul_Bc, A, A, AS, - :(Ac_mul_Bc(B, Ȳ)), - :(Ac_mul_Bc(Ȳ, A))), - (:/, A, A, AS, - :(A_rdiv_Bt(Ȳ, B)), - :(-At_mul_B(Y, A_rdiv_Bt(Ȳ, B)))), - (:At_rdiv_B, A, A, AS, - :(A_ldiv_Bt(B, Ȳ)), - :(-At_mul_B(Y, A_rdiv_Bt(Ȳ, B)))), - (:A_rdiv_Bt, A, A, AS, - :(getfield(Base, :/)(Ȳ, B)), - :(-At_ldiv_Bt(B, Ȳ) * Y)), - (:At_rdiv_Bt, A, A, AS, - :(At_ldiv_Bt(B, Ȳ)), - :(-At_ldiv_Bt(B, Ȳ) * Y)), - (:Ac_rdiv_B, A, A, AS, - :(A_ldiv_Bc(B, Ȳ)), - :(-Ac_mul_B(Y, A_rdiv_Bc(Ȳ, B)))), - (:A_rdiv_Bc, A, A, AS, - :(getfield(Base, :/)(Ȳ, B)), - :(-At_ldiv_Bt(B, Ȳ) * Y)), - (:Ac_rdiv_Bc, A, A, AS, - :(Ac_ldiv_Bc(B, Ȳ)), - :(-Ac_ldiv_Bc(B, Ȳ) * Y)), - (:\, A, A, AS, - :(-A_mul_Bt(At_ldiv_B(A, Ȳ), Y)), - :(At_ldiv_B(A, Ȳ))), - (:At_ldiv_B, A, A, AS, - :(-A_mul_Bt(Y, getfield(Base, :\)(A, Ȳ))), - :(getfield(Base, :\)(A, Ȳ))), - (:A_ldiv_Bt, A, A, AS, - :(-At_mul_Bt(At_rdiv_B(Ȳ, A), Y)), - :(At_rdiv_B(Ȳ, A))), - (:At_ldiv_Bt, A, A, AS, - :(-Y * At_rdiv_Bt(Ȳ, A)), - :(At_rdiv_Bt(Ȳ, A))), - (:Ac_ldiv_B, A, A, AS, - :(-A_mul_Bc(Y, getfield(Base, :\)(A, Ȳ))), - :(getfield(Base, :\)(A, Ȳ))), - (:A_ldiv_Bc, A, A, AS, - :(-Ac_mul_Bc(Ac_rdiv_B(Ȳ, A), Y)), - :(Ac_rdiv_B(Ȳ, A))), - (:Ac_ldiv_Bc, A, A, AS, - :(-Y * Ac_rdiv_Bc(Ȳ, A)), - :(Ac_rdiv_Bc(Ȳ, A))), - (:vecnorm, A, S, S, + (:*, A, A, AS, + :(Ȳ * B'), + :(A' * Ȳ)), + (:*, AT, A, AS, + :(B * transpose(Ȳ)), + :(A * Ȳ)), + (:*, A, AT, AS, + :(Ȳ * B), + :(transpose(Ȳ) * A)), + (:*, AT, AT, AS, + :(transpose(B) * transpose(Ȳ)), + :(transpose(Ȳ) * transpose(A))), + (:*, AH, A, AS, + :(B * transpose(Ȳ)), + :(A * Ȳ)), + (:*, A, AH, AS, + :(Ȳ * B), + :(Ȳ' * A)), + (:*, AH, AH, AS, + :(B' * Ȳ'), + :(Ȳ' * A')), + (:/, A, A, AS, + :(Ȳ / transpose(B)), + :(-transpose(Y) * (Ȳ / transpose(B)))), + (:/, AT, A, AS, + :(B \ transpose(Ȳ)), + :(-transpose(Y) * (Ȳ / transpose(B)))), + (:/, A, AT, AS, + :(Ȳ / B), + :(-(transpose(B) \ transpose(Ȳ)) * Y)), + (:/, AT, AT, AS, + :(transpose(B) \ transpose(Ȳ)), + :(-(transpose(B) \ transpose(Ȳ)) * Y)), + (:/, AH, A, AS, + :(B \ Ȳ'), + :(-Y' * (Ȳ / B'))), + (:/, A, AH, AS, + :(Ȳ / B), + :(-(transpose(B) \ transpose(Ȳ)) * Y)), + (:/, AH, AH, AS, + :(B' \ Ȳ'), + :(-(B' \ Ȳ') * Y)), + (:\, A, A, AS, + :(-(transpose(A) \ Ȳ) * transpose(Y)), + :(transpose(A) \ Ȳ)), + (:\, AT, A, AS, + :(-Y * transpose(A \ Ȳ)), + :(A \ Ȳ)), + (:\, A, AT, AS, + :(-transpose(transpose(Ȳ) / A) * transpose(Y)), + :(transpose(Ȳ) / A)), + (:\, AT, AT, AS, + :(-Y * (transpose(Ȳ) / transpose(A))), + :(transpose(Ȳ) / transpose(A))), + (:\, AH, A, AS, + :(-Y * (A \ Ȳ)'), + :(A \ Ȳ)), + (:\, A, AH, AS, + :(-(Ȳ' / A)' * Y), + :(Ȳ' / A)), + (:\, AH, AH, AS, + :(-Y * (Ȳ' / A')), + :(Ȳ' / A')), + (:norm, A, S, S, :(Ȳ .* Y^(1 - B) .* abs.(A).^B ./ A), :(Ȳ * (Y^(1 - B) * sum(abs.(A).^B .* log.(abs.(A))) - Y * log(Y)) / B)), - (:vecnorm, S, S, S, + (:norm, S, S, S, :(Ȳ * sign(A)), :(0)), - ] +import Base: *, /, \ +import LinearAlgebra: norm for (f, T_A, T_B, T_Y, Ā, B̄) in binary_linalg_optimisations - @eval import Base.$f - @eval @explicit_intercepts $f Tuple{$T_A, $T_B} - @eval ∇(::typeof($f), ::Type{Arg{1}}, p, Y::$T_Y, Ȳ::$T_Y, A::$T_A, B::$T_B) = $Ā - @eval ∇(::typeof($f), ::Type{Arg{2}}, p, Y::$T_Y, Ȳ::$T_Y, A::$T_A, B::$T_B) = $B̄ + @eval begin + @explicit_intercepts $f Tuple{$T_A, $T_B} + ∇(::typeof($f), ::Type{Arg{1}}, p, Y::$T_Y, Ȳ::$T_Y, A::$T_A, B::$T_B) = $Ā + ∇(::typeof($f), ::Type{Arg{2}}, p, Y::$T_Y, Ȳ::$T_Y, A::$T_A, B::$T_B) = $B̄ + end end # Sensitivities for the Kronecker product: -import Base.kron +import LinearAlgebra: kron @explicit_intercepts kron Tuple{A, A} # The allocating versions simply allocate and then call the in-place versions. diff --git a/src/sensitivities/linalg/strided.jl b/src/sensitivities/linalg/strided.jl index 80ddeecd..6aa2708f 100644 --- a/src/sensitivities/linalg/strided.jl +++ b/src/sensitivities/linalg/strided.jl @@ -2,32 +2,34 @@ # BLAS for matrix-vector stuff yet. Definitely an optimisation that we might want to # consider at some point in the future though. const RS = StridedMatrix{<:∇Scalar} +const RST = Transpose{<:∇Scalar, RS} +const RSA = Adjoint{<:∇Scalar, RS} strided_matmul = [ - (:*, 'N', 'C', :Ȳ, :B, 'C', 'N', :A, :Ȳ), - (:At_mul_B, 'N', 'T', :B, :Ȳ, 'N', 'N', :A, :Ȳ), - (:A_mul_Bt, 'N', 'N', :Ȳ, :B, 'T', 'N', :Ȳ, :A), - (:At_mul_Bt, 'T', 'T', :B, :Ȳ, 'T', 'T', :Ȳ, :A), - (:Ac_mul_B, 'N', 'C', :B, :Ȳ, 'N', 'N', :A, :Ȳ), - (:A_mul_Bc, 'N', 'N', :Ȳ, :B, 'C', 'N', :Ȳ, :A), - (:Ac_mul_Bc, 'C', 'C', :B, :Ȳ, 'C', 'C', :Ȳ, :A), + (RS, RS, 'N', 'C', :Ȳ, :B, 'C', 'N', :A, :Ȳ), + (RST, RS, 'N', 'T', :B, :Ȳ, 'N', 'N', :A, :Ȳ), + (RS, RST, 'N', 'N', :Ȳ, :B, 'T', 'N', :Ȳ, :A), + (RST, RST, 'T', 'T', :B, :Ȳ, 'T', 'T', :Ȳ, :A), + (RSA, RS, 'N', 'C', :B, :Ȳ, 'N', 'N', :A, :Ȳ), + (RS, RSA, 'N', 'N', :Ȳ, :B, 'C', 'N', :Ȳ, :A), + (RSA, RSA, 'C', 'C', :B, :Ȳ, 'C', 'C', :Ȳ, :A), ] -for (f, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in strided_matmul +import Base: * +for (TA, TB, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in strided_matmul # Add intercepts and export names. - @eval import Base.$f - @eval @explicit_intercepts $f Tuple{RS, RS} + @eval @explicit_intercepts $(Symbol("*")) Tuple{$TA, $TB} # Define allocating and non-allocating sensitivities for each output. - alloc_Ā = :(Base.BLAS.gemm($tCA, $tDA, $CA, $DA)) - alloc_B̄ = :(Base.BLAS.gemm($tCB, $tDB, $CB, $DB)) - no_alloc_Ā = :(Base.BLAS.gemm!($tCA, $tDA, 1., $CA, $DA, 1., Ā)) - no_alloc_B̄ = :(Base.BLAS.gemm!($tCB, $tDB, 1., $CB, $DB, 1., B̄)) + alloc_Ā = :(LinearAlgebra.BLAS.gemm($tCA, $tDA, $CA, $DA)) + alloc_B̄ = :(LinearAlgebra.BLAS.gemm($tCB, $tDB, $CB, $DB)) + no_alloc_Ā = :(LinearAlgebra.BLAS.gemm!($tCA, $tDA, 1., $CA, $DA, 1., Ā)) + no_alloc_B̄ = :(LinearAlgebra.BLAS.gemm!($tCB, $tDB, 1., $CB, $DB, 1., B̄)) # Add sensitivity definitions. - @eval ∇(::typeof($f), ::Type{Arg{1}}, p, Y::RS, Ȳ::RS, A::RS, B::RS) = $alloc_Ā - @eval ∇(::typeof($f), ::Type{Arg{2}}, p, Y::RS, Ȳ::RS, A::RS, B::RS) = $alloc_B̄ - @eval ∇(Ā, ::typeof($f), ::Type{Arg{1}}, p, Y::RS, Ȳ::RS, A::RS, B::RS) = $no_alloc_Ā - @eval ∇(B̄, ::typeof($f), ::Type{Arg{2}}, p, Y::RS, Ȳ::RS, A::RS, B::RS) = $no_alloc_B̄ + @eval ∇(::typeof(*), ::Type{Arg{1}}, p, Y::RS, Ȳ::RS, A::$TA, B::$TB) = $alloc_Ā + @eval ∇(::typeof(*), ::Type{Arg{2}}, p, Y::RS, Ȳ::RS, A::$TA, B::$TB) = $alloc_B̄ + @eval ∇(Ā, ::typeof(*), ::Type{Arg{1}}, p, Y::RS, Ȳ::RS, A::$TA, B::$TB) = $no_alloc_Ā + @eval ∇(B̄, ::typeof(*), ::Type{Arg{2}}, p, Y::RS, Ȳ::RS, A::$TA, B::$TB) = $no_alloc_B̄ end # # Not every permutation of transpositions makes sense for matrix-vector multiplication. This @@ -38,7 +40,7 @@ end # (:Ac_mul_B, 'C', :b, :ȳ, 'N'), # ] # for (f, tdA, CA, dA, tCb) in strided_matvecmul -# n_Ā, u_Ā = tdA == 'C' ? :(Ā = $CA * $dA') : :(Ā = $CA * $dA.'), :(ger!(1., $CA, $dA, Ā)) +# n_Ā, u_Ā = tdA == 'C' ? :(Ā = $CA * $dA') : :(Ā = $CA * $dA'), :(ger!(1., $CA, $dA, Ā)) # n_b̄, u_b̄ = :(b̄ = gemv($tCb, A, ȳ)), :(b̄ = gemv!($tCb, 1., A, ȳ, 1., b̄)) # generate_primitive(f, [:(T <: StridedMatrix), :(V <: StridedVector)], # [:A, :b], [:Ā, :b̄], [:T, :V], [true, true], :y, :ȳ, [n_Ā, n_b̄], [u_Ā, u_b̄]) diff --git a/src/sensitivities/linalg/symmetric.jl b/src/sensitivities/linalg/symmetric.jl index b7033f76..f362452a 100644 --- a/src/sensitivities/linalg/symmetric.jl +++ b/src/sensitivities/linalg/symmetric.jl @@ -1,4 +1,4 @@ -import Base.Symmetric +import LinearAlgebra: Symmetric @explicit_intercepts Symmetric Tuple{∇Array} ∇(::typeof(Symmetric), ::Type{Arg{1}}, p, Y::∇Array, Ȳ::∇Array, X::∇Array) = UpperTriangular(Ȳ) + LowerTriangular(Ȳ)' - Diagonal(Ȳ) diff --git a/src/sensitivities/linalg/triangular.jl b/src/sensitivities/linalg/triangular.jl index d343538a..6fcb4c83 100644 --- a/src/sensitivities/linalg/triangular.jl +++ b/src/sensitivities/linalg/triangular.jl @@ -1,5 +1,4 @@ -import Base: det, logdet, LowerTriangular, UpperTriangular -export det, logdet, LowerTriangular, UpperTriangular +import LinearAlgebra: det, logdet, LowerTriangular, UpperTriangular const ∇ScalarLT = LowerTriangular{<:∇Scalar} const ∇ScalarUT = UpperTriangular{<:∇Scalar} @@ -7,7 +6,7 @@ const ∇ScalarUT = UpperTriangular{<:∇Scalar} for (ctor, T) in zip([:LowerTriangular, :UpperTriangular], [:∇ScalarLT, :∇ScalarUT]) @eval @explicit_intercepts $ctor Tuple{∇AbstractMatrix} - @eval ∇(::Type{$ctor}, ::Type{Arg{1}}, p, Y::$T, Ȳ::$T, X::∇AbstractMatrix) = full(Ȳ) + @eval ∇(::Type{$ctor}, ::Type{Arg{1}}, p, Y::$T, Ȳ::$T, X::∇AbstractMatrix) = Matrix(Ȳ) @eval ∇( X̄::∇AbstractMatrix, ::Type{$ctor}, diff --git a/src/sensitivities/scalar.jl b/src/sensitivities/scalar.jl index 02e927f9..09cd7008 100644 --- a/src/sensitivities/scalar.jl +++ b/src/sensitivities/scalar.jl @@ -9,16 +9,14 @@ import Base.identity @inline ∇(::typeof(identity), ::Type{Arg{1}}, p, y, ȳ, x) = ȳ @inline ∇(::typeof(identity), ::Type{Arg{1}}, x::Real) = one(x) -DualNumbers.epsilon(::Real) = 0.0 - # Ignore functions that have complex ranges. This may change when Nabla supports complex # numbers. ignored_fs = [(:SpecialFunctions, :hankelh1), (:SpecialFunctions, :hankelh2), - (:(Base.Math.JuliaLibm), :log1p), + (:Base, :log1p), (:Base, :rem2pi), (:Base, :mod), - (:Base, :atan2), + (:Base, :atan), (:Base, :rem)] unary_sensitivities, binary_sensitivities = [], [] @@ -26,7 +24,7 @@ unary_sensitivities, binary_sensitivities = [], [] for (package, f, arity) in diffrules() (package == :NaNMath || (package, f) in ignored_fs) && continue - @eval import $package.$f + @eval import $package: $f if arity == 1 push!(unary_sensitivities, (package, f)) ∂f∂x = diffrule(package, f, :x) diff --git a/src/sensitivity.jl b/src/sensitivity.jl index 007a5a16..9392501b 100644 --- a/src/sensitivity.jl +++ b/src/sensitivity.jl @@ -20,7 +20,7 @@ The work-horse for `@union_intercepts`. function union_intercepts(f::Symbol, type_tuple::Expr, invoke_type_tuple::Expr) call, arg_names = get_union_call(f, type_tuple) body = get_body(f, type_tuple, arg_names, invoke_type_tuple) - return Expr(:macrocall, Symbol("@generated"), Expr(:function, call, body)) + return Expr(:macrocall, Symbol("@generated"), nothing, Expr(:function, call, body)) end """ @@ -133,7 +133,7 @@ function get_body( Expr(Symbol("="), :x, Expr(:tuple, args_dotted)), Expr(Symbol("="), :x_syms, Expr(:tuple, args_dotted_quot)), Expr(Symbol("="), :x_dots, sym_arg_tuple), - Expr(Symbol("="), :is_node, :([any(issubtype.(xj, Node)) for xj in x])), + Expr(Symbol("="), :is_node, :([any((<:).(xj, Node)) for xj in x])), Expr(:return, Expr(:if, Expr(:call, :any, :is_node), :(Nabla.branch_expr( @@ -175,7 +175,7 @@ Get an expression which will obtain the tape from a Node object in `x`. function tape_expr(x::Tuple, syms::NTuple{N, Symbol} where N, is_node::Vector{Bool}) idx = findfirst(is_node) if idx == length(is_node) && isa(x[end], Tuple) - node_idx = findfirst([issubtype(varg, Node) for varg in x[end]]) + node_idx = findfirst([varg <: Node for varg in x[end]]) return Expr(:call, :getfield, Expr(:ref, syms[end], node_idx), quot(:tape)) else return Expr(:call, :getfield, syms[idx], quot(:tape)) diff --git a/test/code_transformation/differentiable.jl b/test/code_transformation/differentiable.jl index 7951feb0..7e9b6e2f 100644 --- a/test/code_transformation/differentiable.jl +++ b/test/code_transformation/differentiable.jl @@ -1,3 +1,17 @@ +# NOTE: As of https://github.com/JuliaLang/julia/pull/21746, all macro calls have +# a LineNumberNode. Since our functions that do expression transformation don't have +# access to the source code locations, we don't insert a LineNumberNode, which means +# that a transformed expression won't compare equal to an expression literal due to +# line information. Thus for testing purposes we define a function that removes line +# information from expressions as well as a custom equality operator that uses this. +# Note that it's only necessary for comparisons that deal with macro calls. +function skip_line_info(ex::Expr) + map!(arg->arg isa LineNumberNode ? nothing : skip_line_info(arg), ex.args, ex.args) + ex +end +skip_line_info(ex) = ex +≃(a, b) = skip_line_info(a) == skip_line_info(b) + @testset "code_transformation/differentiable" begin import Nabla.Nabla: unionise_type, unionise_arg, unionise_subtype, unionise_eval, @@ -26,17 +40,17 @@ @test unionise_subtype(:(T<:V)) == :(T<:$(unionise_type(:V))) # Test Nabla.unionise_eval. - @test unionise_eval(:(eval(:foo))) == :(eval(:(@unionise foo))) - @test unionise_eval(:(eval(DiffBase, :foo))) == :(eval(DiffBase, :(@unionise foo))) - @test unionise_eval(:(eval(:(println("foo"))))) == :(eval(:(@unionise println("foo")))) - @test unionise_eval(:(eval(DiffBase, :(println("foo"))))) == + @test unionise_eval(:(eval(:foo))) ≃ :(eval(:(@unionise foo))) + @test unionise_eval(:(eval(DiffBase, :foo))) ≃ :(eval(DiffBase, :(@unionise foo))) + @test unionise_eval(:(eval(:(println("foo"))))) ≃ :(eval(:(@unionise println("foo")))) + @test unionise_eval(:(eval(DiffBase, :(println("foo"))))) ≃ :(eval(DiffBase, :(@unionise println("foo")))) # Test Nabla.unionise_macro_eval. - @test unionise_macro_eval(:(@eval foo)) == :(@eval @unionise foo) - @test unionise_macro_eval(:(@eval DiffBase foo)) == :(@eval DiffBase @unionise foo) - @test unionise_macro_eval(:(@eval println("foo"))) == :(@eval @unionise println("foo")) - @test unionise_macro_eval(:(@eval DiffBase println("foo"))) == + @test unionise_macro_eval(:(@eval foo)) ≃ :(@eval @unionise foo) + @test unionise_macro_eval(:(@eval DiffBase foo)) ≃ :(@eval DiffBase @unionise foo) + @test unionise_macro_eval(:(@eval println("foo"))) ≃ :(@eval @unionise println("foo")) + @test unionise_macro_eval(:(@eval DiffBase println("foo"))) ≃ :(@eval DiffBase @unionise println("foo")) # Test Nabla.unionise. This depends upon Nabla.unionise_arg, so we express @@ -95,7 +109,7 @@ Expr(Symbol("="), unionise_sig(:(foo(x::T, y::T) where T)), :x) @test unionise(:(eval(:foo))) == unionise_eval(:(eval(:foo))) @test unionise(:(eval(DiffBase, :foo))) == unionise_eval(:(eval(DiffBase, :foo))) - @test unionise(:(@eval foo)) == unionise_macro_eval(:(@eval foo)) - @test unionise(:(@eval DiffBase foo)) == unionise_macro_eval(:(@eval DiffBase foo)) + @test unionise(:(@eval foo)) ≃ unionise_macro_eval(:(@eval foo)) + @test unionise(:(@eval DiffBase foo)) ≃ unionise_macro_eval(:(@eval DiffBase foo)) @test unionise(:(struct Foo{T<:V} end)) == unionise_struct(:(struct Foo{T<:V} end)) end diff --git a/test/core.jl b/test/core.jl index d8916ed5..5a093823 100644 --- a/test/core.jl +++ b/test/core.jl @@ -3,7 +3,7 @@ let # Simple tests for `Tape`. @test getindex(setindex!(Tape(5), "hi", 5), 5) == "hi" - @test endof(Tape(50)) == 50 + @test lastindex(Tape(50)) == 50 @test eachindex(Tape(50)) == Base.OneTo(50) @test length(Tape()) == 0 @test length(Tape(50)) == 50 @@ -14,24 +14,24 @@ let let buffer = IOBuffer() show(buffer, Tape()) - @test String(buffer) == "Empty tape.\n" + @test String(take!(buffer)) == "Empty tape.\n" end let buffer = IOBuffer() show(buffer, Tape(1)) - @test String(buffer) == "1 #undef\n" + @test String(take!(buffer)) == "1 #undef\n" end let buffer = IOBuffer() show(buffer, Tape(2)) - @test String(buffer) == "1 #undef\n2 #undef\n" + @test String(take!(buffer)) == "1 #undef\n2 #undef\n" end let buffer = IOBuffer() tape = Tape(1) tape[1] = 5 show(buffer, tape) - @test String(buffer) == "1 5\n" + @test String(take!(buffer)) == "1 5\n" end # Check isassigned consistency. diff --git a/test/finite_differencing.jl b/test/finite_differencing.jl index 656f13d3..a33fd26d 100644 --- a/test/finite_differencing.jl +++ b/test/finite_differencing.jl @@ -13,8 +13,8 @@ import Nabla: ∇, compute_Dv, approximate_Dv, compute_Dv_update @explicit_intercepts foo Tuple{Matrix{<:∇Scalar}, Matrix{<:∇Scalar}} # Define sensitivity implementations. - const _Vec = Vector{<:∇Scalar} - const _Mat = Matrix{<:∇Scalar} + _Vec = Vector{<:∇Scalar} + _Mat = Matrix{<:∇Scalar} Nabla.∇(::typeof(foo), ::Type{Arg{1}}, p, z, z̄, x::∇Scalar) = 5z̄ Nabla.∇(::typeof(foo), ::Type{Arg{1}}, p, z, z̄, x::∇Scalar, y::∇Scalar) = 5z̄ Nabla.∇(::typeof(foo), ::Type{Arg{1}}, p, z, z̄, x::_Vec, y::_Vec) = 10z̄ diff --git a/test/runtests.jl b/test/runtests.jl index 23364a6f..6448f31b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ -using Base.Test, Nabla -using Distributions, BenchmarkTools +using Nabla +using Test, LinearAlgebra, Random +using Distributions, BenchmarkTools, SpecialFunctions, DualNumbers @testset "Core" begin include("core.jl") @@ -20,7 +21,7 @@ end @testset "Functional" begin include("sensitivities/functional/functional.jl") include("sensitivities/functional/reduce.jl") - include("sensitivities/functional/reducedim.jl") + #include("sensitivities/functional/reducedim.jl") end # Test sensitivities for linear algebra optimisations. @@ -33,8 +34,8 @@ end include("sensitivities/linalg/strided.jl") include("sensitivities/linalg/blas.jl") - @testset "Factorisations" begin - include("sensitivities/linalg/factorization/cholesky.jl") - end + #@testset "Factorisations" begin + # include("sensitivities/linalg/factorization/cholesky.jl") + #end end end diff --git a/test/sensitivities/functional/functional.jl b/test/sensitivities/functional/functional.jl index edb1d52c..425b0116 100644 --- a/test/sensitivities/functional/functional.jl +++ b/test/sensitivities/functional/functional.jl @@ -1,9 +1,14 @@ using SpecialFunctions using DiffRules: diffrule, hasdiffrule +# ones(::AbstractArray) is deprecated in 0.7 and removed in 1.0, but it's a pretty useful +# method, so we'll define our own for testing purposes +oneslike(a::AbstractArray) = ones(eltype(a), size(a)) +oneslike(n::Integer) = ones(n) + @testset "Functional" begin # Apparently Distributions.jl doesn't implement the following, so we'll have to do it. - Distributions.rand(rng::AbstractRNG, a::Distributions.Distribution, n::Integer) = + Random.rand(rng::AbstractRNG, a::Distribution, n::Integer) = [rand(rng, a) for _ in 1:n] let rng = MersenneTwister(123456) @@ -22,12 +27,12 @@ using DiffRules: diffrule, hasdiffrule function check_unary_broadcast(f, x) x_ = Leaf(Tape(), x) s = broadcast(f, x_) - return ∇(s, ones(s.val))[x_] ≈ ∇.(f, Arg{1}, x) + return ∇(s, oneslike(s.val))[x_] ≈ ∇.(f, Arg{1}, x) end for (package, f) in Nabla.unary_sensitivities domain = domain1(eval(f)) - isnull(domain) && error("Could not determine domain for $f.") - x_dist = Uniform(get(domain)...) + domain === nothing && error("Could not determine domain for $f.") + x_dist = Uniform(domain...) x = rand(rng, x_dist, 100) @test check_unary_broadcast(eval(f), x) end @@ -38,11 +43,12 @@ using DiffRules: diffrule, hasdiffrule tape = Tape() x_, y_ = Leaf(tape, x), Leaf(tape, y) s = broadcast(f, x_, y_) - ∇s = ∇(s, ones(s.val)) + o = oneslike(s.val) + ∇s = ∇(s, o) ∇x = broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), - s.val, ones(s.val), x, y) + s.val, o, x, y) ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), - s.val, ones(s.val), x, y) + s.val, o, x, y) @test broadcast(f, x, y) == s.val @test ∇s[x_] ≈ ∇x @test ∇s[y_] ≈ ∇y @@ -51,11 +57,12 @@ using DiffRules: diffrule, hasdiffrule tape = Tape() x_, y_ = Leaf(tape, x), Leaf(tape, y) s = broadcast(f, x_, y_) - ∇s = ∇(s, ones(s.val)) + o = oneslike(s.val) + ∇s = ∇(s, o) ∇x = sum(broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), - s.val, ones(s.val), x, y)) + s.val, o, x, y)) ∇y = broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), - s.val, ones(s.val), x, y) + s.val, o, x, y) @test broadcast(f, x, y) == s.val @test ∇s[x_] ≈ ∇x @test ∇s[y_] ≈ ∇y @@ -64,11 +71,12 @@ using DiffRules: diffrule, hasdiffrule tape = Tape() x_, y_ = Leaf(tape, x), Leaf(tape, y) s = broadcast(f, x_, y_) - ∇s = ∇(s, ones(s.val)) + o = oneslike(s.val) + ∇s = ∇(s, o) ∇x = broadcast((z, z̄, x, y)->∇(f, Arg{1}, nothing, z, z̄, x, y), - s.val, ones(s.val), x, y) + s.val, o, x, y) ∇y = sum(broadcast((z, z̄, x, y)->∇(f, Arg{2}, nothing, z, z̄, x, y), - s.val, ones(s.val), x, y)) + s.val, o, x, y)) @test broadcast(f, x, y) == s.val @test ∇s[x_] ≈ ∇x @test ∇s[y_] ≈ ∇y @@ -86,8 +94,8 @@ using DiffRules: diffrule, hasdiffrule # argument. (∂f∂x == :NaN || ∂f∂y == :NaN) && continue domain = domain2(eval(f)) - isnull(domain) && error("Could not determine domain for $f.") - (x_lb, x_ub), (y_lb, y_ub) = get(domain) + domain === nothing && error("Could not determine domain for $f.") + (x_lb, x_ub), (y_lb, y_ub) = domain x_dist, y_dist = Uniform(x_lb, x_ub), Uniform(y_lb, y_ub) x, y = rand(rng, x_dist, 100), rand(rng, y_dist, 100) check_binary_broadcast(eval(f), x, y) @@ -100,7 +108,7 @@ using DiffRules: diffrule, hasdiffrule x, y, z = randn(rng, 5), randn(rng, 5), randn(rng, 5) x_, y_, z_ = Leaf.(Tape(), (x, y, z)) s_ = broadcast(f, x_, y_, z_) - ∇s = ∇(s_, ones(s_.val)) + ∇s = ∇(s_, oneslike(s_.val)) @test s_.val == broadcast(f, x, y, z) @test ∇s[x_] == getindex.(broadcast((x, y, z)->fmad(f, (x, y, z)), x, y, z), 1) @test ∇s[y_] == getindex.(broadcast((x, y, z)->fmad(f, (x, y, z)), x, y, z), 2) @@ -110,47 +118,47 @@ using DiffRules: diffrule, hasdiffrule let x, y, tape = 5.0, randn(rng, 5), Tape() x_, y_ = Leaf(tape, x), Leaf(tape, y) - z_ = x_ + y_ + z_ = x_ .+ y_ z2_ = broadcast(+, x_, y_) - @test z_.val == x + y - @test ∇(z_, ones(z_.val))[x_] == ∇(z2_, ones(z2_.val))[x_] - @test ∇(z_, ones(z_.val))[y_] == ∇(z2_, ones(z2_.val))[y_] + @test z_.val == x .+ y + @test ∇(z_, oneslike(z_.val))[x_] == ∇(z2_, oneslike(z2_.val))[x_] + @test ∇(z_, oneslike(z_.val))[y_] == ∇(z2_, oneslike(z2_.val))[y_] end let x, y, tape = randn(rng, 5), 5.0, Tape() x_, y_ = Leaf(tape, x), Leaf(tape, y) z_ = x_ * y_ z2_ = broadcast(*, x_, y_) - @test z_.val == x * y - @test ∇(z_, ones(z_.val))[x_] == ∇(z2_, ones(z2_.val))[x_] - @test ∇(z_, ones(z_.val))[y_] == ∇(z2_, ones(z2_.val))[y_] + @test z_.val == x .* y + @test ∇(z_, oneslike(z_.val))[x_] == ∇(z2_, oneslike(z2_.val))[x_] + @test ∇(z_, oneslike(z_.val))[y_] == ∇(z2_, oneslike(z2_.val))[y_] end let x, y, tape = randn(rng, 5), 5.0, Tape() x_, y_ = Leaf(tape, x), Leaf(tape, y) - z_ = x_ - y_ + z_ = x_ .- y_ z2_ = broadcast(-, x_, y_) - @test z_.val == x - y - @test ∇(z_, ones(z_.val))[x_] == ∇(z2_, ones(z2_.val))[x_] - @test ∇(z_, ones(z_.val))[y_] == ∇(z2_, ones(z2_.val))[y_] + @test z_.val == x .- y + @test ∇(z_, oneslike(z_.val))[x_] == ∇(z2_, oneslike(z2_.val))[x_] + @test ∇(z_, oneslike(z_.val))[y_] == ∇(z2_, oneslike(z2_.val))[y_] end let x, y, tape = randn(rng, 5), 5.0, Tape() x_, y_ = Leaf(tape, x), Leaf(tape, y) z_ = x_ / y_ z2_ = broadcast(/, x_, y_) - @test z_.val == x / y - @test ∇(z_, ones(z_.val))[x_] == ∇(z2_, ones(z2_.val))[x_] - @test ∇(z_, ones(z_.val))[y_] == ∇(z2_, ones(z2_.val))[y_] + @test z_.val == x ./ y + @test ∇(z_, oneslike(z_.val))[x_] == ∇(z2_, oneslike(z2_.val))[x_] + @test ∇(z_, oneslike(z_.val))[y_] == ∇(z2_, oneslike(z2_.val))[y_] end let x, y, tape = 5.0, randn(rng, 5), Tape() x_, y_ = Leaf(tape, x), Leaf(tape, y) z_ = x_ \ y_ z2_ = broadcast(\, x_, y_) - @test z_.val == x \ y - @test ∇(z_, ones(z_.val))[x_] == ∇(z2_, ones(z2_.val))[x_] - @test ∇(z_, ones(z_.val))[y_] == ∇(z2_, ones(z2_.val))[y_] + @test z_.val == x .\ y + @test ∇(z_, oneslike(z_.val))[x_] == ∇(z2_, oneslike(z2_.val))[x_] + @test ∇(z_, oneslike(z_.val))[y_] == ∇(z2_, oneslike(z2_.val))[y_] end # Check that dot notation works as expected for all unary function in Nabla for both @@ -160,7 +168,7 @@ using DiffRules: diffrule, hasdiffrule z_ = f.(x_) z2_ = broadcast(f, x_) @test z_.val == f.(x) - @test ∇(z_, ones(z_.val))[x_] == ∇(z2_, ones(z2_.val))[x_] + @test ∇(z_, oneslike(z_.val))[x_] == ∇(z2_, oneslike(z2_.val))[x_] end function check_unary_dot(f, x::∇Scalar) x_ = Leaf(Tape(), x) @@ -170,8 +178,8 @@ using DiffRules: diffrule, hasdiffrule end for (package, f) in Nabla.unary_sensitivities domain = domain1(eval(f)) - isnull(domain) && error("Could not determine domain for $f.") - x_dist = Uniform(get(domain)...) + domain === nothing && error("Could not determine domain for $f.") + x_dist = Uniform(domain...) check_unary_dot(eval(f), rand(rng, x_dist)) check_unary_dot(eval(f), rand(rng, x_dist, 100)) end @@ -183,8 +191,8 @@ using DiffRules: diffrule, hasdiffrule z_ = f.(x_, y_) z2_ = broadcast(f, x_, y_) @test z_.val == f.(x, y) - @test ∇(z_, ones(z_.val))[x_] == ∇(z2_, ones(z2_.val))[x_] - @test ∇(z_, ones(z_.val))[y_] == ∇(z2_, ones(z2_.val))[y_] + @test ∇(z_, oneslike(z_.val))[x_] == ∇(z2_, oneslike(z2_.val))[x_] + @test ∇(z_, oneslike(z_.val))[y_] == ∇(z2_, oneslike(z2_.val))[y_] end function check_binary_dot(f, x::∇Scalar, y::∇Scalar) x_, y_ = Leaf.(Tape(), (x, y)) @@ -194,7 +202,7 @@ using DiffRules: diffrule, hasdiffrule end for (package, f) in Nabla.binary_sensitivities # TODO: More care needs to be taken to test the following. - f in [:atan2, :mod, :rem] && continue + f in [:atan, :mod, :rem] && continue if hasdiffrule(package, f, 2) ∂f∂x, ∂f∂y = diffrule(package, f, :x, :y) else @@ -204,8 +212,8 @@ using DiffRules: diffrule, hasdiffrule # argument. (∂f∂x == :NaN || ∂f∂y == :NaN) && continue domain = domain2(eval(f)) - isnull(domain) && error("Could not determine domain for $f.") - (x_lb, x_ub), (y_lb, y_ub) = get(domain) + domain === nothing && error("Could not determine domain for $f.") + (x_lb, x_ub), (y_lb, y_ub) = domain x_distr = Uniform(x_lb, x_ub) y_distr = Uniform(y_lb, y_ub) x = rand(rng, x_distr, 100) diff --git a/test/sensitivities/functional/reduce.jl b/test/sensitivities/functional/reduce.jl index ad456cdd..3eec855c 100644 --- a/test/sensitivities/functional/reduce.jl +++ b/test/sensitivities/functional/reduce.jl @@ -13,9 +13,9 @@ # Generate some data and get the function to be mapped. f = eval(f) domain = domain1(f) - isnull(domain) && error("Could not determine domain for $f.") - lb, ub = get(domain) - x = rand(rng, N) * (ub - lb) + lb + domain === nothing && error("Could not determine domain for $f.") + lb, ub = domain + x = rand(rng, N) .* (ub - lb) .+ lb # Test +. x_ = Leaf(Tape(), x) @@ -53,7 +53,7 @@ x_ = Leaf(Tape(), x) s = functional(+, x_) @test s.val == functional(+, x) - @test ∇(s)[x_] ≈ ones(x) + @test ∇(s)[x_] ≈ ones(Float64, 100) end end @@ -66,9 +66,9 @@ # Generate some data and get the function to be mapped. f = eval(f) domain = domain1(f) - isnull(domain) && error("Could not determine domain for $f.") - lb, ub = get(domain) - x = rand(rng, N) * (ub - lb) + lb + domain === nothing && error("Could not determine domain for $f.") + lb, ub = domain + x = rand(rng, N) .* (ub .- lb) .+ lb # Test +. x_ = Leaf(Tape(), x) diff --git a/test/sensitivities/functional/reducedim.jl b/test/sensitivities/functional/reducedim.jl index 5d96d97a..bb73044e 100644 --- a/test/sensitivities/functional/reducedim.jl +++ b/test/sensitivities/functional/reducedim.jl @@ -10,19 +10,19 @@ x2_ = reshape([1.0, 2.0, 3.0, 4.0,], (2, 2)) x2 = Leaf(Tape(), x2_) s = mapreducedim(abs2, +, x2, 1) - @test ∇(s, ones(s.val))[x2] ≈ 2.0 * x2_ + @test ∇(s, ones(eltype(s.val), size(s.val)))[x2] ≈ 2.0 * x2_ # mapreducedim under `exp` should trigger the first conditional in the ∇ impl. x3_ = randn(rng, 5, 4) x3 = Leaf(Tape(), x3_) s = mapreducedim(exp, +, x3, 1) - @test ∇(s, ones(s.val))[x3] == exp.(x3_) + @test ∇(s, ones(eltype(s.val), size(s.val)))[x3] == exp.(x3_) # mapreducedim under an anonymous-function should trigger fmad. x4_ = randn(rng, 5, 4) x4 = Leaf(Tape(), x4_) s = mapreducedim(x->x*x, +, x4, 2) - @test ∇(s, ones(s.val))[x4] == 2x4_ + @test ∇(s, ones(eltype(s.val), size(s.val)))[x4] == 2x4_ # Check that `sum` works correctly with `Node`s. x_sum = Leaf(Tape(), randn(rng, 5, 4, 3)) diff --git a/test/sensitivities/indexing.jl b/test/sensitivities/indexing.jl index 94874efc..1a5de6e8 100644 --- a/test/sensitivities/indexing.jl +++ b/test/sensitivities/indexing.jl @@ -10,6 +10,6 @@ x = Leaf(Tape(), 10 * [1, 1, 1]) y = x[2:3] @test y.val == [10, 10] - @test ∇(y, ones(y.val))[x] == [0, 1, 1] + @test ∇(y, ones(eltype(y.val), size(y.val)))[x] == [0, 1, 1] end end diff --git a/test/sensitivities/linalg/blas.jl b/test/sensitivities/linalg/blas.jl index 1f5133d3..42af819c 100644 --- a/test/sensitivities/linalg/blas.jl +++ b/test/sensitivities/linalg/blas.jl @@ -1,20 +1,20 @@ +using LinearAlgebra.BLAS + @testset "BLAS" begin - import Base.BLAS.dot let rng = MersenneTwister(123456) for _ in 1:10 x, y, vx, vy = randn.(rng, [5, 5, 5, 5]) - @test check_errs(dot, dot(x, y), (x, y), (vx, vy)) + @test check_errs(BLAS.dot, BLAS.dot(x, y), (x, y), (vx, vy)) end end let rng = MersenneTwister(123456) for _ in 1:10 x, y, vx, vy = randn.(rng, [10, 6, 10, 6]) - _dot = (x, y)->dot(5, x, 2, y, 1) + _dot = (x, y)->BLAS.dot(5, x, 2, y, 1) @test check_errs(_dot, _dot(x, y), (x, y), (vx, vy)) end end - import Base.BLAS.nrm2 let rng = MersenneTwister(123456) for _ in 1:10 x, vx = randn(rng, 100), randn(rng, 100) @@ -29,7 +29,6 @@ end end - import Base.BLAS.asum let rng = MersenneTwister(123456) λ = x->asum(50, x, 2) for _ in 1:10 @@ -40,7 +39,6 @@ end # Test each of the four permutations of `gemm`. - import Base.BLAS.gemm let rng = MersenneTwister(123456), N = 100 for tA in ['T', 'N'], tB in ['T', 'N'] λ, γ = (α, A, B)->gemm(tA, tB, α, A, B), (A, B)->gemm(tA, tB, A, B) @@ -54,7 +52,6 @@ end # Test both permutations of `gemv`. - import Base.BLAS.gemv let rng = MersenneTwister(123456), N = 100 for tA in ['T', 'N'] λ, γ = (α, A, x)->gemv('T', α, A, x), (A, x)->gemv('T', A, x) @@ -69,9 +66,8 @@ end # Test all four permutations of `symm`. - import Base.BLAS.symm let rng = MersenneTwister(123456), N = 100 - lmask, umask = full(LowerTriangular(ones(N, N))), full(UpperTriangular(ones(N, N))) + lmask, umask = Matrix(LowerTriangular(ones(N, N))), Matrix(UpperTriangular(ones(N, N))) for side in ['L', 'R'], ul in ['L', 'U'] λ, γ = (α, A, B)->symm(side, ul, α, A, B), (A, B)->symm(side, ul, A, B) for _ in 1:10 @@ -83,7 +79,6 @@ end end - import Base.BLAS.symv let rng = MersenneTwister(123456), N = 100 for ul in ['L', 'U'] λ, γ = (α, A, x)->symv(ul, α, A, x), (A, x)->symv(ul, A, x) @@ -97,7 +92,6 @@ end end - import Base.BLAS.trmm let rng = MersenneTwister(123456), N = 10 for side in ['L', 'R'], ul in ['L', 'U'], tA in ['N', 'T'], dA in ['U', 'N'] λ = (α, A, B)->trmm(side, ul, tA, dA, α, A, B) @@ -109,7 +103,6 @@ end end - import Base.BLAS.trmv let rng = MersenneTwister(123456), N = 10 for ul in ['L', 'U'], tA in ['N', 'T'], dA in ['U', 'N'] λ = (A, b)->trmv(ul, tA, dA, A, b) @@ -121,7 +114,6 @@ end end - import Base.BLAS.trsm let rng = MersenneTwister(123456), N = 10 for side in ['L', 'R'], ul in ['L', 'U'], tA in ['N', 'T'], dA in ['U', 'N'] λ = (α, A, X)->trsm(side, ul, tA, dA, α, A, X) @@ -134,13 +126,12 @@ end end - import Base.BLAS.trsv let rng = MersenneTwister(123456), N = 10 for ul in ['L', 'U'], tA in ['N', 'T'], dA in ['U', 'N'] λ = (A, x)->trsv(ul, tA, dA, A, x) for _ in 1:10 A = randn(rng, N, N) + UniformScaling(1) - A = A.'A + A = A'A VA = randn(rng, N, N) x, vx = randn.(rng, [N, N]) @test check_errs(λ, λ(A, x), (A, x), (VA, vx)) diff --git a/test/sensitivities/linalg/diagonal.jl b/test/sensitivities/linalg/diagonal.jl index e1c24261..e756b8a6 100644 --- a/test/sensitivities/linalg/diagonal.jl +++ b/test/sensitivities/linalg/diagonal.jl @@ -1,25 +1,19 @@ @testset "Diagonal" begin let rng = MersenneTwister(123456), N = 10 - λ_2 = x->diagm(x, 2) - λ_m3 = x->diagm(x, -3) - λ_0 = x->diagm(x, 0) - λ_false = x->diagm(x, false) - λ_true = x->diagm(x, true) + λ_2 = x->diagm(2 => x) + λ_m3 = x->diagm(-3 => x) + λ_0 = x->diagm(0 => x) + λ_false = x->diagm(false => x) + λ_true = x->diagm(true => x) for _ in 1:10 - - # Test vector case. x, vx = randn.(rng, [N, N]) - @test check_errs(diagm, diagm(randn(rng, N)), x, vx) + @test check_errs(x->diagm(0 => x), diagm(0 => randn(rng, N)), x, vx) @test check_errs(λ_2, λ_2(randn(rng, N)), x, vx) @test check_errs(λ_m3, λ_m3(randn(rng, N)), x, vx) @test check_errs(λ_0, λ_0(randn(rng, N)), x, vx) @test check_errs(λ_false, λ_false(randn(rng, N)), x, vx) @test check_errs(λ_true, λ_true(randn(rng, N)), x, vx) - - # Test scalar case. - x, vx = randn(rng), randn(rng) - @test check_errs(diagm, diagm(randn(rng)), x, vx) end end let rng = MersenneTwister(123456), N = 10 diff --git a/test/sensitivities/linalg/factorization/cholesky.jl b/test/sensitivities/linalg/factorization/cholesky.jl index 2125198e..6422daec 100644 --- a/test/sensitivities/linalg/factorization/cholesky.jl +++ b/test/sensitivities/linalg/factorization/cholesky.jl @@ -5,7 +5,7 @@ A = randn(rng, N, N) r, d, B2, c = level2partition(A, 4, false) R, D, B3, C = level3partition(A, 4, 4, false) - @test all(r .== R.') + @test all(r .== R') @test all(d .== D) @test B2[1] == B3[1] @test all(c .== C) @@ -14,7 +14,7 @@ rᵀ, dᵀ, B2ᵀ, cᵀ = level2partition(transpose(A), 4, true) @test r == rᵀ @test d == dᵀ - @test B2.' == B2ᵀ + @test B2' == B2ᵀ @test c == cᵀ # Check that level3partition with 'U' is consistent with 'L'. @@ -28,7 +28,7 @@ import Nabla: chol_unblocked_rev, chol_blocked_rev let rng = MersenneTwister(123456), N = 10 - A, Ā = full.(LowerTriangular.(randn.(rng, [N, N], [N, N]))) + A, Ā = Matrix.(LowerTriangular.(randn.(rng, [N, N], [N, N]))) B, B̄ = transpose.([A, Ā]) @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 1, false) @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 3, false) @@ -45,7 +45,7 @@ let rng = MersenneTwister(123456), N = 10 for _ in 1:10 B, VB = randn.(rng, [N, N], [N, N]) - A, VA = B.'B + 1e-6I, VB.'VB + 1e-6I + A, VA = B'B + 1e-6I, VB'VB + 1e-6I Ū = UpperTriangular(randn(rng, N, N)) @test check_errs(chol, Ū, A, 1e-2 .* VA) end diff --git a/test/sensitivities/linalg/generic.jl b/test/sensitivities/linalg/generic.jl index 10fdb7c1..b6e26bed 100644 --- a/test/sensitivities/linalg/generic.jl +++ b/test/sensitivities/linalg/generic.jl @@ -3,7 +3,7 @@ let N = 5, rng = MersenneTwister(123456) # Generate random test quantities for specific types. - ∇Arrays = Union{Type{∇Array}, Type{∇ArrayOrScalar}} + ∇Arrays = Union{Type{∇Array}, Type{∇ArrayOrScalar}, Type{<:Transpose}, Type{<:Adjoint}} trandn(rng::AbstractRNG, ::∇Arrays) = randn(rng, N, N) trandn2(rng::AbstractRNG, ::∇Arrays) = randn(rng, N^2, N^2) trandn(rng::AbstractRNG, ::Type{∇Scalar}) = randn(rng) @@ -13,7 +13,7 @@ for _ in 1:5 # Test unary linalg sensitivities. for (f, T_In, T_Out, X̄, bounds) in Nabla.unary_linalg_optimisations - Z = trand(rng, T_In) .* (bounds[2] - bounds[1]) + bounds[1] + Z = trand(rng, T_In) .* (bounds[2] .- bounds[1]) .+ bounds[1] X = Z'Z + 1e-6 * one(Z) Ȳ, V = eval(f)(X), trandn(rng, T_In) @test check_errs(eval(f), Ȳ, X, 1e-1 .* V) diff --git a/test/sensitivities/linalg/strided.jl b/test/sensitivities/linalg/strided.jl index d73bb8c9..ec4350de 100644 --- a/test/sensitivities/linalg/strided.jl +++ b/test/sensitivities/linalg/strided.jl @@ -3,9 +3,9 @@ let rng = MersenneTwister(123456), N = 100 # Test strided matrix-matrix multiplication sensitivities. - for (f, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in Nabla.strided_matmul + for (TA, TB, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in Nabla.strided_matmul A, B, VA, VB = randn.(rng, [N, N, N, N], [N, N, N, N]) - @test check_errs(eval(f), eval(f)(A, B), (A, B), (VA, VB)) + @test check_errs(*, A * B, (A, B), (VA, VB)) end end end diff --git a/test/sensitivities/scalar.jl b/test/sensitivities/scalar.jl index 1aa6189e..214cdd45 100644 --- a/test/sensitivities/scalar.jl +++ b/test/sensitivities/scalar.jl @@ -5,15 +5,15 @@ using DiffRules: diffrule, hasdiffrule @test in_domain(cos, 10.) @test !in_domain(acos, 10.) @test !in_domain(asin, 10.) - @test get(domain1(sin)) == (minimum(points), maximum(points)) - @test get(domain1(log)) == (minimum(points[points .> 0]), maximum(points)) - @test get(domain1(acos)) == (minimum(points[points .> -1]), - maximum(points[points .< 1])) - @test get(domain2((+))) == ((minimum(points), maximum(points)), - (minimum(points), maximum(points))) - @test get(domain2((^))) == ((minimum(points[points .> 0]), maximum(points)), - (minimum(points), maximum(points))) - @test get(domain2(beta)) == ((minimum(points[points .> 0]), maximum(points)), + @test domain1(sin) == (minimum(points), maximum(points)) + @test domain1(log) == (minimum(points[points .> 0]), maximum(points)) + @test domain1(acos) == (minimum(points[points .> -1]), + maximum(points[points .< 1])) + @test domain2((+)) == ((minimum(points), maximum(points)), + (minimum(points), maximum(points))) + @test domain2((^)) == ((minimum(points[points .> 0]), maximum(points)), + (minimum(points), maximum(points))) + @test domain2(beta) == ((minimum(points[points .> 0]), maximum(points)), (minimum(points[points .> 0]), maximum(points))) end @@ -28,8 +28,8 @@ end unary_check(f, x) = check_errs(eval(f), ȳ, x, v) for (package, f) in Nabla.unary_sensitivities domain = domain1(eval(f)) - isnull(domain) && error("Could not determine domain for $f.") - lb, ub = get(domain) + domain === nothing && error("Could not determine domain for $f.") + lb, ub = domain randx = () -> rand(rng) * (ub - lb) + lb for _ in 1:10 @@ -52,8 +52,8 @@ end if ∂f∂x == :NaN && ∂f∂y != :NaN # Assume that the first argument is integer-valued. domain = domain1(y -> eval(f)(0, y)) - isnull(domain) && error("Could not determine domain for $f.") - lb, ub = get(domain) + domain === nothing && error("Could not determine domain for $f.") + lb, ub = domain randx = () -> rand(rng, 0:5) randy = () -> rand(rng) * (ub - lb) + lb @@ -64,8 +64,8 @@ end elseif ∂f∂x != :NaN && ∂f∂y == :NaN # Assume that the second argument is integer-valued. domain = domain1(x -> eval(f)(x, 0)) - isnull(domain) && error("Could not determine domain for $f.") - lb, ub = get(domain) + domain === nothing && error("Could not determine domain for $f.") + lb, ub = domain randx = () -> rand(rng) * (ub - lb) + lb randy = () -> rand(rng, 0:5) @@ -75,8 +75,8 @@ end end elseif ∂f∂x != :NaN && ∂f∂y != :NaN domain = domain2(eval(f)) - isnull(domain) && error("Could not determine domain for $f.") - (x_lb, x_ub), (y_lb, y_ub) = get(domain) + domain === nothing && error("Could not determine domain for $f.") + (x_lb, x_ub), (y_lb, y_ub) = domain randx = () -> rand(rng) * (x_ub - x_lb) + x_lb randy = () -> rand(rng) * (y_ub - y_lb) + y_lb @@ -91,7 +91,4 @@ end # Test whether the exponentiation amibiguity is resolved. @test ∇(x -> x^2)(1) == (2.0,) end - - # Miscellaneous test for addition to DualNumbers. - @test DualNumbers.epsilon(5.0) == 0.0 end diff --git a/test/sensitivity.jl b/test/sensitivity.jl index df73b4ea..1221b8ec 100644 --- a/test/sensitivity.jl +++ b/test/sensitivity.jl @@ -95,7 +95,7 @@ @test from_func == expected end let - from_func = branch_expr(:bar, [true], ((Leaf{Float64},),), (:x,), :((x...))) + from_func = branch_expr(:bar, [true], ((Leaf{Float64},),), (:x,), :((x...,))) tape = Expr(:call, :getfield, Expr(:ref, :x, :1), quot(:tape)) expected = Expr(:call, :Branch, :bar, :((x...,)), tape) @test from_func == expected From ce13fb51796871d6202de2393564c7c2c6b24150 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 19 Oct 2018 13:16:32 -0700 Subject: [PATCH 02/11] Add Travis and AppVeyor testing on Julia 1.0 --- .travis.yml | 3 ++- appveyor.yml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index b9d68c21..6db61a1f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,8 @@ os: - linux - osx julia: - - 0.7 # TODO: Add 1.0 once 0.7 works + - 0.7 + - 1.0 - nightly notifications: email: false diff --git a/appveyor.yml b/appveyor.yml index 18e3ab4d..4f179d3c 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,7 +1,7 @@ environment: matrix: - julia_version: 0.7 - #- julia_version: 1 + - julia_version: 1 - julia_version: nightly platform: From 9d46a5fb26558598360bc270c912f58414d22b2e Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 19 Oct 2018 13:37:14 -0700 Subject: [PATCH 03/11] Correctly handle broadcast calls containing RNG arguments RNGs don't have broadcast behaviors defined, so in order to use them in broadcasting expressions, e.g. `randn.(rng, x)`, we need to wrap the argument in `Ref` to ensure it broadcasts as a scalar. --- test/sensitivities/linalg/blas.jl | 26 +++++++++---------- test/sensitivities/linalg/diagonal.jl | 2 +- .../linalg/factorization/cholesky.jl | 4 +-- test/sensitivities/linalg/generic.jl | 4 +-- test/sensitivities/linalg/strided.jl | 2 +- test/sensitivities/linalg/symmetric.jl | 2 +- test/sensitivities/linalg/triangular.jl | 4 +-- 7 files changed, 22 insertions(+), 22 deletions(-) diff --git a/test/sensitivities/linalg/blas.jl b/test/sensitivities/linalg/blas.jl index 42af819c..a3cf5eef 100644 --- a/test/sensitivities/linalg/blas.jl +++ b/test/sensitivities/linalg/blas.jl @@ -3,13 +3,13 @@ using LinearAlgebra.BLAS @testset "BLAS" begin let rng = MersenneTwister(123456) for _ in 1:10 - x, y, vx, vy = randn.(rng, [5, 5, 5, 5]) + x, y, vx, vy = randn.(Ref(rng), [5, 5, 5, 5]) @test check_errs(BLAS.dot, BLAS.dot(x, y), (x, y), (vx, vy)) end end let rng = MersenneTwister(123456) for _ in 1:10 - x, y, vx, vy = randn.(rng, [10, 6, 10, 6]) + x, y, vx, vy = randn.(Ref(rng), [10, 6, 10, 6]) _dot = (x, y)->BLAS.dot(5, x, 2, y, 1) @test check_errs(_dot, _dot(x, y), (x, y), (vx, vy)) end @@ -44,7 +44,7 @@ using LinearAlgebra.BLAS λ, γ = (α, A, B)->gemm(tA, tB, α, A, B), (A, B)->gemm(tA, tB, A, B) for _ in 1:10 α, vα = randn.([rng, rng]) - A, B, VA, VB = randn.(rng, [N, N, N, N], [N, N, N, N]) + A, B, VA, VB = randn.(Ref(rng), [N, N, N, N], [N, N, N, N]) @test check_errs(λ, λ(α, A, B), (α, A, B), (vα, VA, VB)) @test check_errs(γ, γ(A, B), (A, B), (VA, VB)) end @@ -57,8 +57,8 @@ using LinearAlgebra.BLAS λ, γ = (α, A, x)->gemv('T', α, A, x), (A, x)->gemv('T', A, x) for _ in 1:10 α, vα = randn.([rng, rng]) - A, VA = randn.(rng, [N, N], [N, N]) - x, vx = randn.(rng, [N, N]) + A, VA = randn.(Ref(rng), [N, N], [N, N]) + x, vx = randn.(Ref(rng), [N, N]) @test check_errs(λ, λ(α, A, x), (α, A, x), (vα, VA, vx)) @test check_errs(γ, γ(A, x), (A, x), (VA, vx)) end @@ -72,7 +72,7 @@ using LinearAlgebra.BLAS λ, γ = (α, A, B)->symm(side, ul, α, A, B), (A, B)->symm(side, ul, A, B) for _ in 1:10 α, vα = randn.([rng, rng]) - A, B, VA, VB = randn.(rng, [N, N, N, N], [N, N, N, N]) + A, B, VA, VB = randn.(Ref(rng), [N, N, N, N], [N, N, N, N]) @test check_errs(λ, λ(α, A, B), (α, A, B), (vα, VA, VB)) @test check_errs(γ, γ(A, B), (A, B), (VA, VB)) end @@ -84,8 +84,8 @@ using LinearAlgebra.BLAS λ, γ = (α, A, x)->symv(ul, α, A, x), (A, x)->symv(ul, A, x) for _ in 1:10 α, vα = randn.([rng, rng]) - A, VA = randn.(rng, [N, N], [N, N]) - x, vx = randn.(rng, [N, N]) + A, VA = randn.(Ref(rng), [N, N], [N, N]) + x, vx = randn.(Ref(rng), [N, N]) @test check_errs(λ, λ(α, A, x), (α, A, x), (vα, VA, vx)) @test check_errs(γ, γ(A, x), (A, x), (VA, vx)) end @@ -97,7 +97,7 @@ using LinearAlgebra.BLAS λ = (α, A, B)->trmm(side, ul, tA, dA, α, A, B) for _ in 1:10 α, vα = randn.([rng, rng]) - A, B, VA, VB = randn.(rng, [N, N, N, N], [N, N, N, N]) + A, B, VA, VB = randn.(Ref(rng), [N, N, N, N], [N, N, N, N]) @test check_errs(λ, λ(α, A, B), (α, A, B), (vα, VA, VB)) end end @@ -107,8 +107,8 @@ using LinearAlgebra.BLAS for ul in ['L', 'U'], tA in ['N', 'T'], dA in ['U', 'N'] λ = (A, b)->trmv(ul, tA, dA, A, b) for _ in 1:10 - A, VA = randn.(rng, [N, N], [N, N]) - b, vb = randn.(rng, [N, N]) + A, VA = randn.(Ref(rng), [N, N], [N, N]) + b, vb = randn.(Ref(rng), [N, N]) @test check_errs(λ, λ(A, b), (A, b), (VA, vb)) end end @@ -119,7 +119,7 @@ using LinearAlgebra.BLAS λ = (α, A, X)->trsm(side, ul, tA, dA, α, A, X) for _ in 1:10 α, vα = randn.([rng, rng]) - A, X, VA, VX = randn.(rng, [N, N, N, N], [N, N, N, N]) + A, X, VA, VX = randn.(Ref(rng), [N, N, N, N], [N, N, N, N]) A = randn(rng, N, N) + UniformScaling(3) @test check_errs(λ, λ(α, A, X), (α, A, X), (vα, VA, VX)) end @@ -133,7 +133,7 @@ using LinearAlgebra.BLAS A = randn(rng, N, N) + UniformScaling(1) A = A'A VA = randn(rng, N, N) - x, vx = randn.(rng, [N, N]) + x, vx = randn.(Ref(rng), [N, N]) @test check_errs(λ, λ(A, x), (A, x), (VA, vx)) end end diff --git a/test/sensitivities/linalg/diagonal.jl b/test/sensitivities/linalg/diagonal.jl index e756b8a6..0e731a55 100644 --- a/test/sensitivities/linalg/diagonal.jl +++ b/test/sensitivities/linalg/diagonal.jl @@ -6,7 +6,7 @@ λ_false = x->diagm(false => x) λ_true = x->diagm(true => x) for _ in 1:10 - x, vx = randn.(rng, [N, N]) + x, vx = randn.(Ref(rng), [N, N]) @test check_errs(x->diagm(0 => x), diagm(0 => randn(rng, N)), x, vx) @test check_errs(λ_2, λ_2(randn(rng, N)), x, vx) diff --git a/test/sensitivities/linalg/factorization/cholesky.jl b/test/sensitivities/linalg/factorization/cholesky.jl index 6422daec..8c7fcd77 100644 --- a/test/sensitivities/linalg/factorization/cholesky.jl +++ b/test/sensitivities/linalg/factorization/cholesky.jl @@ -28,7 +28,7 @@ import Nabla: chol_unblocked_rev, chol_blocked_rev let rng = MersenneTwister(123456), N = 10 - A, Ā = Matrix.(LowerTriangular.(randn.(rng, [N, N], [N, N]))) + A, Ā = Matrix.(LowerTriangular.(randn.(Ref(rng), [N, N], [N, N]))) B, B̄ = transpose.([A, Ā]) @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 1, false) @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 3, false) @@ -44,7 +44,7 @@ # Check sensitivities for lower-triangular version. let rng = MersenneTwister(123456), N = 10 for _ in 1:10 - B, VB = randn.(rng, [N, N], [N, N]) + B, VB = randn.(Ref(rng), [N, N], [N, N]) A, VA = B'B + 1e-6I, VB'VB + 1e-6I Ū = UpperTriangular(randn(rng, N, N)) @test check_errs(chol, Ū, A, 1e-2 .* VA) diff --git a/test/sensitivities/linalg/generic.jl b/test/sensitivities/linalg/generic.jl index b6e26bed..fdc8e3eb 100644 --- a/test/sensitivities/linalg/generic.jl +++ b/test/sensitivities/linalg/generic.jl @@ -21,10 +21,10 @@ # Test binary linalg sensitivities. for (f, T_A, T_B, T_Y, Ā, B̄) in Nabla.binary_linalg_optimisations - A, B, VA, VB = trandn.(rng, (T_A, T_B, T_A, T_B)) + A, B, VA, VB = trandn.(Ref(rng), (T_A, T_B, T_A, T_B)) @test check_errs(eval(f), eval(f)(A, B), (A, B), (VA, VB)) end - A, B, VA, VB = trandn.(rng, (∇Array, ∇Array, ∇Array, ∇Array)) + A, B, VA, VB = trandn.(Ref(rng), (∇Array, ∇Array, ∇Array, ∇Array)) @test check_errs(kron, kron(A, B), (A, B), (VA, VB)) end diff --git a/test/sensitivities/linalg/strided.jl b/test/sensitivities/linalg/strided.jl index ec4350de..db4a19d4 100644 --- a/test/sensitivities/linalg/strided.jl +++ b/test/sensitivities/linalg/strided.jl @@ -4,7 +4,7 @@ # Test strided matrix-matrix multiplication sensitivities. for (TA, TB, tCA, tDA, CA, DA, tCB, tDB, CB, DB) in Nabla.strided_matmul - A, B, VA, VB = randn.(rng, [N, N, N, N], [N, N, N, N]) + A, B, VA, VB = randn.(Ref(rng), [N, N, N, N], [N, N, N, N]) @test check_errs(*, A * B, (A, B), (VA, VB)) end end diff --git a/test/sensitivities/linalg/symmetric.jl b/test/sensitivities/linalg/symmetric.jl index f884a245..20d6063a 100644 --- a/test/sensitivities/linalg/symmetric.jl +++ b/test/sensitivities/linalg/symmetric.jl @@ -1,7 +1,7 @@ @testset "Symmetric" begin let rng = MersenneTwister(123456), N = 100 for _ in 1:10 - X, V, Ȳ = randn.(rng, [N, N, N], [N, N, N]) + X, V, Ȳ = randn.(Ref(rng), [N, N, N], [N, N, N]) @test check_errs(Symmetric, Ȳ, X, V) end end diff --git a/test/sensitivities/linalg/triangular.jl b/test/sensitivities/linalg/triangular.jl index cfa25d85..642674b9 100644 --- a/test/sensitivities/linalg/triangular.jl +++ b/test/sensitivities/linalg/triangular.jl @@ -1,13 +1,13 @@ @testset "Triangular" begin let rng = MersenneTwister(123456), N = 10 for _ in 1:10 - A, VA, L = randn.(rng, [N, N, N], [N, N, N]) + A, VA, L = randn.(Ref(rng), [N, N, N], [N, N, N]) @test check_errs(LowerTriangular, LowerTriangular(L), A, VA) end end let rng = MersenneTwister(123456), N = 10 for _ in 1:10 - A, VA, U = randn.(rng, [N, N, N], [N, N, N]) + A, VA, U = randn.(Ref(rng), [N, N, N], [N, N, N]) @test check_errs(UpperTriangular, UpperTriangular(U), A, VA) end end From d98976ba5afb323474e2b4b719aab8c6f6a63061 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Mon, 22 Oct 2018 13:12:33 -0700 Subject: [PATCH 04/11] Fix Cholesky issue by defining our own chol function See discussion in issue 105. --- src/Nabla.jl | 2 +- src/sensitivities/linalg/factorization/cholesky.jl | 10 +++++++++- test/runtests.jl | 6 +++--- test/sensitivities/linalg/factorization/cholesky.jl | 5 +++-- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/Nabla.jl b/src/Nabla.jl index 597b3dc6..9deee2b8 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -46,6 +46,6 @@ module Nabla include("sensitivities/linalg/blas.jl") include("sensitivities/linalg/diagonal.jl") include("sensitivities/linalg/triangular.jl") - #include("sensitivities/linalg/factorization/cholesky.jl") + include("sensitivities/linalg/factorization/cholesky.jl") end # module Nabla diff --git a/src/sensitivities/linalg/factorization/cholesky.jl b/src/sensitivities/linalg/factorization/cholesky.jl index 1855a77b..f67d884d 100644 --- a/src/sensitivities/linalg/factorization/cholesky.jl +++ b/src/sensitivities/linalg/factorization/cholesky.jl @@ -1,5 +1,13 @@ import LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! -import LinearAlgebra.chol + +# NOTE: Cholesky factorizations pose a significant issue for us as of Julia 0.7, since +# the simple function chol, which produced the U in the factorization U'U, has been +# deprecated in favor of accessing the .U field of a Cholesky object produced by cholesky. +# This does not lend itself well to tracing. To get around this, we'll define our own +# chol that users of Nabla can use to obtain the Julia 0.6 behavior. +# See issue #105 for discussion. +export chol +chol(X::AbstractMatrix{<:Real}) = cholesky(X).U #= See [1] for implementation details: pages 5-9 in particular. The derivations presented in diff --git a/test/runtests.jl b/test/runtests.jl index 6448f31b..8f7382c6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,8 +34,8 @@ end include("sensitivities/linalg/strided.jl") include("sensitivities/linalg/blas.jl") - #@testset "Factorisations" begin - # include("sensitivities/linalg/factorization/cholesky.jl") - #end + @testset "Factorisations" begin + include("sensitivities/linalg/factorization/cholesky.jl") + end end end diff --git a/test/sensitivities/linalg/factorization/cholesky.jl b/test/sensitivities/linalg/factorization/cholesky.jl index 8c7fcd77..5759d217 100644 --- a/test/sensitivities/linalg/factorization/cholesky.jl +++ b/test/sensitivities/linalg/factorization/cholesky.jl @@ -26,10 +26,11 @@ @test transpose(C) == Cᵀ end - import Nabla: chol_unblocked_rev, chol_blocked_rev + import Nabla: chol, chol_unblocked_rev, chol_blocked_rev let rng = MersenneTwister(123456), N = 10 A, Ā = Matrix.(LowerTriangular.(randn.(Ref(rng), [N, N], [N, N]))) - B, B̄ = transpose.([A, Ā]) + # NOTE: BLAS gets angry if we don't materialize the Transpose objects first + B, B̄ = Matrix.(transpose.([A, Ā])) @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 1, false) @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 3, false) @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 5, false) From 301cc70f5f5d978db93a1b5dc70af58858bd55ca Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Mon, 22 Oct 2018 18:19:00 -0700 Subject: [PATCH 05/11] Add keyword specification functionality, implement mapreduce In 0.7, mapreducedim became mapreduce with a dims keyword argument. This causes a lot of issues for us, since our macros and other machinery aren't set up to allow and use keyword arguments... until now! --- src/Nabla.jl | 2 +- src/code_transformation/util.jl | 25 ++++- src/core.jl | 4 +- src/sensitivities/functional/reduce.jl | 6 +- src/sensitivities/functional/reducedim.jl | 30 +++--- src/sensitivity.jl | 113 +++++++++++---------- test/runtests.jl | 2 +- test/sensitivities/functional/reducedim.jl | 12 +-- 8 files changed, 112 insertions(+), 82 deletions(-) diff --git a/src/Nabla.jl b/src/Nabla.jl index 9deee2b8..df161641 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -37,7 +37,7 @@ module Nabla # Sensitivities for functionals. include("sensitivities/functional/functional.jl") include("sensitivities/functional/reduce.jl") - #include("sensitivities/functional/reducedim.jl") + include("sensitivities/functional/reducedim.jl") # Linear algebra optimisations. include("sensitivities/linalg/generic.jl") diff --git a/src/code_transformation/util.jl b/src/code_transformation/util.jl index c5f00d19..ff0a7985 100644 --- a/src/code_transformation/util.jl +++ b/src/code_transformation/util.jl @@ -1,12 +1,31 @@ -function get_union_call(foo::Symbol, type_tuple::Expr) +function add_kwargs!(ex::Expr; kwargs...) + ex.head === :call || throw(ArgumentError("expression is not a function call")) + isempty(ex.args) && throw(ArgumentError("expression body is empty")) + if !isempty(kwargs) + params = Expr(:parameters) + for (name, value) in kwargs + push!(params.args, Expr(:kw, name, value)) + end + # Parameters need to come after the function name and before positional arguments + if length(ex.args) == 1 + push!(ex.args, params) + else + insert!(ex.args, 2, params) + end + end + ex +end +function get_union_call(foo::Symbol, type_tuple::Expr; kwargs...) # Get types from tuple and create a collection of symbols for use in the call. types = get_types(get_body(type_tuple)) arg_names = [Symbol("x$j") for j in 1:length(types)] # Generate the call. - typed_args = [:($name::$typ) for (name, typ) in zip(arg_names, unionise_type.(types))] - return replace_body(type_tuple, Expr(:call, foo, typed_args...)), arg_names + typed_args = map((name, typ)->:($name::$(unionise_type(typ))), arg_names, types) + call = add_kwargs!(Expr(:call, foo, typed_args...); kwargs...) + + return replace_body(type_tuple, call), arg_names end """ diff --git a/src/core.jl b/src/core.jl index 944b2ba4..59f32d69 100644 --- a/src/core.jl +++ b/src/core.jl @@ -74,9 +74,9 @@ struct Branch{T} <: Node{T} tape::Tape pos::Int end -function Branch(f, args::Tuple, tape::Tape) +function Branch(f, args::Tuple, tape::Tape; kwargs...) unboxed = unbox.(args) - branch = Branch(f(unboxed...), f, args, tape, length(tape) + 1) + branch = Branch(f(unboxed...; kwargs...), f, args, tape, length(tape) + 1) push!(tape, branch) return branch end diff --git a/src/sensitivities/functional/reduce.jl b/src/sensitivities/functional/reduce.jl index 41bd630d..ed02d4e0 100644 --- a/src/sensitivities/functional/reduce.jl +++ b/src/sensitivities/functional/reduce.jl @@ -1,11 +1,11 @@ # Intercepts for `mapreduce`, `mapfoldl` and `mapfoldr` under `op` `+`. const plustype = Union{typeof(+), typeof(Base.add_sum)} const type_tuple = :(Tuple{Any, $plustype, ∇ArrayOrScalar}) -for f in (:mapreduce, :mapfoldl, :mapfoldr) +for f in (:mapfoldl, :mapfoldr) @eval begin import Base: $f - @explicit_intercepts $f $type_tuple [false, false, true] - ∇(::typeof($f), ::Type{Arg{3}}, p, y, ȳ, f, ::$plustype, A::∇ArrayOrScalar) = + @explicit_intercepts $f $type_tuple [false, false, true] (init=nothing,) + ∇(::typeof($f), ::Type{Arg{3}}, p, y, ȳ, f, ::$plustype, A::∇ArrayOrScalar; init=nothing) = hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, Real}) ? broadcast(An->ȳ * ∇(f, Arg{1}, An), A) : broadcast(An->ȳ * fmad(f, (An,), Val{1}), A) diff --git a/src/sensitivities/functional/reducedim.jl b/src/sensitivities/functional/reducedim.jl index 6ee81b2e..43018935 100644 --- a/src/sensitivities/functional/reducedim.jl +++ b/src/sensitivities/functional/reducedim.jl @@ -1,20 +1,24 @@ -import Base: mapreducedim, sum +import Base: mapreduce, sum -accept_wo_default = :(Tuple{Function, typeof(+), AbstractArray{<:∇Scalar}, Any}) -accept_w_default = :(Tuple{Function, typeof(+), AbstractArray{<:∇Scalar}, Any, ∇Scalar}) -@eval @explicit_intercepts mapreducedim $accept_wo_default [false, false, true, false] -@eval @explicit_intercepts mapreducedim $accept_w_default [false, false, true, false, true] - -∇(::typeof(mapreducedim), +@explicit_intercepts( + mapreduce, + Tuple{Function, Union{typeof(+), typeof(Base.add_sum)}, AbstractArray{<:∇Scalar}}, + [false, false, true], + (dims=:, init=nothing), +) +function ∇( + ::typeof(mapreduce), ::Type{Arg{3}}, p, y, ȳ, f, - ::typeof(+), - A::AbstractArray{<:∇Scalar}, - region, - v0=nothing, -) = hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, ∇Scalar}) ? + ::Union{typeof(+), typeof(Base.add_sum)}, + A::AbstractArray{<:∇Scalar}; + dims=:, + init=nothing, +) + hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, ∇Scalar}) ? broadcast((An, ȳn)->ȳn * ∇(f, Arg{1}, An), A, ȳ) : broadcast((An, ȳn)->ȳn * fmad(f, (An,), Val{1}), A, ȳ) +end # Make `sum` work. It currently fails as the type specification is too restrictive. -sum(n::Node{<:AbstractArray}, region) = mapreducedim(identity, +, n, region) +sum(n::Node{<:AbstractArray}; dims=:) = mapreduce(identity, Base.add_sum, n, dims=dims) diff --git a/src/sensitivity.jl b/src/sensitivity.jl index 9392501b..aeb61adb 100644 --- a/src/sensitivity.jl +++ b/src/sensitivity.jl @@ -1,15 +1,17 @@ -import Base.Meta.quot +using Base.Meta export Arg, add_∇, add_∇!, ∇, preprocess, @explicit_intercepts, @union_intercepts """ - @union_intercepts f type_tuple invoke_type_tuple + @union_intercepts f type_tuple invoke_type_tuple [kwargs] Interception strategy based on adding a method to `f` which accepts the union of each of the types specified by `type_tuple`. If none of the arguments are `Node`s then the method -of `f` specified by `invoke_type_tuple` is invoked. +of `f` specified by `invoke_type_tuple` is invoked. If applicable, keyword arguments +should be provided as a `NamedTuple` and be added to the generated function's signature. """ -macro union_intercepts(f::Symbol, type_tuple::Expr, invoke_type_tuple::Expr) - return esc(union_intercepts(f, type_tuple, invoke_type_tuple)) +macro union_intercepts(f::Symbol, type_tuple::Expr, invoke_type_tuple::Expr, kwargs::Expr=:(())) + kwargs.head === :tuple || throw(ArgumentError("malformed keyword argument specification")) + return esc(union_intercepts(f, type_tuple, invoke_type_tuple; eval(kwargs)...)) end """ @@ -17,28 +19,30 @@ end The work-horse for `@union_intercepts`. """ -function union_intercepts(f::Symbol, type_tuple::Expr, invoke_type_tuple::Expr) - call, arg_names = get_union_call(f, type_tuple) - body = get_body(f, type_tuple, arg_names, invoke_type_tuple) +function union_intercepts(f::Symbol, type_tuple::Expr, invoke_type_tuple::Expr; kwargs...) + call, arg_names = get_union_call(f, type_tuple; kwargs...) + body = get_body(f, type_tuple, arg_names, invoke_type_tuple; kwargs...) return Expr(:macrocall, Symbol("@generated"), nothing, Expr(:function, call, body)) end """ - @explicit_intercepts(f::Symbol, type_tuple::Expr, is_node::Expr) + @explicit_intercepts(f::Symbol, type_tuple::Expr, is_node::Expr[, kwargs::Expr]) @explicit_intercepts(f::Symbol, type_tuple::Expr) Create a collection of methods which intecept the function calls to `f` in which at least one argument is a `Node`. Types of arguments are specified by the type tuple expression in `type_tuple`. If there are arguments which are not differentiable, they can be specified by providing a boolean vector `is_node` which indicates those arguments that are -differentiable with `true` values and those which are not as `false`. +differentiable with `true` values and those which are not as `false`. Keyword arguments +to add to the function signature can be specified in `kwargs`, which must be a `NamedTuple`. """ -macro explicit_intercepts(f::SymOrExpr, type_tuple::Expr, is_node::Expr) - return esc(explicit_intercepts(f, type_tuple, eval(is_node))) -end -macro explicit_intercepts(f::SymOrExpr, type_tuple::Expr) - is_node = [true for _ in get_types(get_body(type_tuple))] - return esc(explicit_intercepts(f, type_tuple, is_node)) +macro explicit_intercepts( + f::SymOrExpr, + type_tuple::Expr, + is_node::Expr=:([true for _ in $(get_types(get_body(type_tuple)))]), + kwargs::Expr=:(()), +) + return esc(explicit_intercepts(f, type_tuple, eval(is_node); eval(kwargs)...)) end """ @@ -48,10 +52,10 @@ Return a `:block` expression which evaluates to declare all of the combinations that could be required to catch if a `Node` is ever passed to the function specified in `expr`. """ -function explicit_intercepts(f::SymOrExpr, types::Expr, is_node::Vector{Bool}) +function explicit_intercepts(f::SymOrExpr, types::Expr, is_node::Vector{Bool}; kwargs...) function explicit_intercepts_(states::Vector{Bool}) if length(states) == length(is_node) - return any(states .== true) ? boxed_method(f, types, states) : [] + return any(states) ? boxed_method(f, types, states; kwargs...) : [] else return vcat( explicit_intercepts_(vcat(states, false)), @@ -67,23 +71,25 @@ end f::SymOrExpr, type_tuple::Expr, is_node::Vector{Bool}, - arg_names::Vector{Symbol}, + arg_names::Vector{Symbol}; + kwargs... ) Construct a method of the Function `f`, whose argument's types are specified by `type_tuple`. Arguments which are potentially `Node`s should be indicated by `true` values -in `is_node`. +in `is_node`. Any provided keyword arguments will be added to the method. """ function boxed_method( f::SymOrExpr, type_tuple::Expr, is_node::Vector{Bool}, - arg_names::Vector{Symbol}, + arg_names::Vector{Symbol}; + kwargs... ) # Get the argument types and create the function call. types = get_types(get_body(type_tuple)) noded_types = [node ? :(Node{<:$tp}) : tp for (node, tp) in zip(is_node, types)] - call = replace_body(type_tuple, get_sig(f, arg_names, noded_types)) + call = replace_body(type_tuple, get_sig(f, arg_names, noded_types; kwargs...)) # Construct body of call. tuple_expr = Expr(:tuple, arg_names...) @@ -93,64 +99,62 @@ function boxed_method( # Combine call signature with the body to create a new function. return Expr(Symbol("="), call, body) end -boxed_method(f, t, n) = boxed_method(f, t, n, [gensym() for _ in n]) +boxed_method(f, t, n; kwargs...) = boxed_method(f, t, n, [gensym() for _ in n]; kwargs...) """ - get_sig(f::SymOrExpr, arg_names::Vector{Symbol}, types::Vector) + get_sig(f::SymOrExpr, arg_names::Vector{Symbol}, types::Vector; kwargs...) Generate a function signature for `f` in which the arguments, whose names are `arg_names`, specified by the `true` entires of `is_node` have type `Node`. The other arguments have -types specified by `types`. +types specified by `types`. If keyword arguments are provided, they will be added to +the method signature. """ -get_sig(f::SymOrExpr, arg_names::Vector{Symbol}, types::Vector) = - Expr(:call, f, [Expr(Symbol("::"), nm, tp) for (nm, tp) in zip(arg_names, types)]...) +get_sig(f::SymOrExpr, arg_names::Vector{Symbol}, types::Vector; kwargs...) = + add_kwargs!(Expr(:call, f, map((nm, tp)->:($nm::$tp), arg_names, types)...); kwargs...) """ - get_body(foo::Symbol, type_tuple::Expr, arg_names::Vector, invoke_type_tuple::Expr) + get_body(foo::Symbol, type_tuple::Expr, arg_names::Vector, invoke_type_tuple::Expr; kwargs...) -Get the body of the @generated function which is used to intercept the invokations +Get the body of the @generated function which is used to intercept the invocations specified by type_tuple. """ function get_body( foo::Symbol, type_tuple::Expr, arg_names::Vector, - invoke_type_tuple::Expr + invoke_type_tuple::Expr; + kwargs... ) + quot_arg_names = map(quot, arg_names) + dots = Symbol("...") + arg_tuple = any(isa_vararg.(get_types(get_body(type_tuple)))) ? - Expr(:tuple, arg_names[1:end-1]..., Expr(Symbol("..."), arg_names[end])) : + Expr(:tuple, arg_names[1:end-1]..., Expr(dots, arg_names[end])) : Expr(:tuple, arg_names...) sym_arg_tuple = any(isa_vararg.(get_types(get_body(type_tuple)))) ? - Expr(:tuple, quot.(arg_names[1:end-1])..., - quot(Expr(Symbol("..."), arg_names[end]))) : - Expr(:tuple, quot.(arg_names)...) - quot_arg_names = [quot(arg_name) for arg_name in arg_names] + Expr(:tuple, quot_arg_names[1:end-1]..., quot(Expr(dots, arg_names[end]))) : + Expr(:tuple, quot_arg_names...) - dots = Symbol("...") args_dotted = Expr(dots, Expr(:vect, arg_names...)) args_dotted_quot = Expr(dots, Expr(:vect, quot_arg_names...)) + + branch = :(Nabla.branch_expr($(quot(foo)), is_node, x, x_syms, $(quot(arg_tuple)))) + add_kwargs!(branch; kwargs...) + + invoke = :(Nabla.invoke_expr($(quot(foo)), $(quot(invoke_type_tuple)), x_dots)) + add_kwargs!(invoke; kwargs...) + return Expr(:block, Expr(Symbol("="), :x, Expr(:tuple, args_dotted)), Expr(Symbol("="), :x_syms, Expr(:tuple, args_dotted_quot)), Expr(Symbol("="), :x_dots, sym_arg_tuple), Expr(Symbol("="), :is_node, :([any((<:).(xj, Node)) for xj in x])), - Expr(:return, - Expr(:if, Expr(:call, :any, :is_node), - :(Nabla.branch_expr( - $(quot(foo)), - is_node, - x, - x_syms, - $(quot(arg_tuple)), - )), - :(Nabla.invoke_expr($(quot(foo)), $(quot(invoke_type_tuple)), x_dots)) - ) - ) + Expr(:return, Expr(:if, Expr(:call, :any, :is_node), branch, invoke)) ) end """ - branch_expr(foo::Symbol, is_node::Vector{Bool}, x::Tuple, arg_tuple::Expr) + branch_expr(foo::Symbol, is_node::Vector{Bool}, x::Tuple, arg_tuple::Expr; kwargs...) Generate an expression to call Branch. """ @@ -159,13 +163,16 @@ function branch_expr( is_node::Vector{Bool}, x::Tuple, syms::NTuple{<:Any, Symbol}, - arg_tuple::Expr, + arg_tuple::Expr; + kwargs... ) - return Expr(:call, :Branch, foo, arg_tuple, tape_expr(x, syms, is_node)) + call = Expr(:call, :Branch, foo, arg_tuple, tape_expr(x, syms, is_node)) + add_kwargs!(call; kwargs...) + return call end -invoke_expr(f::Symbol, invoke_tuple::Expr, arg_syms) = - Expr(:call, :invoke, f, invoke_tuple, arg_syms...) +invoke_expr(f::Symbol, invoke_tuple::Expr, arg_syms; kwargs...) = + Expr(:call, :invoke, f, invoke_tuple, arg_syms...; kwargs...) """ tape_expr(x::Tuple, syms::NTuple{N, Symbol} where N, is_node::Vector{Bool}) diff --git a/test/runtests.jl b/test/runtests.jl index 8f7382c6..dd1b5469 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,7 +21,7 @@ end @testset "Functional" begin include("sensitivities/functional/functional.jl") include("sensitivities/functional/reduce.jl") - #include("sensitivities/functional/reducedim.jl") + include("sensitivities/functional/reducedim.jl") end # Test sensitivities for linear algebra optimisations. diff --git a/test/sensitivities/functional/reducedim.jl b/test/sensitivities/functional/reducedim.jl index bb73044e..a89d7d84 100644 --- a/test/sensitivities/functional/reducedim.jl +++ b/test/sensitivities/functional/reducedim.jl @@ -2,30 +2,30 @@ let rng = MersenneTwister(123456) # mapreducedim on a single-dimensional array should be consistent with mapreduce. x = Leaf(Tape(), [1.0, 2.0, 3.0, 4.0, 5.0]) - s = 5.0 * mapreducedim(abs2, +, x, 1)[1] + s = 5.0 * mapreduce(abs2, +, x, dims=1) @test ∇(s)[x] ≈ 5.0 * [2.0, 4.0, 6.0, 8.0, 10.0] - # mapreducedim on a two-dimensional array when reduced over a single dimension + # mapreduce on a two-dimensional array when reduced over a single dimension # should give different results to mapreduce over the same array. x2_ = reshape([1.0, 2.0, 3.0, 4.0,], (2, 2)) x2 = Leaf(Tape(), x2_) - s = mapreducedim(abs2, +, x2, 1) + s = mapreduce(abs2, +, x2, dims=1) @test ∇(s, ones(eltype(s.val), size(s.val)))[x2] ≈ 2.0 * x2_ # mapreducedim under `exp` should trigger the first conditional in the ∇ impl. x3_ = randn(rng, 5, 4) x3 = Leaf(Tape(), x3_) - s = mapreducedim(exp, +, x3, 1) + s = mapreduce(exp, +, x3, dims=1) @test ∇(s, ones(eltype(s.val), size(s.val)))[x3] == exp.(x3_) # mapreducedim under an anonymous-function should trigger fmad. x4_ = randn(rng, 5, 4) x4 = Leaf(Tape(), x4_) - s = mapreducedim(x->x*x, +, x4, 2) + s = mapreduce(x->x*x, +, x4, dims=2) @test ∇(s, ones(eltype(s.val), size(s.val)))[x4] == 2x4_ # Check that `sum` works correctly with `Node`s. x_sum = Leaf(Tape(), randn(rng, 5, 4, 3)) - @test sum(x_sum, [2, 3]).val == mapreducedim(identity, +, x_sum, [2, 3]).val + @test sum(x_sum, dims=[2, 3]).val == mapreduce(identity, +, x_sum, dims=[2, 3]).val end end From 381f158d8877088e39e8cba119715ca791f7ae17 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Mon, 22 Oct 2018 22:54:18 -0700 Subject: [PATCH 06/11] Add Project files and deploy docs on 1.0 --- .travis.yml | 11 +++++++++-- Project.toml | 22 ++++++++++++++++++++++ docs/Project.toml | 5 +++++ docs/make.jl | 2 +- 4 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 Project.toml create mode 100644 docs/Project.toml diff --git a/.travis.yml b/.travis.yml index 6db61a1f..6d453f5e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,5 +20,12 @@ matrix: after_success: # push coverage results to Codecov - julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(Codecov.process_folder())' - # build documentation - - julia -e 'using Pkg; Pkg.add("Documenter"); include(joinpath("docs", "make.jl"))' +jobs: + include: + - state: "Documentation" + julia: 1.0 + os: linux + script: + - julia --project=docs/ -e 'using Pkg; Pkg.instantiate(); Pkg.develop(PackageSpec(path=pwd()))' + - julia --project=docs/ docs/make.jl + after_success: skip diff --git a/Project.toml b/Project.toml new file mode 100644 index 00000000..7d5d5bfa --- /dev/null +++ b/Project.toml @@ -0,0 +1,22 @@ +name = "Nabla" +uuid = "49c96f43-aa6d-5a04-a506-44c7070ebe78" +version = "0.2.0" + +[compat] +julia = "0.7, 1.0" + +[deps] +DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" +DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" +FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[extras] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["BenchmarkTools", "Distributions", "Random", "Test"] diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 00000000..53bc6f84 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,5 @@ +[deps] +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" + +[compat] +Documenter = "~0.19" diff --git a/docs/make.jl b/docs/make.jl index 7d6110dc..b6bb4457 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -17,7 +17,7 @@ makedocs( deploydocs( repo = "github.com/invenia/Nabla.jl.git", - julia = "0.6", + julia = "1.0", target = "build", deps = nothing, make = nothing, From a1a00827bb68f2e9d9f96962a98a30896b9b73b4 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Tue, 23 Oct 2018 09:49:16 -0700 Subject: [PATCH 07/11] Address review comments, part 1 Summary of changes: * Move `oneslike` and `zeroslike` inside of the package, since they're useful more generally, not just in testing. * Add spaces around `->` in the anonymous functions I've introduced. * Reuse the `plustype` variable in the `mapreduce` signature. --- src/Nabla.jl | 9 +++++ src/code_transformation/util.jl | 2 +- src/sensitivities/functional/functional.jl | 4 +-- src/sensitivities/functional/reducedim.jl | 38 +++++++++++---------- src/sensitivities/linalg/blas.jl | 16 ++++----- src/sensitivities/linalg/diagonal.jl | 6 ++-- src/sensitivities/linalg/generic.jl | 10 +++--- test/runtests.jl | 2 ++ test/sensitivities/functional/functional.jl | 5 --- test/sensitivities/functional/reducedim.jl | 6 ++-- test/sensitivities/indexing.jl | 2 +- 11 files changed, 55 insertions(+), 45 deletions(-) diff --git a/src/Nabla.jl b/src/Nabla.jl index df161641..5835010b 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -14,6 +14,15 @@ module Nabla const ∇ArrayOrScalar = Union{AbstractArray{<:∇Scalar}, ∇Scalar} const SymOrExpr = Union{Symbol, Expr} + # ones/zeros(::AbstractArray) is deprecated in 0.7 and removed in 1.0, but it's a + # pretty useful method, so we'll define our own for internal use + for f in (:ones, :zeros) + like = Symbol(f, "like") + @eval begin + $(like)(a::AbstractArray) = $(f)(eltype(a), size(a)) + $(like)(n::Integer) = $(f)(n) + end + end # Meta-programming utilities specific to Nabla. include("code_transformation/util.jl") diff --git a/src/code_transformation/util.jl b/src/code_transformation/util.jl index ff0a7985..79dbeba6 100644 --- a/src/code_transformation/util.jl +++ b/src/code_transformation/util.jl @@ -22,7 +22,7 @@ function get_union_call(foo::Symbol, type_tuple::Expr; kwargs...) arg_names = [Symbol("x$j") for j in 1:length(types)] # Generate the call. - typed_args = map((name, typ)->:($name::$(unionise_type(typ))), arg_names, types) + typed_args = map((name, typ) -> :($name::$(unionise_type(typ))), arg_names, types) call = add_kwargs!(Expr(:call, foo, typed_args...); kwargs...) return replace_body(type_tuple, call), arg_names diff --git a/src/sensitivities/functional/functional.jl b/src/sensitivities/functional/functional.jl index 8d226ad4..18b10e5d 100644 --- a/src/sensitivities/functional/functional.jl +++ b/src/sensitivities/functional/functional.jl @@ -35,9 +35,9 @@ Base.BroadcastStyle(::NodeStyle{S}, B::BroadcastStyle) where {S} = Broadcast.broadcast_axes(x::Node) = broadcast_axes(x.val) Broadcast.broadcastable(x::Node) = x -function Base.copy(bc::Broadcasted{NodeStyle{S}}) where S +function Base.copy(bc::Broadcasted{<:NodeStyle}) args = bc.args - tape = getfield(args[findfirst(x->x isa Node, args)], :tape) + tape = getfield(args[findfirst(x -> x isa Node, args)], :tape) return Branch(broadcast, (bc.f, args...), tape) end diff --git a/src/sensitivities/functional/reducedim.jl b/src/sensitivities/functional/reducedim.jl index 43018935..2e9ca67b 100644 --- a/src/sensitivities/functional/reducedim.jl +++ b/src/sensitivities/functional/reducedim.jl @@ -1,23 +1,25 @@ import Base: mapreduce, sum -@explicit_intercepts( - mapreduce, - Tuple{Function, Union{typeof(+), typeof(Base.add_sum)}, AbstractArray{<:∇Scalar}}, - [false, false, true], - (dims=:, init=nothing), -) -function ∇( - ::typeof(mapreduce), - ::Type{Arg{3}}, - p, y, ȳ, f, - ::Union{typeof(+), typeof(Base.add_sum)}, - A::AbstractArray{<:∇Scalar}; - dims=:, - init=nothing, -) - hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, ∇Scalar}) ? - broadcast((An, ȳn)->ȳn * ∇(f, Arg{1}, An), A, ȳ) : - broadcast((An, ȳn)->ȳn * fmad(f, (An,), Val{1}), A, ȳ) +@eval begin + @explicit_intercepts( + mapreduce, + Tuple{Function, $plustype, AbstractArray{<:∇Scalar}}, + [false, false, true], + (dims=:, init=nothing), + ) + function ∇( + ::typeof(mapreduce), + ::Type{Arg{3}}, + p, y, ȳ, f, + ::$plustype, + A::AbstractArray{<:∇Scalar}; + dims=:, + init=nothing, + ) + hasmethod(∇, Tuple{typeof(f), Type{Arg{1}}, ∇Scalar}) ? + broadcast((An, ȳn)->ȳn * ∇(f, Arg{1}, An), A, ȳ) : + broadcast((An, ȳn)->ȳn * fmad(f, (An,), Val{1}), A, ȳ) + end end # Make `sum` work. It currently fails as the type specification is too restrictive. diff --git a/src/sensitivities/linalg/blas.jl b/src/sensitivities/linalg/blas.jl index 42b2ed94..6b9319e6 100644 --- a/src/sensitivities/linalg/blas.jl +++ b/src/sensitivities/linalg/blas.jl @@ -17,13 +17,13 @@ const SA = StridedArray [false, true, false, true, false], ) ∇(::typeof(dot), ::Type{Arg{2}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - scal!(n, z̄, blascopy!(n, y, iy, zeros(eltype(x), size(x)), ix), ix) + scal!(n, z̄, blascopy!(n, y, iy, zeroslike(x), ix), ix) ∇(::typeof(dot), ::Type{Arg{4}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - scal!(n, z̄, blascopy!(n, x, ix, zeros(eltype(y), size(y)), iy), iy) + scal!(n, z̄, blascopy!(n, x, ix, zeroslike(y), iy), iy) ∇(x̄, ::typeof(dot), ::Type{Arg{2}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - (x̄ .= x̄ .+ scal!(n, z̄, blascopy!(n, y, iy, zeros(eltype(x), size(x)), ix), ix)) + (x̄ .= x̄ .+ scal!(n, z̄, blascopy!(n, y, iy, zeroslike(x), ix), ix)) ∇(ȳ, ::typeof(dot), ::Type{Arg{4}}, p, z, z̄, n::Int, x::SA, ix::Int, y::SA, iy::Int) = - (ȳ .= ȳ .+ scal!(n, z̄, blascopy!(n, x, ix, zeros(eltype(y), size(y)), iy), iy)) + (ȳ .= ȳ .+ scal!(n, z̄, blascopy!(n, x, ix, zeroslike(y), iy), iy)) # Short-form `nrm2`. @explicit_intercepts nrm2 Tuple{Union{StridedVector, Array}} @@ -37,9 +37,9 @@ const SA = StridedArray [false, true, false], ) ∇(::typeof(nrm2), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - scal!(n, ȳ / y, blascopy!(n, x, inc, zeros(eltype(x), size(x)), inc), inc) + scal!(n, ȳ / y, blascopy!(n, x, inc, zeroslike(x), inc), inc) ∇(x̄, ::typeof(nrm2), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - (x̄ .= x̄ .+ scal!(n, ȳ / y, blascopy!(n, x, inc, zeros(eltype(x), size(x)), inc), inc)) + (x̄ .= x̄ .+ scal!(n, ȳ / y, blascopy!(n, x, inc, zeroslike(x), inc), inc)) # Short-form `asum`. @explicit_intercepts asum Tuple{Union{StridedVector, Array}} @@ -53,9 +53,9 @@ const SA = StridedArray [false, true, false], ) ∇(::typeof(asum), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - scal!(n, ȳ, blascopy!(n, sign.(x), inc, zeros(eltype(x), size(x)), inc), inc) + scal!(n, ȳ, blascopy!(n, sign.(x), inc, zeroslike(x), inc), inc) ∇(x̄, ::typeof(asum), ::Type{Arg{2}}, p, y, ȳ, n::Integer, x, inc::Integer) = - (x̄ .= x̄ .+ scal!(n, ȳ, blascopy!(n, sign.(x), inc, zeros(eltype(x), size(x)), inc), inc)) + (x̄ .= x̄ .+ scal!(n, ȳ, blascopy!(n, sign.(x), inc, zeroslike(x), inc), inc)) # Some weird stuff going on that I haven't figured out yet. diff --git a/src/sensitivities/linalg/diagonal.jl b/src/sensitivities/linalg/diagonal.jl index 0b4ab85e..8bd3281f 100644 --- a/src/sensitivities/linalg/diagonal.jl +++ b/src/sensitivities/linalg/diagonal.jl @@ -11,7 +11,7 @@ function ∇( ȳ::∇AbstractVector, x::∇AbstractMatrix, ) - x̄ = zeros(eltype(x), size(x)) + x̄ = zeroslike(x) x̄[diagind(x̄)] = ȳ return x̄ end @@ -39,7 +39,7 @@ function ∇( x::∇AbstractMatrix, k::Integer, ) - x̄ = zeros(eltype(x), size(x)) + x̄ = zeroslike(x) x̄[diagind(x̄, k)] = ȳ return x̄ end @@ -90,7 +90,7 @@ function ∇( Ȳ::∇ScalarDiag, X::∇AbstractMatrix, ) - X̄ = zeros(eltype(X), size(X)) + X̄ = zeroslike(X) copyto!(view(X̄, diagind(X)), Ȳ.diag) return X̄ end diff --git a/src/sensitivities/linalg/generic.jl b/src/sensitivities/linalg/generic.jl index 30b3d478..3a124d1c 100644 --- a/src/sensitivities/linalg/generic.jl +++ b/src/sensitivities/linalg/generic.jl @@ -17,8 +17,10 @@ for (f, T_In, T_Out, X̄, bounds) in unary_linalg_optimisations else @eval import LinearAlgebra: $f end - @eval @explicit_intercepts $f Tuple{$T_In} - @eval ∇(::typeof($f), ::Type{Arg{1}}, p, Y::$T_Out, Ȳ::$T_Out, X::$T_In) = $X̄ + @eval begin + @explicit_intercepts $f Tuple{$T_In} + ∇(::typeof($f), ::Type{Arg{1}}, p, Y::$T_Out, Ȳ::$T_Out, X::$T_In) = $X̄ + end end # Implementation of sensitivities for binary linalg optimisations. @@ -115,9 +117,9 @@ import LinearAlgebra: kron # The allocating versions simply allocate and then call the in-place versions. ∇(::typeof(kron), ::Type{Arg{1}}, p, Y::A, Ȳ::A, A::A, B::A) = - ∇(zeros(eltype(A), size(A)), kron, Arg{1}, p, Y, Ȳ, A, B) + ∇(zeroslike(A), kron, Arg{1}, p, Y, Ȳ, A, B) ∇(::typeof(kron), ::Type{Arg{2}}, p, Y::A, Ȳ::A, A::A, B::A) = - ∇(zeros(eltype(B), size(B)), kron, Arg{2}, p, Y, Ȳ, A, B) + ∇(zeroslike(B), kron, Arg{2}, p, Y, Ȳ, A, B) function ∇(Ā::A, ::typeof(kron), ::Type{Arg{1}}, p, Y::A, Ȳ::A, A::A, B::A) (I, J), (K, L), m = size(A), size(B), length(Y) diff --git a/test/runtests.jl b/test/runtests.jl index dd1b5469..28c8ee3b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,8 @@ using Nabla using Test, LinearAlgebra, Random using Distributions, BenchmarkTools, SpecialFunctions, DualNumbers +using Nabla: oneslike, zeroslike + @testset "Core" begin include("core.jl") include("code_transformation/util.jl") diff --git a/test/sensitivities/functional/functional.jl b/test/sensitivities/functional/functional.jl index 425b0116..85b18d7e 100644 --- a/test/sensitivities/functional/functional.jl +++ b/test/sensitivities/functional/functional.jl @@ -1,11 +1,6 @@ using SpecialFunctions using DiffRules: diffrule, hasdiffrule -# ones(::AbstractArray) is deprecated in 0.7 and removed in 1.0, but it's a pretty useful -# method, so we'll define our own for testing purposes -oneslike(a::AbstractArray) = ones(eltype(a), size(a)) -oneslike(n::Integer) = ones(n) - @testset "Functional" begin # Apparently Distributions.jl doesn't implement the following, so we'll have to do it. Random.rand(rng::AbstractRNG, a::Distribution, n::Integer) = diff --git a/test/sensitivities/functional/reducedim.jl b/test/sensitivities/functional/reducedim.jl index a89d7d84..91632953 100644 --- a/test/sensitivities/functional/reducedim.jl +++ b/test/sensitivities/functional/reducedim.jl @@ -10,19 +10,19 @@ x2_ = reshape([1.0, 2.0, 3.0, 4.0,], (2, 2)) x2 = Leaf(Tape(), x2_) s = mapreduce(abs2, +, x2, dims=1) - @test ∇(s, ones(eltype(s.val), size(s.val)))[x2] ≈ 2.0 * x2_ + @test ∇(s, one(s.val))[x2] ≈ 2.0 * x2_ # mapreducedim under `exp` should trigger the first conditional in the ∇ impl. x3_ = randn(rng, 5, 4) x3 = Leaf(Tape(), x3_) s = mapreduce(exp, +, x3, dims=1) - @test ∇(s, ones(eltype(s.val), size(s.val)))[x3] == exp.(x3_) + @test ∇(s, one(s.val))[x3] == exp.(x3_) # mapreducedim under an anonymous-function should trigger fmad. x4_ = randn(rng, 5, 4) x4 = Leaf(Tape(), x4_) s = mapreduce(x->x*x, +, x4, dims=2) - @test ∇(s, ones(eltype(s.val), size(s.val)))[x4] == 2x4_ + @test ∇(s, one(s.val))[x4] == 2x4_ # Check that `sum` works correctly with `Node`s. x_sum = Leaf(Tape(), randn(rng, 5, 4, 3)) diff --git a/test/sensitivities/indexing.jl b/test/sensitivities/indexing.jl index 1a5de6e8..1ee4a3a6 100644 --- a/test/sensitivities/indexing.jl +++ b/test/sensitivities/indexing.jl @@ -10,6 +10,6 @@ x = Leaf(Tape(), 10 * [1, 1, 1]) y = x[2:3] @test y.val == [10, 10] - @test ∇(y, ones(eltype(y.val), size(y.val)))[x] == [0, 1, 1] + @test ∇(y, oneslike(y.val))[x] == [0, 1, 1] end end From d466865d24669157b0af1f4c333cb12d8a21fd28 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Tue, 23 Oct 2018 11:28:52 -0700 Subject: [PATCH 08/11] Use oneslike one more time --- test/sensitivities/functional/reduce.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sensitivities/functional/reduce.jl b/test/sensitivities/functional/reduce.jl index 3eec855c..56f6123e 100644 --- a/test/sensitivities/functional/reduce.jl +++ b/test/sensitivities/functional/reduce.jl @@ -53,7 +53,7 @@ x_ = Leaf(Tape(), x) s = functional(+, x_) @test s.val == functional(+, x) - @test ∇(s)[x_] ≈ ones(Float64, 100) + @test ∇(s)[x_] ≈ oneslike(100) end end From 5500da428f810fd3b4ea60cbacc47a84c3217e26 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Tue, 23 Oct 2018 11:44:24 -0700 Subject: [PATCH 09/11] Fix a typo in Travis specification --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 6d453f5e..f8963d61 100644 --- a/.travis.yml +++ b/.travis.yml @@ -22,7 +22,7 @@ after_success: - julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(Codecov.process_folder())' jobs: include: - - state: "Documentation" + - stage: "Documentation" julia: 1.0 os: linux script: From fc33cc1788bc5b47d44722d6c173e5f59b3354eb Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Tue, 23 Oct 2018 12:17:26 -0700 Subject: [PATCH 10/11] Add version bounds for dependencies in Project.toml --- Project.toml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 7d5d5bfa..8e92c862 100644 --- a/Project.toml +++ b/Project.toml @@ -2,9 +2,6 @@ name = "Nabla" uuid = "49c96f43-aa6d-5a04-a506-44c7070ebe78" version = "0.2.0" -[compat] -julia = "0.7, 1.0" - [deps] DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" @@ -12,6 +9,13 @@ FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +[compat] +julia = "0.7, 1.0" +DiffRules = "0.0.1" +DualNumbers = "0.6.0" +FDM = "0.1.0" +SpecialFunctions = "0.3.0" + [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" From 7badedc42d0e5c3c4e3abee441692fe71a6789bd Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Tue, 23 Oct 2018 12:34:20 -0700 Subject: [PATCH 11/11] Fix version bounds --- Project.toml | 6 +++--- REQUIRE | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 8e92c862..315d57f2 100644 --- a/Project.toml +++ b/Project.toml @@ -11,10 +11,10 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] julia = "0.7, 1.0" -DiffRules = "0.0.1" +DiffRules = "0.0" DualNumbers = "0.6.0" -FDM = "0.1.0" -SpecialFunctions = "0.3.0" +FDM = "0.1.0, 0.2.0" +SpecialFunctions = ">=0.5.0" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" diff --git a/REQUIRE b/REQUIRE index 31e93ed4..236d5f49 100644 --- a/REQUIRE +++ b/REQUIRE @@ -2,4 +2,4 @@ julia 0.7 DualNumbers 0.6.0 DiffRules 0.0.1 FDM 0.1.0 -SpecialFunctions 0.3.0 +SpecialFunctions 0.5.0