From 911be9a59d0ee896ebfc292017d15fe133993c9b Mon Sep 17 00:00:00 2001 From: Joshua Hamilton Date: Thu, 23 Jan 2025 21:33:42 -0600 Subject: [PATCH] Optimizer.compute_optimal_parameters should return mutable list --- fsrs/optimizer.py | 2 +- tests/test_optimizer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fsrs/optimizer.py b/fsrs/optimizer.py index 3a3494d..f7768de 100644 --- a/fsrs/optimizer.py +++ b/fsrs/optimizer.py @@ -274,7 +274,7 @@ def _update_parameters( num_reviews = _num_reviews() if num_reviews < mini_batch_size: - return DEFAULT_PARAMETERS + return list(DEFAULT_PARAMETERS) # Define FSRS Scheduler parameters as torch tensors with gradients params = torch.tensor( diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 6807765..497243f 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -45,7 +45,7 @@ def test_zero_revlogs(self): optimal_parameters = optimizer.compute_optimal_parameters() - assert optimal_parameters == DEFAULT_PARAMETERS + assert optimal_parameters == list(DEFAULT_PARAMETERS) def test_review_logs(self): """ @@ -81,7 +81,7 @@ def test_review_logs(self): optimal_parameters = optimizer.compute_optimal_parameters() # the optimal paramaters are no longer equal to the starting parameters - assert optimal_parameters != DEFAULT_PARAMETERS + assert optimal_parameters != list(DEFAULT_PARAMETERS) # the output is expected assert np.allclose(optimal_parameters, expected_optimal_parameters) @@ -114,7 +114,7 @@ def test_few_review_logs(self): optimal_parameters = optimizer.compute_optimal_parameters() - assert optimal_parameters == DEFAULT_PARAMETERS + assert optimal_parameters == list(DEFAULT_PARAMETERS) def test_unordered_review_logs(self): """