Skip to content

Commit

Permalink
Add parameters to the students API endpoint to allow filtering
Browse files Browse the repository at this point in the history
- Filter by course_id
- Filter by assignment_id
  • Loading branch information
marcospri committed Jul 3, 2024
1 parent 5d8d4da commit e0ad070
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 2 deletions.
25 changes: 23 additions & 2 deletions lms/services/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
from sqlalchemy.exc import NoResultFound
from sqlalchemy.sql import Select

from lms.models import AssignmentMembership, LTIRole, LTIUser, RoleScope, RoleType, User
from lms.models import (
AssignmentGrouping,
AssignmentMembership,
LTIRole,
LTIUser,
RoleScope,
RoleType,
User,
)


class UserNotFound(Exception):
Expand Down Expand Up @@ -94,18 +102,22 @@ def _user_search_query(self, application_instance_id, user_id) -> Select:

return query

def get_users(
def get_users( # noqa: PLR0913
self,
role_scope: RoleScope,
role_type: RoleType,
instructor_h_userid: str | None = None,
course_id: str | None = None,
assignment_id: str | None = None,
) -> Select[tuple[User]]:
"""
Get a query to fetch users.
:param role_scope: return only users with this LTI role scope.
:param role_type: return only users with this LTI role type.
:param instructor_h_userid: return only users that belongs to courses/assignments where the user instructor_h_userid is an instructor.
:param course_id: return only users that belong to course_id.
:param assignment_id: return only users that belong to assignment_id.
"""
query = (
select(User.id)
Expand All @@ -128,6 +140,15 @@ def get_users(
)
)

if course_id:
query = query.join(
AssignmentGrouping,
AssignmentGrouping.assignment_id == AssignmentMembership.assignment_id,
).where(AssignmentGrouping.grouping_id == course_id)

if assignment_id:
query = query.where(AssignmentMembership.assignment_id == assignment_id)

# Deduplicate based on the row's h_userid taking the last updated one
query = query.distinct(User.h_userid).order_by(
User.h_userid, User.updated.desc()
Expand Down
9 changes: 9 additions & 0 deletions lms/views/dashboard/api/user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from marshmallow import fields, validate
from pyramid.view import view_config

from lms.js_config_types import APIStudent, APIStudents
Expand All @@ -14,6 +15,12 @@
class ListUsersSchema(PaginationParametersMixin):
"""Query parameters to fetch a list of users."""

course_id = fields.Integer(required=False, validate=validate.Range(min=1))
"""Return users that belong to the course with this ID."""

assignment_id = fields.Integer(required=False, validate=validate.Range(min=1))
"""Return users that belong to the assignment with this ID."""


class UserViews:
def __init__(self, request) -> None:
Expand All @@ -34,6 +41,8 @@ def students(self) -> APIStudents:
instructor_h_userid=self.request.user.h_userid
if self.request.user
else None,
course_id=self.request.parsed_params.get("course_id"),
assignment_id=self.request.parsed_params.get("assignment_id"),
)
students, pagination = get_page(
self.request, students_query, [User.display_name, User.id]
Expand Down
54 changes: 54 additions & 0 deletions tests/unit/lms/services/user_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,60 @@ def test_get_users(self, service, db_session):

assert db_session.scalars(query).all() == [student]

def test_get_users_by_course_id(self, service, db_session):
assignment = factories.Assignment()
course = factories.Course()
student = factories.User()
factories.User(h_userid=student.h_userid) # Duplicated student
teacher = factories.User()
factories.AssignmentMembership.create(
assignment=assignment,
user=student,
lti_role=factories.LTIRole(scope=RoleScope.COURSE, type=RoleType.LEARNER),
)
factories.AssignmentMembership.create(
assignment=assignment,
user=teacher,
lti_role=factories.LTIRole(
scope=RoleScope.COURSE, type=RoleType.INSTRUCTOR
),
)
factories.AssignmentGrouping(assignment=assignment, grouping=course)
db_session.flush()

query = service.get_users(
role_scope=RoleScope.COURSE, role_type=RoleType.LEARNER, course_id=course.id
)

assert db_session.scalars(query).all() == [student]

def test_get_users_by_assigment_id(self, service, db_session):
assignment = factories.Assignment()
student = factories.User()
factories.User(h_userid=student.h_userid) # Duplicated student
teacher = factories.User()
factories.AssignmentMembership.create(
assignment=assignment,
user=student,
lti_role=factories.LTIRole(scope=RoleScope.COURSE, type=RoleType.LEARNER),
)
factories.AssignmentMembership.create(
assignment=assignment,
user=teacher,
lti_role=factories.LTIRole(
scope=RoleScope.COURSE, type=RoleType.INSTRUCTOR
),
)
db_session.flush()

query = service.get_users(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
assignment_id=assignment.id,
)

assert db_session.scalars(query).all() == [student]

def test_get_users_by_h_userid(self, service, db_session):
# Assignment the h_userid belongs to as a teacher
assignment = factories.Assignment()
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/lms/views/dashboard/api/user_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

class TestUserViews:
def test_get_students(self, user_service, pyramid_request, views, get_page):
pyramid_request.parsed_params = {
"course_id": sentinel.course_id,
"assignment_id": sentinel.assignment_id,
}
students = factories.User.create_batch(5)
get_page.return_value = students, sentinel.pagination

Expand All @@ -19,6 +23,8 @@ def test_get_students(self, user_service, pyramid_request, views, get_page):
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
instructor_h_userid=pyramid_request.user.h_userid,
course_id=sentinel.course_id,
assignment_id=sentinel.assignment_id,
)
get_page.assert_called_once_with(
pyramid_request,
Expand Down

0 comments on commit e0ad070

Please sign in to comment.