Skip to content

Commit

Permalink
Filter course metrics by assignments
Browse files Browse the repository at this point in the history
  • Loading branch information
marcospri committed Jul 9, 2024
1 parent 36c4cb3 commit 1ac289f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 3 deletions.
15 changes: 13 additions & 2 deletions lms/services/course.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _search_query( # noqa: PLR0913, PLR0917
)

if h_userids:
# Only courses where these H's h_userids belongs to
# Only courses these h_userids belong to
query = (
query.join(GroupingMembership)
.join(User)
Expand Down Expand Up @@ -141,15 +141,17 @@ def search( # noqa: PLR0913, PLR0917

def get_courses(
self,
instructor_h_userid: str | None,
instructor_h_userid: str | None = None,
organization: Organization | None = None,
h_userids: list[str] | None = None,
assignment_ids: list[str] | None = None,
) -> Select[tuple[Course]]:
"""Get a list of unique courses.
:param organization: organization the courses belong to.
:param instructor_h_userid: return only courses where instructor_h_userid is an instructor.
:param h_userids: return only courses where these users are members.
:param assignment_ids: return only the courses these assignments belong to.
"""
courses_query = (
self._search_query(
Expand Down Expand Up @@ -182,6 +184,15 @@ def get_courses(
)
)

if assignment_ids:
courses_query = courses_query.where(
Course.id.in_(
select(AssignmentGrouping.grouping_id).where(
AssignmentGrouping.assignment_id.in_(assignment_ids)
)
)
)

return (
select(Course)
.where(
Expand Down
5 changes: 5 additions & 0 deletions lms/views/dashboard/api/course.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class CoursesMetricsSchema(PyramidRequestSchema):
h_userids = fields.List(fields.Str())
"""Return metrics for these users only."""

assignment_ids = fields.List(fields.Integer())
"""Return metrics for these assignments only."""


class CourseViews:
def __init__(self, request) -> None:
Expand Down Expand Up @@ -67,6 +70,7 @@ def courses(self) -> APICourses:
)
def organization_courses(self) -> APICourses:
filter_by_h_userids = self.request.parsed_params.get("h_userids")
filter_by_assignment_ids = self.request.parsed_params.get("assignment_ids")
org = self.dashboard_service.get_request_organization(self.request)
courses = self.request.db.scalars(
self.course_service.get_courses(
Expand All @@ -75,6 +79,7 @@ def organization_courses(self) -> APICourses:
if self.request.user
else None,
h_userids=filter_by_h_userids,
assignment_ids=filter_by_assignment_ids,
)
).all()
courses_assignment_counts = (
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/lms/services/course_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,21 @@ def test_get_courses_by_instructor_h_userid(self, svc, db_session):
svc.get_courses(instructor_h_userid=user.h_userid)
).all() == [course]

def test_get_courses_by_assignment_ids(self, svc, db_session):
course = factories.Course()
assignment = factories.Assignment()
user = factories.User()
factories.AssignmentMembership.create(
assignment=assignment, user=user, lti_role=factories.LTIRole()
)
factories.AssignmentGrouping(grouping=course, assignment=assignment)

db_session.flush()

assert db_session.scalars(
svc.get_courses(assignment_ids=[assignment.id])
).all() == [course]

@pytest.fixture
def course(self, application_instance, grouping_service):
return factories.Course(
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/lms/views/dashboard/api/course_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def test_get_organization_courses(
dashboard_service.get_request_organization.return_value = org
course_service.get_courses.return_value = select(Course).order_by(Course.id)
pyramid_request.matchdict["organization_public_id"] = sentinel.public_id
pyramid_request.parsed_params = {"h_userids": sentinel.h_userids}
pyramid_request.parsed_params = {
"h_userids": sentinel.h_userids,
"assignment_ids": sentinel.assignment_ids,
}
db_session.flush()

response = views.organization_courses()
Expand All @@ -62,6 +65,7 @@ def test_get_organization_courses(
organization=org,
instructor_h_userid=pyramid_request.user.h_userid,
h_userids=sentinel.h_userids,
assignment_ids=sentinel.assignment_ids,
)
assignment_service.get_courses_assignments_count.assert_called_once_with(
[c.id for c in courses]
Expand Down

0 comments on commit 1ac289f

Please sign in to comment.