From 6ac3b73e0f435f6060d4b8e61ca3d6e2bc04d99b Mon Sep 17 00:00:00 2001 From: Tomas Pazderka Date: Wed, 19 Jun 2019 21:58:49 +0200 Subject: [PATCH] Remove RefreshDB and related code from SessionDB --- src/oic/utils/sdb.py | 184 +------------------------------------------ tests/test_sdb.py | 38 --------- 2 files changed, 3 insertions(+), 219 deletions(-) diff --git a/src/oic/utils/sdb.py b/src/oic/utils/sdb.py index f88fbea0a..a05f866d4 100644 --- a/src/oic/utils/sdb.py +++ b/src/oic/utils/sdb.py @@ -313,108 +313,6 @@ def from_json(cls, json_struct): return cls(**dic) -class RefreshDB(object): - """Database for refresh token storage.""" - - def __init__(self): - warnings.warn( - "Using `RefreshDB` is deprecated, please use `Token` and `refresh_token_factory` instead.", - DeprecationWarning, - stacklevel=2, - ) - - def get(self, refresh_token): - """ - Retrieve info about the authentication proces from the refresh token. - - :return: Dictionary with info - :raises: KeyError - """ - raise NotImplementedError - - def store(self, token, info): - """ - Store the information about the authentication process. - - :param token: Token - :param info: Information associated with token to be stored - """ - raise NotImplementedError - - def remove(self, token): - """ - Remove the token and related information from the internal storage. - - :param token: Token to be removed - """ - raise NotImplementedError - - def create_token(self, client_id, uid, scopes, sub, authzreq, sid): - """ - Create refresh token for given combination of client_id and sub and store it in internal storage. - - :param client_id: Client_id of the consumer - :param uid: User identification - :param scopes: Scopes associated with the token - :param sub: Sub identifier - :param authzreq: Authorization request - :param sid: Session ID - :return: Refresh token - """ - refresh_token = "Refresh_{}".format(rndstr(5 * 16)) - self.store( - refresh_token, - { - "client_id": client_id, - "uid": uid, - "scope": scopes, - "sub": sub, - "authzreq": authzreq, - "sid": sid, - }, - ) - return refresh_token - - def verify_token(self, client_id, refresh_token): - """Verify if the refresh token belongs to client_id.""" - if not refresh_token.startswith("Refresh_"): - raise WrongTokenType - try: - stored_cid = self.get(refresh_token).get("client_id") - except KeyError: - return False - return client_id == stored_cid - - def revoke_token(self, token): - """Remove token from database.""" - self.remove(token) - - -class DictRefreshDB(RefreshDB): - """Dictionary based implementation of RefreshDB.""" - - def __init__(self): - super(DictRefreshDB, self).__init__() - warnings.warn( - "Using `DictRefreshDB` is deprecated, please use `Token` and `refresh_token_factory` instead.", - DeprecationWarning, - stacklevel=2, - ) - self._db = {} # type: Dict[str, Dict[str, str]] - - def get(self, refresh_token): - """Retrieve info for given token from dictionary.""" - return self._db[refresh_token].copy() - - def store(self, token, info): - """Add token and info to the dictionary.""" - self._db[token] = info - - def remove(self, token): - """Remove the token from the dictionary.""" - self._db.pop(token) - - def create_session_db( base_url, secret, @@ -450,7 +348,6 @@ def create_session_db( return SessionDB( base_url, db, - refresh_db=None, code_factory=code_factory, token_factory=token_factory, refresh_token_factory=refresh_token_factory, @@ -559,8 +456,6 @@ def __init__( self, base_url, db, - refresh_db=None, - refresh_token_expires_in=None, token_factory=None, code_factory=None, refresh_token_factory=None, @@ -570,12 +465,6 @@ def __init__( :param db: Database for storing the session information. """ - if refresh_token_expires_in is not None: - warnings.warn( - "Setting a `refresh_token_expires_in` has no effect, please set the expiration on " - "`refresh_token_factory`.", - DeprecationWarning, - ) self.base_url = base_url if not isinstance(db, SessionBackend): warnings.warn( @@ -588,27 +477,11 @@ def __init__( self.token_factory_order = ["code", "access_token"] - # TODO: This should simply be a factory like all the others too, - # even for the default case. - if refresh_token_factory: - if refresh_db: - raise ImproperlyConfigured( - "Only use one of refresh_db or refresh_token_factory" - ) - self._refresh_db = None self.token_factory["refresh_token"] = refresh_token_factory self.token_factory_order.append("refresh_token") - elif refresh_db: - warnings.warn( - "Using `refresh_db` is deprecated, please use `refresh_token_factory`", - DeprecationWarning, - stacklevel=2, - ) - self._refresh_db = refresh_db else: # Not configured - self._refresh_db = None self.token_factory["refresh_token"] = None self.access_token = self.token_factory["access_token"] @@ -829,26 +702,7 @@ def upgrade_to_token( dic["oidreq"] = oidreq if issue_refresh: - if "authn_event" in dic: - authn_event = AuthnEvent.from_json(dic["authn_event"]) - else: - authn_event = None - if authn_event: - uid = authn_event.uid - else: - uid = None - - if self._refresh_db: - refresh_token = self._refresh_db.create_token( - dic["client_id"], - uid, - dic.get("scope"), - dic["sub"], - dic["authzreq"], - key, - ) - dic["refresh_token"] = refresh_token - elif self.token_factory["refresh_token"] is not None: + if self.token_factory["refresh_token"] is not None: refresh_token = self.token_factory["refresh_token"](key, sinfo=dic) dic["refresh_token"] = refresh_token self._db[key] = dic @@ -865,34 +719,7 @@ def refresh_token(self, rtoken, client_id): WrongTokenType for wrong token type """ # assert that it is a refresh token and that it is valid - if self._refresh_db: - if self._refresh_db.verify_token(client_id, rtoken): - # Valid refresh token - _info = self._refresh_db.get(rtoken) - try: - sid = _info["sid"] - except KeyError: - areq = json.loads(_info["authzreq"]) - sid = self.token_factory["code"].key(user=_info["uid"], areq=areq) - dic = _info - dic["response_type"] = areq["response_type"].split(" ") - else: - try: - dic = self._db[sid] - except KeyError: - dic = _info - - access_token = self.access_token(sid=sid, sinfo=dic) - try: - at = dic["access_token"] - except KeyError: - pass - else: - if at: - self.access_token.invalidate(at) - else: - raise ExpiredToken() - elif self.token_factory["refresh_token"] is None: + if self.token_factory["refresh_token"] is None: raise WrongTokenType() elif self.token_factory["refresh_token"].valid(rtoken): if self.token_factory["refresh_token"].is_expired(rtoken): @@ -931,9 +758,6 @@ def is_valid(self, token, client_id=None): :param token: Access or refresh token :param client_id: Client ID, needed only for Refresh token """ - if token.startswith("Refresh_"): - return self._refresh_db.verify_token(client_id, token) - try: typ, sid = self._get_token_type_and_key(token) except KeyError: @@ -987,9 +811,7 @@ def revoke_refresh_token(self, rtoken): :param rtoken: Refresh token """ - if self._refresh_db: - self._refresh_db.revoke_token(rtoken) - elif self.token_factory["refresh_token"] is not None: + if self.token_factory["refresh_token"] is not None: self.token_factory["refresh_token"].invalidate(rtoken) return True diff --git a/tests/test_sdb.py b/tests/test_sdb.py index f716aa1fd..e9068b16b 100644 --- a/tests/test_sdb.py +++ b/tests/test_sdb.py @@ -16,7 +16,6 @@ from oic.utils.sdb import AuthnEvent from oic.utils.sdb import Crypt from oic.utils.sdb import DefaultToken -from oic.utils.sdb import DictRefreshDB from oic.utils.sdb import DictSessionBackend from oic.utils.sdb import ExpiredToken from oic.utils.sdb import WrongTokenType @@ -86,43 +85,6 @@ def test_to_json(self): } -class TestDictRefreshDB(object): - @pytest.fixture(autouse=True) - def create_rdb(self): - self.rdb = DictRefreshDB() - - def test_verify_token(self): - token = self.rdb.create_token( - "client1", "uid", "openid", "sub1", "authzreq", "sid" - ) - assert self.rdb.verify_token("client1", token) - assert self.rdb.verify_token("client2", token) is False - - def test_revoke_token(self): - token = self.rdb.create_token( - "client1", "uid", "openid", "sub1", "authzreq", "sid" - ) - self.rdb.remove(token) - assert self.rdb.verify_token("client1", token) is False - with pytest.raises(KeyError): - self.rdb.get(token) - - def test_get_token(self): - with pytest.raises(KeyError): - self.rdb.get("token") - token = self.rdb.create_token( - "client1", "uid", ["openid"], "sub1", "authzreq", "sid" - ) - assert self.rdb.get(token) == { - "client_id": "client1", - "sub": "sub1", - "scope": ["openid"], - "uid": "uid", - "authzreq": "authzreq", - "sid": "sid", - } - - class TestToken(object): @pytest.fixture(autouse=True) def create_token(self):