diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index c6c095cc5..4d34881d0 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -367,6 +367,12 @@ inference: tracking-only: +- name: batch_size + label: Batch Size + type: int + default: 4 + range: 1,512 + - name: tracking.tracker label: Tracker (cross-frame identity) Method type: stacked diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 5a807f919..184088897 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -605,6 +605,7 @@ def get_selected_frames_to_predict( def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInference: predict_frames_choice = pipeline_form_data.get("_predict_frames", "") + batch_size = pipeline_form_data.get("batch_size") frame_selection = self.get_selected_frames_to_predict(pipeline_form_data) frame_count = self.count_total_frames_for_selection_option(frame_selection) @@ -617,6 +618,7 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen ) ], total_frame_count=frame_count, + batch_size=batch_size, ) elif predict_frames_choice.startswith("suggested"): items_for_inference = runners.ItemsForInference( @@ -626,6 +628,7 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen ) ], total_frame_count=frame_count, + batch_size=batch_size, ) else: items_for_inference = runners.ItemsForInference.from_video_frames_dict( @@ -633,6 +636,7 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen total_frame_count=frame_count, labels_path=self.labels_filename, labels=self.labels, + batch_size=batch_size, ) return items_for_inference diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index ca60c4127..5e581a73e 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -151,6 +151,7 @@ class ItemsForInference: items: List[ItemForInference] total_frame_count: int + batch_size: int def __len__(self): return len(self.items) @@ -160,6 +161,7 @@ def from_video_frames_dict( cls, video_frames_dict: Dict[Video, List[int]], total_frame_count: int, + batch_size: int, labels: Labels, labels_path: Optional[str] = None, ): @@ -174,7 +176,9 @@ def from_video_frames_dict( video_idx=labels.videos.index(video), ) ) - return cls(items=items, total_frame_count=total_frame_count) + return cls( + items=items, total_frame_count=total_frame_count, batch_size=batch_size + ) @attr.s(auto_attribs=True)