Skip to content

Commit

Permalink
Merge pull request #92 from BenjaminDev/91_Laplace_GGN_arguments
Browse files Browse the repository at this point in the history
Laplace ggn arguments
  • Loading branch information
SamDuffield authored May 29, 2024
2 parents c399762 + 3e8036d commit 1585144
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions posteriors/laplace/dense_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def outer_loss(z, batch):

with torch.no_grad(), CatchAuxError():
ggn_batch, aux = ggn(
partial(forward, batch=batch),
partial(outer_loss, batch=batch),
lambda params: forward(params, batch),
lambda z: outer_loss(z, batch),
forward_has_aux=True,
loss_has_aux=False,
normalize=False,
Expand Down
4 changes: 2 additions & 2 deletions posteriors/laplace/diag_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def outer_loss(z, batch):

with torch.no_grad(), CatchAuxError():
diag_ggn_batch, aux = diag_ggn(
partial(forward, batch=batch),
partial(outer_loss, batch=batch),
lambda params: forward(params, batch),
lambda z: outer_loss(z, batch),
forward_has_aux=True,
loss_has_aux=False,
normalize=False,
Expand Down
4 changes: 2 additions & 2 deletions tests/laplace/test_dense_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def normal_log_likelihood(y_pred, batch):
) # validate args introduces control flows not yet supported in torch.func.vmap


def forward_m(params, batch, model):
y_pred = functional_call(model, params, batch[0])
def forward_m(params, b, model):
y_pred = functional_call(model, params, b[0])
return y_pred, torch.tensor([])


Expand Down
4 changes: 2 additions & 2 deletions tests/laplace/test_diag_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def normal_log_likelihood(y_pred, batch):
) # validate args introduces control flows not yet supported in torch.func.vmap


def forward_m(params, batch, model):
y_pred = functional_call(model, params, batch[0])
def forward_m(params, b, model):
y_pred = functional_call(model, params, b[0])
return y_pred, torch.tensor([])


Expand Down

0 comments on commit 1585144

Please sign in to comment.