Skip to content

Commit

Permalink
Remove RefreshDB and related code from SessionDB
Browse files Browse the repository at this point in the history
  • Loading branch information
tpazderka committed Aug 2, 2019
1 parent 2e751ed commit 6ac3b73
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 219 deletions.
184 changes: 3 additions & 181 deletions src/oic/utils/sdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,108 +313,6 @@ def from_json(cls, json_struct):
return cls(**dic)


class RefreshDB(object):
"""Database for refresh token storage."""

def __init__(self):
warnings.warn(
"Using `RefreshDB` is deprecated, please use `Token` and `refresh_token_factory` instead.",
DeprecationWarning,
stacklevel=2,
)

def get(self, refresh_token):
"""
Retrieve info about the authentication proces from the refresh token.
:return: Dictionary with info
:raises: KeyError
"""
raise NotImplementedError

def store(self, token, info):
"""
Store the information about the authentication process.
:param token: Token
:param info: Information associated with token to be stored
"""
raise NotImplementedError

def remove(self, token):
"""
Remove the token and related information from the internal storage.
:param token: Token to be removed
"""
raise NotImplementedError

def create_token(self, client_id, uid, scopes, sub, authzreq, sid):
"""
Create refresh token for given combination of client_id and sub and store it in internal storage.
:param client_id: Client_id of the consumer
:param uid: User identification
:param scopes: Scopes associated with the token
:param sub: Sub identifier
:param authzreq: Authorization request
:param sid: Session ID
:return: Refresh token
"""
refresh_token = "Refresh_{}".format(rndstr(5 * 16))
self.store(
refresh_token,
{
"client_id": client_id,
"uid": uid,
"scope": scopes,
"sub": sub,
"authzreq": authzreq,
"sid": sid,
},
)
return refresh_token

def verify_token(self, client_id, refresh_token):
"""Verify if the refresh token belongs to client_id."""
if not refresh_token.startswith("Refresh_"):
raise WrongTokenType
try:
stored_cid = self.get(refresh_token).get("client_id")
except KeyError:
return False
return client_id == stored_cid

def revoke_token(self, token):
"""Remove token from database."""
self.remove(token)


class DictRefreshDB(RefreshDB):
"""Dictionary based implementation of RefreshDB."""

def __init__(self):
super(DictRefreshDB, self).__init__()
warnings.warn(
"Using `DictRefreshDB` is deprecated, please use `Token` and `refresh_token_factory` instead.",
DeprecationWarning,
stacklevel=2,
)
self._db = {} # type: Dict[str, Dict[str, str]]

def get(self, refresh_token):
"""Retrieve info for given token from dictionary."""
return self._db[refresh_token].copy()

def store(self, token, info):
"""Add token and info to the dictionary."""
self._db[token] = info

def remove(self, token):
"""Remove the token from the dictionary."""
self._db.pop(token)


def create_session_db(
base_url,
secret,
Expand Down Expand Up @@ -450,7 +348,6 @@ def create_session_db(
return SessionDB(
base_url,
db,
refresh_db=None,
code_factory=code_factory,
token_factory=token_factory,
refresh_token_factory=refresh_token_factory,
Expand Down Expand Up @@ -559,8 +456,6 @@ def __init__(
self,
base_url,
db,
refresh_db=None,
refresh_token_expires_in=None,
token_factory=None,
code_factory=None,
refresh_token_factory=None,
Expand All @@ -570,12 +465,6 @@ def __init__(
:param db: Database for storing the session information.
"""
if refresh_token_expires_in is not None:
warnings.warn(
"Setting a `refresh_token_expires_in` has no effect, please set the expiration on "
"`refresh_token_factory`.",
DeprecationWarning,
)
self.base_url = base_url
if not isinstance(db, SessionBackend):
warnings.warn(
Expand All @@ -588,27 +477,11 @@ def __init__(

self.token_factory_order = ["code", "access_token"]

# TODO: This should simply be a factory like all the others too,
# even for the default case.

if refresh_token_factory:
if refresh_db:
raise ImproperlyConfigured(
"Only use one of refresh_db or refresh_token_factory"
)
self._refresh_db = None
self.token_factory["refresh_token"] = refresh_token_factory
self.token_factory_order.append("refresh_token")
elif refresh_db:
warnings.warn(
"Using `refresh_db` is deprecated, please use `refresh_token_factory`",
DeprecationWarning,
stacklevel=2,
)
self._refresh_db = refresh_db
else:
# Not configured
self._refresh_db = None
self.token_factory["refresh_token"] = None

self.access_token = self.token_factory["access_token"]
Expand Down Expand Up @@ -829,26 +702,7 @@ def upgrade_to_token(
dic["oidreq"] = oidreq

if issue_refresh:
if "authn_event" in dic:
authn_event = AuthnEvent.from_json(dic["authn_event"])
else:
authn_event = None
if authn_event:
uid = authn_event.uid
else:
uid = None

if self._refresh_db:
refresh_token = self._refresh_db.create_token(
dic["client_id"],
uid,
dic.get("scope"),
dic["sub"],
dic["authzreq"],
key,
)
dic["refresh_token"] = refresh_token
elif self.token_factory["refresh_token"] is not None:
if self.token_factory["refresh_token"] is not None:
refresh_token = self.token_factory["refresh_token"](key, sinfo=dic)
dic["refresh_token"] = refresh_token
self._db[key] = dic
Expand All @@ -865,34 +719,7 @@ def refresh_token(self, rtoken, client_id):
WrongTokenType for wrong token type
"""
# assert that it is a refresh token and that it is valid
if self._refresh_db:
if self._refresh_db.verify_token(client_id, rtoken):
# Valid refresh token
_info = self._refresh_db.get(rtoken)
try:
sid = _info["sid"]
except KeyError:
areq = json.loads(_info["authzreq"])
sid = self.token_factory["code"].key(user=_info["uid"], areq=areq)
dic = _info
dic["response_type"] = areq["response_type"].split(" ")
else:
try:
dic = self._db[sid]
except KeyError:
dic = _info

access_token = self.access_token(sid=sid, sinfo=dic)
try:
at = dic["access_token"]
except KeyError:
pass
else:
if at:
self.access_token.invalidate(at)
else:
raise ExpiredToken()
elif self.token_factory["refresh_token"] is None:
if self.token_factory["refresh_token"] is None:
raise WrongTokenType()
elif self.token_factory["refresh_token"].valid(rtoken):
if self.token_factory["refresh_token"].is_expired(rtoken):
Expand Down Expand Up @@ -931,9 +758,6 @@ def is_valid(self, token, client_id=None):
:param token: Access or refresh token
:param client_id: Client ID, needed only for Refresh token
"""
if token.startswith("Refresh_"):
return self._refresh_db.verify_token(client_id, token)

try:
typ, sid = self._get_token_type_and_key(token)
except KeyError:
Expand Down Expand Up @@ -987,9 +811,7 @@ def revoke_refresh_token(self, rtoken):
:param rtoken: Refresh token
"""
if self._refresh_db:
self._refresh_db.revoke_token(rtoken)
elif self.token_factory["refresh_token"] is not None:
if self.token_factory["refresh_token"] is not None:
self.token_factory["refresh_token"].invalidate(rtoken)

return True
Expand Down
38 changes: 0 additions & 38 deletions tests/test_sdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from oic.utils.sdb import AuthnEvent
from oic.utils.sdb import Crypt
from oic.utils.sdb import DefaultToken
from oic.utils.sdb import DictRefreshDB
from oic.utils.sdb import DictSessionBackend
from oic.utils.sdb import ExpiredToken
from oic.utils.sdb import WrongTokenType
Expand Down Expand Up @@ -86,43 +85,6 @@ def test_to_json(self):
}


class TestDictRefreshDB(object):
@pytest.fixture(autouse=True)
def create_rdb(self):
self.rdb = DictRefreshDB()

def test_verify_token(self):
token = self.rdb.create_token(
"client1", "uid", "openid", "sub1", "authzreq", "sid"
)
assert self.rdb.verify_token("client1", token)
assert self.rdb.verify_token("client2", token) is False

def test_revoke_token(self):
token = self.rdb.create_token(
"client1", "uid", "openid", "sub1", "authzreq", "sid"
)
self.rdb.remove(token)
assert self.rdb.verify_token("client1", token) is False
with pytest.raises(KeyError):
self.rdb.get(token)

def test_get_token(self):
with pytest.raises(KeyError):
self.rdb.get("token")
token = self.rdb.create_token(
"client1", "uid", ["openid"], "sub1", "authzreq", "sid"
)
assert self.rdb.get(token) == {
"client_id": "client1",
"sub": "sub1",
"scope": ["openid"],
"uid": "uid",
"authzreq": "authzreq",
"sid": "sid",
}


class TestToken(object):
@pytest.fixture(autouse=True)
def create_token(self):
Expand Down

0 comments on commit 6ac3b73

Please sign in to comment.