diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index bd98bd3ab..84786f31c 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.14" +version = "0.6.15" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index f069579c3..08208a221 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -13,6 +13,7 @@ DifferentiationInterface ```@docs Context Constant +Cache ``` ## First order diff --git a/DifferentiationInterface/docs/src/explanation/advanced.md b/DifferentiationInterface/docs/src/explanation/advanced.md index d5c170e43..76c5ba012 100644 --- a/DifferentiationInterface/docs/src/explanation/advanced.md +++ b/DifferentiationInterface/docs/src/explanation/advanced.md @@ -10,22 +10,29 @@ However, the release v0.6 introduced the possibility of additional "context" arg Contexts can be useful if you have a function `y = f(x, a, b, c, ...)` or `f!(y, x, a, b, c, ...)` and you want derivatives of `y` with respect to `x` only. Another option would be creating a closure, but that is sometimes undesirable. -!!! warning - This feature is still experimental, and will likely not be supported by all backends. - At the moment, it only works with certain backends, among which ForwardDiff, Zygote and Enzyme. - ### Types of contexts Every context argument must be wrapped in a subtype of [`Context`](@ref) and come after the differentiated input `x`. -Right now, there is only one kind of context, namely [`Constant`](@ref), but we might add more. -Semantically, calling +Right now, there are two kinds of context: [`Constant`](@ref) and [`Cache`](@ref). + +!!! warning + This feature is still experimental and will not be supported by all backends. + At the moment: + - `Constant` is supported by all backends except symbolic ones + - `Cache` is only supported by finite difference backends + +Semantically, both of these calls compute the partial gradient of `f(x, c)` with respect to `x`, but they consider `c` differently: ```julia gradient(f, backend, x, Constant(c)) +gradient(f, backend, x, Cache(c)) ``` -computes the partial gradient of `f(x, c)` with respect to `x`, while keeping `c` constant. -Importantly, one can prepare an operator with an arbitrary value `c'` of the constant (subject to the usual restrictions on preparation). +In the first call, `c` is kept unchanged throughout the function evaluation. +In the second call, `c` can be mutated with values computed during the function. + +Importantly, one can prepare an operator with an arbitrary value `c'` of the `Constant` (subject to the usual restrictions on preparation). +The values in a provided `Cache` never matter anyway. ## Sparsity diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 7e9454372..9831aba8e 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -65,7 +65,7 @@ include("misc/zero_backends.jl") ## Exported -export Context, Constant +export Context, Constant, Cache export SecondOrder export value_and_pushforward!, value_and_pushforward diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 9bbbe56a4..9815fd647 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -15,6 +15,7 @@ Abstract supertype for additional context arguments, which can be passed to diff # See also - [`Constant`](@ref) +- [`Cache`](@ref) """ abstract type Context end @@ -58,6 +59,25 @@ end Base.convert(::Type{Constant{T}}, x) where {T} = Constant(convert(T, x)) +""" + Cache + +Concrete type of [`Context`](@ref) argument which can be mutated with active values during differentiation. + +The initial values present inside the cache do not matter. +""" +struct Cache{T} <: Context + data::T +end + +unwrap(c::Cache) = c.data + +function Base.convert(::Type{Cache{T}}, x::Cache) where {T} + return Cache(convert(T, x.data)) +end + +Base.convert(::Type{Cache{T}}, x) where {T} = Cache(convert(T, x)) + struct Rewrap{C,T} function Rewrap(contexts::Vararg{Context,C}) where {C} T = typeof(contexts) diff --git a/DifferentiationInterface/test/Back/FiniteDiff/test.jl b/DifferentiationInterface/test/Back/FiniteDiff/test.jl index 5965f3c86..bf583915c 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/test.jl @@ -14,7 +14,7 @@ end test_differentiation( AutoFiniteDiff(), - default_scenarios(; include_constantified=true); + default_scenarios(; include_constantified=true, include_cachified=true); excluded=[:second_derivative, :hvp], logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl index a1ad1bda0..3a54f2123 100644 --- a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl @@ -14,7 +14,7 @@ end test_differentiation( AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), - default_scenarios(; include_constantified=true); + default_scenarios(; include_constantified=true, include_cachified=true); excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index a619c678c..24dbf51b6 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterfaceTest" uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.8.3" +version = "0.8.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl index 2f0afbc58..da535e5f6 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl @@ -21,13 +21,13 @@ function (f::DIT.NumToArr{JLMatrix{T}})(x::Number) where {T} return sin.(x .* a) end -myjl(f::DIT.MultiplyByConstant) = f -myjl(f::DIT.WritableClosure) = f +myjl(f::DIT.FunctionModifier) = f myjl(x::Number) = x myjl(x::AbstractArray) = jl(x) myjl(x::Tuple) = map(myjl, x) myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x))) +myjl(x::DI.Cache) = DI.Cache(myjl(DI.unwrap(x))) myjl(::Nothing) = nothing function myjl(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl index 0b8d0d9d2..938caad16 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl @@ -7,45 +7,45 @@ using Random: AbstractRNG, default_rng using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector -mySArray(f::Function) = f -mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T}) -mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6}) -mySArray(f::DIT.MultiplyByConstant) = f -mySArray(f::DIT.WritableClosure) = f +mystatic(f::Function) = f +mystatic(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T}) +mystatic(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6}) +mystatic(f::DIT.FunctionModifier) = f -mySArray(x::Number) = x -myMArray(x::Number) = x +mystatic(x::Number) = x +mymutablestatic(x::Number) = x -mySArray(x::AbstractVector{T}) where {T} = convert(SVector{length(x),T}, x) -myMArray(x::AbstractVector{T}) where {T} = convert(MVector{length(x),T}, x) +mystatic(x::AbstractVector{T}) where {T} = convert(SVector{length(x),T}, x) +mymutablestatic(x::AbstractVector{T}) where {T} = convert(MVector{length(x),T}, x) -function mySArray(x::AbstractMatrix{T}) where {T} +function mystatic(x::AbstractMatrix{T}) where {T} return convert(SMatrix{size(x, 1),size(x, 2),T,length(x)}, x) end -function myMArray(x::AbstractMatrix{T}) where {T} +function mymutablestatic(x::AbstractMatrix{T}) where {T} return convert(MMatrix{size(x, 1),size(x, 2),T,length(x)}, x) end -mySArray(x::Tuple) = map(mySArray, x) -mySArray(x::DI.Constant) = DI.Constant(mySArray(DI.unwrap(x))) -mySArray(::Nothing) = nothing +mystatic(x::Tuple) = map(mystatic, x) +mystatic(x::DI.Constant) = DI.Constant(mystatic(DI.unwrap(x))) +mystatic(x::DI.Cache) = DI.Cache(mymutablestatic(DI.unwrap(x))) +mystatic(::Nothing) = nothing -function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} +function mystatic(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} (; f, x, y, tang, contexts, res1, res2) = scen return Scenario{op,pl_op,pl_fun}( - mySArray(f); - x=mySArray(x), - y=pl_fun == :in ? myMArray(y) : mySArray(y), - tang=mySArray(tang), - contexts=mySArray(contexts), - res1=mySArray(res1), - res2=mySArray(res2), + mystatic(f); + x=mystatic(x), + y=pl_fun == :in ? mymutablestatic(y) : mystatic(y), + tang=mystatic(tang), + contexts=mystatic(contexts), + res1=mystatic(res1), + res2=mystatic(res2), ) end function DIT.static_scenarios(args...; kwargs...) scens = DIT.default_scenarios(args...; kwargs...) - return mySArray.(scens) + return mystatic.(scens) end end diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index a6c9f8c3b..340d9ef44 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -459,6 +459,7 @@ function default_scenarios( include_batchified=true, include_closurified=false, include_constantified=false, + include_cachified=false, ) x_ = rand(rng) dx_ = rand(rng) @@ -504,6 +505,7 @@ function default_scenarios( include_normal && append!(final_scens, scens) include_closurified && append!(final_scens, closurify(scens)) include_constantified && append!(final_scens, constantify(scens)) + include_cachified && append!(final_scens, cachify(scens)) return final_scens end diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index a3f63a2f7..5e497c137 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -1,3 +1,5 @@ +abstract type FunctionModifier end + """ zero(scen::Scenario) @@ -56,7 +58,7 @@ function batchify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} end end -struct WritableClosure{pl_fun,F,X,Y} +struct WritableClosure{pl_fun,F,X,Y} <: FunctionModifier f::F x_buffer::Vector{X} y_buffer::Vector{Y} @@ -90,13 +92,14 @@ Return a new `Scenario` identical to `scen` except for the function `f` which is """ function closurify(scen::Scenario) (; f, x, y) = scen + @assert isempty(scen.contexts) x_buffer = [zero(x)] y_buffer = [zero(y)] closure_f = WritableClosure{function_place(scen)}(f, x_buffer, y_buffer) return change_function(scen, closure_f) end -struct MultiplyByConstant{pl_fun,F} +struct MultiplyByConstant{pl_fun,F} <: FunctionModifier f::F end @@ -123,6 +126,7 @@ The output and result fields are updated accordingly. """ function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} (; f,) = scen + @assert isempty(scen.contexts) multiply_f = MultiplyByConstant{pl_fun}(f) a = 3.0 return Scenario{op,pl_op,pl_fun}( @@ -136,11 +140,63 @@ function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} ) end +struct StoreInCache{pl_fun,F} <: FunctionModifier + f::F +end + +function StoreInCache{pl_fun}(f::F) where {pl_fun,F} + return StoreInCache{pl_fun,F}(f) +end + +Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") + +function (sc::StoreInCache{:out})(x, y_cache) + y = sc.f(x) + if y isa Number + y_cache[1] = y + return y_cache[1] + else + copyto!(y_cache, y) + return copy(y_cache) + end +end + +function (sc::StoreInCache{:in})(y, x, y_cache) + sc.f(y_cache, x) + copyto!(y, y_cache) + return nothing +end + +""" + cachify(scen::Scenario) + +Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument `a` to store the result before it is returned. +""" +function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} + (; f,) = scen + @assert isempty(scen.contexts) + cache_f = StoreInCache{pl_fun}(f) + y_cache = if scen.y isa Number + [myzero(scen.y)] + else + mysimilar(scen.y) + end + return Scenario{op,pl_op,pl_fun}( + cache_f; + x=scen.x, + y=scen.y, + tang=scen.tang, + contexts=(Cache(y_cache),), + res1=scen.res1, + res2=scen.res2, + ) +end + function batchify(scens::AbstractVector{<:Scenario}) batchifiable_scens = filter(s -> operator(s) in (:pushforward, :pullback, :hvp), scens) return batchify.(batchifiable_scens) end closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens) - constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens) +cachify(scens::AbstractVector{<:Scenario}) = cachify.(scens) diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index 252475851..8c32d06f0 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -69,7 +69,15 @@ function Base.:(==)( eq_x = scen1.x == scen2.x eq_y = scen1.y == scen2.y eq_tang = scen1.tang == scen2.tang - eq_contexts = scen1.contexts == scen2.contexts + eq_contexts = all( + map(scen1.contexts, scen2.contexts) do c1, c2 + if c1 isa Cache || c2 isa Cache + return true + else + return c1 == c2 + end + end, + ) eq_res1 = scen1.res1 == scen2.res1 eq_res2 = scen1.res2 == scen2.res2 return (eq_x && eq_y && eq_tang && eq_contexts && eq_res1 && eq_res2) diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index ce275f713..10afcb8a5 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -24,10 +24,16 @@ LOGGING = get(ENV, "CI", "false") == "false" ## Generate all scenarios gpu_scenarios(; - include_constantified=true, include_closurified=true, include_batchified=true + include_constantified=true, + include_closurified=true, + include_batchified=true, + include_cachified=true, ) static_scenarios(; - include_constantified=true, include_closurified=true, include_batchified=true + include_constantified=true, + include_closurified=true, + include_batchified=true, + include_cachified=true, ) ## Weird arrays @@ -40,11 +46,13 @@ test_differentiation(AutoForwardDiff(), component_scenarios(); logging=LOGGING) test_differentiation(AutoZygote(), gpu_scenarios(); excluded=SECOND_ORDER, logging=LOGGING) -## Closures +## Closures & caches test_differentiation( AutoFiniteDiff(), - default_scenarios(; include_normal=false, include_closurified=true); + default_scenarios(; + include_normal=false, include_closurified=true, include_cachified=true + ); excluded=SECOND_ORDER, logging=LOGGING, );