Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Focal tversky loss #932

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Aditi-Mhatre
Copy link

Hi,

Added:

  • Focal Tversky loss function in Focal_Tversky.py based on the paper
  • tests for the loss function (all passed so far)

Changed:

  • losses -> init and Focal_Tversky
  • tests -> test_losses

Maybe you can take a look at it.

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Aditi-Mhatre, thanks for adding a new loss function and thank you for adding a test!

I left some comments below, and can you please include the loss to the docs/? Thanks 🤗

beta: float = 0.5,
gamma: float = 1.0,
):
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
Copy link
Collaborator

@qubvel qubvel Oct 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move the class docstring under __init__ similar to other losses

Comment on lines +226 to +236
output_sum = torch.sum(output, dim=dims)
target_sum = torch.sum(target, dim=dims)
difference = LA.vector_norm(output - target, ord=1, dim=dims)

intersection = (output_sum + target_sum - difference) / 2 # TP
fp = output_sum - intersection
fn = target_sum - intersection

tversky_score = (intersection + smooth) / (
intersection + alpha * fp + beta * fn + smooth
).clamp_min(eps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand we can just call soft_tversky_score here?

Comment on lines +98 to +102
[
[[1, 1, 1, 1], [1, 1, 1, 1], 0.0, 1e-5, 0.5, 0.5, 1.0],
[[0, 1, 1, 0], [0, 1, 1, 0], 0.0, 1e-5, 0.5, 0.5, 1.0],
[[1, 1, 1, 1], [0, 0, 0, 0], 1.0, 1e-5, 0.5, 0.5, 1.0],
[[1, 0, 1, 0], [0, 1, 0, 0], 1.0, 1e-5, 0.5, 0.5, 1.0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a test case with gamma != 1?

@zifuwanggg
Copy link
Contributor

zifuwanggg commented Oct 9, 2024

The focal Tversky loss has already been implemented.

@qubvel
Copy link
Collaborator

qubvel commented Oct 9, 2024

Indeed 😄 missed it, probably a documentation issue! Thanks for pointing it @zifuwanggg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants