From bacd26904e2521908ce5806fbe1805b59ce2013d Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Sun, 27 Oct 2024 15:41:32 +0800 Subject: [PATCH] add test of optimizer --- tests/model_test.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/model_test.py b/tests/model_test.py index 69e6972..8d0de47 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -76,6 +76,7 @@ def test_forward(self): def test_loss_and_grad(self): model = FSRS(DEFAULT_PARAMETER) + optimizer = torch.optim.Adam(model.parameters(), lr=4e-2) loss_fn = nn.BCELoss(reduction="none") t_histories = torch.tensor( [ @@ -135,3 +136,31 @@ def test_loss_and_grad(self): ), atol=1e-4, ) + optimizer.step() + assert torch.allclose( + model.w, + torch.tensor( + [ + 0.44255, + 1.22385, + 3.2129998, + 15.65105, + 7.2349, + 0.4945, + 1.4204, + 0.0446, + 1.5057501, + 0.1592, + 0.97925, + 1.9794999, + 0.07000001, + 0.33605, + 2.3097994, + 0.2715, + 2.9498, + 0.47655, + 0.62210006, + ] + ), + atol=1e-4, + )