diff --git a/docs/api.yml b/docs/api.yml index aeb198a7..df187bef 100644 --- a/docs/api.yml +++ b/docs/api.yml @@ -3665,7 +3665,7 @@ paths: /users/researchgroups/: get: operationId: users_researchgroups_list - description: Manage user membership in research groups + description: Manage user membership in research groups. parameters: - in: query name: id @@ -3761,7 +3761,7 @@ paths: description: '' post: operationId: users_researchgroups_create - description: Manage user membership in research groups + description: Manage user membership in research groups. tags: - users requestBody: @@ -3788,7 +3788,7 @@ paths: /users/researchgroups/{id}/: get: operationId: users_researchgroups_retrieve - description: Manage user membership in research groups + description: Manage user membership in research groups. parameters: - in: path name: id @@ -3809,7 +3809,7 @@ paths: description: '' put: operationId: users_researchgroups_update - description: Manage user membership in research groups + description: Manage user membership in research groups. parameters: - in: path name: id @@ -3842,7 +3842,7 @@ paths: description: '' patch: operationId: users_researchgroups_partial_update - description: Manage user membership in research groups + description: Manage user membership in research groups. parameters: - in: path name: id @@ -3874,7 +3874,7 @@ paths: description: '' delete: operationId: users_researchgroups_destroy - description: Manage user membership in research groups + description: Manage user membership in research groups. parameters: - in: path name: id @@ -3892,7 +3892,7 @@ paths: /users/users/: get: operationId: users_users_list - description: Read only access to user data + description: Manage user account data. parameters: - in: query name: date_joined @@ -4355,7 +4355,7 @@ paths: description: '' post: operationId: users_users_create - description: Read only access to user data + description: Manage user account data. tags: - users requestBody: @@ -4382,7 +4382,7 @@ paths: /users/users/{id}/: get: operationId: users_users_retrieve - description: Read only access to user data + description: Manage user account data. parameters: - in: path name: id @@ -4403,7 +4403,7 @@ paths: description: '' put: operationId: users_users_update - description: Read only access to user data + description: Manage user account data. parameters: - in: path name: id @@ -4436,7 +4436,7 @@ paths: description: '' patch: operationId: users_users_partial_update - description: Read only access to user data + description: Manage user account data. parameters: - in: path name: id @@ -4468,7 +4468,7 @@ paths: description: '' delete: operationId: users_users_destroy - description: Read only access to user data + description: Manage user account data. parameters: - in: path name: id @@ -4887,7 +4887,7 @@ components: type: integer PatchedResearchGroup: type: object - description: Object serializer for the `ResearchGroup` class + description: Object serializer for the `ResearchGroup` model. properties: id: type: integer @@ -4908,7 +4908,7 @@ components: PatchedRestrictedUser: type: object description: Object serializer for the `User` class with administrative fields - marked as read only + marked as read only. properties: id: type: integer @@ -5053,7 +5053,7 @@ components: - time ResearchGroup: type: object - description: Object serializer for the `ResearchGroup` class + description: Object serializer for the `ResearchGroup` model. properties: id: type: integer @@ -5078,7 +5078,7 @@ components: RestrictedUser: type: object description: Object serializer for the `User` class with administrative fields - marked as read only + marked as read only. properties: id: type: integer diff --git a/keystone_api/apps/allocations/models.py b/keystone_api/apps/allocations/models.py index e78ec3b6..9911e230 100644 --- a/keystone_api/apps/allocations/models.py +++ b/keystone_api/apps/allocations/models.py @@ -49,7 +49,7 @@ def get_research_group(self) -> ResearchGroup: return self.request.group - def __str__(self) -> str: + def __str__(self) -> str: # pragma: nocover """Return a human-readable summary of the allocation""" return f'{self.cluster} allocation for {self.request.group}' @@ -91,7 +91,7 @@ def get_research_group(self) -> ResearchGroup: return self.group - def __str__(self) -> str: + def __str__(self) -> str: # pragma: nocover """Return the request title as a string""" return truncatechars(self.title, 100) @@ -120,7 +120,7 @@ def get_research_group(self) -> ResearchGroup: return self.request.group - def __str__(self) -> str: + def __str__(self) -> str: # pragma: nocover """Return a human-readable identifier for the allocation request""" return f'{self.reviewer} review for \"{self.request.title}\"' @@ -142,7 +142,7 @@ class Cluster(models.Model): description = models.TextField(max_length=150, null=True, blank=True) enabled = models.BooleanField(default=True) - def __str__(self) -> str: + def __str__(self) -> str: # pragma: nocover """Return the cluster name as a string""" return str(self.name) diff --git a/keystone_api/apps/research_products/models.py b/keystone_api/apps/research_products/models.py index ecd600eb..96b27dd7 100644 --- a/keystone_api/apps/research_products/models.py +++ b/keystone_api/apps/research_products/models.py @@ -30,7 +30,7 @@ class Grant(models.Model): objects = GrantManager() - def __str__(self) -> str: + def __str__(self) -> str: # pragma: nocover """Return the grant title truncated to 50 characters""" return truncatechars(self.title, 100) @@ -49,7 +49,7 @@ class Publication(models.Model): objects = PublicationManager() - def __str__(self) -> str: + def __str__(self) -> str: # pragma: nocover """Return the publication title truncated to 50 characters""" return truncatechars(self.title, 100) diff --git a/keystone_api/apps/users/admin.py b/keystone_api/apps/users/admin.py index 836a3d34..e0f0b731 100644 --- a/keystone_api/apps/users/admin.py +++ b/keystone_api/apps/users/admin.py @@ -22,17 +22,17 @@ @admin.register(User) class UserAdmin(auth.admin.UserAdmin): - """Admin interface for managing user accounts""" + """Admin interface for managing user accounts.""" @admin.action def activate_selected_users(self, request, queryset) -> None: - """Mark selected users as active""" + """Mark selected users as active.""" queryset.update(is_active=True) @admin.action def deactivate_selected_users(self, request, queryset) -> None: - """Mark selected users as inactive""" + """Mark selected users as inactive.""" queryset.update(is_active=False) @@ -52,12 +52,12 @@ def deactivate_selected_users(self, request, queryset) -> None: @admin.register(ResearchGroup) class ResearchGroupAdmin(admin.ModelAdmin): - """Admin interface for managing research group delegates""" + """Admin interface for managing research group delegates.""" @staticmethod @admin.display def pi(obj: ResearchGroup) -> str: - """Return the username of the research group PI""" + """Return the username of the research group PI.""" return obj.pi.username diff --git a/keystone_api/apps/users/managers.py b/keystone_api/apps/users/managers.py index 38ce4836..3b196b10 100644 --- a/keystone_api/apps/users/managers.py +++ b/keystone_api/apps/users/managers.py @@ -12,22 +12,22 @@ from django.contrib.auth.base_user import BaseUserManager from django.db import models -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: nocover from apps.users.models import User __all__ = ['ResearchGroupManager', 'UserManager'] class UserManager(BaseUserManager): - """Object manager for the `User` database model""" + """Object manager for the `User` database model.""" def create_user( - self, - username: str, - password: str, - **extra_fields + self, + username: str, + password: str, + **extra_fields ) -> 'User': - """Create and a new user account + """Create a new user account. Args: username: The account username @@ -54,7 +54,7 @@ def create_superuser( password: str, **extra_fields ) -> 'User': - """Create and a new user account with superuser privileges + """Create a new user account with superuser privileges. Args: username: The account username @@ -79,7 +79,7 @@ def create_superuser( class ResearchGroupManager(models.Manager): - """Object manager for the `ResearchGroup` database model""" + """Object manager for the `ResearchGroup` database model.""" def groups_for_user(self, user: 'User') -> models.QuerySet: """Get all research groups the user is affiliated with. diff --git a/keystone_api/apps/users/models.py b/keystone_api/apps/users/models.py index 6d64171b..e7d21242 100644 --- a/keystone_api/apps/users/models.py +++ b/keystone_api/apps/users/models.py @@ -17,7 +17,7 @@ class User(auth_models.AbstractBaseUser, auth_models.PermissionsMixin): - """Proxy model for the built-in django `User` model""" + """Proxy model for the built-in django `User` model.""" # These values should always be defined when extending AbstractBaseUser USERNAME_FIELD = 'username' @@ -44,7 +44,7 @@ class User(auth_models.AbstractBaseUser, auth_models.PermissionsMixin): class ResearchGroup(models.Model): - """A user research group tied to a slurm account""" + """A user research group tied to a slurm account.""" name = models.CharField(max_length=255, unique=True) pi = models.ForeignKey(User, on_delete=models.CASCADE, related_name='research_group_pi') @@ -54,16 +54,16 @@ class ResearchGroup(models.Model): objects = ResearchGroupManager() def get_all_members(self) -> tuple[User, ...]: - """Return all research group members""" + """Return all research group members.""" return (self.pi,) + tuple(self.admins.all()) + tuple(self.members.all()) def get_privileged_members(self) -> tuple[User, ...]: - """Return all research group members with admin privileges""" + """Return all research group members with admin privileges.""" return (self.pi,) + tuple(self.admins.all()) - def __str__(self) -> str: - """Return the research group's account name""" + def __str__(self) -> str: # pragma: nocover # pragma: nocover + """Return the research group's account name.""" return str(self.name) diff --git a/keystone_api/apps/users/permissions.py b/keystone_api/apps/users/permissions.py index 912a446d..90c07938 100644 --- a/keystone_api/apps/users/permissions.py +++ b/keystone_api/apps/users/permissions.py @@ -16,13 +16,13 @@ class IsGroupAdminOrReadOnly(permissions.BasePermission): - """Grant read-only access is granted to all authenticated users. + """Grant read-only access to all authenticated users. Staff users retain all read/write permissions. """ def has_permission(self, request: Request, view: View) -> bool: - """Return whether the request has permissions to access the requested resource""" + """Return whether the request has permissions to access the requested resource.""" if request.method == 'TRACE': return request.user.is_staff @@ -30,7 +30,7 @@ def has_permission(self, request: Request, view: View) -> bool: return True def has_object_permission(self, request: Request, view: View, obj: ResearchGroup): - """Return whether the incoming HTTP request has permission to access a database record""" + """Return whether the incoming HTTP request has permission to access a database record.""" # Read permissions are allowed to any request if request.method in permissions.SAFE_METHODS: @@ -41,10 +41,10 @@ def has_object_permission(self, request: Request, view: View, obj: ResearchGroup class IsSelfOrReadOnly(permissions.BasePermission): - """Gives read-only permissions to everyone but limits write access to staff users and record owners""" + """Grant read-only permissions to everyone and limit write access to staff and record owners.""" def has_permission(self, request: Request, view: View) -> bool: - """Return whether the request has permissions to access the requested resource""" + """Return whether the request has permissions to access the requested resource.""" # Allow all users to read/update existing records # Rely on object level permissions for further refinement of update permissions @@ -55,7 +55,7 @@ def has_permission(self, request: Request, view: View) -> bool: return request.user.is_staff def has_object_permission(self, request: Request, view: View, obj: User) -> bool: - """Return whether the incoming HTTP request has permission to access a database record""" + """Return whether the incoming HTTP request has permission to access a database record.""" # Write operations are restricted to staff and user's modifying their own data is_record_owner = obj == request.user diff --git a/keystone_api/apps/users/serializers.py b/keystone_api/apps/users/serializers.py index 95d96ffd..05f575e7 100644 --- a/keystone_api/apps/users/serializers.py +++ b/keystone_api/apps/users/serializers.py @@ -20,20 +20,20 @@ class ResearchGroupSerializer(serializers.ModelSerializer): - """Object serializer for the `ResearchGroup` class""" + """Object serializer for the `ResearchGroup` model.""" class Meta: - """Serializer settings""" + """Serializer settings.""" model = ResearchGroup fields = '__all__' class PrivilegeUserSerializer(serializers.ModelSerializer): - """Object serializer for the `User` class""" + """Object serializer for the `User` model.""" class Meta: - """Serializer settings""" + """Serializer settings.""" model = User fields = '__all__' @@ -41,7 +41,7 @@ class Meta: extra_kwargs = {'password': {'write_only': True}} def validate(self, attrs: dict) -> None: - """Validate user attributes match the ORM data model + """Validate user attributes match the ORM data model. Args: attrs: Dictionary of user attributes @@ -56,10 +56,10 @@ def validate(self, attrs: dict) -> None: class RestrictedUserSerializer(PrivilegeUserSerializer): - """Object serializer for the `User` class with administrative fields marked as read only""" + """Object serializer for the `User` class with administrative fields marked as read only.""" class Meta: - """Serializer settings""" + """Serializer settings.""" model = User fields = '__all__' @@ -67,7 +67,7 @@ class Meta: extra_kwargs = {'password': {'write_only': True}} def create(self, validated_data: dict) -> None: - """Raises an error when attempting to create a new record + """Raises an error when attempting to create a new record. Raises: RuntimeError: Every time the function is called diff --git a/keystone_api/apps/users/tasks.py b/keystone_api/apps/users/tasks.py index cc78e72d..c777cdb7 100644 --- a/keystone_api/apps/users/tasks.py +++ b/keystone_api/apps/users/tasks.py @@ -15,7 +15,7 @@ def get_ldap_connection() -> ldap.ldapobject.LDAPObject: - """Establish a new LDAP connection""" + """Establish a new LDAP connection.""" conn = ldap.initialize(settings.AUTH_LDAP_SERVER_URI) if settings.AUTH_LDAP_BIND_DN: @@ -29,14 +29,14 @@ def get_ldap_connection() -> ldap.ldapobject.LDAPObject: @shared_task() -def ldap_update_users(prune=False) -> None: - """Update the user database with the latest data from LDAP +def ldap_update_users(prune: bool = False) -> None: + """Update the user database with the latest data from LDAP. This function performs no action if the `AUTH_LDAP_SERVER_URI` setting is not configured in the application settings. Args: - prune: Optionally delete old LDAP accounts with usernames no longer found in LDAP + prune: Optionally delete accounts with usernames no longer found in LDAP """ if not settings.AUTH_LDAP_SERVER_URI: diff --git a/keystone_api/apps/users/tests/test_managers/test_ResearchGroupManager.py b/keystone_api/apps/users/tests/test_managers/test_ResearchGroupManager.py index f7610db8..3c403eaf 100644 --- a/keystone_api/apps/users/tests/test_managers/test_ResearchGroupManager.py +++ b/keystone_api/apps/users/tests/test_managers/test_ResearchGroupManager.py @@ -1,4 +1,4 @@ -"""Tests for the `ResearchGroupManager` class""" +"""Tests for the `ResearchGroupManager` class.""" from django.test import TestCase @@ -7,10 +7,10 @@ class GroupsForUser(TestCase): - """Test fetching group affiliations via the `groups_for_user` method""" + """Test fetching group affiliations via the `groups_for_user` method.""" def setUp(self): - """Create temporary users and groups""" + """Create temporary users and groups.""" self.test_user = create_test_user(username='test_user') other_user = create_test_user(username='other_user') @@ -30,7 +30,7 @@ def setUp(self): self.group4 = ResearchGroup.objects.create(name='Group4', pi=other_user) def test_groups_for_user(self) -> None: - """Test all groups are returned for a test user""" + """Test all groups are returned for a test user.""" result = ResearchGroup.objects.groups_for_user(self.test_user).all() self.assertCountEqual(result, [self.group1, self.group2, self.group3]) diff --git a/keystone_api/apps/users/tests/test_managers/test_UserManager.py b/keystone_api/apps/users/tests/test_managers/test_UserManager.py index 5a705f34..93c8ed53 100644 --- a/keystone_api/apps/users/tests/test_managers/test_UserManager.py +++ b/keystone_api/apps/users/tests/test_managers/test_UserManager.py @@ -1,4 +1,4 @@ -"""Tests for the `UserManager` class""" +"""Tests for the `UserManager` class.""" from django.core.exceptions import ValidationError from django.test import TestCase @@ -7,10 +7,10 @@ class UserCreation(TestCase): - """Test the creation of user accounts""" + """Test the creation of user accounts.""" def test_create_user(self) -> None: - """Test the creation of generic user accounts""" + """Test the creation of generic user accounts.""" user = User.objects.create_user( username='foobar', @@ -28,7 +28,7 @@ def test_create_user(self) -> None: self.assertFalse(user.is_superuser) def test_create_superuser(self) -> None: - """Test the creation of superuser accounts""" + """Test the creation of superuser accounts.""" admin_user = User.objects.create_superuser( username='foobar', @@ -46,7 +46,7 @@ def test_create_superuser(self) -> None: self.assertTrue(admin_user.is_superuser) def test_superusers_must_be_staff(self) -> None: - """Test superusers are required to be staff users""" + """Test superusers are required to be staff users.""" with self.assertRaisesRegex(ValueError, 'must set `is_staff=True`.'): User.objects.create_superuser( @@ -58,7 +58,7 @@ def test_superusers_must_be_staff(self) -> None: is_staff=False) def test_superusers_must_be_superusers(self) -> None: - """Test superusers are required to have superuser permissions""" + """Test superusers are required to have superuser permissions.""" with self.assertRaisesRegex(ValueError, 'must set `is_superuser=True`'): User.objects.create_superuser( @@ -70,7 +70,7 @@ def test_superusers_must_be_superusers(self) -> None: is_superuser=False) def test_passwords_are_validated(self) -> None: - """Test passwords are required to meet security criteria""" + """Test passwords are required to meet security criteria.""" with self.assertRaisesRegex(ValidationError, 'This password is too short'): User.objects.create_user( diff --git a/keystone_api/apps/users/tests/test_models/test_ResearchGroup.py b/keystone_api/apps/users/tests/test_models/test_ResearchGroup.py index 48511682..61f41e1d 100644 --- a/keystone_api/apps/users/tests/test_models/test_ResearchGroup.py +++ b/keystone_api/apps/users/tests/test_models/test_ResearchGroup.py @@ -1,4 +1,4 @@ -"""Tests for the `ResearchGroup` model""" +"""Tests for the `ResearchGroup` model.""" from django.test import TestCase @@ -7,10 +7,10 @@ class GetAllMembers(TestCase): - """Test fetching all group members via the `get_all_members` member""" + """Test fetching all group members via the `get_all_members` member.""" - def setUp(self): - """Create temporary user accounts for use in tests""" + def setUp(self) -> None: + """Create temporary user accounts for use in tests.""" self.pi = create_test_user(username='pi') self.admin1 = create_test_user(username='admin1') @@ -19,7 +19,7 @@ def setUp(self): self.member2 = create_test_user(username='unprivileged2') def test_all_accounts_returned(self) -> None: - """Test all group members are included in the returned list""" + """Test all group members are included in the returned list.""" group = ResearchGroup.objects.create(pi=self.pi) group.admins.add(self.admin1) @@ -32,10 +32,10 @@ def test_all_accounts_returned(self) -> None: class GetPrivilegedMembers(TestCase): - """Test fetching group members via the `get_privileged_members` member""" + """Test fetching group members via the `get_privileged_members` member.""" - def setUp(self): - """Create temporary user accounts for use in tests""" + def setUp(self) -> None: + """Create temporary user accounts for use in tests.""" self.pi = create_test_user(username='pi') self.admin1 = create_test_user(username='admin1') @@ -44,14 +44,14 @@ def setUp(self): self.member2 = create_test_user(username='member2') def test_pi_only(self) -> None: - """Test returned group members for a group with a PI only""" + """Test returned group members for a group with a PI only.""" group = ResearchGroup.objects.create(pi=self.pi) expected_members = (self.pi,) self.assertEqual(expected_members, group.get_privileged_members()) def test_pi_with_admins(self) -> None: - """Test returned group members for a group with a PI and admins""" + """Test returned group members for a group with a PI and admins.""" group = ResearchGroup.objects.create(pi=self.pi) group.admins.add(self.admin1) @@ -61,7 +61,7 @@ def test_pi_with_admins(self) -> None: self.assertEqual(expected_members, group.get_privileged_members()) def test_pi_with_members(self) -> None: - """Test returned group members for a group with a PI and unprivileged members""" + """Test returned group members for a group with a PI and unprivileged members.""" group = ResearchGroup.objects.create(pi=self.pi) group.members.add(self.member1) @@ -71,7 +71,7 @@ def test_pi_with_members(self) -> None: self.assertEqual(expected_members, group.get_privileged_members()) def test_pi_with_admin_and_members(self) -> None: - """Test returned group members for a group with a PI, admins, and unprivileged members""" + """Test returned group members for a group with a PI, admins, and unprivileged members.""" group = ResearchGroup.objects.create(pi=self.pi) group.admins.add(self.admin1) diff --git a/keystone_api/apps/users/tests/test_models/test_User.py b/keystone_api/apps/users/tests/test_models/test_User.py index fec8f666..52d607dc 100644 --- a/keystone_api/apps/users/tests/test_models/test_User.py +++ b/keystone_api/apps/users/tests/test_models/test_User.py @@ -1,4 +1,4 @@ -"""Tests for the `User` class""" +"""Tests for the `User` class.""" from django.contrib.auth import get_user_model from django.test import TestCase @@ -7,9 +7,9 @@ class UserModelRegistration(TestCase): - """Test the registration of the model with the Django authentication system""" + """Test the registration of the model with the Django authentication system.""" def test_registered_as_default_user_model(self) -> None: - """Test the `User` class is returned by the built-in `get_user_model` method""" + """Test the `User` class is returned by the built-in `get_user_model` method.""" self.assertIs(User, get_user_model()) diff --git a/keystone_api/apps/users/tests/test_serializers/__init__.py b/keystone_api/apps/users/tests/test_serializers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/keystone_api/apps/users/tests/test_serializers/test_PrivilegeUserSerializer.py b/keystone_api/apps/users/tests/test_serializers/test_PrivilegeUserSerializer.py new file mode 100644 index 00000000..26c2cb59 --- /dev/null +++ b/keystone_api/apps/users/tests/test_serializers/test_PrivilegeUserSerializer.py @@ -0,0 +1,42 @@ +"""Tests for the `PrivilegeUserSerializer` class.""" + +from django.contrib.auth.hashers import check_password +from django.test import TestCase +from rest_framework.exceptions import ValidationError as DRFValidationError + +from apps.users.serializers import PrivilegeUserSerializer + + +class Validate(TestCase): + """Test data validation via the `validate`.""" + + def setUp(self) -> None: + """Define dummy user data.""" + + self.user_data = { + 'username': 'testuser', + 'password': 'Password123!', + 'email': 'testuser@example.com', + } + + def test_validate_password_is_hashed(self) -> None: + """Test the password is hashed during validation.""" + + serializer = PrivilegeUserSerializer(data=self.user_data) + self.assertTrue(serializer.is_valid()) + self.assertTrue(check_password('Password123!', serializer.validated_data['password'])) + + def test_validate_password_invalid(self) -> None: + """Test an invalid password raises a `ValidationError`.""" + + self.user_data['password'] = '123' # Too short + serializer = PrivilegeUserSerializer(data=self.user_data) + with self.assertRaises(DRFValidationError): + serializer.is_valid(raise_exception=True) + + def test_validate_without_password(self) -> None: + """Test validation fails when a password is not provided.""" + + del self.user_data['password'] + serializer = PrivilegeUserSerializer(data=self.user_data) + self.assertFalse(serializer.is_valid()) diff --git a/keystone_api/apps/users/tests/test_serializers/test_RestrictedUserSerializer.py b/keystone_api/apps/users/tests/test_serializers/test_RestrictedUserSerializer.py new file mode 100644 index 00000000..908eee73 --- /dev/null +++ b/keystone_api/apps/users/tests/test_serializers/test_RestrictedUserSerializer.py @@ -0,0 +1,16 @@ +"""Tests for the `RestrictedUserSerializer` class.""" + +from django.test import TestCase + +from apps.users.serializers import RestrictedUserSerializer + + +class Create(TestCase): + """Test the `create` method.""" + + def test_create_raises_not_permitted(self) -> None: + """Test that the create method raises a `RuntimeError`.""" + + serializer = RestrictedUserSerializer() + with self.assertRaises(RuntimeError): + serializer.create({'username': 'testuser', 'password': 'Password123!'}) diff --git a/keystone_api/apps/users/tests/test_tasks/__init__.py b/keystone_api/apps/users/tests/test_tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/keystone_api/apps/users/tests/test_tasks/test_get_ldap_connection.py b/keystone_api/apps/users/tests/test_tasks/test_get_ldap_connection.py new file mode 100644 index 00000000..0461fae9 --- /dev/null +++ b/keystone_api/apps/users/tests/test_tasks/test_get_ldap_connection.py @@ -0,0 +1,64 @@ +"""Tests for the `get_ldap_connection` function.""" + +from unittest.mock import Mock, patch + +import ldap +from django.conf import settings +from django.test import TestCase + +from apps.users.tasks import get_ldap_connection + + +class TLSConfiguration(TestCase): + """Test the configuration of TLS based on application settings.""" + + @patch('ldap.initialize') + @patch('ldap.set_option') + @patch('ldap.ldapobject.LDAPObject') + def test_get_ldap_connection(self, mock_ldap: Mock, mock_set_option: Mock, mock_initialize: Mock) -> None: + """Test an LDAP connection is correctly configured with TLS enabled.""" + + # Set up mock objects + mock_conn = mock_ldap.return_value + mock_initialize.return_value = mock_conn + mock_set_option.return_value = None + + # Configure settings for testing + settings.AUTH_LDAP_SERVER_URI = 'ldap://testserver' + settings.AUTH_LDAP_BIND_DN = 'cn=admin,dc=example,dc=com' + settings.AUTH_LDAP_BIND_PASSWORD = 'password123' + settings.AUTH_LDAP_START_TLS = True + + # Call the function to test + conn = get_ldap_connection() + self.assertEqual(conn, mock_conn) + + # Check the calls + mock_initialize.assert_called_once_with('ldap://testserver') + mock_conn.bind.assert_called_once_with('cn=admin,dc=example,dc=com', 'password123') + mock_set_option.assert_called_once_with(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER) + mock_conn.start_tls_s.assert_called_once() + + @patch('ldap.initialize') + @patch('ldap.ldapobject.LDAPObject') + def test_get_ldap_connection_without_tls(self, mock_ldap: Mock, mock_initialize: Mock) -> None: + """Test an LDAP connection is correctly configured with TLS disabled.""" + + # Set up mock objects + mock_conn = mock_ldap.return_value + mock_initialize.return_value = mock_conn + + # Configure settings for testing + settings.AUTH_LDAP_SERVER_URI = 'ldap://testserver' + settings.AUTH_LDAP_BIND_DN = 'cn=admin,dc=example,dc=com' + settings.AUTH_LDAP_BIND_PASSWORD = 'password' + settings.AUTH_LDAP_START_TLS = False + + # Call the function to test + conn = get_ldap_connection() + self.assertEqual(conn, mock_conn) + + # Check the calls + mock_initialize.assert_called_once_with('ldap://testserver') + mock_conn.bind.assert_called_once_with('cn=admin,dc=example,dc=com', 'password') + mock_conn.start_tls_s.assert_not_called() diff --git a/keystone_api/apps/users/tests/test_views/__init__.py b/keystone_api/apps/users/tests/test_views/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/keystone_api/apps/users/tests/test_views/test_UserViewSet.py b/keystone_api/apps/users/tests/test_views/test_UserViewSet.py new file mode 100644 index 00000000..d34d721b --- /dev/null +++ b/keystone_api/apps/users/tests/test_views/test_UserViewSet.py @@ -0,0 +1,36 @@ +"""Tests for the `UserViewSet` class.""" + +from django.test import RequestFactory, TestCase + +from apps.users.models import User +from apps.users.serializers import PrivilegeUserSerializer, RestrictedUserSerializer +from apps.users.views import UserViewSet + + +class GetSerializerClass(TestCase): + """Test the `get_serializer_class` method.""" + + def setUp(self) -> None: + self.factory = RequestFactory() + self.staff_user = User.objects.create(username='staffuser', is_staff=True) + self.regular_user = User.objects.create(username='regularuser', is_staff=False) + + def test_get_serializer_class_for_staff_user(self) -> None: + """Test the `PrivilegeUserSerializer` serializer is returned for a staff user.""" + + request = self.factory.get('/users/') + request.user = self.staff_user + view = UserViewSet(request=request) + + serializer_class = view.get_serializer_class() + self.assertEqual(serializer_class, PrivilegeUserSerializer) + + def test_get_serializer_class_for_regular_user(self) -> None: + """Test the `RestrictedUserSerializer` serializer is returned for a staff user.""" + + request = self.factory.get('/users/') + request.user = self.regular_user + view = UserViewSet(request=request) + + serializer_class = view.get_serializer_class() + self.assertEqual(serializer_class, RestrictedUserSerializer) diff --git a/keystone_api/apps/users/tests/utils.py b/keystone_api/apps/users/tests/utils.py index 35a1ad1a..c9496bc4 100644 --- a/keystone_api/apps/users/tests/utils.py +++ b/keystone_api/apps/users/tests/utils.py @@ -1,4 +1,4 @@ -"""Testing utilities specific to dealing with user accounts""" +"""Testing utilities specific to dealing with user accounts.""" from apps.users.models import User @@ -11,7 +11,7 @@ def create_test_user( email: str = "foo@bar.com", **kwargs ) -> User: - """Create a user account for testing purposes + """Create a user account for testing purposes. Args: username: The account username @@ -25,4 +25,10 @@ def create_test_user( The saved user account """ - return User.objects.create_user(username, password, first_name=first_name, last_name=last_name, email=email, **kwargs) + return User.objects.create_user( + username=username, + password=password, + first_name=first_name, + last_name=last_name, + email=email, + **kwargs) diff --git a/keystone_api/apps/users/urls.py b/keystone_api/apps/users/urls.py index c04eb424..48fe3ca5 100644 --- a/keystone_api/apps/users/urls.py +++ b/keystone_api/apps/users/urls.py @@ -1,4 +1,4 @@ -"""URL routing for the parent application""" +"""URL routing for the parent application.""" from rest_framework.routers import DefaultRouter diff --git a/keystone_api/apps/users/views.py b/keystone_api/apps/users/views.py index 6da1b13d..bd9fc8d1 100644 --- a/keystone_api/apps/users/views.py +++ b/keystone_api/apps/users/views.py @@ -5,6 +5,7 @@ """ from rest_framework import permissions, viewsets +from rest_framework.serializers import Serializer from .models import * from .permissions import IsGroupAdminOrReadOnly, IsSelfOrReadOnly @@ -17,7 +18,7 @@ class ResearchGroupViewSet(viewsets.ModelViewSet): - """Manage user membership in research groups""" + """Manage user membership in research groups.""" queryset = ResearchGroup.objects.all() permission_classes = [permissions.IsAuthenticated, IsGroupAdminOrReadOnly] @@ -25,13 +26,13 @@ class ResearchGroupViewSet(viewsets.ModelViewSet): class UserViewSet(viewsets.ModelViewSet): - """Read only access to user data""" + """Manage user account data.""" queryset = User.objects.all() permission_classes = [permissions.IsAuthenticated, IsSelfOrReadOnly] - def get_serializer_class(self): - """Return the appropriate data serializer""" + def get_serializer_class(self) -> type[Serializer]: + """Return the appropriate data serializer based on user roles/permissions.""" if self.request.user.is_staff: return PrivilegeUserSerializer