Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support groups query customization #352

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 55 additions & 34 deletions django_auth_adfs/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@


class AdfsBaseBackend(ModelBackend):
def exchange_auth_code(self, authorization_code, request):
logger.debug("Received authorization code: %s", authorization_code)
data = {
'grant_type': 'authorization_code',
'client_id': settings.CLIENT_ID,
'redirect_uri': provider_config.redirect_uri(request),
'code': authorization_code,
}
if settings.CLIENT_SECRET:
data['client_secret'] = settings.CLIENT_SECRET

logger.debug("Getting access token at: %s", provider_config.token_endpoint)
response = provider_config.session.post(provider_config.token_endpoint, data, timeout=settings.TIMEOUT)
def _ms_request(self, action, url, data=None, **kwargs):
"""
Make a Microsoft Entra/GraphQL request


Args:
action (callable): The callable for making a request.
url (str): The URL the request should be sent to.
data (dict): Optional dictionary of data to be sent in the request.

Returns:
response: The response from the server. If it's not a 200, a
PermissionDenied is raised.
"""
response = action(url, data=data, timeout=settings.TIMEOUT, **kwargs)
# 200 = valid token received
# 400 = 'something' is wrong in our request
if response.status_code == 400:
Expand All @@ -39,7 +42,21 @@ def exchange_auth_code(self, authorization_code, request):
if response.status_code != 200:
logger.error("Unexpected ADFS response: %s", response.content.decode())
raise PermissionDenied
return response

def exchange_auth_code(self, authorization_code, request):
logger.debug("Received authorization code: %s", authorization_code)
data = {
'grant_type': 'authorization_code',
'client_id': settings.CLIENT_ID,
'redirect_uri': provider_config.redirect_uri(request),
'code': authorization_code,
}
if settings.CLIENT_SECRET:
data['client_secret'] = settings.CLIENT_SECRET

logger.debug("Getting access token at: %s", provider_config.token_endpoint)
response = self._ms_request(provider_config.session.post, provider_config.token_endpoint, data)
adfs_response = response.json()
return adfs_response

Expand All @@ -66,21 +83,30 @@ def get_obo_access_token(self, access_token):
else:
data["resource"] = 'https://graph.microsoft.com'

response = provider_config.session.get(provider_config.token_endpoint, data=data, timeout=settings.TIMEOUT)
# 200 = valid token received
# 400 = 'something' is wrong in our request
if response.status_code == 400:
logger.error("ADFS server returned an error: %s", response.json()["error_description"])
raise PermissionDenied

if response.status_code != 200:
logger.error("Unexpected ADFS response: %s", response.content.decode())
raise PermissionDenied

response = self._ms_request(provider_config.session.get, provider_config.token_endpoint, data)
obo_access_token = response.json()["access_token"]
logger.debug("Received OBO access token: %s", obo_access_token)
return obo_access_token

def get_group_memberships_from_ms_graph_params(self):
"""
Return the parameters to be used in the querystring
when fetching the user's group memberships.

Possible keys to be used:
- $count
- $expand
- $filter
- $orderby
- $search
- $select
- $top

Docs:
https://learn.microsoft.com/en-us/graph/api/group-list-transitivememberof?view=graph-rest-1.0&tabs=python#http-request
"""
return {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reading this on my phone, but I'm a bit confused here. This function does nothing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I just read the docs. Should've scrolled before I commented!

I wouldn't mind a default implementation here, but I'm also fine with this. ☺️

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I lean towards not having a default to avoid breaking existing integrations. I don't hold that strongly though. If you suggest a default, I'm happy to include it.


def get_group_memberships_from_ms_graph(self, obo_access_token):
"""
Looks up a users group membership from the MS Graph API
Expand All @@ -95,17 +121,12 @@ def get_group_memberships_from_ms_graph(self, obo_access_token):
provider_config.msgraph_endpoint
)
headers = {"Authorization": "Bearer {}".format(obo_access_token)}
response = provider_config.session.get(graph_url, headers=headers, timeout=settings.TIMEOUT)
# 200 = valid token received
# 400 = 'something' is wrong in our request
if response.status_code in [400, 401]:
logger.error("MS Graph server returned an error: %s", response.json()["message"])
raise PermissionDenied

if response.status_code != 200:
logger.error("Unexpected MS Graph response: %s", response.content.decode())
raise PermissionDenied

response = self._ms_request(
action=provider_config.session.get,
url=graph_url,
data=self.get_group_memberships_from_ms_graph_params(),
headers=headers,
)
claim_groups = []
for group_data in response.json()["value"]:
if group_data["displayName"] is None:
Expand Down
7 changes: 6 additions & 1 deletion django_auth_adfs/rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
BaseAuthentication, get_authorization_header
)

from django_auth_adfs.exceptions import MFARequired


class AdfsAccessTokenAuthentication(BaseAuthentication):
"""
Expand Down Expand Up @@ -33,7 +35,10 @@ def authenticate(self, request):
# Authenticate the user
# The AdfsAuthCodeBackend authentication backend will notice the "access_token" parameter
# and skip the request for an access token using the authorization code
user = authenticate(access_token=auth[1])
try:
user = authenticate(access_token=auth[1])
except MFARequired as e:
raise exceptions.AuthenticationFailed('MFA auth is required.') from e

if user is None:
raise exceptions.AuthenticationFailed('Invalid access token.')
Expand Down
6 changes: 6 additions & 0 deletions docs/settings_ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ GROUPS_CLAIM
Name of the claim in the JWT access token from ADFS that contains the groups the user is member of.
If an entry in this claim matches a group configured in Django, the user will join it automatically.

If using Azure AD and there are too many groups to fit in the JWT access token, the application will
make a request to the Microsoft GraphQL API to find the groups. If you have many groups but only
need a specific few, you can customize the request by overriding
``AdfsBaseBackend.get_group_memberships_from_ms_graph_params`` and specifying the
`OData query parameters <https://learn.microsoft.com/en-us/graph/api/group-list-transitivememberof?view=graph-rest-1.0&tabs=python#http-request>`_.

Set this setting to ``None`` to disable automatic group handling. The group memberships of the user
will not be touched.

Expand Down
Loading