From 1ac289f70ef8f731d5e4161057a1bbde04764bd3 Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Tue, 9 Jul 2024 09:43:22 +0200 Subject: [PATCH] Filter course metrics by assignments --- lms/services/course.py | 15 +++++++++++++-- lms/views/dashboard/api/course.py | 5 +++++ tests/unit/lms/services/course_test.py | 15 +++++++++++++++ tests/unit/lms/views/dashboard/api/course_test.py | 6 +++++- 4 files changed, 38 insertions(+), 3 deletions(-) diff --git a/lms/services/course.py b/lms/services/course.py index 9d5cc94e33..3e9114efe8 100644 --- a/lms/services/course.py +++ b/lms/services/course.py @@ -110,7 +110,7 @@ def _search_query( # noqa: PLR0913, PLR0917 ) if h_userids: - # Only courses where these H's h_userids belongs to + # Only courses these h_userids belong to query = ( query.join(GroupingMembership) .join(User) @@ -141,15 +141,17 @@ def search( # noqa: PLR0913, PLR0917 def get_courses( self, - instructor_h_userid: str | None, + instructor_h_userid: str | None = None, organization: Organization | None = None, h_userids: list[str] | None = None, + assignment_ids: list[str] | None = None, ) -> Select[tuple[Course]]: """Get a list of unique courses. :param organization: organization the courses belong to. :param instructor_h_userid: return only courses where instructor_h_userid is an instructor. :param h_userids: return only courses where these users are members. + :param assignment_ids: return only the courses these assignments belong to. """ courses_query = ( self._search_query( @@ -182,6 +184,15 @@ def get_courses( ) ) + if assignment_ids: + courses_query = courses_query.where( + Course.id.in_( + select(AssignmentGrouping.grouping_id).where( + AssignmentGrouping.assignment_id.in_(assignment_ids) + ) + ) + ) + return ( select(Course) .where( diff --git a/lms/views/dashboard/api/course.py b/lms/views/dashboard/api/course.py index 1140260fb8..ac599a4259 100644 --- a/lms/views/dashboard/api/course.py +++ b/lms/views/dashboard/api/course.py @@ -25,6 +25,9 @@ class CoursesMetricsSchema(PyramidRequestSchema): h_userids = fields.List(fields.Str()) """Return metrics for these users only.""" + assignment_ids = fields.List(fields.Integer()) + """Return metrics for these assignments only.""" + class CourseViews: def __init__(self, request) -> None: @@ -67,6 +70,7 @@ def courses(self) -> APICourses: ) def organization_courses(self) -> APICourses: filter_by_h_userids = self.request.parsed_params.get("h_userids") + filter_by_assignment_ids = self.request.parsed_params.get("assignment_ids") org = self.dashboard_service.get_request_organization(self.request) courses = self.request.db.scalars( self.course_service.get_courses( @@ -75,6 +79,7 @@ def organization_courses(self) -> APICourses: if self.request.user else None, h_userids=filter_by_h_userids, + assignment_ids=filter_by_assignment_ids, ) ).all() courses_assignment_counts = ( diff --git a/tests/unit/lms/services/course_test.py b/tests/unit/lms/services/course_test.py index 480c5614b9..1cc8b4c04b 100644 --- a/tests/unit/lms/services/course_test.py +++ b/tests/unit/lms/services/course_test.py @@ -391,6 +391,21 @@ def test_get_courses_by_instructor_h_userid(self, svc, db_session): svc.get_courses(instructor_h_userid=user.h_userid) ).all() == [course] + def test_get_courses_by_assignment_ids(self, svc, db_session): + course = factories.Course() + assignment = factories.Assignment() + user = factories.User() + factories.AssignmentMembership.create( + assignment=assignment, user=user, lti_role=factories.LTIRole() + ) + factories.AssignmentGrouping(grouping=course, assignment=assignment) + + db_session.flush() + + assert db_session.scalars( + svc.get_courses(assignment_ids=[assignment.id]) + ).all() == [course] + @pytest.fixture def course(self, application_instance, grouping_service): return factories.Course( diff --git a/tests/unit/lms/views/dashboard/api/course_test.py b/tests/unit/lms/views/dashboard/api/course_test.py index e7d81a85a6..9f25dec180 100644 --- a/tests/unit/lms/views/dashboard/api/course_test.py +++ b/tests/unit/lms/views/dashboard/api/course_test.py @@ -50,7 +50,10 @@ def test_get_organization_courses( dashboard_service.get_request_organization.return_value = org course_service.get_courses.return_value = select(Course).order_by(Course.id) pyramid_request.matchdict["organization_public_id"] = sentinel.public_id - pyramid_request.parsed_params = {"h_userids": sentinel.h_userids} + pyramid_request.parsed_params = { + "h_userids": sentinel.h_userids, + "assignment_ids": sentinel.assignment_ids, + } db_session.flush() response = views.organization_courses() @@ -62,6 +65,7 @@ def test_get_organization_courses( organization=org, instructor_h_userid=pyramid_request.user.h_userid, h_userids=sentinel.h_userids, + assignment_ids=sentinel.assignment_ids, ) assignment_service.get_courses_assignments_count.assert_called_once_with( [c.id for c in courses]