Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we support class weight in the CEWithChunkedOutputLoss class #1746

Open
ye-jin-shop opened this issue Oct 2, 2024 · 4 comments
Open
Assignees
Labels
enhancement New feature or request

Comments

@ye-jin-shop
Copy link

I am trying to add weights to the loss function. I think it would be nice to have it in the class function? The original class CrossEntropyLoss has weight as an arg.

@joecummings
Copy link
Contributor

This seems like a good idea - thanks @ye-jin-shop!

Just so I fully understand what you're trying to do, can you provide a small code example of how you would be using this weight argument?

@ye-jin-shop
Copy link
Author

@joecummings Thank you for the response!

For example, if I have three classes [0, 1, 2], and I want to put lower weight on the first class. If I am using CE from torch.nn, I can have

weights = [0.5, 1.0, 1.0]
class_weights = torch.FloatTensor(weights).cuda()
criterion = nn.CrossEntropyLoss(weight=class_weights)

While calculating the loss, we could have different weights on different classes. I wonder for this class CEWithChunkedOutputLoss, we could add the arg weight into

. It could be passed from line 30 I think.

@joecummings joecummings added the enhancement New feature or request label Oct 2, 2024
@felipemello1
Copy link
Contributor

felipemello1 commented Oct 3, 2024

This should be easy. We can add *args, **kwargs to the init, and something like this should work, but needs some testing if there is any conflict with ignore_index and reduction

class CEWithChunkedOutputLoss(torch.nn.Module):
  def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100, *args, **kwargs):
          super().__init__()
          self.num_output_chunks = num_output_chunks
          self.ignore_index = ignore_index
          self.args = args
          self.kwargs = kwargs

  
      def compute_cross_entropy(
          self, logits: torch.Tensor, labels: torch.Tensor
      ) -> torch.Tensor:
          """
          Upcast logits to fp32 and compute cross entropy loss.
          """
          return F.cross_entropy(
              logits.float(), labels, ignore_index=self.ignore_index, reduction="sum", *args, **kwargs
          )

i can test something like this next week, but if you want to contribute directly, i would be glad to review your PR. Let me know and i can give some pointers. Otherwise, i will post here when i get back to this

@ye-jin-shop
Copy link
Author

Hey @felipemello1. I am having other priorities right now. I will leave this to you (not urgent, but nice to have). Thank you for your support!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants