diff --git a/posteriors/laplace/dense_ggn.py b/posteriors/laplace/dense_ggn.py index 16002d7b..04e11eaa 100644 --- a/posteriors/laplace/dense_ggn.py +++ b/posteriors/laplace/dense_ggn.py @@ -137,8 +137,8 @@ def outer_loss(z, batch): with torch.no_grad(), CatchAuxError(): ggn_batch, aux = ggn( - partial(forward, batch=batch), - partial(outer_loss, batch=batch), + lambda params: forward(params, batch), + lambda z: outer_loss(z, batch), forward_has_aux=True, loss_has_aux=False, normalize=False, diff --git a/posteriors/laplace/diag_ggn.py b/posteriors/laplace/diag_ggn.py index 26f5a116..36067a16 100644 --- a/posteriors/laplace/diag_ggn.py +++ b/posteriors/laplace/diag_ggn.py @@ -134,8 +134,8 @@ def outer_loss(z, batch): with torch.no_grad(), CatchAuxError(): diag_ggn_batch, aux = diag_ggn( - partial(forward, batch=batch), - partial(outer_loss, batch=batch), + lambda params: forward(params, batch), + lambda z: outer_loss(z, batch), forward_has_aux=True, loss_has_aux=False, normalize=False, diff --git a/tests/laplace/test_dense_ggn.py b/tests/laplace/test_dense_ggn.py index d7b3909f..bded3a38 100644 --- a/tests/laplace/test_dense_ggn.py +++ b/tests/laplace/test_dense_ggn.py @@ -18,8 +18,8 @@ def normal_log_likelihood(y_pred, batch): ) # validate args introduces control flows not yet supported in torch.func.vmap -def forward_m(params, batch, model): - y_pred = functional_call(model, params, batch[0]) +def forward_m(params, b, model): + y_pred = functional_call(model, params, b[0]) return y_pred, torch.tensor([]) diff --git a/tests/laplace/test_diag_ggn.py b/tests/laplace/test_diag_ggn.py index 5961e1c8..89a172b9 100644 --- a/tests/laplace/test_diag_ggn.py +++ b/tests/laplace/test_diag_ggn.py @@ -18,8 +18,8 @@ def normal_log_likelihood(y_pred, batch): ) # validate args introduces control flows not yet supported in torch.func.vmap -def forward_m(params, batch, model): - y_pred = functional_call(model, params, batch[0]) +def forward_m(params, b, model): + y_pred = functional_call(model, params, b[0]) return y_pred, torch.tensor([])