Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it a error in loss function ? #2

Open
YChienHung opened this issue Dec 3, 2024 · 1 comment
Open

Is it a error in loss function ? #2

YChienHung opened this issue Dec 3, 2024 · 1 comment

Comments

@YChienHung
Copy link

    def compute(self, data: Dict[str, torch.Tensor],
                num_objects: List[int]) -> Dict[str, torch.Tensor]:
        batch_size, num_frames = data['rgb'].shape[:2]
        losses = defaultdict(float)
        t_range = range(1, num_frames)

        for bi in range(batch_size):
            logits = torch.stack(
                [data[f'logits_{ti}'][bi, :num_objects[bi] + 1] for ti in t_range], dim=0)

            cls_gt = data['cls_gt'][bi, 1:]  # remove gt for the first frame
            soft_gt = cls_to_one_hot(cls_gt, num_objects[bi])

            loss_ce, loss_dice = self.mask_loss(logits, soft_gt)
            losses['loss_ce'] += loss_ce / batch_size
            losses['loss_dice'] += loss_dice / batch_size
        
        # start
        aux = [data[f'aux_{ti}'] for ti in t_range]
        if 'sensory_logits' in aux[0]:
            sensory_log = torch.stack(
                [a['sensory_logits'][bi, :num_objects[bi] + 1] for a in aux], dim=0)
            loss_ce, loss_dice = self.mask_loss(sensory_log, soft_gt)
            losses['aux_sensory_ce'] += loss_ce / batch_size * self.sensory_weight
            losses['aux_sensory_dice'] += loss_dice / batch_size * self.sensory_weight

        if 'q_logits' in aux[0]:
            num_levels = aux[0]['q_logits'].shape[2]

            for l in range(num_levels):
                query_log = torch.stack(
                    [a['q_logits'][bi, :num_objects[bi] + 1, l] for a in aux], dim=0)

                loss_ce, loss_dice = self.mask_loss(query_log, soft_gt)
                losses[f'aux_query_ce_l{l}'] += loss_ce / batch_size * self.query_weight
                losses[f'aux_query_dice_l{l}'] += loss_dice / batch_size * self.query_weight
         # end

        losses['total_loss'] = sum(losses.values())

        return losses

I find that the code start from aux = [data[f'aux_{ti}'] for ti in t_range] is not in the loop, so is it a error ?

@qinliuliuqin
Copy link
Collaborator

@YChienHung Thanks for your good question. This function is inherited from Cutie. Probably not an error, but just a weird implementation. I will figure it out and get back to you soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants