Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(2->1) Add method to get RecordingSession via Video through Labels #1278

Open
wants to merge 11 commits into
base: liezl/acg-add-recording-session
Choose a base branch
from
Open
13 changes: 13 additions & 0 deletions sleap/io/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from attrs.validators import deep_iterable, instance_of
import numpy as np


from sleap.util import deep_iterable_converter

# from sleap.io.dataset import Labels # TODO(LM): Circular import, implement Observer
from sleap.io.video import Video


Expand Down Expand Up @@ -408,9 +411,11 @@ class RecordingSession:
not linked to a `Video`.
"""

# TODO(LM): Consider implementing Observer pattern for `camera_cluster` and `labels`
camera_cluster: CameraCluster = field(factory=CameraCluster)
metadata: dict = field(factory=dict)
_video_by_camcorder: Dict[Camcorder, Video] = field(factory=dict)
labels: Optional["Labels"] = None

@property
def videos(self) -> List[Video]:
Expand Down Expand Up @@ -516,6 +521,10 @@ def add_video(self, video: Video, camcorder: Camcorder):
# Add camcorder-to-video (1-to-1) map to `RecordingSession`
self._video_by_camcorder[camcorder] = video

# Update labels cache
if self.labels is not None:
self.labels.update_session(self, video)

def remove_video(self, video: Video):
"""Removes a `Video` from the `RecordingSession`.

Expand All @@ -536,6 +545,10 @@ def remove_video(self, video: Video):
# Remove camcorder-to-video map from `RecordingSession`
self._video_by_camcorder.pop(camcorder)

# Update labels cache
if self.labels is not None and self.labels.get_session(video) is not None:
self.labels.remove_session_video(self, video)

def __attrs_post_init__(self):
self.camera_cluster.add_session(self)

Expand Down
170 changes: 142 additions & 28 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,34 +103,85 @@ class LabelsDataCache:
def __attrs_post_init__(self):
self.update()

def update(self, new_frame: Optional[LabeledFrame] = None):
def rebuild_cache(self):
"""(Re)builds the cache from scratch."""

self._lf_by_video = {video: [] for video in self.labels.videos}
self._frame_idx_map = dict()
self._track_occupancy = dict()
self._frame_count_cache = dict()
self._session_by_video: Dict[Video, RecordingSession] = dict()

# Loop through labeled frames only once
for lf in self.labels:
self._lf_by_video[lf.video].append(lf)

# Loop through videos a second time after _lf_by_video is created
for video in self.labels.videos:
self._frame_idx_map[video] = {
lf.frame_idx: lf for lf in self._lf_by_video[video]
}
self._track_occupancy[video] = self._make_track_occupancy(video)

# Loop S X V times to build session-by-video map
for session in self.labels.sessions:
for video in session.videos:
self._session_by_video[video] = session

def add_labeled_frame(self, new_frame: LabeledFrame):
"""Add a new labeled frame to the cache.

Args:
new_frame: The new labeled frame to add.
"""
new_vid = new_frame.video

if new_vid not in self._lf_by_video:
self._lf_by_video[new_vid] = []
if new_vid not in self._frame_idx_map:
self._frame_idx_map[new_vid] = dict()
self._lf_by_video[new_vid].append(new_frame)
self._frame_idx_map[new_vid][new_frame.frame_idx] = new_frame

def add_recording_session(self, new_session: RecordingSession):
"""Add a new recording session to the cache.

Args:
new_session: The new recording session to add.
"""

for video in new_session.videos:
self._session_by_video[video] = new_session

def add_video_to_session(self, session: RecordingSession, new_video: Video):
"""Add a new video to a recording session in the cache.

Args:
new_video: The new video to add.
session: The recording session to add the video to.
"""

self._session_by_video[new_video] = session

def update(
self,
new_item: Optional[
Union[LabeledFrame, RecordingSession, Tuple[RecordingSession, Video]]
] = None,
):
"""Build (or rebuilds) various caches."""
# Data structures for caching
if new_frame is None:
self._lf_by_video = {video: [] for video in self.labels.videos}
self._frame_idx_map = dict()
self._track_occupancy = dict()
self._frame_count_cache = dict()

# Loop through labeled frames only once
for lf in self.labels:
self._lf_by_video[lf.video].append(lf)

# Loop through videos a second time after _lf_by_video is created
for video in self.labels.videos:
self._frame_idx_map[video] = {
lf.frame_idx: lf for lf in self._lf_by_video[video]
}
self._track_occupancy[video] = self._make_track_occupancy(video)
else:
new_vid = new_frame.video
if new_item is None:
self.rebuild_cache()

elif isinstance(new_item, LabeledFrame):
self.add_labeled_frame(new_item)

elif isinstance(new_item, RecordingSession):
self.add_recording_session(new_item)

if new_vid not in self._lf_by_video:
self._lf_by_video[new_vid] = []
if new_vid not in self._frame_idx_map:
self._frame_idx_map[new_vid] = dict()
self._lf_by_video[new_vid].append(new_frame)
self._frame_idx_map[new_vid][new_frame.frame_idx] = new_frame
elif isinstance(new_item, tuple):
self.add_video_to_session(*new_item)

def find_frames(
self, video: Video, frame_idx: Optional[Union[int, Iterable[int]]] = None
Expand Down Expand Up @@ -218,10 +269,13 @@ def remove_frame(self, frame: LabeledFrame):

def remove_video(self, video: Video):
"""Remove video and update cache as needed."""

if video in self._lf_by_video:
del self._lf_by_video[video]
if video in self._frame_idx_map:
del self._frame_idx_map[video]
if video in self._session_by_video:
del self._session_by_video[video]

def track_swap(
self,
Expand Down Expand Up @@ -420,7 +474,7 @@ class Labels(MutableSequence):
nodes: List[Node] = attr.ib(default=attr.Factory(list))
tracks: List[Track] = attr.ib(default=attr.Factory(list))
suggestions: List[SuggestionFrame] = attr.ib(default=attr.Factory(list))
sessions: List[RecordingSession] = attr.ib(default=attr.Factory(list))
_sessions: List[RecordingSession] = attr.ib(default=attr.Factory(list))
negative_anchors: Dict[Video, list] = attr.ib(default=attr.Factory(dict))
provenance: Dict[Text, Union[str, int, float, bool]] = attr.ib(
default=attr.Factory(dict)
Expand All @@ -438,13 +492,18 @@ def __attrs_post_init__(self):
# frames but not in the lists on our object
self._update_from_labels()

# Update caches used to find frames by frame index
# Create cache to find frames by frame index and `RecordingSession`s by `Video`s
self._cache = LabelsDataCache(self)

# Create a variable to store a temporary storage directory
# used when we unzip
self.__temp_dir = None

# TODO(LM): Add Observer pattern between `Labels`, `RecordingSession`s,
# `LabelsDataCache`, `MainWindow`, and others.
for session in self.sessions:
session.labels = self

def _update_from_labels(self, merge: bool = False):
"""Updates top level attributes with data from labeled frames.

Expand Down Expand Up @@ -585,6 +644,19 @@ def has_missing_videos(self) -> bool:
"""Return True if any of the video files in the labels are missing."""
return any(video.is_missing for video in self.videos)

@property
def sessions(self) -> List[RecordingSession]:
"""Return a list of sessions in the labels."""
return self._sessions

@sessions.setter
def sessions(self, value: RecordingSession):
"""Set the sessions in the labels."""
raise ValueError(
"Direct assignment to `Labels.sessions` is not allowed. "
"Please use `Labels.add_session` to add a session."
)

def __len__(self) -> int:
"""Return number of labeled frames."""
return len(self.labeled_frames)
Expand Down Expand Up @@ -1597,7 +1669,49 @@ def add_session(self, session: RecordingSession):
)

if session not in self.sessions:
self.sessions.append(session)
self._sessions.append(session)
session.labels = self

self._cache.update(session)

def update_session(self, session: RecordingSession, video: Video = None):
"""Update `Video` to `RecordingSession` map in the `LabelsDataCache`.

Args:
session: `RecordingSession` instance
video: `Video` instance linked to a `RecordingSession`
"""

if session not in self.sessions:
raise KeyError("Session is not in labels.")
if video is None:
new_item = session
else:
new_item = (session, video)
self._cache.update(new_item)

def get_session(self, video: Video) -> Optional[RecordingSession]:
"""Get the recording session associated with a video.

Args:
video: `Video` instance

Returns:
`RecordingSession` instance
"""
return self._cache._session_by_video.get(video, None)

def remove_session_video(self, session: RecordingSession, video: Video):
"""Remove a video from a recording session.

Args:
session: `RecordingSession` instance
video: `Video` instance
"""

self._cache._session_by_video.pop(video, None)
if video in session.videos:
session.remove_video(video)

@classmethod
def from_json(cls, *args, **kwargs):
Expand Down
29 changes: 26 additions & 3 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import os

import pytest
import numpy as np
from pathlib import Path, PurePath

import sleap

from pathlib import Path, PurePath

from sleap.io.cameras import RecordingSession
from sleap.skeleton import Skeleton
from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track
from sleap.io.video import Video, MediaVideo
from sleap.io.dataset import Labels, load_file
from sleap.io.legacy import load_labels_json_old
from sleap.io.format.ndx_pose import NDXPoseAdaptor
from sleap.io.format import filehandle
from sleap.gui.suggestions import VideoFrameSuggestions, SuggestionFrame

from tests.io.test_formats import assert_read_labels_match

TEST_H5_DATASET = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5"
Expand Down Expand Up @@ -996,15 +999,35 @@ def test_save_labels_with_sessions(
for cam_1, cam_2 in zip(session, loaded_session):
assert cam_1 == cam_2

assert loaded_session.labels == loaded_labels


def test_add_session(min_labels_slp: Labels, min_session_session: RecordingSession):
def test_add_session_and_update_session(
min_labels_slp: Labels, min_session_session: RecordingSession
):
"""Test that we can add a `RecordingSession` to a `Labels` object."""

labels = min_labels_slp
session = min_session_session

labels.add_session(session)
assert labels.sessions == [session]
assert labels._cache._session_by_video == dict()

video = labels.videos[0]
session.add_video(video, session.camera_cluster.cameras[0])
assert labels._cache._session_by_video == {video: session}
assert labels.get_session(video) == session

labels.remove_session_video(session, video)
assert video not in session.videos
assert video not in labels._cache._session_by_video

with pytest.raises(ValueError):
labels.sessions = [session]

# Warning: we would like to prevent the following, but might require custom class
labels.sessions.append(session)


roomrys marked this conversation as resolved.
Show resolved Hide resolved
def test_labels_hdf5(multi_skel_vid_labels, tmpdir):
Expand Down
Loading