Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Append primal results to batched jacobian computations #1198

Open
floffy-f opened this issue Jan 9, 2025 · 4 comments
Open

Append primal results to batched jacobian computations #1198

floffy-f opened this issue Jan 9, 2025 · 4 comments

Comments

@floffy-f
Copy link

floffy-f commented Jan 9, 2025

Problem statement

While reading the docs I came to wonder about the following MWE in the nested-AD section:

model = Chain(Dense(2 => 4, tanh), Dense(4 => 2))
ps, st = Lux.setup(StableRNG(0), model)
x = randn(StableRNG(0), Float32, 2, 10)
y = randn(StableRNG(11), Float32, 2, 10)

function loss_function_batched(model, x, ps, st, y)
    # Make it a stateful layer
    smodel = StatefulLuxLayer{true}(model, ps, st)
    ŷ = smodel(x)
    loss_emp = sum(abs2, ŷ .- y)
    # You can use `AutoZygote()` as well but `AutoForwardDiff()` tends to be more efficient here
    J = batched_jacobian(smodel, AutoForwardDiff(), x)
    loss_reg = abs2(norm(J .* 0.01f0))
    return loss_emp + loss_reg
end

loss_function_batched(model, x, ps, st, y)

It is noticeable that the forward pass must be done at least twice (once at ŷ = smodel(x) and at least once at J = batched_jacobian(smodel, AutoForwardDiff(), x)) here to obtain both $\hat{y}$ and $J$, which is inefficient (or is it?). This may imply some costly resources wasting in cases where the number of chunks is small.
batched_jacobian will internally compute $\hat{y}$ anyway, so it would be interesting to use it.

Design ideas to fix it

Here are some ways that one may fix this by fetching the primal computation results in the internal API.
In all of those, the batched_jacobian function accepts a new parameter which defaults to false, asking whether to return the primal (new api) or not (as before).

function batched_jacobian(f::F, backend::AbstractADType, x::AbstractArray; primal::Bool=false) where {F}
    return batched_jacobian_internal(f, backend, x, Val{primal})
end

FowrardDiff possible solutions

Since both Zygote and ForwardDiff must compute at some point $\hat{y}$ internally, may it be possible to extract it ?
For the ForwardDiff par for example, it seems to be related to the file src/autodiff/batched_autodiff.jl line 160:166 (main branch). One may do the same as partials_wrap line 159 for the primal "values" while computing the first chunk.

# line 124
@views function batched_forwarddiff_jacobian_first_chunk(
        f::F,
        x::AbstractMatrix{T}, ::Type{Tag}, ::ForwardDiff.Chunk{CK}, ::Type{Dual},
        ::Type{Partials}) where {F, T, Tag, CK, Dual, Partials}
    N, B = size(x)

    n_idxs = min(CK, N)
    idxs = 1:n_idxs
    idxs_next = (1 + CK):N

    dev = get_device(x)

    partials = map(𝒾 -> Partials(ntuple(𝒿 -> ifelse(𝒾 == 𝒿, oneunit(T), zero(T)), CK)),
        dev(collect(1:n_idxs)))
    x_part_duals = Dual.(x[idxs, :], partials)

    if length(idxs_next) == 0
        x_part_next = similar(x_part_duals, 0, B)
    else
        x_part_next = Dual.(x[idxs_next, :],
            map(𝒾 -> Partials(ntuple(_ -> zero(T), CK)), dev(collect(1:length(idxs_next)))))
    end

    x_duals = vcat(x_part_duals, x_part_next)
    y_duals_ = f(x_duals)
    @argcheck ndims(y_duals_) > 1 && size(y_duals_, ndims(y_duals_)) == B
    y_duals = reshape(y_duals_, :, B)

    partials_wrap(y, i) = ForwardDiff.partials(Tag, y, i)
    return ForwardDiff.values(Tag, y_duals_), stack(i -> partials_wrap.(y_duals, i), 1:CK; dims=2)
end

Then the first chunck would be computed with this method by changing line 100 as follows and then returning (or not) $y$.

     y, J_partial = batched_forwarddiff_jacobian_first_chunk!!(f, x, Tag, ck, dual_type, partials_type)

Zygote possible solutions

If I understood the code correctly, this one may be simpler, at ext/LuxZygoteExt/batched_autodiff.jl.

    # line 31
    if primal
        return y, J
    return J

About Enzyme

When ho/ho-enzyme (#954) will be ready, using e.g. julia AutoEnzyme(; mode=Enzyme.ForwardWithPrimal) will make this issue almost trivial.
From personal experience with Reactant+Enzyme+nestedAD, this looks like it is particularly difficult for now though.

vjp/jvp

All the ideas mentioned above also seem applicable to both vector_jacobian_product and jacobian_vector_product.

Conclusion

Are those ideas applicable in the way described, or some other way? Are they useful to the API?
I could implement changes myself if needed, but I might need guidance wrt the codebase, especially the tests targeting this part of the code. If this seems relevant, I'll open a PR.

Disclaimer

(I am not proficient in Julia coding, especially regarding the autodiff libraries and GPU technicalities. I may have misunderstood some mechanisms (e.g. some inplace operation that may make my proposed solutions inapplicable, or any scalar indexing issues that may arise), in which case I would not know how to address the above issue.)

@floffy-f floffy-f changed the title Add primal computations to batched jacobian computations Append primal results to batched jacobian computations Jan 9, 2025
@avik-pal
Copy link
Member

Thanks for the detailed write-up. In terms of API, I think primal::Utils.BoolLike = False() is a good way to go about it (the False() ensures we are type-stable by default while allowing users to pass in false/true). One thing that needs to be worked out is that we need to update the chain-rules for all the respective functions, which might get hard.

From personal experience with Reactant+Enzyme+nestedAD, this looks like it is particularly difficult for now though.

This is actually the way to go. We have been using Reactant internally for PINNs with higher-order AD and it works like a charm. Some operations are still missing but we can add them quite easily (if you open issues with models that you have which are not working, I would be more than happy to implement them in EnzymeJAX). FWIW I was adding benchmarks in Reactant and nested ad there is about 27x faster than doing the Zygote + ForwardDiff trick (https://github.com/EnzymeAD/Reactant.jl/blob/ap/perf_static/perf/HNN/results_gpu.csv)

@floffy-f
Copy link
Author

Thank you for your kind answer.
I'll try the Reactant+Enzyme+nestedAD solution harder then, taking your HNN as an example, thank you for providing the benchmarks, they are (very) convincing.
If I encounter difficulties I'll ask you on discourse for some help, to target the right package for new issues.

I agree that writing the custom AD rules for the proposed version of the batched-jacobian seems like quite a nightmare indeed. I don't think I could write those in a reasonable amount of time in a PR, so I'd understand if you want to close this issue in the meantime.

Quick question about EnzymeJAX: is it faster in the end than Reactant+Enzyme+NestedAD in your experience? And is it better in some way than regular Jax?

@avik-pal
Copy link
Member

Quick question about EnzymeJAX: is it faster in the end than Reactant+Enzyme+NestedAD in your experience? And is it better in some way than regular Jax?

cc @wsmoses do we have any numbers for EnzymeJAX vs JAX on nested AD?

@wsmoses
Copy link
Contributor

wsmoses commented Jan 16, 2025

not at the moment, but for tensor programs I presume it would be faster.

For EnzymeJaX vs pure jax, the various tesnro optimizations have shown double digit perf speedups on various ML training codes, but obviously this is program dependent

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants