Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase the number of negatives in CachedGISTEmbedLoss to 2 #2932

Open
daegonYu opened this issue Sep 12, 2024 · 6 comments
Open

Increase the number of negatives in CachedGISTEmbedLoss to 2 #2932

daegonYu opened this issue Sep 12, 2024 · 6 comments

Comments

@daegonYu
Copy link
Contributor

I am trying to train by increasing the number of negatives in CachedGISTEmbedLoss to 2 as shown below. Is there any theoretical problem that could occur during training? Training proceeds without error.

    # 추가적인 초기화가 필요 없는 경우
    class CachedGISTEmbedLoss_modified(losses.CachedGISTEmbedLoss):
        # def __init__(self, *args, **kargs):
        #     super().__init__(*args, **kargs)

        def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor:
            """Calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
            if len(reps) == 2:
                anchor, positive = reps
                anchor_guide, positive_guide = reps_guided
                negative = None
                negative_guide = None
            elif len(reps) == 3:
                anchor, positive, negative = reps
                anchor_guide, positive_guide, negative_guide = reps_guided
            elif len(reps) == 4:
                anchor, positive, negative, negative_extra = reps
                anchor_guide, positive_guide, negative_guide, negative_extra_guide = reps_guided
            else:
                raise ValueError("Expected 2, 3, or 4 embeddings, got {}".format(len(reps)))

            # Concatenate embeddings along the batch dimension
            anchor = torch.cat(anchor, dim=0)
            positive = torch.cat(positive, dim=0)
            anchor_guide = torch.cat(anchor_guide, dim=0)
            positive_guide = torch.cat(positive_guide, dim=0)
            
            # Handle the case where we have a negative sample
            if negative:
                negative = torch.cat(negative, dim=0)
                negative_guide = torch.cat(negative_guide, dim=0)

            # Handle the case where we have an extra negative sample (4 embeddings case)
            if len(reps) == 4:
                negative_extra = torch.cat(negative_extra, dim=0)
                negative_extra_guide = torch.cat(negative_extra_guide, dim=0)

            labels = torch.arange(anchor.size(0)).long().to(anchor.device)
            batch_size = anchor.shape[0]

            losses: List[torch.Tensor] = []
            for b in tqdm.trange(
                0,
                batch_size,
                self.mini_batch_size,
                desc="Preparing caches",
                disable=not self.show_progress_bar,
            ):
                e = b + self.mini_batch_size
                # Compute the similarity matrices for anchor and positive samples.
                guided_ap_sim = self.sim_matrix(anchor_guide[b:e], positive_guide)
                guided_aa_sim = self.sim_matrix(anchor_guide[b:e], anchor_guide)
                guided_pp_sim = self.sim_matrix(positive_guide[b:e], positive_guide)
                guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1)

                # Compute similarity scores for current mini-batch.
                ap_sim = self.sim_matrix(anchor[b:e], positive)  # (mbsz,bsz)
                aa_sim = self.sim_matrix(anchor[b:e], anchor)
                pp_sim = self.sim_matrix(positive[b:e], positive)

                ap_sim[guided_ap_sim > guided_sim] = -torch.inf
                aa_sim[guided_aa_sim > guided_sim] = -torch.inf
                pp_sim[guided_pp_sim >= guided_sim] = -torch.inf

                scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)

                # Handle the case where we have a negative sample
                if negative is not None:
                    guided_an_sim = self.sim_matrix(anchor_guide[b:e], negative_guide)
                    an_sim = self.sim_matrix(anchor[b:e], negative)
                    an_sim[guided_an_sim > guided_sim] = -torch.inf
                    scores = torch.cat([scores, an_sim], dim=1)

                # Handle the case where we have an extra negative sample
                if len(reps) == 4:
                    guided_ane_sim = self.sim_matrix(anchor_guide[b:e], negative_extra_guide)
                    ane_sim = self.sim_matrix(anchor[b:e], negative_extra)
                    ane_sim[guided_ane_sim > guided_sim] = -torch.inf
                    scores = torch.cat([scores, ane_sim], dim=1)

                scores = scores / self.temperature
                loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
                loss_mbatch.backward()
                losses.append(loss_mbatch.detach())

            loss = sum(losses).requires_grad_()

            self.cache = [[r.grad for r in rs] for rs in reps]  # e.g. 3 * bsz/mbsz * (mbsz, hdim)

            return loss

    loss = CachedGISTEmbedLoss_modified(model=model, guide=guide, mini_batch_size=mini_batch_size)     
@tomaarsen
Copy link
Collaborator

Hello!

This looks correct to me, well done. I think the only missed case is that if you have an eval_dataset, then the evaluation will use calculate_loss rather than calculate_loss_and_cache_gradients, and the former method still only accepts 1 negative.
But the evaluation is just for printing the loss to make sure that you're not overfitting (but overfitting doesn't happen very often with Sentence Transformer models as far as I've noticed).

Ideally I'd update all of the in-batch negatives losses to be able to accept any number of negatives. MultipleNegativesRankingLoss already does, it's just not very well documented.

  • Tom Aarsen

@daegonYu
Copy link
Contributor Author

Thank you for your advice. As you said, there was an error in calculate_loss, so I fixed it as below and it works fine.

    def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor:
        """Calculate the cross-entropy loss. No need to cache the gradients."""
        if len(reps) == 2:
            anchor, positive = reps
            anchor_guide, positive_guide = reps_guided
            negative = None
            negative_guide = None
        elif len(reps) == 3:
            anchor, positive, negative = reps
            anchor_guide, positive_guide, negative_guide = reps_guided
        elif len(reps) == 4:
            anchor, positive, negative, negative_extra = reps
            anchor_guide, positive_guide, negative_guide, negative_extra_guide = reps_guided
        else:
            raise ValueError("Expected 2, 3, or 4 embeddings, got {}".format(len(reps)))

        anchor = torch.cat(anchor, dim=0)
        positive = torch.cat(positive, dim=0)
        anchor_guide = torch.cat(anchor_guide, dim=0)
        positive_guide = torch.cat(positive_guide, dim=0)
        # Handle the case where we have a negative sample
        if negative:
            negative = torch.cat(negative, dim=0)
            negative_guide = torch.cat(negative_guide, dim=0)

        # Handle the case where we have an extra negative sample (4 embeddings case)
        if len(reps) == 4:
            negative_extra = torch.cat(negative_extra, dim=0)
            negative_extra_guide = torch.cat(negative_extra_guide, dim=0)

        labels = torch.arange(anchor.size(0)).long().to(anchor.device)
        batch_size = anchor.shape[0]

        losses: list[torch.Tensor] = []
        for b in tqdm.trange(
            0,
            batch_size,
            self.mini_batch_size,
            desc="Preparing caches",
            disable=not self.show_progress_bar,
        ):
            e = b + self.mini_batch_size
            # Let's compute the similarity matrices for the combinations of anchor and positive samples.
            guided_ap_sim = self.sim_matrix(anchor_guide[b:e], positive_guide)
            guided_aa_sim = self.sim_matrix(anchor_guide[b:e], anchor_guide)
            guided_pp_sim = self.sim_matrix(positive_guide[b:e], positive_guide)
            # Define the anchor threshold
            guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1)

            # Compute similarity scores for current mini-batch.
            # anchor (mbsz,hdim), positive (bsz,hdim)
            ap_sim = self.sim_matrix(anchor[b:e], positive)  # (mbsz,bsz)
            aa_sim = self.sim_matrix(anchor[b:e], anchor)
            pp_sim = self.sim_matrix(positive[b:e], positive)

            # Find which samples cannot be used as negatives because they are
            # more similar to the query than the assigned positive as deemed by the guide model.
            # For these samples, we mask them with -inf to basically ignore their contribution to
            # the loss.

            ap_sim[guided_ap_sim > guided_sim] = -torch.inf
            aa_sim[guided_aa_sim > guided_sim] = -torch.inf
            pp_sim[guided_pp_sim > guided_sim] = -torch.inf

            scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)

            # Handle the case where we have a negative sample
            if negative is not None:
                guided_an_sim = self.sim_matrix(anchor_guide[b:e], negative_guide)
                an_sim = self.sim_matrix(anchor[b:e], negative)
                an_sim[guided_an_sim > guided_sim] = -torch.inf
                scores = torch.cat([scores, an_sim], dim=1)
                
            # Handle the case where we have an extra negative sample
            if len(reps) == 4:
                guided_ane_sim = self.sim_matrix(anchor_guide[b:e], negative_extra_guide)
                ane_sim = self.sim_matrix(anchor[b:e], negative_extra)
                ane_sim[guided_ane_sim > guided_sim] = -torch.inf
                scores = torch.cat([scores, ane_sim], dim=1)

            scores = scores / self.temperature
            loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
            losses.append(loss_mbatch)

        loss = sum(losses)
        return loss
    

I wrote it so that it works well for any number of negatives as below. Can you tell me what the problem is? The train loss is 0.000.

class CachedGISTEmbedLoss(losses.CachedGISTEmbedLoss):
    def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor:
        """Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
        if len(reps) != len(reps_guided):
            raise ValueError("reps and reps_guided must have the same length")

        # Concatenate embeddings along the batch dimension
        concatenated_reps = [torch.cat(rep, dim=0) for rep in reps]
        concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided]

        labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device)
        batch_size = concatenated_reps[0].shape[0]

        losses: List[torch.Tensor] = []
        for b in tqdm.trange(
            0,
            batch_size,
            self.mini_batch_size,
            desc="Preparing caches",
            disable=not self.show_progress_bar,
        ):
            e = b + self.mini_batch_size

            # Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples
            guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1])
            guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0])
            guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1])

            # Define the anchor threshold for each similarity matrix
            guided_sim = guided_ap_sim.diagonal(offset=0).view(-1, 1)

            # Compute similarity scores for the current mini-batch
            ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1])  # anchor-positive similarity
            aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0])  # anchor-anchor similarity
            pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1])  # positive-positive similarity

            # Apply thresholds based on guided model similarities
            ap_sim[guided_ap_sim > guided_sim] = -torch.inf
            aa_sim[guided_aa_sim > guided_sim] = -torch.inf
            pp_sim[guided_pp_sim > guided_sim] = -torch.inf

            # Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
            scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)

            # If there are negatives (len(reps) > 2), process them
            if len(concatenated_reps) > 2:
                for i in range(2, len(concatenated_reps)):  # Start from 2 since first 2 are anchor-positive
                    guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
                    neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
                    neg_sim[guided_neg_sim > guided_sim[0]] = -torch.inf
                    scores = torch.cat([scores, neg_sim], dim=1)

            # Normalize the scores and calculate the cross-entropy loss
            scores = scores / self.temperature
            loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
            loss_mbatch.backward()
            losses.append(loss_mbatch.detach())

        loss = sum(losses).requires_grad_()

        self.cache = [[r.grad for r in rs] for rs in reps]  # Cache the gradients

        return loss

    def calculate_loss(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor:
        """Generalized function to calculate the cross-entropy loss without caching gradients."""
        if len(reps) != len(reps_guided):
            raise ValueError("reps and reps_guided must have the same length")

        # Concatenate embeddings along the batch dimension
        concatenated_reps = [torch.cat(rep, dim=0) for rep in reps]
        concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided]

        labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device)
        batch_size = concatenated_reps[0].shape[0]

        losses: List[torch.Tensor] = []
        for b in tqdm.trange(
            0,
            batch_size,
            self.mini_batch_size,
            desc="Calculating loss",
            disable=not self.show_progress_bar,
        ):
            e = b + self.mini_batch_size

            # Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples
            guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1])
            guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0])
            guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1])

            # Define the anchor threshold for each similarity matrix
            guided_sim = guided_ap_sim.diagonal(offset=0).view(-1, 1)

            # Compute similarity scores for the current mini-batch
            ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1])  # anchor-positive similarity
            aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0])  # anchor-anchor similarity
            pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1])  # positive-positive similarity

            # Apply thresholds based on guided model similarities
            ap_sim[guided_ap_sim > guided_sim] = -torch.inf
            aa_sim[guided_aa_sim > guided_sim] = -torch.inf
            pp_sim[guided_pp_sim > guided_sim] = -torch.inf

            # Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
            scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)

            # If there are negatives (len(reps) > 2), process them
            if len(concatenated_reps) > 2:
                for i in range(2, len(concatenated_reps)):  # Start from 2 since first 2 are anchor-positive
                    guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
                    neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
                    neg_sim[guided_neg_sim > guided_sim[0]] = -torch.inf
                    scores = torch.cat([scores, neg_sim], dim=1)

            # Normalize the scores and calculate the cross-entropy loss
            scores = scores / self.temperature
            loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
            losses.append(loss_mbatch)

        loss = sum(losses)
        return loss

@daegonYu
Copy link
Contributor Author

In the negatives filtering part, I modified it so that guided_sim considers all values ​​of the mini-batch instead of just considering the 0th value with guided_sim[0]. I will soon experiment to see if this loss function works correctly.

After change

            # If there are negatives (len(reps) > 2), process them
            if len(concatenated_reps) > 2:
                for i in range(2, len(concatenated_reps)):  # Start from 2 since first 2 are anchor-positive
                    guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
                    neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
                    neg_sim[guided_neg_sim > guided_sim] = -torch.inf
                    scores = torch.cat([scores, neg_sim], dim=1)

Before change

            # If there are negatives (len(reps) > 2), process them
            if len(concatenated_reps) > 2:
                for i in range(2, len(concatenated_reps)):  # Start from 2 since first 2 are anchor-positive
                    guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
                    neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
                    neg_sim[guided_neg_sim > guided_sim[0]] = -torch.inf
                    scores = torch.cat([scores, neg_sim], dim=1)

@daegonYu
Copy link
Contributor Author

daegonYu commented Sep 19, 2024

There was a problem that the anchor-positive similarity of the guide in the mini-batch was fixed to 0, which should start from the starting point (b) of the mini-batch.

# Before change : guided_sim = guided_ap_sim.diagonal(offset=0).view(-1, 1)
guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1)

After fixing this, it works normally. Below is the full code of the changed class.

class CachedGISTEmbedLoss(losses.CachedGISTEmbedLoss):
    def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor:
        """Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
        if len(reps) != len(reps_guided):
            raise ValueError("reps and reps_guided must have the same length")

        # Concatenate embeddings along the batch dimension
        concatenated_reps = [torch.cat(rep, dim=0) for rep in reps]
        concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided]

        labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device)
        batch_size = concatenated_reps[0].shape[0]

        losses: List[torch.Tensor] = []
        for b in tqdm.trange(
            0,
            batch_size,
            self.mini_batch_size,
            desc="Preparing caches",
            disable=not self.show_progress_bar,
        ):
            e = b + self.mini_batch_size

            # Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples
            guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1])
            guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0])
            guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1])

            # Define the anchor threshold for each similarity matrix
            guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1)

            # Compute similarity scores for the current mini-batch
            ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1])  # anchor-positive similarity
            aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0])  # anchor-anchor similarity
            pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1])  # positive-positive similarity

            # Apply thresholds based on guided model similarities
            ap_sim[guided_ap_sim > guided_sim] = -torch.inf
            aa_sim[guided_aa_sim > guided_sim] = -torch.inf
            pp_sim[guided_pp_sim > guided_sim] = -torch.inf

            # Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
            scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)

            # If there are negatives (len(reps) > 2), process them
            if len(concatenated_reps) > 2:
                for i in range(2, len(concatenated_reps)):  # Start from 2 since first 2 are anchor-positive
                    guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
                    neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
                    neg_sim[guided_neg_sim > guided_sim] = -torch.inf
                    scores = torch.cat([scores, neg_sim], dim=1)

            # Normalize the scores and calculate the cross-entropy loss
            scores = scores / self.temperature
            loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
            loss_mbatch.backward()
            losses.append(loss_mbatch.detach())

        loss = sum(losses).requires_grad_()

        self.cache = [[r.grad for r in rs] for rs in reps]  # Cache the gradients

        return loss

    def calculate_loss(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor:
        """Generalized function to calculate the cross-entropy loss without caching gradients."""
        if len(reps) != len(reps_guided):
            raise ValueError("reps and reps_guided must have the same length")

        # Concatenate embeddings along the batch dimension
        concatenated_reps = [torch.cat(rep, dim=0) for rep in reps]
        concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided]

        labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device)
        batch_size = concatenated_reps[0].shape[0]

        losses: List[torch.Tensor] = []
        for b in tqdm.trange(
            0,
            batch_size,
            self.mini_batch_size,
            desc="Calculating loss",
            disable=not self.show_progress_bar,
        ):
            e = b + self.mini_batch_size

            # Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples
            guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1])
            guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0])
            guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1])

            # Define the anchor threshold for each similarity matrix
            guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1)

            # Compute similarity scores for the current mini-batch
            ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1])  # anchor-positive similarity
            aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0])  # anchor-anchor similarity
            pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1])  # positive-positive similarity

            # Apply thresholds based on guided model similarities
            ap_sim[guided_ap_sim > guided_sim] = -torch.inf
            aa_sim[guided_aa_sim > guided_sim] = -torch.inf
            pp_sim[guided_pp_sim > guided_sim] = -torch.inf

            # Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
            scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)

            # If there are negatives (len(reps) > 2), process them
            if len(concatenated_reps) > 2:
                for i in range(2, len(concatenated_reps)):  # Start from 2 since first 2 are anchor-positive
                    guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
                    neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
                    neg_sim[guided_neg_sim > guided_sim] = -torch.inf
                    scores = torch.cat([scores, neg_sim], dim=1)

            # Normalize the scores and calculate the cross-entropy loss
            scores = scores / self.temperature
            loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
            losses.append(loss_mbatch)

        loss = sum(losses)
        return loss

You can see that it works well by running it in the colab below.

https://colab.research.google.com/drive/1aU7xiepABsAG1UGk-1LuDkfGjAz3I4o3?usp=sharing

How about opening a pull request so that others can use this code?

@tomaarsen
Copy link
Collaborator

Great work! Yes, I'd be very open to a PR to extend the behaviour of this class.

  • Tom Aarsen

@daegonYu
Copy link
Contributor Author

I submitted a pull request.

#2946

feel free to modify / tell me your opinion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants