diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 4d0630f23..d0b32afad 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -11,7 +11,10 @@ from attrs.validators import deep_iterable, instance_of import numpy as np + from sleap.util import deep_iterable_converter + +# from sleap.io.dataset import Labels # TODO(LM): Circular import, implement Observer from sleap.io.video import Video @@ -408,9 +411,11 @@ class RecordingSession: not linked to a `Video`. """ + # TODO(LM): Consider implementing Observer pattern for `camera_cluster` and `labels` camera_cluster: CameraCluster = field(factory=CameraCluster) metadata: dict = field(factory=dict) _video_by_camcorder: Dict[Camcorder, Video] = field(factory=dict) + labels: Optional["Labels"] = None @property def videos(self) -> List[Video]: @@ -516,6 +521,10 @@ def add_video(self, video: Video, camcorder: Camcorder): # Add camcorder-to-video (1-to-1) map to `RecordingSession` self._video_by_camcorder[camcorder] = video + # Update labels cache + if self.labels is not None: + self.labels.update_session(self, video) + def remove_video(self, video: Video): """Removes a `Video` from the `RecordingSession`. @@ -536,6 +545,10 @@ def remove_video(self, video: Video): # Remove camcorder-to-video map from `RecordingSession` self._video_by_camcorder.pop(camcorder) + # Update labels cache + if self.labels is not None and self.labels.get_session(video) is not None: + self.labels.remove_session_video(self, video) + def __attrs_post_init__(self): self.camera_cluster.add_session(self) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index a9782bad7..cc3090714 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -103,34 +103,85 @@ class LabelsDataCache: def __attrs_post_init__(self): self.update() - def update(self, new_frame: Optional[LabeledFrame] = None): + def rebuild_cache(self): + """(Re)builds the cache from scratch.""" + + self._lf_by_video = {video: [] for video in self.labels.videos} + self._frame_idx_map = dict() + self._track_occupancy = dict() + self._frame_count_cache = dict() + self._session_by_video: Dict[Video, RecordingSession] = dict() + + # Loop through labeled frames only once + for lf in self.labels: + self._lf_by_video[lf.video].append(lf) + + # Loop through videos a second time after _lf_by_video is created + for video in self.labels.videos: + self._frame_idx_map[video] = { + lf.frame_idx: lf for lf in self._lf_by_video[video] + } + self._track_occupancy[video] = self._make_track_occupancy(video) + + # Loop S X V times to build session-by-video map + for session in self.labels.sessions: + for video in session.videos: + self._session_by_video[video] = session + + def add_labeled_frame(self, new_frame: LabeledFrame): + """Add a new labeled frame to the cache. + + Args: + new_frame: The new labeled frame to add. + """ + new_vid = new_frame.video + + if new_vid not in self._lf_by_video: + self._lf_by_video[new_vid] = [] + if new_vid not in self._frame_idx_map: + self._frame_idx_map[new_vid] = dict() + self._lf_by_video[new_vid].append(new_frame) + self._frame_idx_map[new_vid][new_frame.frame_idx] = new_frame + + def add_recording_session(self, new_session: RecordingSession): + """Add a new recording session to the cache. + + Args: + new_session: The new recording session to add. + """ + + for video in new_session.videos: + self._session_by_video[video] = new_session + + def add_video_to_session(self, session: RecordingSession, new_video: Video): + """Add a new video to a recording session in the cache. + + Args: + new_video: The new video to add. + session: The recording session to add the video to. + """ + + self._session_by_video[new_video] = session + + def update( + self, + new_item: Optional[ + Union[LabeledFrame, RecordingSession, Tuple[RecordingSession, Video]] + ] = None, + ): """Build (or rebuilds) various caches.""" # Data structures for caching - if new_frame is None: - self._lf_by_video = {video: [] for video in self.labels.videos} - self._frame_idx_map = dict() - self._track_occupancy = dict() - self._frame_count_cache = dict() - - # Loop through labeled frames only once - for lf in self.labels: - self._lf_by_video[lf.video].append(lf) - - # Loop through videos a second time after _lf_by_video is created - for video in self.labels.videos: - self._frame_idx_map[video] = { - lf.frame_idx: lf for lf in self._lf_by_video[video] - } - self._track_occupancy[video] = self._make_track_occupancy(video) - else: - new_vid = new_frame.video + if new_item is None: + self.rebuild_cache() + + elif isinstance(new_item, LabeledFrame): + self.add_labeled_frame(new_item) + + elif isinstance(new_item, RecordingSession): + self.add_recording_session(new_item) - if new_vid not in self._lf_by_video: - self._lf_by_video[new_vid] = [] - if new_vid not in self._frame_idx_map: - self._frame_idx_map[new_vid] = dict() - self._lf_by_video[new_vid].append(new_frame) - self._frame_idx_map[new_vid][new_frame.frame_idx] = new_frame + elif isinstance(new_item, tuple): + self.add_video_to_session(*new_item) def find_frames( self, video: Video, frame_idx: Optional[Union[int, Iterable[int]]] = None @@ -218,10 +269,13 @@ def remove_frame(self, frame: LabeledFrame): def remove_video(self, video: Video): """Remove video and update cache as needed.""" + if video in self._lf_by_video: del self._lf_by_video[video] if video in self._frame_idx_map: del self._frame_idx_map[video] + if video in self._session_by_video: + del self._session_by_video[video] def track_swap( self, @@ -420,7 +474,7 @@ class Labels(MutableSequence): nodes: List[Node] = attr.ib(default=attr.Factory(list)) tracks: List[Track] = attr.ib(default=attr.Factory(list)) suggestions: List[SuggestionFrame] = attr.ib(default=attr.Factory(list)) - sessions: List[RecordingSession] = attr.ib(default=attr.Factory(list)) + _sessions: List[RecordingSession] = attr.ib(default=attr.Factory(list)) negative_anchors: Dict[Video, list] = attr.ib(default=attr.Factory(dict)) provenance: Dict[Text, Union[str, int, float, bool]] = attr.ib( default=attr.Factory(dict) @@ -438,13 +492,18 @@ def __attrs_post_init__(self): # frames but not in the lists on our object self._update_from_labels() - # Update caches used to find frames by frame index + # Create cache to find frames by frame index and `RecordingSession`s by `Video`s self._cache = LabelsDataCache(self) # Create a variable to store a temporary storage directory # used when we unzip self.__temp_dir = None + # TODO(LM): Add Observer pattern between `Labels`, `RecordingSession`s, + # `LabelsDataCache`, `MainWindow`, and others. + for session in self.sessions: + session.labels = self + def _update_from_labels(self, merge: bool = False): """Updates top level attributes with data from labeled frames. @@ -585,6 +644,19 @@ def has_missing_videos(self) -> bool: """Return True if any of the video files in the labels are missing.""" return any(video.is_missing for video in self.videos) + @property + def sessions(self) -> List[RecordingSession]: + """Return a list of sessions in the labels.""" + return self._sessions + + @sessions.setter + def sessions(self, value: RecordingSession): + """Set the sessions in the labels.""" + raise ValueError( + "Direct assignment to `Labels.sessions` is not allowed. " + "Please use `Labels.add_session` to add a session." + ) + def __len__(self) -> int: """Return number of labeled frames.""" return len(self.labeled_frames) @@ -1597,7 +1669,49 @@ def add_session(self, session: RecordingSession): ) if session not in self.sessions: - self.sessions.append(session) + self._sessions.append(session) + session.labels = self + + self._cache.update(session) + + def update_session(self, session: RecordingSession, video: Video = None): + """Update `Video` to `RecordingSession` map in the `LabelsDataCache`. + + Args: + session: `RecordingSession` instance + video: `Video` instance linked to a `RecordingSession` + """ + + if session not in self.sessions: + raise KeyError("Session is not in labels.") + if video is None: + new_item = session + else: + new_item = (session, video) + self._cache.update(new_item) + + def get_session(self, video: Video) -> Optional[RecordingSession]: + """Get the recording session associated with a video. + + Args: + video: `Video` instance + + Returns: + `RecordingSession` instance + """ + return self._cache._session_by_video.get(video, None) + + def remove_session_video(self, session: RecordingSession, video: Video): + """Remove a video from a recording session. + + Args: + session: `RecordingSession` instance + video: `Video` instance + """ + + self._cache._session_by_video.pop(video, None) + if video in session.videos: + session.remove_video(video) @classmethod def from_json(cls, *args, **kwargs): diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index cb8842ddc..faedaa501 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -1,18 +1,21 @@ import os + import pytest import numpy as np -from pathlib import Path, PurePath import sleap + +from pathlib import Path, PurePath + from sleap.io.cameras import RecordingSession from sleap.skeleton import Skeleton from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track from sleap.io.video import Video, MediaVideo from sleap.io.dataset import Labels, load_file -from sleap.io.legacy import load_labels_json_old from sleap.io.format.ndx_pose import NDXPoseAdaptor from sleap.io.format import filehandle from sleap.gui.suggestions import VideoFrameSuggestions, SuggestionFrame + from tests.io.test_formats import assert_read_labels_match TEST_H5_DATASET = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" @@ -996,8 +999,12 @@ def test_save_labels_with_sessions( for cam_1, cam_2 in zip(session, loaded_session): assert cam_1 == cam_2 + assert loaded_session.labels == loaded_labels + -def test_add_session(min_labels_slp: Labels, min_session_session: RecordingSession): +def test_add_session_and_update_session( + min_labels_slp: Labels, min_session_session: RecordingSession +): """Test that we can add a `RecordingSession` to a `Labels` object.""" labels = min_labels_slp @@ -1005,6 +1012,22 @@ def test_add_session(min_labels_slp: Labels, min_session_session: RecordingSessi labels.add_session(session) assert labels.sessions == [session] + assert labels._cache._session_by_video == dict() + + video = labels.videos[0] + session.add_video(video, session.camera_cluster.cameras[0]) + assert labels._cache._session_by_video == {video: session} + assert labels.get_session(video) == session + + labels.remove_session_video(session, video) + assert video not in session.videos + assert video not in labels._cache._session_by_video + + with pytest.raises(ValueError): + labels.sessions = [session] + + # Warning: we would like to prevent the following, but might require custom class + labels.sessions.append(session) def test_labels_hdf5(multi_skel_vid_labels, tmpdir):