Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataInterpolations support #178

Merged
merged 35 commits into from
Sep 2, 2024

Conversation

SouthEndMusic
Copy link
Contributor

No description provided.

@SouthEndMusic SouthEndMusic marked this pull request as draft August 20, 2024 19:32
@SouthEndMusic
Copy link
Contributor Author

Is this the way to go with dispatching on callable structs? Or should it be done per interpolation type separately?

@adrhill
Copy link
Owner

adrhill commented Aug 20, 2024

Wow, this PR came impressively quickly after merging #177!

Is this the way to go with dispatching on callable structs?

I'm afraid this might just overload the constructors.
You can inspect the generated code by omitting the eval statement, e.g. calling

SCT.overload_gradient_1_to_1(:DataInterpolations, AbstractInterpolation)

which will most likely return a function similar to

function DataInterpolations.AbstractInterpolation(t::SparseConnectivityTracer.GradientTracer)
    return SparseConnectivityTracer.gradient_tracer_1_to_1(t, false)
end

which unfortunately is not what you want.

I'll try to figure out how to work around it. The simplest solution would be to add a new specialized overload_gradient_callable_1_to_1 code generation utility.

@SouthEndMusic
Copy link
Contributor Author

SouthEndMusic commented Aug 21, 2024

There must be an easier way to get all the interpolation types. Also, for some reason the tests cannot find DataInterpolations.

Edit: nvm, I'll try some different things.

@adrhill
Copy link
Owner

adrhill commented Aug 21, 2024

I think your initial attempt was the intuitive way to do it, it just needs some work from our side.

@SouthEndMusic
Copy link
Contributor Author

SouthEndMusic commented Aug 21, 2024

I think my latest approach is the best, this is also the level on which other DataInterpolation extensions operate. This still doesn't work, and I think I know why looking at the generated methods:

[6] _interpolate(tx::T, ty::T) where T<:SparseConnectivityTracer.GradientTracer
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:119
  [7] _interpolate(tx::T, ty::T) where T<:SparseConnectivityTracer.HessianTracer
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:188
  [8] _interpolate(tx::SparseConnectivityTracer.GradientTracer, ::Real)
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:125
  [9] _interpolate(::Real, ty::SparseConnectivityTracer.GradientTracer)
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:129
 [10] _interpolate(dx::D, y::Real) where {P, T<:SparseConnectivityTracer.GradientTracer, D<:SparseConnectivityTracer.Dual{P, T}}
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:163
 [11] _interpolate(dx::D, dy::D) where {P, T<:SparseConnectivityTracer.GradientTracer, D<:SparseConnectivityTracer.Dual{P, T}}
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:145
 [12] _interpolate(dx::D, dy::D) where {P, T<:SparseConnectivityTracer.HessianTracer, D<:SparseConnectivityTracer.Dual{P, T}}
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:227
 [13] _interpolate(x::Real, dy::D) where {P, T<:SparseConnectivityTracer.GradientTracer, D<:SparseConnectivityTracer.Dual{P, T}}
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\gradient_tracer.jl:196
 [14] _interpolate(dx::D, y::Real) where {P, T<:SparseConnectivityTracer.HessianTracer, D<:SparseConnectivityTracer.Dual{P, T}}
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:256
 [15] _interpolate(x::Real, ty::SparseConnectivityTracer.HessianTracer)
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:204
 [16] _interpolate(tx::SparseConnectivityTracer.HessianTracer, y::Real)
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:200
 [17] _interpolate(x::Real, dy::D) where {P, T<:SparseConnectivityTracer.HessianTracer, D<:SparseConnectivityTracer.Dual{P, T}}
     @ SparseConnectivityTracerDataInterpolationsExt C:\Users\konin_bt\SciML\Other\SparseConnectivityTracer.jl\src\overloads\hessian_tracer.jl:284

The original methods are e.g. _interpolate(A::LinearInterpolation, t), so the first argument is not a real number at all. However, this is assumed by the generated methods. I think this is a quite general issue you have to adress.

@adrhill
Copy link
Owner

adrhill commented Aug 21, 2024

Yes, this kind of function requires writing manual overloads. I should add to the documentation that our "N-to-M operators" assume Real inputs.

Judging by its name, _interpolate seems to be an internal function. We should avoid touching these, as they can break without SemVer notice.

Overloading on callable AbstractInterpolation structs has the additional advantage of them being "1-to-1 operators" on Reals.

@SouthEndMusic
Copy link
Contributor Author

For my application I only need scalar to scalar, but for many interpolation types the output can also be a vector. Or is that still considered 1 to 1?

@adrhill
Copy link
Owner

adrhill commented Aug 21, 2024

Are interpolations generic functions $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$?
If so, we'll have to think about what the sparsity patterns should look like and implement the overloads manually.

An overly conservative estimate could be obtained by using tools from src/overloads/arrays.jl:

  1. if all scalar entries in the input vector interact with each other, the union of their index sets can be computed using first_order_or or second_order_or
  2. we can then return a Fill vector from FillArrays.jl that has the combined tracer on each entry.

(This is how we handle matrix inversion for example.)

@SouthEndMusic
Copy link
Contributor Author

DataInterpolations input is always 1D, but the output can be in any vector space. But most support is for scalar/vector output.

@adrhill
Copy link
Owner

adrhill commented Aug 21, 2024

Could you rebase this PR on main and move the test cases to the new test/ext folder?

I’ll add the needed overloads for you tomorrow. Any tests you add will help me out greatly.

@SouthEndMusic
Copy link
Contributor Author

In general what I think should be supported (maybe each can be a separate issue/PR):

  • scalar to scalar interpolation: that's just calling the interpolation object as a callable struct
  • scalar to array interpolation: idem but with array output (vectors, matrices are also supported somewhat)
  • code optimization: all derivatives of constant interpolation are zero, second derivative of linear interpolation is zero
  • derivatives and integrals: computed as e.g. derivative(A::LinearInterpolation, t), integral(A::LinearInterpolation, t) or integral(A::LinearInterpolation, t1, t2)
  • local: in certain intervals certain derivatives of certain interpolation types are 0. Might not be worth the effort though
  • derivatives w.r.t. constructor input: quite niche, but sometimes people want to compute derivatives w.r.t. input data (for e.g. optimization). Here the sparsity pattern depends on the interpolation type.

@adrhill
Copy link
Owner

adrhill commented Aug 21, 2024

Thanks for the list! I think that should all be doable in here.

@SouthEndMusic
Copy link
Contributor Author

SouthEndMusic commented Aug 22, 2024

By the way, what's the philosophy of where extensions are located? DataInterpolations has its own extensions for Symbolics, Zygote etc. For those extensions it's also not that bad if they use DataInterpolations internals

@adrhill
Copy link
Owner

adrhill commented Aug 22, 2024

I'd say we should have them in here for now:

  • SCT is less stable and less established than DataInterpolations
  • the overloads won't use internals from DataInterpolations but (in this case) will have to use internals from SCT

@codecov-commenter
Copy link

codecov-commenter commented Aug 22, 2024

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

Attention: Patch coverage is 87.12871% with 26 lines in your changes missing coverage. Please review.

Project coverage is 90.42%. Comparing base (c7355d3) to head (23ee8ab).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
test/ext/test_DataInterpolations.jl 86.01% 20 Missing ⚠️
test/runtests.jl 69.23% 4 Missing ⚠️
src/overloads/ambiguities.jl 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #178      +/-   ##
==========================================
- Coverage   91.36%   90.42%   -0.94%     
==========================================
  Files          41       44       +3     
  Lines        1772     2037     +265     
==========================================
+ Hits         1619     1842     +223     
- Misses        153      195      +42     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@adrhill
Copy link
Owner

adrhill commented Aug 22, 2024

1D interpolations should be working now.
Do you have an example for N-dimensional interpolations? I couldn't find any in the DataInterpolations docs.

@adrhill adrhill marked this pull request as ready for review August 22, 2024 17:08
@adrhill
Copy link
Owner

adrhill commented Aug 23, 2024

Urgh, I wasn't aware DataInterpolations had separate methods for AbstractArray inputs...

@adrhill
Copy link
Owner

adrhill commented Aug 28, 2024

What about things like this:

import SparseConnectivityTracer as SCT
using DataInterpolations

method = SCT.TracerLocalSparsityDetector()

t = [0.0, 1.0, 3.0]
f(u) = ConstantInterpolation(u, t)(2.0)

u = [1.0, 2.0, 5.0]

SCT.jacobian_sparsity(f, u, method)
# 1×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 1 stored entry:
# ⋅  1  ⋅

As you can see it already works for some methods, but for other methods it doesn't yet (e.g. LinearInterpolation).

Wait, this is an entirely different use-case than what we are currently implementing. You don't want to differentiate through the interpolation, you want to differentiate through the creation of the interpolant?

@adrhill
Copy link
Owner

adrhill commented Aug 28, 2024

As you can see it already works for some methods

I would say that with the additional overloads on tracers, this is unintended behavior, as this PR only returns the input indices of the "interpolation query tracers t", not the potential "interpolant constructor tracers" u. Pushing tracers through interpolants constructed from tracers would require additional methods.
We should overload the constructors to throw a NotImplementedError on Tracers to not bloat this PR.

@adrhill
Copy link
Owner

adrhill commented Aug 28, 2024

The differentiation of the result of an interpolation w.r.t. to the input arguments of its constructor is also a bit more nuanced than what we are currently doing.

@gdalle
Copy link
Collaborator

gdalle commented Aug 28, 2024

You don't want to differentiate through the interpolation, you want to differentiate through the creation of the interpolant?

As usual, I assume that this works out of the box for local tracers, and that only global tracers are an issue?

The differentiation of the result of an interpolation w.r.t. to the input arguments of its constructor is also a bit more nuanced than what we are currently doing.

IIUC this would require constructing each field of the corresponding Interpolation type (some of which may not even be public API) and filling them with the appropriate tracers in a mathematically sound way. So there's no hope of doing this in a batched way. I think this is out of scope for the current PR, although I understand the appeal.

Of course there's also the crazy option: make LinearInterpolation(u::Vector{Tracer}, t::Vector{Tracer}) return a new TracerLinearInterpolation object without all the caches and SciML shenanigans.

@adrhill
Copy link
Owner

adrhill commented Aug 28, 2024

As usual, I assume that this works out of the box for local tracers, and that only global tracers are an issue?

Maybe before this PR, but with the new overloads we possibly return the wrong patterns.

IIUC this would require constructing each field of the corresponding Interpolation type (some of which may not even be public API) and filling them with the appropriate tracers in a mathematically sound way.

I‘m not even sure whether there is a way for Global tracers to return any kind of sparsity. Let’s take LinearInterpolation as an example. For a given query t, the output depends on the closest lower and higher datapoints. Global tracers don’t have the required primal value for such an ordering.

@adrhill
Copy link
Owner

adrhill commented Aug 28, 2024

If a dense pattern is the correct solution for global tracers, the solution might not be too complicated:

Collect all tracers from the data (if there are any) and take a union over all of them. Then take a union with the „query tracers“ and return a Fill array of the correct size.

@SouthEndMusic
Copy link
Contributor Author

I‘m not even sure whether there is a way for Global tracers to return any kind of sparsity

You're right about that, sparsity of an interpolation call w.r.t. the input u is only meaningful locally, for global you cannot say anything or you would have to assume dense dependency on the whole of u.

We should overload the constructors to throw a NotImplementedError on Tracers to not bloat this PR

That is fine by me, this is quite a niche usecase. This was also implemented for DataInterpolations + Enzyme only partially and relatively recently because someone specifically asked for it. I think the most important use cases left are derivative and integral.

@gdalle
Copy link
Collaborator

gdalle commented Aug 28, 2024

Global tracers don’t have the required primal value for such an ordering.

That depends, if t is assumed to be ordered then we have that neighborness information

@adrhill
Copy link
Owner

adrhill commented Aug 28, 2024

That depends, if t is assumed to be ordered then we have that neighborness information

No, the query is entirely independent from the data. We don’t know which data points are closest to the query without primals.

@gdalle
Copy link
Collaborator

gdalle commented Aug 28, 2024

In that case, my option of a dummy TracerInterpolation containing the merged tracers of all u and all t may not be too dumb?

@gdalle
Copy link
Collaborator

gdalle commented Aug 29, 2024

On second thought, I don't think you can systematically throw error on constructors even if u and t have Tracer elements, because the method that gets called will depend primarily on the vector type, not on the element type.
If DataInterpolations only implements something like LinearInterpolation(u::AbstractVector, t::AbstractVector) then you're in luck and you can be more specific, but if they have special cases then you're screwed because Vector{<:Real} is more specific than AbstractVector{<:Tracer}.

@adrhill
Copy link
Owner

adrhill commented Aug 29, 2024

@SouthEndMusic do you have an idea why the 1.6 tests fail and others don't?

@adrhill
Copy link
Owner

adrhill commented Sep 2, 2024

I spent a bit too much time writing tests for interpolants containing tracers in c993787 just to figure out we can't add overloads on them for the same reason @gdalle mentioned we can't overload the constructors:

[...] The method that gets called will depend primarily on the vector type, not on the element type.

If DataInterpolations only implements something like LinearInterpolation(u::AbstractVector, t::AbstractVector) then you're in luck and you can be more specific, but if they have special cases then you're screwed because Vector{<:Real} is more specific than AbstractVector{<:Tracer}.

The same applies here: we can easily overload (interp::QuadraticInterpolation)(t::HessianTracer), but overloads like (interp::QuadraticInterpolation{T})(x) where {T<:AbstractArray{<:AbstractTracer}} might get missed unless we generate a lot of code over a big loop of common array types.

That is fine by me, this is quite a niche usecase.

I afraid we'll have to leave it at supporting tracer inputs, at least for now.

@adrhill adrhill merged commit cb273a1 into adrhill:main Sep 2, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new overloads A new method on tracers is required by a user.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants