Skip to content

Commit

Permalink
Replace partial with a lambda
Browse files Browse the repository at this point in the history
implying we relax from 'batch' named kwarg to any arg for the batch in 'forward'
  • Loading branch information
BenjaminDev committed May 24, 2024
1 parent 3ab0b67 commit 3e8036d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions posteriors/laplace/diag_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def outer_loss(z, batch):

with torch.no_grad(), CatchAuxError():
diag_ggn_batch, aux = diag_ggn(
partial(forward, batch=batch),
partial(outer_loss, batch=batch),
lambda params: forward(params, batch),
lambda z: outer_loss(z, batch),
forward_has_aux=True,
loss_has_aux=False,
normalize=False,
Expand Down

0 comments on commit 3e8036d

Please sign in to comment.