Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple active arguments? #508

Open
gdalle opened this issue Sep 27, 2024 · 2 comments
Open

Multiple active arguments? #508

gdalle opened this issue Sep 27, 2024 · 2 comments
Labels
core Related to the core utilities of the package

Comments

@gdalle
Copy link
Member

gdalle commented Sep 27, 2024

With v0.6, DI supports any number of context arguments (constants or caches), but only a single active argument x (which comes first).
The last step to full generality is support for any number of active arguments. This would improve performance e.g. with Enzyme, and also (apparently) avoid some differentiation errors.

The issue is that for multiple active arguments, every operator output must become a tuple instead of a single object. In addition to being another breaking change, that might slow down backends which do not support multiple arguments.

My workaround idea for now is:

  1. force the first argument to always be active
  2. if there are no other active arguments (easy to check by constraining the Vararg), return single objects
  3. otherwise, return tuples

Related past discussions:

@gdalle gdalle changed the title Automatically gather several active arguments? Multiple active arguments? Sep 28, 2024
@wsmoses
Copy link

wsmoses commented Sep 28, 2024

recopying from slackhole:

using Enzyme, BenchmarkTools, ADTypes, DifferentiationInterface

mul(a, b) = @inbounds (a[1] * b[1])

const a = [2.0]
const b = [3.0]
Enzyme.gradient(Reverse, mul, a, b)

@btime Enzyme.gradient($Reverse, $mul, $a, $b) # 52.496 ns (2 allocations: 128 bytes)

fuse(tup) = mul(tup[1], tup[2])

function bench_di(backend, a, b)
   prep = DifferentiationInterface.prepare_gradient(fuse, backend, (a, b))
   @btime DifferentiationInterface.gradient($fuse, $prep, $backend, $((a, b)))
   return nothing
end

bench_di(AutoEnzyme(), a, b) #  131.136 ns (4 allocations: 464 bytes)

@btime Enzyme.gradient(Reverse, fuse, $((a, b)) #  129.619 ns (4 allocations: 464 bytes)

@wsmoses
Copy link

wsmoses commented Sep 28, 2024

and

using Enzyme, BenchmarkTools, ADTypes, DifferentiationInterface, StaticArrays

mul(a, b) = @inbounds a[1] * sum(b)

const a = [2.0]
const b = @SMatrix ones(8,8)
Enzyme.gradient(Reverse, mul, a, b)

@btime Enzyme.gradient($Reverse, $mul, $a, $b) #  7.183 μs (13 allocations: 3.45 KiB)

fuse(tup) = mul(tup[1], tup[2])

function bench_di(backend, a, b)
   prep = DifferentiationInterface.prepare_gradient(fuse, backend, (a, b))
   @btime DifferentiationInterface.gradient($fuse, $prep, $backend, $((a, b)))
   return nothing
end

bench_di(AutoEnzyme(), a, b) #   27.910 μs (164 allocations: 11.50 KiB)

@btime Enzyme.gradient(Reverse, fuse, $((a, b))) #  27.810 μs (163 allocations: 10.97 KiB)

@gdalle gdalle added the core Related to the core utilities of the package label Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core Related to the core utilities of the package
Projects
None yet
Development

No branches or pull requests

2 participants