Skip to content

Commit

Permalink
Start implementing Cache contexts (#587)
Browse files Browse the repository at this point in the history
* Start implementing caches

* First Cache implementation and tests

* Function modifiers

* Scenario conversion
  • Loading branch information
gdalle authored Oct 16, 2024
1 parent d4b17c1 commit 495d988
Show file tree
Hide file tree
Showing 14 changed files with 148 additions and 46 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.14"
version = "0.6.15"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ DifferentiationInterface
```@docs
Context
Constant
Cache
```

## First order
Expand Down
23 changes: 15 additions & 8 deletions DifferentiationInterface/docs/src/explanation/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,29 @@ However, the release v0.6 introduced the possibility of additional "context" arg
Contexts can be useful if you have a function `y = f(x, a, b, c, ...)` or `f!(y, x, a, b, c, ...)` and you want derivatives of `y` with respect to `x` only.
Another option would be creating a closure, but that is sometimes undesirable.

!!! warning
This feature is still experimental, and will likely not be supported by all backends.
At the moment, it only works with certain backends, among which ForwardDiff, Zygote and Enzyme.

### Types of contexts

Every context argument must be wrapped in a subtype of [`Context`](@ref) and come after the differentiated input `x`.
Right now, there is only one kind of context, namely [`Constant`](@ref), but we might add more.
Semantically, calling
Right now, there are two kinds of context: [`Constant`](@ref) and [`Cache`](@ref).

!!! warning
This feature is still experimental and will not be supported by all backends.
At the moment:
- `Constant` is supported by all backends except symbolic ones
- `Cache` is only supported by finite difference backends

Semantically, both of these calls compute the partial gradient of `f(x, c)` with respect to `x`, but they consider `c` differently:

```julia
gradient(f, backend, x, Constant(c))
gradient(f, backend, x, Cache(c))
```

computes the partial gradient of `f(x, c)` with respect to `x`, while keeping `c` constant.
Importantly, one can prepare an operator with an arbitrary value `c'` of the constant (subject to the usual restrictions on preparation).
In the first call, `c` is kept unchanged throughout the function evaluation.
In the second call, `c` can be mutated with values computed during the function.

Importantly, one can prepare an operator with an arbitrary value `c'` of the `Constant` (subject to the usual restrictions on preparation).
The values in a provided `Cache` never matter anyway.

## Sparsity

Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ include("misc/zero_backends.jl")

## Exported

export Context, Constant
export Context, Constant, Cache
export SecondOrder

export value_and_pushforward!, value_and_pushforward
Expand Down
20 changes: 20 additions & 0 deletions DifferentiationInterface/src/utils/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Abstract supertype for additional context arguments, which can be passed to diff
# See also
- [`Constant`](@ref)
- [`Cache`](@ref)
"""
abstract type Context end

Expand Down Expand Up @@ -58,6 +59,25 @@ end

Base.convert(::Type{Constant{T}}, x) where {T} = Constant(convert(T, x))

"""
Cache
Concrete type of [`Context`](@ref) argument which can be mutated with active values during differentiation.
The initial values present inside the cache do not matter.
"""
struct Cache{T} <: Context
data::T
end

unwrap(c::Cache) = c.data

function Base.convert(::Type{Cache{T}}, x::Cache) where {T}
return Cache(convert(T, x.data))
end

Base.convert(::Type{Cache{T}}, x) where {T} = Cache(convert(T, x))

struct Rewrap{C,T}
function Rewrap(contexts::Vararg{Context,C}) where {C}
T = typeof(contexts)
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Back/FiniteDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end

test_differentiation(
AutoFiniteDiff(),
default_scenarios(; include_constantified=true);
default_scenarios(; include_constantified=true, include_cachified=true);
excluded=[:second_derivative, :hvp],
logging=LOGGING,
);
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end

test_differentiation(
AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)),
default_scenarios(; include_constantified=true);
default_scenarios(; include_constantified=true, include_cachified=true);
excluded=SECOND_ORDER,
logging=LOGGING,
);
2 changes: 1 addition & 1 deletion DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterfaceTest"
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.8.3"
version = "0.8.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ function (f::DIT.NumToArr{JLMatrix{T}})(x::Number) where {T}
return sin.(x .* a)
end

myjl(f::DIT.MultiplyByConstant) = f
myjl(f::DIT.WritableClosure) = f
myjl(f::DIT.FunctionModifier) = f

myjl(x::Number) = x
myjl(x::AbstractArray) = jl(x)
myjl(x::Tuple) = map(myjl, x)
myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x)))
myjl(x::DI.Cache) = DI.Cache(myjl(DI.unwrap(x)))
myjl(::Nothing) = nothing

function myjl(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,45 @@ using Random: AbstractRNG, default_rng
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector

mySArray(f::Function) = f
mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T})
mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6})
mySArray(f::DIT.MultiplyByConstant) = f
mySArray(f::DIT.WritableClosure) = f
mystatic(f::Function) = f
mystatic(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T})
mystatic(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6})
mystatic(f::DIT.FunctionModifier) = f

mySArray(x::Number) = x
myMArray(x::Number) = x
mystatic(x::Number) = x
mymutablestatic(x::Number) = x

mySArray(x::AbstractVector{T}) where {T} = convert(SVector{length(x),T}, x)
myMArray(x::AbstractVector{T}) where {T} = convert(MVector{length(x),T}, x)
mystatic(x::AbstractVector{T}) where {T} = convert(SVector{length(x),T}, x)
mymutablestatic(x::AbstractVector{T}) where {T} = convert(MVector{length(x),T}, x)

function mySArray(x::AbstractMatrix{T}) where {T}
function mystatic(x::AbstractMatrix{T}) where {T}
return convert(SMatrix{size(x, 1),size(x, 2),T,length(x)}, x)
end
function myMArray(x::AbstractMatrix{T}) where {T}
function mymutablestatic(x::AbstractMatrix{T}) where {T}
return convert(MMatrix{size(x, 1),size(x, 2),T,length(x)}, x)
end

mySArray(x::Tuple) = map(mySArray, x)
mySArray(x::DI.Constant) = DI.Constant(mySArray(DI.unwrap(x)))
mySArray(::Nothing) = nothing
mystatic(x::Tuple) = map(mystatic, x)
mystatic(x::DI.Constant) = DI.Constant(mystatic(DI.unwrap(x)))
mystatic(x::DI.Cache) = DI.Cache(mymutablestatic(DI.unwrap(x)))
mystatic(::Nothing) = nothing

function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
function mystatic(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
(; f, x, y, tang, contexts, res1, res2) = scen
return Scenario{op,pl_op,pl_fun}(
mySArray(f);
x=mySArray(x),
y=pl_fun == :in ? myMArray(y) : mySArray(y),
tang=mySArray(tang),
contexts=mySArray(contexts),
res1=mySArray(res1),
res2=mySArray(res2),
mystatic(f);
x=mystatic(x),
y=pl_fun == :in ? mymutablestatic(y) : mystatic(y),
tang=mystatic(tang),
contexts=mystatic(contexts),
res1=mystatic(res1),
res2=mystatic(res2),
)
end

function DIT.static_scenarios(args...; kwargs...)
scens = DIT.default_scenarios(args...; kwargs...)
return mySArray.(scens)
return mystatic.(scens)
end

end
2 changes: 2 additions & 0 deletions DifferentiationInterfaceTest/src/scenarios/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ function default_scenarios(
include_batchified=true,
include_closurified=false,
include_constantified=false,
include_cachified=false,
)
x_ = rand(rng)
dx_ = rand(rng)
Expand Down Expand Up @@ -504,6 +505,7 @@ function default_scenarios(
include_normal && append!(final_scens, scens)
include_closurified && append!(final_scens, closurify(scens))
include_constantified && append!(final_scens, constantify(scens))
include_cachified && append!(final_scens, cachify(scens))

return final_scens
end
62 changes: 59 additions & 3 deletions DifferentiationInterfaceTest/src/scenarios/modify.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
abstract type FunctionModifier end

"""
zero(scen::Scenario)
Expand Down Expand Up @@ -56,7 +58,7 @@ function batchify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
end
end

struct WritableClosure{pl_fun,F,X,Y}
struct WritableClosure{pl_fun,F,X,Y} <: FunctionModifier
f::F
x_buffer::Vector{X}
y_buffer::Vector{Y}
Expand Down Expand Up @@ -90,13 +92,14 @@ Return a new `Scenario` identical to `scen` except for the function `f` which is
"""
function closurify(scen::Scenario)
(; f, x, y) = scen
@assert isempty(scen.contexts)
x_buffer = [zero(x)]
y_buffer = [zero(y)]
closure_f = WritableClosure{function_place(scen)}(f, x_buffer, y_buffer)
return change_function(scen, closure_f)
end

struct MultiplyByConstant{pl_fun,F}
struct MultiplyByConstant{pl_fun,F} <: FunctionModifier
f::F
end

Expand All @@ -123,6 +126,7 @@ The output and result fields are updated accordingly.
"""
function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
(; f,) = scen
@assert isempty(scen.contexts)
multiply_f = MultiplyByConstant{pl_fun}(f)
a = 3.0
return Scenario{op,pl_op,pl_fun}(
Expand All @@ -136,11 +140,63 @@ function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
)
end

struct StoreInCache{pl_fun,F} <: FunctionModifier
f::F
end

function StoreInCache{pl_fun}(f::F) where {pl_fun,F}
return StoreInCache{pl_fun,F}(f)
end

Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))")

function (sc::StoreInCache{:out})(x, y_cache)
y = sc.f(x)
if y isa Number
y_cache[1] = y
return y_cache[1]
else
copyto!(y_cache, y)
return copy(y_cache)
end
end

function (sc::StoreInCache{:in})(y, x, y_cache)
sc.f(y_cache, x)
copyto!(y, y_cache)
return nothing
end

"""
cachify(scen::Scenario)
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument `a` to store the result before it is returned.
"""
function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
(; f,) = scen
@assert isempty(scen.contexts)
cache_f = StoreInCache{pl_fun}(f)
y_cache = if scen.y isa Number
[myzero(scen.y)]
else
mysimilar(scen.y)
end
return Scenario{op,pl_op,pl_fun}(
cache_f;
x=scen.x,
y=scen.y,
tang=scen.tang,
contexts=(Cache(y_cache),),
res1=scen.res1,
res2=scen.res2,
)
end

function batchify(scens::AbstractVector{<:Scenario})
batchifiable_scens = filter(s -> operator(s) in (:pushforward, :pullback, :hvp), scens)
return batchify.(batchifiable_scens)
end

closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens)

constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens)
cachify(scens::AbstractVector{<:Scenario}) = cachify.(scens)
10 changes: 9 additions & 1 deletion DifferentiationInterfaceTest/src/scenarios/scenario.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,15 @@ function Base.:(==)(
eq_x = scen1.x == scen2.x
eq_y = scen1.y == scen2.y
eq_tang = scen1.tang == scen2.tang
eq_contexts = scen1.contexts == scen2.contexts
eq_contexts = all(
map(scen1.contexts, scen2.contexts) do c1, c2
if c1 isa Cache || c2 isa Cache
return true
else
return c1 == c2
end
end,
)
eq_res1 = scen1.res1 == scen2.res1
eq_res2 = scen1.res2 == scen2.res2
return (eq_x && eq_y && eq_tang && eq_contexts && eq_res1 && eq_res2)
Expand Down
16 changes: 12 additions & 4 deletions DifferentiationInterfaceTest/test/weird.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@ LOGGING = get(ENV, "CI", "false") == "false"
## Generate all scenarios

gpu_scenarios(;
include_constantified=true, include_closurified=true, include_batchified=true
include_constantified=true,
include_closurified=true,
include_batchified=true,
include_cachified=true,
)
static_scenarios(;
include_constantified=true, include_closurified=true, include_batchified=true
include_constantified=true,
include_closurified=true,
include_batchified=true,
include_cachified=true,
)

## Weird arrays
Expand All @@ -40,11 +46,13 @@ test_differentiation(AutoForwardDiff(), component_scenarios(); logging=LOGGING)

test_differentiation(AutoZygote(), gpu_scenarios(); excluded=SECOND_ORDER, logging=LOGGING)

## Closures
## Closures & caches

test_differentiation(
AutoFiniteDiff(),
default_scenarios(; include_normal=false, include_closurified=true);
default_scenarios(;
include_normal=false, include_closurified=true, include_cachified=true
);
excluded=SECOND_ORDER,
logging=LOGGING,
);
Expand Down

0 comments on commit 495d988

Please sign in to comment.