Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Loss and Metrics for SegFormer #128

Open
ebgoldstein opened this issue Mar 28, 2023 · 4 comments
Open

Custom Loss and Metrics for SegFormer #128

ebgoldstein opened this issue Mar 28, 2023 · 4 comments
Assignees
Labels
enhancement New feature or request

Comments

@ebgoldstein
Copy link
Member

see this discussion re: specifying a Loss (and/or metric) for Segformer model:

huggingface/transformers#22092

These lines of code provide an upsampling template: https://github.com/huggingface/transformers/blob/v4.27.2/src/transformers/models/segformer/modeling_tf_segformer.py#L793-L811

@ebgoldstein ebgoldstein added the enhancement New feature or request label Mar 28, 2023
@ebgoldstein ebgoldstein self-assigned this Mar 28, 2023
@dbuscombe-usgs dbuscombe-usgs self-assigned this May 17, 2023
@dbuscombe-usgs
Copy link
Member

I want to help here. I'm ready to start playing with custom losses for segformers. We simply need to make our own version of TFSegformerForSemanticSegmentation with a new loss, or loss options passed as arguments, correct?

@ebgoldstein
Copy link
Member Author

correct, we just need to write a custom loss that upsamples the logits to the mask size

@dbuscombe-usgs
Copy link
Member

dbuscombe-usgs commented May 18, 2023

I'm looking at our existing Dice loss function and seeing if it is consistent with the existing class and loss function call. key differences:

  1. from_logits=True, reduction="none". Our Dice comes "already reduced", using (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth) . But that's ok - the mask reduction occurs over the batch
  2. loss is called using loss = self.hf_compute_loss(logits=logits, labels=labels), so we need a function that can take those inputs
  3. Our function requirs 'nclasses', which would need to be passed to the new class

@dbuscombe-usgs
Copy link
Member

I'm putting this on the back-burner again. Too busy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants