diff --git a/sleap/gui/app.py b/sleap/gui/app.py index becc1d83a..3d1f7c443 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -788,6 +788,12 @@ def new_instance_menu_action(): "Delete Predictions beyond Max Instances...", self.commands.deleteInstanceLimitPredictions, ) + add_menu_item( + labelMenu, + "delete frame limit predictions", + "Delete Predictions beyond Frame Limit...", + self.commands.deleteFrameLimitPredictions, + ) ### Tracks Menu ### diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 1a64a071c..bd0997a59 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -494,6 +494,10 @@ def deleteInstanceLimitPredictions(self): """Gui for deleting instances beyond some number in each frame.""" self.execute(DeleteInstanceLimitPredictions) + def deleteFrameLimitPredictions(self): + """Gui for deleting instances beyond some frame number.""" + self.execute(DeleteFrameLimitPredictions) + def completeInstanceNodes(self, instance: Instance): """Adds missing nodes to given instance.""" self.execute(AddMissingInstanceNodes, instance=instance) @@ -2470,6 +2474,39 @@ def ask(cls, context: CommandContext, params: dict) -> bool: return super().ask(context, params) +class DeleteFrameLimitPredictions(InstanceDeleteCommand): + @staticmethod + def get_frame_instance_list(context: CommandContext, params: Dict): + """Called from the parent `InstanceDeleteCommand.ask` method. + + Returns: + List of instances to be deleted. + """ + predicted_instances = [] + # Select the instances to be deleted + for lf in context.labels.labeled_frames: + if lf.frame_idx >= params["frame_idx_threshold"]: + predicted_instances.extend( + [(lf, inst) for inst in lf.predicted_instances] + ) + return predicted_instances + + @classmethod + def ask(cls, context: CommandContext, params: Dict) -> bool: + current_video = context.state["video"] + frame_idx_thresh, okay = QtWidgets.QInputDialog.getInt( + context.app, + "Delete Instance beyond Frame Number...", + "Frame number after which instances to be deleted:", + 1, + 1, + len(current_video), + ) + if okay: + params["frame_idx_threshold"] = frame_idx_thresh + return super().ask(context, params) + + class TransposeInstances(EditCommand): topics = [UpdateTopic.project_instances, UpdateTopic.tracks] diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 899b1f4a0..cc9267858 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -16,6 +16,7 @@ ReplaceVideo, OpenSkeleton, SaveProjectAs, + DeleteFrameLimitPredictions, get_new_version_filename, ) from sleap.instance import Instance, LabeledFrame @@ -847,6 +848,27 @@ def load_and_assert_changes(new_video_path: Path): shutil.move(new_video_path, expected_video_path) +def test_DeleteFrameLimitPredictions( + centered_pair_predictions: Labels, centered_pair_vid: Video +): + """Test deleting instances beyond a certain frame limit.""" + labels = centered_pair_predictions + + # Set-up command context + context = CommandContext.from_labels(labels) + context.state["video"] = centered_pair_vid + + # Set-up params for the command + params = {"frame_idx_threshold": 900} + + expected_instances = 423 + predicted_instances = DeleteFrameLimitPredictions.get_frame_instance_list( + context, params + ) + + assert len(predicted_instances) == expected_instances + + @pytest.mark.parametrize("export_extension", [".json.zip", ".slp"]) def test_exportLabelsPackage(export_extension, centered_pair_labels: Labels, tmpdir): def assert_loaded_package_similar(path_to_pkg: Path, sugg=False, pred=False):