Skip to content

Commit

Permalink
Add optional optimizer (#82)
Browse files Browse the repository at this point in the history
* small changes so torch tensors are compatible with existing code

* rename previous unit tests as test_basic.py

* add optional torch dependency

* add numpy as optional dependency to avoid torch warning

* add Optimizer class

* temporarily export Optimizer class unconditionally

* add a file for optimizer unit tests

* add job for unit test gh workflow for optimizer

* fix linting errors

* test_optimizer job should run all unit tests

* provide alternative class definition for Optimizer if torch is not installed

* add pandas as a dev dependency

* install pandas in test_optimizer job

* add josh's (my) anki review logs from 1711744352250 to 1728234780857 for optimizer unit test

* add more optimizer unit tests

* reduce duplication of default parameters

* compare optimal parameters to expected parameters using numpy.allclose to pass on multiple machines

* add some documentation to optimizer unit tests

* add poetry as dev dependency

* format code

* add section for Optimizer in README

* add note about python versions when installing dev dependencies

* card ids are now epoch microseconds

* store card_id in ReviewLog object rather than copy of Card object

* bump version to 5.0.0rc1 for pre-release

* add myself (josh) to list of authors in pyproject.toml

* card ids are once again milliseconds

* add note about batch card creation in README to avoid card id collisions

* bump version -> 5.0.0 for release

* move Optimizer class to its own file

* convert DEFAULT_PARAMETERS to tuple

* add short docstring to top of optimizer module

* Optimizer.compute_optimal_parameters should return mutable list
  • Loading branch information
joshdavham authored Jan 24, 2025
1 parent dd1e48e commit 4daa566
Show file tree
Hide file tree
Showing 10 changed files with 13,257 additions and 45 deletions.
28 changes: 26 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Test Python
on: [push, pull_request]

jobs:
test:
test_basic:

runs-on: ubuntu-latest
strategy:
Expand All @@ -24,4 +24,28 @@ jobs:
pip install pytest
- name: Test with pytest
run: |
pytest
pytest tests/test_basic.py
test_optimizer:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pandas
pip install ".[optimizer]"
- name: Test with pytest
run: |
pytest
2 changes: 2 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ After cloning this repo, install `fsrs` locally in editable mode along with the
pip install -e ".[dev]"
```

Note: you may not be able to install all of the `dev` dependencies if you are using python 3.13 or 3.14. If this is causing trouble, consider trying to install the `dev` dependencies in a python 3.10-3.12 environment.

Now you're ready to make changes to files in the `fsrs` directory and see your changes reflected immediately.

### Pass the checks
Expand Down
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- [Installation](#installation)
- [Quickstart](#quickstart)
- [Usage](#usage)
- [Optimizer (optional)](#optimizer-optional)
- [Reference](#reference)
- [Other FSRS implementations](#other-fsrs-implementations)
- [Other SRS python packages](#other-srs-python-packages)
Expand Down Expand Up @@ -171,6 +172,57 @@ new_card = Card.from_dict(card_dict)
new_review_log = ReviewLog.from_dict(review_log_dict)
```

### Batch card creation

If you batch create `Card` objects, ensure that you leave at least 1 millisecond between creating each individual card

```python
from fsrs import Card
import time

cards = []
for i in range(100):

card = Card()

cards.append(card)

# wait 1 millisecond
time.sleep(0.001)
```

Each `Card` object has a `card_id` attribute which is the epoch milliseconds of when the card was created. In order to keep each of the card id's unique, two cards must not be created within 1 millisecond of eachother.

## Optimizer (optional)

If you have a collection of `ReviewLog` objects, you can optionally reuse them to compute an optimal set of parameters for the `Scheduler` to make it more accurate at scheduling reviews.

### Installation
To install the optimizer, first ensure you're using `python 3.10-3.12`, then run:
```
pip install "fsrs[optimizer]"
```

### Usage

```python
from fsrs import ReviewLog, Optimizer, Scheduler

# load your ReviewLog objects into a list (order doesn't matter)
review_logs = [ReviewLog1, ReviewLog2, ...]

# initialize the optimizer with the review logs
optimizer = Optimizer(review_logs)

# compute a set of optimized parameters
optimal_parameters = optimizer.compute_optimal_parameters()

# initialize a new scheduler with the optimized parameters!
scheduler = Scheduler(parameters=optimal_parameters)
```

Note: The computed optimal parameters may be slightly different than the parameters computed by Anki for the same set of review logs. This is because the two implementations are slightly different and updated at different times. If you're interested in the official Rust-based Anki implementation, please see [here](https://github.com/open-spaced-repetition/fsrs-rs).

## Reference

Card objects have one of three possible states
Expand Down
21 changes: 19 additions & 2 deletions fsrs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@
Py-FSRS is the official Python implementation of the FSRS scheduler algorithm, which can be used to develop spaced repetition systems.
"""

from .fsrs import Scheduler, Card, Rating, ReviewLog, State
from .fsrs import (
Scheduler,
Card,
Rating,
ReviewLog,
State,
DEFAULT_PARAMETERS,
)

__all__ = ["Scheduler", "Card", "Rating", "ReviewLog", "State"]
from .optimizer import Optimizer

__all__ = [
"Scheduler",
"Card",
"Rating",
"ReviewLog",
"State",
"Optimizer",
"DEFAULT_PARAMETERS",
]
76 changes: 39 additions & 37 deletions fsrs/fsrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,31 @@
from __future__ import annotations
import math
from datetime import datetime, timezone, timedelta
from copy import deepcopy
from copy import copy
from enum import IntEnum
import random
from random import random

DEFAULT_PARAMETERS = (
0.40255,
1.18385,
3.173,
15.69105,
7.1949,
0.5345,
1.4604,
0.0046,
1.54575,
0.1192,
1.01925,
1.9395,
0.11,
0.29605,
2.2698,
0.2315,
2.9898,
0.51655,
0.6621,
)

DECAY = -0.5
FACTOR = 0.9 ** (1 / DECAY) - 1
Expand Down Expand Up @@ -67,7 +89,7 @@ class Card:
Represents a flashcard in the FSRS system.
Attributes:
card_id (int): The id of the card. Defaults to the epoch miliseconds of when the card was created.
card_id (int): The id of the card. Defaults to the epoch milliseconds of when the card was created.
state (State): The card's current learning state.
step (int | None): The card's current learning or relearning step or None if the card is in the Review state.
stability (float | None): Core mathematical parameter used for future scheduling.
Expand Down Expand Up @@ -95,7 +117,7 @@ def __init__(
last_review: datetime | None = None,
) -> None:
if card_id is None:
# epoch miliseconds of when the card was created
# epoch milliseconds of when the card was created
card_id = int(datetime.now(timezone.utc).timestamp() * 1000)
self.card_id = card_id

Expand Down Expand Up @@ -203,25 +225,25 @@ class ReviewLog:
Represents the log entry of a Card object that has been reviewed.
Attributes:
card (Card): Copy of the card object that was reviewed.
card_id (int): The id of the card being reviewed.
rating (Rating): The rating given to the card during the review.
review_datetime (datetime): The date and time of the review.
review_duration (int | None): The number of miliseconds it took to review the card or None if unspecified.
"""

card: Card
card_id: int
rating: Rating
review_datetime: datetime
review_duration: int | None

def __init__(
self,
card: Card,
card_id: int,
rating: Rating,
review_datetime: datetime,
review_duration: int | None = None,
) -> None:
self.card = deepcopy(card)
self.card_id = card_id
self.rating = rating
self.review_datetime = review_datetime
self.review_duration = review_duration
Expand All @@ -235,11 +257,11 @@ def to_dict(
This method is specifically useful for storing ReviewLog objects in a database.
Returns:
dict: A dictionary representation of the Card object.
dict: A dictionary representation of the ReviewLog object.
"""

return_dict = {
"card": self.card.to_dict(),
"card_id": self.card_id,
"rating": self.rating.value,
"review_datetime": self.review_datetime.isoformat(),
"review_duration": self.review_duration,
Expand All @@ -261,13 +283,13 @@ def from_dict(
ReviewLog: A ReviewLog object created from the provided dictionary.
"""

card = Card.from_dict(source_dict["card"])
card_id = source_dict["card_id"]
rating = Rating(int(source_dict["rating"]))
review_datetime = datetime.fromisoformat(source_dict["review_datetime"])
review_duration = source_dict["review_duration"]

return ReviewLog(
card=card,
card_id=card_id,
rating=rating,
review_datetime=review_datetime,
review_duration=review_duration,
Expand Down Expand Up @@ -298,27 +320,7 @@ class Scheduler:

def __init__(
self,
parameters: tuple[float, ...] | list[float] = (
0.40255,
1.18385,
3.173,
15.69105,
7.1949,
0.5345,
1.4604,
0.0046,
1.54575,
0.1192,
1.01925,
1.9395,
0.11,
0.29605,
2.2698,
0.2315,
2.9898,
0.51655,
0.6621,
),
parameters: tuple[float, ...] | list[float] = DEFAULT_PARAMETERS,
desired_retention: float = 0.9,
learning_steps: tuple[timedelta, ...] | list[timedelta] = (
timedelta(minutes=1),
Expand Down Expand Up @@ -365,7 +367,7 @@ def review_card(
):
raise ValueError("datetime must be timezone-aware and set to UTC")

card = deepcopy(card)
card = copy(card)

if review_datetime is None:
review_datetime = datetime.now(timezone.utc)
Expand All @@ -375,7 +377,7 @@ def review_card(
)

review_log = ReviewLog(
card=card,
card_id=card.card_id,
rating=rating,
review_datetime=review_datetime,
review_duration=review_duration,
Expand Down Expand Up @@ -665,7 +667,7 @@ def _next_interval(self, stability: float) -> int:
(self.desired_retention ** (1 / DECAY)) - 1
)

next_interval = round(next_interval) # intervals are full days
next_interval = round(float(next_interval)) # intervals are full days

# must be at least 1 day long
next_interval = max(next_interval, 1)
Expand Down Expand Up @@ -797,7 +799,7 @@ def _get_fuzz_range(interval_days: int) -> tuple[int, int]:
min_ivl, max_ivl = _get_fuzz_range(interval_days)

fuzzed_interval_days = (
random.random() * (max_ivl - min_ivl + 1)
random() * (max_ivl - min_ivl + 1)
) + min_ivl # the next interval is a random value between min_ivl and max_ivl

fuzzed_interval_days = min(round(fuzzed_interval_days), self.maximum_interval)
Expand Down
Loading

0 comments on commit 4daa566

Please sign in to comment.