Skip to content

Commit

Permalink
Prevent concurrent OAuth token refreshes in Canvas
Browse files Browse the repository at this point in the history
The Canvas integration does not use `OAuthHTTPService` so concurrent refresh
prevention has to be implemented separately. Fortunately this is pretty simple.
  • Loading branch information
robertknight committed Jun 6, 2024
1 parent 5146a00 commit 4c83a6e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
8 changes: 8 additions & 0 deletions lms/services/canvas_api/_authenticated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Access to the authenticated parts of the Canvas API."""

from lms.db import CouldNotAcquireLock
from lms.services.exceptions import ConcurrentTokenRefreshError
from lms.validation.authentication import OAuthTokenResponseSchema

DEFAULT_TIMEOUT = (10, 10)
Expand Down Expand Up @@ -87,6 +89,12 @@ def get_refreshed_token(self, refresh_token):
previous token call
:return: A new access token string
"""

try:
self._oauth2_token_service.try_lock_for_refresh()
except CouldNotAcquireLock as exc:
raise ConcurrentTokenRefreshError() from exc

return self._send_token_request(
grant_type="refresh_token", refresh_token=refresh_token
)
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/lms/services/canvas_api/_authenticated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import pytest

from lms.db import CouldNotAcquireLock
from lms.services import CanvasAPIServerError, OAuth2TokenError
from lms.services.canvas_api._basic import BasicClient
from lms.services.exceptions import ConcurrentTokenRefreshError
from lms.validation.authentication import OAuthTokenResponseSchema
from tests import factories

Expand Down Expand Up @@ -84,6 +86,8 @@ def test_get_refreshed_token(

assert token == "new_access_token"

oauth2_token_service.try_lock_for_refresh.assert_called_once()

basic_client.send.assert_called_once_with(
"POST",
"login/oauth2/token",
Expand All @@ -104,6 +108,14 @@ def test_get_refreshed_token(
token_response["expires_in"],
)

def test_get_refreshed_token_raises_if_lock_not_acquired(
self, authenticated_client, oauth2_token_service
):
oauth2_token_service.try_lock_for_refresh.side_effect = CouldNotAcquireLock()

with pytest.raises(ConcurrentTokenRefreshError):
authenticated_client.get_refreshed_token("refresh_token")

@pytest.fixture
def basic_client(self, token_response):
basic_api = create_autospec(BasicClient)
Expand Down

0 comments on commit 4c83a6e

Please sign in to comment.