diff --git a/lms/models/lms_course.py b/lms/models/lms_course.py index 7cc5022a3e..7a5dd124a9 100644 --- a/lms/models/lms_course.py +++ b/lms/models/lms_course.py @@ -51,7 +51,7 @@ class LMSCourse(CreatedUpdatedMixin, Base): """The start date of the course. Only for when we get this information directly from the LMS""" ends_at: Mapped[datetime | None] = mapped_column() - """The end date of the course. Only for when we get this information direclty from the LMS""" + """The end date of the course. Only for when we get this information directly from the LMS""" class LMSCourseApplicationInstance(CreatedUpdatedMixin, Base): diff --git a/lms/services/course.py b/lms/services/course.py index 8b4bacd76c..8976ea1ff3 100644 --- a/lms/services/course.py +++ b/lms/services/course.py @@ -1,6 +1,9 @@ import json from copy import deepcopy +from datetime import datetime +from typing import Mapping +from dateutil import parser from sqlalchemy import Select, select, union from lms.db import full_text_match @@ -311,6 +314,8 @@ def _upsert_lms_course(self, course: Course, lti_params: LTIParams) -> LMSCourse "https://purl.imsglobal.org/spec/lti-nrps/claim/namesroleservice", {} ).get("context_memberships_url") + course_starts_at, course_ends_at = self._get_course_dates(lti_params) + lms_course = bulk_upsert( self._db, LMSCourse, @@ -321,10 +326,18 @@ def _upsert_lms_course(self, course: Course, lti_params: LTIParams) -> LMSCourse "h_authority_provided_id": course.authority_provided_id, "name": course.lms_name, "lti_context_memberships_url": lti_context_membership_url, + "starts_at": course_starts_at, + "ends_at": course_ends_at, } ], index_elements=["h_authority_provided_id"], - update_columns=["updated", "name", "lti_context_memberships_url"], + update_columns=[ + "updated", + "name", + "lti_context_memberships_url", + "starts_at", + "ends_at", + ], ).one() bulk_upsert( self._db, @@ -415,6 +428,22 @@ def _get_copied_from_course(self, lti_params) -> Course | None: return None + def _get_course_dates( + self, lti_params: Mapping + ) -> tuple[datetime | None, datetime | None]: + """Get the dates for the current curse, None if not available.""" + try: + course_starts_at = parser.isoparse(lti_params.get("custom_course_starts")) + except (TypeError, ValueError): + course_starts_at = None + + try: + course_ends_at = parser.isoparse(lti_params.get("custom_course_ends")) + except (TypeError, ValueError): + course_ends_at = None + + return course_starts_at, course_ends_at + def course_service_factory(_context, request): return CourseService( diff --git a/tests/unit/lms/services/course_test.py b/tests/unit/lms/services/course_test.py index 30615ccd86..fab386687e 100644 --- a/tests/unit/lms/services/course_test.py +++ b/tests/unit/lms/services/course_test.py @@ -1,4 +1,4 @@ -from datetime import date, datetime +from datetime import UTC, date, datetime from unittest.mock import call, patch, sentinel import pytest @@ -161,9 +161,29 @@ def test_get_from_launch_when_new_and_historical_course_exists( ) assert course == upsert_course.return_value + @pytest.mark.parametrize( + "custom_course_starts, course_starts_at", + [(None, None), ("2022-01-01T00:00:00Z", datetime(2022, 1, 1, tzinfo=UTC))], + ) + @pytest.mark.parametrize( + "custom_course_ends, course_ends_at", + [(None, None), ("2022-01-01T00:00:00Z", datetime(2022, 1, 1, tzinfo=UTC))], + ) def test_upsert_course( - self, svc, grouping_service, bulk_upsert, db_session, lti_params + self, + svc, + grouping_service, + bulk_upsert, + db_session, + lti_params, + custom_course_starts, + course_starts_at, + custom_course_ends, + course_ends_at, ): + lti_params["custom_course_starts"] = custom_course_starts + lti_params["custom_course_ends"] = custom_course_ends + course = svc.upsert_course( lti_params=lti_params, extra=sentinel.extra, @@ -196,10 +216,18 @@ def test_upsert_course( "h_authority_provided_id": course.authority_provided_id, "name": course.lms_name, "lti_context_memberships_url": None, + "starts_at": course_starts_at, + "ends_at": course_ends_at, } ], index_elements=["h_authority_provided_id"], - update_columns=["updated", "name", "lti_context_memberships_url"], + update_columns=[ + "updated", + "name", + "lti_context_memberships_url", + "starts_at", + "ends_at", + ], ), call().one(), call(