Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug Fix for CenterNetBoxLoss #2432

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

TillBeemelmanns
Copy link

I am currently doing some experiments with the Waymo 3D Object Detector Centernet and found out that there is a bug in the CenterNetBoxLoss function.

I trained the model for 2 epochs with the original loss function which results in the following predictions on the val set:
waymo_open_dataset_0

Notably there is a constant yaw offset caused by the false ops.floor() operation in the loss function applied on the gt heading.

This PR fixes this issue. The following image depicts the predictions after 2 epochs using the fixed loss function.
waymo_open_dataset_0
(Note, that the network has not converged and no NMS has been applied)

The problem probably raised during the conversion from the Tensorflow op tf.math.floormod to Keras ops. However,
tf.math.floormod(a, b) != ops.floor(ops.mod(a, b))
but
tf.math.floormod(a, b) = ops.mod(a, b)

Hence ops.floor() simply removes the accuracy of the gt heading during loss computation.

Additionally, I implemented a test case for the heading classification (+regression) part of the CenterNetBoxLoss.

Frameworks & Versions:

  • keras-cv==0.8.2
  • keras==3.2.0
  • tensorflow==2.16.1

@divyashreepathihalli @sampathweb

@divyashreepathihalli
Copy link
Collaborator

@TillBeemelmanns Thank you for the update!

@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Jul 23, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jul 23, 2024
@TillBeemelmanns
Copy link
Author

Thanks for approving. I will resolve the remaining failing checks end of this week.

@TillBeemelmanns
Copy link
Author

@divyashreepathihalli I think I fixed the Pytorch pipeline. Could you rerun the workflow ?

@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Aug 12, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Aug 12, 2024
@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Sep 16, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Sep 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants