diff --git a/tests/unit/lms/services/assignment_test.py b/tests/unit/lms/services/assignment_test.py index 71e8aa66c7..5561383c79 100644 --- a/tests/unit/lms/services/assignment_test.py +++ b/tests/unit/lms/services/assignment_test.py @@ -4,7 +4,7 @@ import pytest from h_matchers import Any -from lms.models import AssignmentGrouping, AssignmentMembership, RoleScope +from lms.models import AssignmentGrouping, AssignmentMembership, RoleScope, RoleType from lms.services.assignment import AssignmentService, factory from tests import factories @@ -239,55 +239,52 @@ def test_is_member(self, svc, db_session): assert svc.is_member(assignment, user.h_userid) assert not svc.is_member(assignment, other_user.h_userid) - def test_get_assignments(self, svc, db_session): - assert db_session.scalars(svc.get_assignments()).all() - - def test_get_assignments_by_course(self, svc, db_session, assignment): + @pytest.mark.parametrize("instructor_h_userid", [True, False]) + @pytest.mark.parametrize("course_id", [True, False]) + @pytest.mark.parametrize("h_userids", [True, False]) + def test_get_assignments( + self, + svc, + db_session, + instructor_h_userid, + assignment, + with_assignment_noise, + course_id, + h_userids, + ): + factories.User() course = factories.Course() - factories.AssignmentGrouping.create(assignment=assignment, grouping=course) - db_session.flush() - - assert db_session.scalars(svc.get_assignments(course_id=course.id)).all() == [ - assignment - ] - - def test_get_assignments_with_instructor_h_userid(self, svc, db_session): - factories.User() # User not in assignment - assignment = factories.Assignment() user = factories.User() - lti_role = factories.LTIRole(scope=RoleScope.COURSE) + lti_role = factories.LTIRole(scope=RoleScope.COURSE, type=RoleType.INSTRUCTOR) factories.AssignmentMembership.create( assignment=assignment, user=user, lti_role=lti_role ) - # Other membership record, with a different role factories.AssignmentMembership.create( assignment=assignment, user=user, lti_role=factories.LTIRole() ) - + factories.AssignmentGrouping.create(assignment=assignment, grouping=course) db_session.flush() - assert db_session.scalars(svc.get_assignments(user.h_userid)).all() == [ - assignment - ] + query_parameters = {} - def test_get_assignments_with_h_userids(self, svc, db_session): - factories.User() # User not in assignment - assignment = factories.Assignment() - user = factories.User() - lti_role = factories.LTIRole(scope=RoleScope.COURSE) - factories.AssignmentMembership.create( - assignment=assignment, user=user, lti_role=lti_role - ) - # Other membership record, with a different role - factories.AssignmentMembership.create( - assignment=assignment, user=user, lti_role=factories.LTIRole() - ) + if instructor_h_userid: + query_parameters["instructor_h_userid"] = user.h_userid - db_session.flush() + if course_id: + query_parameters["course_id"] = course.id - assert db_session.scalars( - svc.get_assignments(h_userids=[user.h_userid]) - ).all() == [assignment] + if h_userids: + query_parameters["h_userids"] = [user.h_userid] + + query = svc.get_assignments(**query_parameters) + + if not query_parameters: + assert set(db_session.scalars(query).all()) == set( + [assignment] + with_assignment_noise + ) + + else: + assert db_session.scalars(query).all() == [assignment] def test_get_assignments_by_course_id_with_duplicate(self, db_session, svc): course = factories.Course() @@ -357,14 +354,16 @@ def non_matching_params(self, request, matching_params): @pytest.fixture(autouse=True) def with_assignment_noise(self, assignment): - factories.Assignment( - tool_consumer_instance_guid=assignment.tool_consumer_instance_guid, - resource_link_id="noise_resource_link_id", - ) - factories.Assignment( - tool_consumer_instance_guid="noise_tool_consumer_instance_guid", - resource_link_id=assignment.resource_link_id, - ) + return [ + factories.Assignment( + tool_consumer_instance_guid=assignment.tool_consumer_instance_guid, + resource_link_id="noise_resource_link_id", + ), + factories.Assignment( + tool_consumer_instance_guid="noise_tool_consumer_instance_guid", + resource_link_id=assignment.resource_link_id, + ), + ] @pytest.fixture def create_assignment(self, svc):