From e7b71776576acbdcfa293c189d3898dd1d44e042 Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Fri, 5 Jul 2024 16:02:30 +0200 Subject: [PATCH] Filter assignment metrics by users --- lms/services/assignment.py | 40 ++++++++++++++----- lms/views/dashboard/api/assignment.py | 20 +++++++++- tests/unit/lms/services/assignment_test.py | 21 +++++++++- .../views/dashboard/api/assignment_test.py | 5 ++- 4 files changed, 72 insertions(+), 14 deletions(-) diff --git a/lms/services/assignment.py b/lms/services/assignment.py index 786cf9c77c..7d2482268d 100644 --- a/lms/services/assignment.py +++ b/lms/services/assignment.py @@ -9,6 +9,8 @@ AssignmentMembership, Grouping, LTIRole, + RoleScope, + RoleType, User, ) from lms.services.upsert import bulk_upsert @@ -210,21 +212,39 @@ def is_member(self, assignment: Assignment, h_userid: str) -> bool: ) def get_assignments( - self, h_userid: str | None = None, course_id: int | None = None + self, + instructor_h_userid: str | None = None, + course_id: int | None = None, + h_userids: list[str] | None = None, ) -> Select[tuple[Assignment]]: """Get a query to fetch assignments. - :params: h_userid only return assignments the users is a member of. - :params: course_id only return assignments that belong to this course. + :param instructor_h_userid: return only assignments where instructor_h_userid is an instructor. + :param course_id: only return assignments that belong to this course. + :param h_userids: return only assignments where these users are members. """ - assignments_query = select(Assignment) + query = select(Assignment) + + if instructor_h_userid: + query = query.where( + Assignment.id.in_( + select(AssignmentMembership.assignment_id) + .join(User) + .join(LTIRole) + .where( + User.h_userid == instructor_h_userid, + LTIRole.scope == RoleScope.COURSE, + LTIRole.type == RoleType.INSTRUCTOR, + ) + ) + ) - if h_userid: - assignments_query = ( - assignments_query.join(AssignmentMembership) + if h_userids: + query = ( + query.join(AssignmentMembership) .join(User) - .where(User.h_userid == h_userid) + .where(User.h_userid.in_(h_userids)) ) if course_id: @@ -232,13 +252,13 @@ def get_assignments( self._deduplicated_course_assigments_query([course_id]).subquery() ) - assignments_query = assignments_query.where( + query = query.where( # Get only assignment from the candidates above Assignment.id == deduplicated_course_assignments.c.assignment_id, deduplicated_course_assignments.c.grouping_id == course_id, ) - return assignments_query.order_by(Assignment.title, Assignment.id).distinct() + return query.order_by(Assignment.title, Assignment.id).distinct() def _deduplicated_course_assigments_query(self, course_ids: list[int]): # Get all assignment IDs we recorded from this course diff --git a/lms/views/dashboard/api/assignment.py b/lms/views/dashboard/api/assignment.py index 5aa391526d..503770fa86 100644 --- a/lms/views/dashboard/api/assignment.py +++ b/lms/views/dashboard/api/assignment.py @@ -11,6 +11,7 @@ from lms.security import Permissions from lms.services import UserService from lms.services.h_api import HAPI +from lms.validation import PyramidRequestSchema from lms.views.dashboard.pagination import PaginationParametersMixin, get_page @@ -21,6 +22,15 @@ class ListAssignmentsSchema(PaginationParametersMixin): """Return assignments that belong to the course with this ID.""" +class AssignmentsMetricsSchema(PyramidRequestSchema): + """Query parameters to fetch metrics for assignments.""" + + location = "querystring" + + h_userids = fields.List(fields.Str()) + """Return metrics for these users only.""" + + class AssignmentViews: def __init__(self, request) -> None: self.request = request @@ -39,7 +49,9 @@ def __init__(self, request) -> None: ) def assignments(self) -> APIAssignments: assignments = self.assignment_service.get_assignments( - h_userid=self.request.user.h_userid if self.request.user else None, + instructor_h_userid=self.request.user.h_userid + if self.request.user + else None, course_id=self.request.parsed_params.get("course_id"), ) assignments, pagination = get_page( @@ -72,9 +84,11 @@ def assignment(self) -> APIAssignment: request_method="GET", renderer="json", permission=Permissions.DASHBOARD_VIEW, + schema=AssignmentsMetricsSchema, ) def course_assignments_metrics(self) -> APIAssignments: current_h_userid = self.request.user.h_userid if self.request.user else None + filter_by_h_userids = self.request.parsed_params.get("h_userids") course = self.dashboard_service.get_request_course(self.request) course_students = self.request.db.scalars( self.user_service.get_users( @@ -82,12 +96,14 @@ def course_assignments_metrics(self) -> APIAssignments: role_scope=RoleScope.COURSE, role_type=RoleType.LEARNER, instructor_h_userid=current_h_userid, + h_userids=filter_by_h_userids, ) ).all() assignments = self.request.db.scalars( self.assignment_service.get_assignments( course_id=course.id, - h_userid=self.request.user.h_userid if self.request.user else None, + instructor_h_userid=current_h_userid, + h_userids=filter_by_h_userids, ) ).all() diff --git a/tests/unit/lms/services/assignment_test.py b/tests/unit/lms/services/assignment_test.py index f702976001..71e8aa66c7 100644 --- a/tests/unit/lms/services/assignment_test.py +++ b/tests/unit/lms/services/assignment_test.py @@ -251,7 +251,7 @@ def test_get_assignments_by_course(self, svc, db_session, assignment): assignment ] - def test_get_assignments_with_h_userid(self, svc, db_session): + def test_get_assignments_with_instructor_h_userid(self, svc, db_session): factories.User() # User not in assignment assignment = factories.Assignment() user = factories.User() @@ -270,6 +270,25 @@ def test_get_assignments_with_h_userid(self, svc, db_session): assignment ] + def test_get_assignments_with_h_userids(self, svc, db_session): + factories.User() # User not in assignment + assignment = factories.Assignment() + user = factories.User() + lti_role = factories.LTIRole(scope=RoleScope.COURSE) + factories.AssignmentMembership.create( + assignment=assignment, user=user, lti_role=lti_role + ) + # Other membership record, with a different role + factories.AssignmentMembership.create( + assignment=assignment, user=user, lti_role=factories.LTIRole() + ) + + db_session.flush() + + assert db_session.scalars( + svc.get_assignments(h_userids=[user.h_userid]) + ).all() == [assignment] + def test_get_assignments_by_course_id_with_duplicate(self, db_session, svc): course = factories.Course() other_course = factories.Course() diff --git a/tests/unit/lms/views/dashboard/api/assignment_test.py b/tests/unit/lms/views/dashboard/api/assignment_test.py index b221fe6c18..36326f1668 100644 --- a/tests/unit/lms/views/dashboard/api/assignment_test.py +++ b/tests/unit/lms/views/dashboard/api/assignment_test.py @@ -23,7 +23,8 @@ def test_get_assignments( response = views.assignments() assignment_service.get_assignments.assert_called_once_with( - pyramid_request.user.h_userid, course_id=sentinel.course_id + instructor_h_userid=pyramid_request.user.h_userid, + course_id=sentinel.course_id, ) get_page.assert_called_once_with( pyramid_request, @@ -65,6 +66,7 @@ def test_course_assignments( user_service, ): pyramid_request.matchdict["course_id"] = sentinel.id + pyramid_request.parsed_params = {"h_userids": sentinel.h_userids} course = factories.Course() section = factories.CanvasSection(parent=course) dashboard_service.get_request_course.return_value = course @@ -100,6 +102,7 @@ def test_course_assignments( role_scope=RoleScope.COURSE, role_type=RoleType.LEARNER, instructor_h_userid=pyramid_request.user.h_userid, + h_userids=sentinel.h_userids, ) h_api.get_annotation_counts.assert_called_once_with( [course.authority_provided_id, section.authority_provided_id],