Skip to content

Commit

Permalink
Inference Fix (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 authored Oct 8, 2024
1 parent a6c99f9 commit e525054
Show file tree
Hide file tree
Showing 19 changed files with 306 additions and 199 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ jobs:
with:
ref: ${{ github.head_ref }}

- name: Install pre-commit
run: python -m pip install 'pre-commit<4.0.0'

- name: Run pre-commit
uses: pre-commit/[email protected]

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,5 @@ mlruns
wandb
tests/_data
tests/integration/save-directory
tests/integration/infer-save-directory
data
9 changes: 6 additions & 3 deletions luxonis_train/attached_modules/base_attached_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def get_input_tensors(
return inputs[self.node_tasks[self.required_labels[0]]]

def prepare(
self, inputs: Packet[Tensor], labels: Labels
self, inputs: Packet[Tensor], labels: Labels | None
) -> tuple[Unpack[Ts]]:
"""Prepares node outputs for the forward pass of the module.
Expand All @@ -287,8 +287,9 @@ def prepare(
@type inputs: L{Packet}[Tensor]
@param inputs: Output from the node, inputs to the attached module.
@type labels: L{Labels}
@param labels: Labels from the dataset.
@type labels: L{Labels} | None
@param labels: Labels from the dataset. If not provided, empty labels are used.
This is useful in visualizers for working with standalone images.
@rtype: tuple[Unpack[Ts]]
@return: Prepared inputs. Should allow the following usage with the
Expand Down Expand Up @@ -325,6 +326,8 @@ def prepare(
set(self.supported_tasks) & set(self.node_tasks)
)
x = self.get_input_tensors(inputs)
if labels is None:
return x, None # type: ignore
label, task_type = self._get_label(labels)
if task_type in [TaskType.CLASSIFICATION, TaskType.SEGMENTATION]:
if len(x) == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class ObjectKeypointSimilarity(

def __init__(
self,
n_keypoints: int | None = None,
sigmas: list[float] | None = None,
area_factor: float | None = None,
use_cocoeval_oks: bool = True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def run(
label_canvas: Tensor,
prediction_canvas: Tensor,
inputs: Packet[Tensor],
labels: Labels,
labels: Labels | None,
) -> Tensor | tuple[Tensor, Tensor] | tuple[Tensor, list[Tensor]]:
return self(
label_canvas, prediction_canvas, *self.prepare(inputs, labels)
Expand Down
25 changes: 14 additions & 11 deletions luxonis_train/attached_modules/visualizers/bbox_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def forward(
label_canvas: Tensor,
prediction_canvas: Tensor,
predictions: list[Tensor],
targets: Tensor,
) -> tuple[Tensor, Tensor]:
targets: Tensor | None,
) -> tuple[Tensor, Tensor] | Tensor:
"""Creates a visualization of the bounding box predictions and
labels.
Expand All @@ -189,26 +189,29 @@ def forward(
@type targets: Tensor
@param targets: The target bounding boxes.
"""
targets_viz = self.draw_targets(
label_canvas,
targets,
color_dict=self.colors,
predictions_viz = self.draw_predictions(
prediction_canvas,
predictions,
label_dict=self.bbox_labels,
color_dict=self.colors,
draw_labels=self.draw_labels,
fill=self.fill,
font=self.font,
font_size=self.font_size,
width=self.width,
)
predictions_viz = self.draw_predictions(
prediction_canvas,
predictions,
label_dict=self.bbox_labels,
if targets is None:
return predictions_viz

targets_viz = self.draw_targets(
label_canvas,
targets,
color_dict=self.colors,
label_dict=self.bbox_labels,
draw_labels=self.draw_labels,
fill=self.fill,
font=self.font,
font_size=self.font_size,
width=self.width,
)
return targets_viz, predictions_viz.to(targets_viz.device)
return targets_viz, predictions_viz
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import Tensor

from luxonis_train.enums import TaskType
from luxonis_train.utils import Labels, Packet

from .base_visualizer import BaseVisualizer
from .utils import figure_to_torch, numpy_to_torch_img, torch_img_to_numpy
Expand Down Expand Up @@ -56,29 +57,38 @@ def _generate_plot(
ax.grid(True)
return figure_to_torch(fig, width, height)

def prepare(
self, inputs: Packet[Tensor], labels: Labels | None
) -> tuple[Tensor, Tensor]:
predictions, targets = super().prepare(inputs, labels)
if isinstance(predictions, list):
predictions = predictions[0]
return predictions, targets

def forward(
self,
label_canvas: Tensor,
prediction_canvas: Tensor,
predictions: Tensor,
labels: Tensor,
targets: Tensor | None,
) -> Tensor | tuple[Tensor, Tensor]:
overlay = torch.zeros_like(label_canvas)
plots = torch.zeros_like(prediction_canvas)
for i in range(len(overlay)):
prediction = predictions[i]
gt = self._get_class_name(labels[i])
arr = torch_img_to_numpy(label_canvas[i].clone())
curr_class = self._get_class_name(prediction)
arr = cv2.putText(
arr,
f"GT: {gt}",
(5, 10),
cv2.FONT_HERSHEY_SIMPLEX,
self.font_scale,
self.color,
self.thickness,
)
if targets is not None:
gt = self._get_class_name(targets[i])
arr = cv2.putText(
arr,
f"GT: {gt}",
(5, 10),
cv2.FONT_HERSHEY_SIMPLEX,
self.font_scale,
self.color,
self.thickness,
)
arr = cv2.putText(
arr,
f"Pred: {curr_class}",
Expand Down
21 changes: 12 additions & 9 deletions luxonis_train/attached_modules/visualizers/keypoint_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,9 @@ def forward(
label_canvas: Tensor,
prediction_canvas: Tensor,
predictions: list[Tensor],
targets: Tensor,
targets: Tensor | None,
**kwargs,
) -> tuple[Tensor, Tensor]:
target_viz = self.draw_targets(
label_canvas,
targets,
colors=self.visible_color,
connectivity=self.connectivity,
**kwargs,
)
) -> tuple[Tensor, Tensor] | Tensor:
pred_viz = self.draw_predictions(
prediction_canvas,
predictions,
Expand All @@ -113,4 +106,14 @@ def forward(
visibility_threshold=self.visibility_threshold,
**kwargs,
)
if targets is None:
return pred_viz

target_viz = self.draw_targets(
label_canvas,
targets,
colors=self.visible_color,
connectivity=self.connectivity,
**kwargs,
)
return target_viz, pred_viz
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def forward(
label_canvas: Tensor,
prediction_canvas: Tensor,
outputs: Packet[Tensor],
labels: Labels,
) -> tuple[Tensor, Tensor]:
labels: Labels | None,
) -> tuple[Tensor, Tensor] | Tensor:
for visualizer in self.visualizers:
match visualizer.run(
label_canvas, prediction_canvas, outputs, labels
Expand All @@ -57,4 +57,6 @@ def forward(
raise NotImplementedError(
"Unexpected return type from visualizer."
)
if labels is None:
return prediction_canvas
return label_canvas, prediction_canvas
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor

from luxonis_train.enums import TaskType
from luxonis_train.utils import Labels, Packet

from .base_visualizer import BaseVisualizer
from .utils import (
Expand Down Expand Up @@ -95,14 +96,22 @@ def draw_targets(

return viz

def prepare(
self, inputs: Packet[Tensor], labels: Labels | None
) -> tuple[Tensor, Tensor]:
predictions, targets = super().prepare(inputs, labels)
if isinstance(predictions, list):
predictions = predictions[0]
return predictions, targets

def forward(
self,
label_canvas: Tensor,
prediction_canvas: Tensor,
predictions: Tensor,
targets: Tensor,
targets: Tensor | None,
**kwargs,
) -> tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor] | Tensor:
"""Creates a visualization of the segmentation predictions and
labels.
Expand All @@ -118,18 +127,21 @@ def forward(
@return: A tuple of the label and prediction visualizations.
"""

targets_vis = self.draw_targets(
label_canvas,
targets,
predictions_vis = self.draw_predictions(
prediction_canvas,
predictions,
colors=self.colors,
alpha=self.alpha,
background_class=self.background_class,
background_color=self.background_color,
**kwargs,
)
predictions_vis = self.draw_predictions(
prediction_canvas,
predictions,
if targets is None:
return predictions_vis

targets_vis = self.draw_targets(
label_canvas,
targets,
colors=self.colors,
alpha=self.alpha,
background_class=self.background_class,
Expand Down
6 changes: 1 addition & 5 deletions luxonis_train/attached_modules/visualizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,7 @@ def get_unnormalized_images(cfg: Config, inputs: dict[str, Tensor]) -> Tensor:
if cfg.trainer.preprocessing.normalize.active:
mean = normalize_params.get("mean", [0.485, 0.456, 0.406])
std = normalize_params.get("std", [0.229, 0.224, 0.225])
return preprocess_images(
images,
mean=mean,
std=std,
)
return preprocess_images(images, mean=mean, std=std)


def number_to_hsl(seed: int) -> tuple[float, float, float]:
Expand Down
35 changes: 20 additions & 15 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
from .utils.infer_utils import (
IMAGE_FORMATS,
VIDEO_FORMATS,
process_dataset_images,
process_images,
process_video,
infer_from_dataset,
infer_from_directory,
infer_from_video,
)
from .utils.train_utils import create_trainer

Expand Down Expand Up @@ -466,25 +466,30 @@ def infer(
weights = weights or self.cfg.model.weights

with replace_weights(self.lightning_module, weights):
if source_path:
source_path_obj = Path(source_path)
if source_path_obj.suffix.lower() in VIDEO_FORMATS:
process_video(self, source_path_obj, view, save_dir)
elif source_path_obj.is_file():
process_images(self, [source_path_obj], view, save_dir)
elif source_path_obj.is_dir():
image_files = [
if save_dir is not None:
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
if source_path is not None:
source_path = Path(source_path)
if source_path.suffix.lower() in VIDEO_FORMATS:
infer_from_video(
self, video_path=source_path, save_dir=save_dir
)
elif source_path.is_file():
infer_from_directory(self, [source_path], save_dir)
elif source_path.is_dir():
image_files = (
f
for f in source_path_obj.iterdir()
for f in source_path.iterdir()
if f.suffix.lower() in IMAGE_FORMATS
]
process_images(self, image_files, view, save_dir)
)
infer_from_directory(self, image_files, save_dir)
else:
raise ValueError(
f"Source path {source_path} is not a valid file or directory."
)
else:
process_dataset_images(self, view, save_dir)
infer_from_dataset(self, view, save_dir)

def tune(self) -> None:
"""Runs Optuna tunning of hyperparameters."""
Expand Down
Loading

0 comments on commit e525054

Please sign in to comment.