Skip to content

Commit

Permalink
fixed crashing when no labels are present in the batch
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Jan 12, 2024
1 parent 270ec4f commit c559448
Showing 1 changed file with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, cast

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -209,8 +209,11 @@ def forward(
def _preprocess_target(self, target: Tensor, batch_size: int, scale_tensor: Tensor):
"""Preprocess target in shape [batch_size, N, 5] where N is maximum number of
instances in one image."""
sample_ids, counts = torch.unique(target[:, 0].int(), return_counts=True)
out_target = torch.zeros(batch_size, counts.max(), 5, device=target.device)
sample_ids, counts = cast(
tuple[Tensor, Tensor], torch.unique(target[:, 0].int(), return_counts=True)
)
c_max = int(counts.max()) if counts.numel() > 0 else 0
out_target = torch.zeros(batch_size, c_max, 5, device=target.device)
out_target[:, :, 0] = -1
for id, count in zip(sample_ids, counts):
out_target[id, :count] = target[target[:, 0] == id][:, 1:]
Expand Down

0 comments on commit c559448

Please sign in to comment.