-
Notifications
You must be signed in to change notification settings - Fork 438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can we support class weight in the CEWithChunkedOutputLoss class #1746
Comments
This seems like a good idea - thanks @ye-jin-shop! Just so I fully understand what you're trying to do, can you provide a small code example of how you would be using this weight argument? |
@joecummings Thank you for the response! For example, if I have three classes [0, 1, 2], and I want to put lower weight on the first class. If I am using CE from torch.nn, I can have
While calculating the loss, we could have different weights on different classes. I wonder for this class CEWithChunkedOutputLoss, we could add the arg weight into
|
This should be easy. We can add *args, **kwargs to the init, and something like this should work, but needs some testing if there is any conflict with ignore_index and reduction
i can test something like this next week, but if you want to contribute directly, i would be glad to review your PR. Let me know and i can give some pointers. Otherwise, i will post here when i get back to this |
Hey @felipemello1. I am having other priorities right now. I will leave this to you (not urgent, but nice to have). Thank you for your support! |
I am trying to add weights to the loss function. I think it would be nice to have it in the class function? The original class CrossEntropyLoss has weight as an arg.
The text was updated successfully, but these errors were encountered: