Skip to content

Commit

Permalink
Add batch size to GUI for inference (#1771)
Browse files Browse the repository at this point in the history
  • Loading branch information
shrivaths16 authored May 15, 2024
1 parent 18aad91 commit 43a4f13
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
6 changes: 6 additions & 0 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@ inference:

tracking-only:

- name: batch_size
label: Batch Size
type: int
default: 4
range: 1,512

- name: tracking.tracker
label: Tracker (cross-frame identity) Method
type: stacked
Expand Down
4 changes: 4 additions & 0 deletions sleap/gui/learning/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ def get_selected_frames_to_predict(

def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInference:
predict_frames_choice = pipeline_form_data.get("_predict_frames", "")
batch_size = pipeline_form_data.get("batch_size")

frame_selection = self.get_selected_frames_to_predict(pipeline_form_data)
frame_count = self.count_total_frames_for_selection_option(frame_selection)
Expand All @@ -617,6 +618,7 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen
)
],
total_frame_count=frame_count,
batch_size=batch_size,
)
elif predict_frames_choice.startswith("suggested"):
items_for_inference = runners.ItemsForInference(
Expand All @@ -626,13 +628,15 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen
)
],
total_frame_count=frame_count,
batch_size=batch_size,
)
else:
items_for_inference = runners.ItemsForInference.from_video_frames_dict(
video_frames_dict=frame_selection,
total_frame_count=frame_count,
labels_path=self.labels_filename,
labels=self.labels,
batch_size=batch_size,
)
return items_for_inference

Expand Down
6 changes: 5 additions & 1 deletion sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class ItemsForInference:

items: List[ItemForInference]
total_frame_count: int
batch_size: int

def __len__(self):
return len(self.items)
Expand All @@ -160,6 +161,7 @@ def from_video_frames_dict(
cls,
video_frames_dict: Dict[Video, List[int]],
total_frame_count: int,
batch_size: int,
labels: Labels,
labels_path: Optional[str] = None,
):
Expand All @@ -174,7 +176,9 @@ def from_video_frames_dict(
video_idx=labels.videos.index(video),
)
)
return cls(items=items, total_frame_count=total_frame_count)
return cls(
items=items, total_frame_count=total_frame_count, batch_size=batch_size
)


@attr.s(auto_attribs=True)
Expand Down

0 comments on commit 43a4f13

Please sign in to comment.