Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes for allowed_groups and admin_groups #758

Merged
merged 12 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/tutorials/general-setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ projects' authenticator classes.
- {attr}`.OAuthenticator.allow_all`
- {attr}`.OAuthenticator.allow_existing_users`
- {attr}`.OAuthenticator.allowed_users`
- {attr}`.OAuthenticator.allowed_groups`
- {attr}`.OAuthenticator.admin_users`
- {attr}`.OAuthenticator.admin_groups`

Your authenticator class may have unique config, so in the end it can look
something like this:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,14 @@ be relevant to read more about in the configuration reference:
## Loading user groups

The `AzureAdOAuthenticator` can load the group-membership of users from the access token.
This is done by setting the `AzureAdOAuthenticator.groups_claim` to the name of the claim that contains the
group-membership.

```python
c.JupyterHub.authenticator_class = "azuread"

# {...} other settings (see above)

c.AzureAdOAuthenticator.manage_groups = True
c.AzureAdOAuthenticator.user_groups_claim = 'groups' # this is the default
c.AzureAdOAuthenticator.auth_state_groups_key = "user.groups" # this is the default
```

This requires Azure AD to be configured to include the group-membership in the access token.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ c.GenericOAuthenticator.userdata_url = "https://accounts.example.com/auth/realms
#
c.GenericOAuthenticator.scope = ["openid", "email", "groups"]
c.GenericOAuthenticator.username_claim = "email"
c.GenericOAuthenticator.claim_groups_key = "groups"
c.GenericOAuthenticator.auth_state_groups_key = "oauth_user.groups"

# Authorization
# -------------
Expand Down
26 changes: 14 additions & 12 deletions oauthenticator/azuread.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,26 @@ def _username_claim_default(self):
return "name"

user_groups_claim = Unicode(
"groups",
"",
config=True,
help="""
Name of claim containing user group memberships.
.. deprecated:: 17.0

Will populate JupyterHub groups if Authenticator.manage_groups is True.
Use :attr:`auth_state_groups_key` instead.
""",
)

@default('auth_state_groups_key')
def _auth_state_groups_key_default(self):
key = "user.groups"
if self.user_groups_claim:
key = f"{self.user_auth_state_key}.{self.user_groups_claim}"
cls = self.__class__.__name__
self.log.warning(
f"{cls}.user_groups_claim is deprecated in OAuthenticator 17. Use {cls}.auth_state_groups_key = {key!r}"
)
return key

tenant_id = Unicode(
config=True,
help="""
Expand All @@ -55,15 +66,6 @@ def _authorize_url_default(self):
def _token_url_default(self):
return f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/token"

async def update_auth_model(self, auth_model, **kwargs):
auth_model = await super().update_auth_model(auth_model, **kwargs)

if getattr(self, "manage_groups", False):
user_info = auth_model["auth_state"][self.user_auth_state_key]
auth_model["groups"] = user_info[self.user_groups_claim]

return auth_model

async def token_to_user(self, token_info):
id_token = token_info['id_token']
decoded = jwt.decode(
Expand Down
2 changes: 1 addition & 1 deletion oauthenticator/globus.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ async def update_auth_model(self, auth_model):
to False makes it be revoked.
"""
user_groups = set()
if self.allowed_globus_groups or self.admin_globus_groups:
if self.allowed_globus_groups or self.admin_globus_groups or self.manage_groups:
tokens = self.get_globus_tokens(auth_model["auth_state"]["token_response"])
user_groups = await self._fetch_users_groups(tokens)
# sets are not JSONable, cast to list for auth_state
Expand Down
33 changes: 23 additions & 10 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ def build_auth_state_dict(self, token_info, user_info):
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`

.. versionchanged:: 17.0
This method be async.
This method may be async.
"""

# We know for sure the `access_token` key exists, oterwise we would have errored out already
Expand Down Expand Up @@ -1133,6 +1133,8 @@ async def get_user_groups(self, auth_state: dict):
Returns a set of groups the user belongs to based on auth_state_groups_key
and provided auth_state.

Only called when :attr:`manage_groups` is True.

- If auth_state_groups_key is a callable, it returns the list of groups directly.
Callable may be async.
- If auth_state_groups_key is a nested dictionary key like
Expand Down Expand Up @@ -1168,11 +1170,23 @@ async def update_auth_model(self, auth_model):
- `name`: the normalized username
- `admin`: the admin status (True/False/None), where None means it
should be unchanged.
- `auth_state`: the dictionary of of auth state
returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict`
- `auth_state`: the auth state dictionary,
returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict`

Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
"""
# NOTE: this base implementation should _not_ be updated to do anything
# subclasses should have full control without calling super()
return auth_model

async def _apply_managed_groups(self, auth_model):
"""Applies managed_groups logic

Called after `update_auth_model` to populate the `groups` field.
Only called if `manage_groups` is True.

The public method for subclasses to override is `.get_user_groups`.
"""
if self.manage_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
Expand Down Expand Up @@ -1244,7 +1258,10 @@ async def authenticate(self, handler, data=None, **kwargs):

# update the auth_model with info to later authorize the user in
# check_allowed, such as admin status and group memberships
return await self.update_auth_model(auth_model)
auth_model = await self.update_auth_model(auth_model)
if self.manage_groups:
auth_model = await self._apply_managed_groups(auth_model)
return auth_model

async def check_allowed(self, username, auth_model):
"""
Expand Down Expand Up @@ -1289,12 +1306,8 @@ async def check_allowed(self, username, auth_model):
return True

# allow users who are members of allowed_groups
if self.manage_groups and self.allowed_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
if isawaitable(user_groups):
user_groups = await user_groups
if any(user_groups & self.allowed_groups):
if self.manage_groups and self.allowed_groups and auth_model.get("groups"):
if set(auth_model["groups"]) & self.allowed_groups:
return True

# users should be explicitly allowed via config, otherwise they aren't
Expand Down
60 changes: 5 additions & 55 deletions oauthenticator/openshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@

from jupyterhub.auth import LocalAuthenticator
from tornado.httpclient import HTTPClient, HTTPRequest
from traitlets import Bool, Set, Unicode, default
from traitlets import Bool, Unicode, default

from oauthenticator.oauth2 import OAuthenticator


class OpenShiftOAuthenticator(OAuthenticator):
user_auth_state_key = "openshift_user"

@default("auth_state_groups_key")
def _auth_state_groups_key_default(self):
return "openshift_user.groups"

@default("scope")
def _scope_default(self):
return ["user:info"]
Expand Down Expand Up @@ -45,24 +49,6 @@ def _http_request_kwargs_default(self):
""",
)

allowed_groups = Set(
config=True,
help="""
Allow members of selected OpenShift groups to sign in.
""",
)

admin_groups = Set(
config=True,
help="""
Allow members of selected OpenShift groups to sign in and consider them
as JupyterHub admins.

If this is set and a user isn't part of one of these groups or listed in
`admin_users`, a user signing in will have their admin status revoked.
""",
)

openshift_auth_api_url = Unicode(
config=True,
help="""
Expand Down Expand Up @@ -158,42 +144,6 @@ def user_info_to_username(self, user_info):
"""
return user_info['metadata']['name']

async def update_auth_model(self, auth_model):
"""
Sets admin status to True or False if `admin_groups` is configured and
the user isn't part of `admin_users`. Note that leaving it at None makes
users able to retain an admin status while setting it to False makes it
be revoked.
"""
if auth_model["admin"]:
# auth_model["admin"] being True means the user was in admin_users
return auth_model

if self.admin_groups:
# admin status should in this case be True or False, not None
user_info = auth_model["auth_state"][self.user_auth_state_key]
user_groups = set(user_info["groups"])
auth_model["admin"] = bool(user_groups & self.admin_groups)

return auth_model

async def check_allowed(self, username, auth_model):
"""
Overrides OAuthenticator.check_allowed to also allow users part of
`allowed_groups`.
"""
if await super().check_allowed(username, auth_model):
return True

if self.allowed_groups:
user_info = auth_model["auth_state"][self.user_auth_state_key]
user_groups = set(user_info["groups"])
if user_groups & self.allowed_groups:
return True

# users should be explicitly allowed via config, otherwise they aren't
return False


class LocalOpenShiftOAuthenticator(LocalAuthenticator, OpenShiftOAuthenticator):
"""A version that mixes in local system user creation"""
47 changes: 46 additions & 1 deletion oauthenticator/tests/test_auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def user_model():
return {
"email": "[email protected]",
"name": "user1",
"groups": ["group1"],
}


Expand Down Expand Up @@ -62,6 +63,47 @@ def user_model():
True,
True,
),
# common tests with allowed_groups and manage_groups
(
"20",
{
"allowed_groups": {"group1"},
"auth_state_groups_key": "auth0_user.groups",
"manage_groups": True,
},
True,
None,
),
(
"21",
{
"allowed_groups": {"test-user-not-in-group"},
"auth_state_groups_key": "auth0_user.groups",
"manage_groups": True,
},
False,
None,
),
(
"22",
{
"admin_groups": {"group1"},
"auth_state_groups_key": "auth0_user.groups",
"manage_groups": True,
},
True,
True,
),
(
"23",
{
"admin_groups": {"test-user-not-in-group"},
"auth_state_groups_key": "auth0_user.groups",
"manage_groups": True,
},
False,
False,
),
],
)
async def test_auth0(
Expand All @@ -84,7 +126,10 @@ async def test_auth0(

if expect_allowed:
assert auth_model
assert set(auth_model) == {"name", "admin", "auth_state"}
if authenticator.manage_groups:
assert set(auth_model) == {"name", "admin", "auth_state", "groups"}
else:
assert set(auth_model) == {"name", "admin", "auth_state"}
assert auth_model["admin"] == expect_admin
auth_state = auth_model["auth_state"]
assert json.dumps(auth_state)
Expand Down
Loading