diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index 89c18f67..af1a7e6a 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, cast import torch import torch.nn.functional as F @@ -209,8 +209,11 @@ def forward( def _preprocess_target(self, target: Tensor, batch_size: int, scale_tensor: Tensor): """Preprocess target in shape [batch_size, N, 5] where N is maximum number of instances in one image.""" - sample_ids, counts = torch.unique(target[:, 0].int(), return_counts=True) - out_target = torch.zeros(batch_size, counts.max(), 5, device=target.device) + sample_ids, counts = cast( + tuple[Tensor, Tensor], torch.unique(target[:, 0].int(), return_counts=True) + ) + c_max = int(counts.max()) if counts.numel() > 0 else 0 + out_target = torch.zeros(batch_size, c_max, 5, device=target.device) out_target[:, :, 0] = -1 for id, count in zip(sample_ids, counts): out_target[id, :count] = target[target[:, 0] == id][:, 1:]