Skip to content

Commit

Permalink
sustech: use CAS
Browse files Browse the repository at this point in the history
  • Loading branch information
taoky committed Oct 14, 2023
1 parent 4812005 commit 708de8e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 12 deletions.
54 changes: 42 additions & 12 deletions frontend/auth_providers/sustech.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
from datetime import timedelta
from xml.etree import ElementTree

from django.urls import path

from .base import DomainEmailValidator
from .external import ExternalLoginView, ExternalGetCodeView
from typing import Optional

from ..models import AccountLog
from .cas import CASBaseLoginView

class LoginView(ExternalLoginView):
template_context = {'provider_name': '南方科技大学'}
provider = 'sustech'
group = 'sustech'


class GetCodeView(ExternalGetCodeView):
class LoginView(CASBaseLoginView):
provider = 'sustech'
duration = timedelta(hours=1)
validate_identity = DomainEmailValidator(['sustech.edu.cn', 'mail.sustech.edu.cn'])
group = 'sustech'
service: str
ticket: str
sno: str

cas_name = 'CRA SSO / SUSTech CAS'
cas_login_url = 'https://sso.cra.ac.cn/realms/cra-service-realm/protocol/cas/login'
cas_service_validate_url = 'https://sso.cra.ac.cn/realms/cra-service-realm/protocol/cas/serviceValidate'

def check_ticket(self) -> Optional[ElementTree.Element]:
tree = super().check_ticket()
if not tree:
return None
self.identity = tree.find(self.YALE_CAS_URL + 'user').text.strip()
self.mail = tree.find(self.YALE_CAS_URL + 'attributes').find(self.YALE_CAS_URL + 'mail').text.strip()
self.name = tree.find(self.YALE_CAS_URL + 'attributes').find(self.YALE_CAS_URL + 'cn').text.strip()
return tree

def on_get_account(self, account):
def to_set(s):
return set(s.split(',')) if s else set()
def from_set(vs):
return ','.join(sorted(vs))
custom_attrs: list[tuple[str, str]] = [
('邮箱', self.mail),
('姓名', self.name)
]
for display_name, self_value in custom_attrs:
try:
o = AccountLog.objects.get(account=account, content_type=display_name)
new_value = from_set(to_set(o.contents) | {self_value})
if new_value != o.contents:
o.contents = new_value
o.save()
except AccountLog.DoesNotExist:
AccountLog.objects.create(account=account, contents=f"{self.value}", content_type=display_name)
return account


urlpatterns = [
path('sustech/login/', LoginView.as_view()),
path('sustech/get_code/', GetCodeView.as_view()),
]
28 changes: 28 additions & 0 deletions frontend/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.test import TestCase
from .auth_providers.ustc import LoginView as USTCLoginView
from .auth_providers.sustech import LoginView as SUSTECHLoginView
from unittest import mock
from contextlib import contextmanager

Expand All @@ -16,6 +17,18 @@
</cas:authenticationSuccess>
</cas:serviceResponse>"""

SUSTECH_CAS_EXAMPLE_RESPONSE = """<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
<cas:authenticationSuccess>
<cas:user>11899999</cas:user>
<cas:attributes>
<cas:mail>[email protected]</cas:mail>
<cas:givenName>San</cas:givenName>
<cas:sn>ZHANG</cas:sn>
<cas:cn>ZHANG San</cas:cn>
</cas:attributes>
</cas:authenticationSuccess>
</cas:serviceResponse>"""


class MockResponse:
def __init__(self, text):
Expand All @@ -32,6 +45,10 @@ def mock_urlopen(url, timeout=None):
if "serviceValidate" in url:
success = True
yield MockResponse(USTC_CAS_EXAMPLE_RESPONSE)
elif "sso.cra.ac.cn" in url:
if "serviceValidate" in url:
success = True
yield MockResponse(SUSTECH_CAS_EXAMPLE_RESPONSE)
if not success:
raise ValueError("Unknown URL")

Expand All @@ -46,3 +63,14 @@ def test_ustc(self):
self.assertEqual(tree.tag, "{http://www.yale.edu/tp/cas}authenticationSuccess")
self.assertEqual(v.identity, "2201234567")
self.assertEqual(v.sno, "SA21011000")

@mock.patch("frontend.auth_providers.cas.urlopen", new=mock_urlopen)
def test_sustech(self):
v = SUSTECHLoginView()
v.service = "http://example.com/accounts/sustech/login/"
v.ticket = "ST-1234567890"
tree = v.check_ticket()
self.assertEqual(tree.tag, "{http://www.yale.edu/tp/cas}authenticationSuccess")
self.assertEqual(v.identity, "11899999")
self.assertEqual(v.mail, "[email protected]")
self.assertEqual(v.name, "ZHANG San")

0 comments on commit 708de8e

Please sign in to comment.