Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

91 laplace ggn arguments #92

Merged

Conversation

BenjaminDev
Copy link
Contributor

Thanks for the very interesting library. I generally like to attempt fixing issues to get to know a code base. Feel free to just close this PR if this change is missing the point.

Fixes #91

implying we relax from 'batch' named kwarg to any arg for the batch in 'forward'
posteriors/laplace/dense_ggn.py Outdated Show resolved Hide resolved
tests/laplace/test_dense_ggn.py Show resolved Hide resolved
@SamDuffield
Copy link
Contributor

Looking good! Can we do the same for laplace.diag_ggn too? Then we should be good to merge

implying we relax from 'batch' named kwarg to any arg for the batch in 'forward'
@BenjaminDev
Copy link
Contributor Author

BenjaminDev commented May 24, 2024

@SamDuffield, Just one comment after thinking this change over. Lambda's in python are late binding which means this change from partial to lambda x: ... is subtly different.

To show the difference consider:

with torch.no_grad(), CatchAuxError():
        forward_partial = partial(forward, batch=batch) # Original pattern which forces consumers of the lib to use the `batch` kw arg. But! batch is bound on this line.
        diag_ggn_fn = diag_ggn(
            forward_partial, 
            lambda z: outer_loss(z, batch), # New pattern BUT! `batch` is only bound when the the lambda is invoked. i.e when `diag_ggn_fn` is called.
            forward_has_aux=True,
            loss_has_aux=False,
            normalize=False,
        )
        batch=None # If batch is modified the `partial` function is unaffected as it bound to `batch` in `forward_partial.keywords['batch']`. But the `lambda` only binds later and will be affected. 
        diag_ggn_batch, aux = diag_ggn_fn(state.params)

A pattern that makes a lambda bind early is:

    ...
     diag_ggn(
            lambda params, _batch=batch: forward(params, _batch)# Now `batch` binds in the  `.__defaults__` 
            ...
    )
 ...

I updated the diff to reflect the way we were both heading and I think it's a perfectly fine solution and all tests pass. But, I just wanted to point out that my small change is, on a "python" level significantly different.

I think we good to merge unless the above triggers a larger concern on your side.

Copy link
Contributor

@SamDuffield SamDuffield left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a very good point, but as you note I think all is good as we enforce the specific (params, batch) signature. So good to merge!

@SamDuffield SamDuffield merged commit 1585144 into normal-computing:main May 29, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

More flexible Laplace GGN arguments
2 participants