Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache the LU factorisation in the direct linear solver and better static array support #64

Merged
merged 8 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.2"
manifest_format = "2.0"
project_hash = "e7f4896b7e8c3921c7466749f467ab7680867992"
project_hash = "ff84ddc3d5227f964f2cd507ce5cbc83b4fba207"

[[deps.AMD]]
deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"]
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
19 changes: 18 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Documenter
using ForwardDiff: ForwardDiff
using ImplicitDifferentiation
using Literate
using StaticArrays: StaticArrays

DocMeta.setdocmeta!(
ImplicitDifferentiation, :DocTestSetup, :(using ImplicitDifferentiation); recursive=true
Expand Down Expand Up @@ -62,8 +63,24 @@ fmt = Documenter.HTML(;
edit_link=:commit,
)

if isdefined(Base, :get_extension)
extension_modules = [
Base.get_extension(ImplicitDifferentiation, :ImplicitDifferentiationChainRulesExt)
Base.get_extension(ImplicitDifferentiation, :ImplicitDifferentiationForwardDiffExt)
Base.get_extension(
ImplicitDifferentiation, :ImplicitDifferentiationStaticArraysExt
)
]
else
extension_modules = [
ImplicitDifferentiation.ImplicitDifferentiationChainRulesExt,
ImplicitDifferentiation.ImplicitDifferentiationForwardDiffExt,
ImplicitDifferentiation.ImplicitDifferentiationStaticArraysExt,
]
end

makedocs(;
modules=[ImplicitDifferentiation],
modules=vcat([ImplicitDifferentiation], extension_modules),
authors="Guillaume Dalle, Mohamed Tarek and contributors",
repo="https://github.com/gdalle/ImplicitDifferentiation.jl/blob/{commit}{path}#{line}",
sitename="ImplicitDifferentiation.jl",
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ DirectLinearSolver
IterativeLinearSolver
HandleByproduct
ReturnByproduct
ChainRulesCore.rrule
```

## Internals
Expand Down
46 changes: 29 additions & 17 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,36 @@

## Supported autodiff backends

- Forward mode: [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)
- Reverse mode: all the packages compatible with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)
| Mode | Backend | Support |
| ------- | ---------------------------------------------------------- | ------- |
| Forward | [ForwardDiff.jl] | yes |
| Reverse | [ChainRules.jl]-compatible ([Zygote.jl], [ReverseDiff.jl]) | yes |
gdalle marked this conversation as resolved.
Show resolved Hide resolved
| Forward | [ChainRules.jl]-compatible ([Diffractor.jl]) | soon |
| Both | [Enzyme.jl] | someday |

In the future, we would like to add
[ForwardDiff.jl]: https://github.com/JuliaDiff/ForwardDiff.jl
[ChainRules.jl]: https://github.com/JuliaDiff/ChainRules.jl
[Zygote.jl]: https://github.com/FluxML/Zygote.jl
[ReverseDiff.jl]: https://github.com/JuliaDiff/ReverseDiff.jl
[Enzyme.jl]: https://github.com/EnzymeAD/Enzyme.jl
[Diffractor.jl]: https://github.com/JuliaDiff/Diffractor.jl

- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)
- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl)
## Writing conditions

## Higher-dimensional arrays
We recommend that the conditions themselves do not involve calls to autodiff, even when they describe a gradient.
Otherwise, you will need to make sure that nested autodiff works well in your case.
For instance, if you're differentiating your implicit function in reverse mode with Zygote.jl, you may want to use [`Zygote.forwarddiff`](https://fluxml.ai/Zygote.jl/stable/utils/#Zygote.forwarddiff) to wrap the conditions and differentiate them with ForwardDiff.jl instead.

## Matrices and higher-order arrays

For simplicity, our examples only display functions that eat and spit out vectors.
However, arbitrary array shapes are supported, as long as the forward mapping _and_ conditions return similar arrays.
Beware however, sparse arrays will be densified in the differentiation process.

## Scalar input / output
## Scalars

Functions that eat or spit out a single number are not supported.
The forward mapping _and_ conditions need arrays: for example, instead of returning `value` you should return `[value]` (a 1-element `Vector`).
Consider using an `SVector` from [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) if you seek increased performance.
The forward mapping _and_ conditions need arrays: for example, instead of returning `val` you should return `[val]` (a 1-element `Vector`).

## Multiple inputs / outputs

Expand All @@ -44,17 +55,18 @@ The same trick works for multiple outputs.

## Using byproducts

At first glance, it is not obvious why we impose that the forward mapping should return a byproduct `z` in addition to `y`.
It is mainly useful when the solution procedure creates objects such as Jacobians, which we want to reuse when computing or differentiating the `conditions`.
We will provide simple examples soon.
In the meantime, an advanced application is given by [DifferentiableFrankWolfe.jl](https://github.com/gdalle/DifferentiableFrankWolfe.jl).
Why would the forward mapping return a byproduct `z` in addition to `y`?
It is mainly useful when the solution procedure creates objects such as Jacobians, which we want to reuse when computing or differentiating the conditions.
In that case, you may want to write the differentiation rules yourself for the conditions.
A more advanced application is given by [DifferentiableFrankWolfe.jl](https://github.com/gdalle/DifferentiableFrankWolfe.jl).

Keep in mind that derivatives of `z` will not be computed: the byproduct is considered constant during differentiation (unlike the case of multiple outputs outlined above).

## Differentiating byproducts
## Performance tips

Nope. Sorry. Don't even think about it.
The package is not designed to compute derivatives of `z`, only `y`, which is why the byproduct is considered constant during differentiation.
If you work with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) if you seek increased performance.

## Modeling constrained optimization problems
## Modeling tips

To express constrained optimization problems as implicit functions, you might need differentiable projections or proximal operators to write the optimality conditions.
See [_Efficient and modular implicit differentiation_](https://arxiv.org/abs/2105.15183) for precise formulations.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ CurrentModule = ImplicitDifferentiation
[ImplicitDifferentiation.jl](https://github.com/gdalle/ImplicitDifferentiation.jl) is a package for automatic differentiation of functions defined implicitly, i.e., _forward mappings_

```math
x \in \mathbb{R}^n \longmapsto y(x) \in \mathbb{R}^m
f: x \in \mathbb{R}^n \longmapsto y(x) \in \mathbb{R}^m
```

whose output is defined by _conditions_
Expand Down
4 changes: 2 additions & 2 deletions examples/1_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ The forward mapping and the conditions should accept the same set of keyword arg
=#

function conditions_optim(x, y; method)
∇₂f = 2 .* (y .^ 2 .- x)
∇₂f = @. 4 * (y^2 - x) * y
return ∇₂f
end

Expand Down Expand Up @@ -129,7 +129,7 @@ In this case, the optimization problem boils down to the componentwise square ro
=#

function forward_nlsolve(x; method)
F!(storage, y) = (storage .= y .^ 2 - x)
F!(storage, y) = (storage .= y .^ 2 .- x)
initial_y = similar(x)
initial_y .= 1
result = nlsolve(F!, initial_y; method)
Expand Down
2 changes: 1 addition & 1 deletion examples/2_advanced.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function proj_hypercube(p)
end

function conditions_cstr_optim(x, y)
∇₂f = 2 .* (y .^ 2 .- x)
∇₂f = @. 4 * (y^2 - x) * y
η = 0.1
return y .- proj_hypercube(y .- η .* ∇₂f)
end
Expand Down
6 changes: 3 additions & 3 deletions ext/ImplicitDifferentiationChainRulesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ This is only available if ChainRulesCore.jl is loaded (extension), except on Jul
- By default, this returns a single output `y(x)` with a pullback accepting a single cotangent `dy`.
- If `ReturnByproduct()` is passed as an argument, this returns a couple of outputs `(y(x),z(x))` with a pullback accepting a couple of cotangents `(dy, dz)` (remember that `z(x)` is not differentiated so its cotangent is ignored).

We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = -Bᵀu`.
We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = -Bᵀu` (see [`ImplicitFunction`](@ref) for the definition of `A` and `B`).
Keyword arguments are given to both `implicit.forward` and `implicit.conditions`.
"""
function ChainRulesCore.rrule(
Expand Down Expand Up @@ -53,7 +53,7 @@ function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...
) where {R}
(y, z), implicit_pullback = rrule(rc, implicit, x, ReturnByproduct(); kwargs...)
implicit_pullback_no_byproduct(dy) = implicit_pullback((dy, nothing))
implicit_pullback_no_byproduct(dy) = Base.front(implicit_pullback((dy, nothing)))
return y, implicit_pullback_no_byproduct
end

Expand All @@ -73,7 +73,7 @@ function (implicit_pullback::ImplicitPullback)((dy, dz))
mul!(dx_vec, Bᵀ_op, dF_vec)
lmul!(-one(R), dx_vec)
dx = reshape(dx_vec, size(x))
return (NoTangent(), dx)
return (NoTangent(), dx, NoTangent())
end

end
2 changes: 1 addition & 1 deletion ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ This is only available if ForwardDiff.jl is loaded (extension).
- By default, this returns a single output `y_and_dy(x)`.
- If `ReturnByproduct()` is passed as an argument, this returns a couple of outputs `(y_and_dy(x),z(x))` (remember that `z(x)` is not differentiated so `dz(x)` doesn't exist).

We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`.
We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u` (see [`ImplicitFunction`](@ref) for the definition of `A` and `B`).
Keyword arguments are given to both `implicit.forward` and `implicit.conditions`.
"""
function (implicit::ImplicitFunction)(
Expand Down
2 changes: 2 additions & 0 deletions src/conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
end
end

function Base.show(io::IO, conditions::Conditions{byproduct}) where {byproduct}
return print(io, "Conditions{$byproduct}($(conditions.c))")

Check warning on line 16 in src/conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/conditions.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
end

(conditions::Conditions{true})(x, y, z; kwargs...) = conditions.c(x, y, z; kwargs...)
(conditions::Conditions{false})(x, y, z; kwargs...) = conditions.c(x, y; kwargs...)

handles_byproduct(::Conditions{byproduct}) where {byproduct} = byproduct
2 changes: 2 additions & 0 deletions src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
end
end

function Base.show(io::IO, forward::Forward{byproduct}) where {byproduct}
return print(io, "Forward{$byproduct}($(forward.f))")

Check warning on line 16 in src/forward.jl

View check run for this annotation

Codecov / codecov/patch

src/forward.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
end

function (forward::Forward{true})(x; kwargs...)
Expand All @@ -26,3 +26,5 @@
z = 0
return y, z
end

handles_byproduct(::Forward{byproduct}) where {byproduct} = byproduct
6 changes: 5 additions & 1 deletion src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@
return ImplicitFunction(f, c, linear_solver, HandleByproduct())
end

function Base.show(io::IO, implicit::ImplicitFunction)
@unpack forward, conditions, linear_solver = implicit
return print(io, "ImplicitFunction($(forward.f), $(conditions.c), $linear_solver)")

Check warning on line 71 in src/implicit_function.jl

View check run for this annotation

Codecov / codecov/patch

src/implicit_function.jl#L69-L71

Added lines #L69 - L71 were not covered by tests
end

function (implicit::ImplicitFunction)(x::AbstractArray; kwargs...)
Expand All @@ -77,6 +77,10 @@
end

function (implicit::ImplicitFunction)(x::AbstractArray, ::ReturnByproduct; kwargs...)
y, z = implicit.forward(x, ; kwargs...)
y, z = implicit.forward(x; kwargs...)
return (y, z)
end

function handles_byproduct(implicit::ImplicitFunction)
return handles_byproduct(implicit.forward) && handles_byproduct(implicit.conditions)
end
8 changes: 6 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,17 @@ end
function (pfm::PushforwardMul!)(res::AbstractVector, δinput_vec::AbstractVector)
δinput = reshape(δinput_vec, pfm.input_size)
δoutput = only(pfm.pushforward(δinput))
return res .= vec(δoutput)
for i in eachindex(IndexLinear(), res, δoutput)
res[i] = δoutput[i]
end
end

function (pbm::PullbackMul!)(res::AbstractVector, δoutput_vec::AbstractVector)
δoutput = reshape(δoutput_vec, pbm.output_size)
δinput = only(pbm.pullback(δoutput))
return res .= vec(δinput)
for i in eachindex(IndexLinear(), res, δinput)
res[i] = δinput[i]
end
end

## Override this function from LinearOperators to avoid generating the whole methods table
Expand Down
Loading
Loading