Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Descriptor loss train #287

Open
ericzzj1989 opened this issue Jan 16, 2023 · 10 comments
Open

Descriptor loss train #287

ericzzj1989 opened this issue Jan 16, 2023 · 10 comments

Comments

@ericzzj1989
Copy link

ericzzj1989 commented Jan 16, 2023

Hi, first thank you for this great work which really helped me a lot!
I want to use the Superpoint model to train on my own data. The detection loss part seems normal but the descriptor loss is oscillating and cannot converge.
The input are 256*256 image and warped image with homography, and the model and loss function are the same with your repo.
The detector and descriptor loss are as below with 300 epochs.
Detector loss
image
Descriptor loss
image

Can you give me some advice?

@rpautrat
Copy link
Owner

Hi, are you using exactly the code from this repo or did you plug in some parts of this repo into your own code? Are you using the same parameters as in the master branch?

@ericzzj1989
Copy link
Author

Thanks for your prompt reply.
I use the paired data generation (exactly as coco.py), superpoint model and loss function code parts (reimplemented by PyTorch) of your repo. The data augmentation and model parameters in config file are the same with yours.

@rpautrat
Copy link
Owner

I see, then it might be a bit tricky for me to help you as it is a different code... It could be an implementation bug, or simply that your reimplementation might need different parameter tuning than this repo.

Note that there is also a Pytorch reimplementation of SuperPoint (partially based on this repo) that you might want to check out: https://github.com/eric-yyjau/pytorch-superpoint

@ericzzj1989
Copy link
Author

Thanks for your advice.
I refer to the repo https://github.com/shaofengzeng/SuperPoint-Pytorch (based on your repo) and compare with yours. I personally think it is the same code and the same issue. And I also found that many issues mentioned the problem of descriptor loss.
I played the parameters and RGB input, descriptor loss kept oscillating and could not converge. I have no idea where to solve this issue.

@rpautrat
Copy link
Owner

Tuning the descriptor loss was quite tricky in my case. But overall, training with a triplet loss is also rather tricky in general.

One thing that usually helps in my experience of training with triplet losses is to pre-train the network with a "relaxed" definition of the negative samples. In SuperPoint, given one cell at position (h, w), the corresponding cell (h', w') in the other image is used as positive anchor, while the other cells are considered as negative ones. But the descriptor of let's say pixel (h'+1, w') is also very close to the one of (h, w), thus forcing these two descriptors to be far apart is confusing the network (at least at the beginning of the training). So what you could do is to ignore the neighboring pixels of each positive cell in the descriptor loss, which is equivalent to making the negative samples less hard. Once the training has converged with this easier loss, you can fine-tune with the actual SuperPoint loss to get the best performance.

@ericzzj1989
Copy link
Author

ericzzj1989 commented Jan 17, 2023

Thanks for your experience.
Is there any reference code you said? If so, would you please share it with me?

@rpautrat
Copy link
Owner

For SuperPoint, unfortunately not. I did not have to use this trick when I trained it.

But I had to use it for other works requiring a triplet loss, one example is here: https://github.com/mihaidusmanu/d2-net/blob/master/lib/loss.py. The loss is a bit different from the SuperPoint one though and is a triplet loss with hardest negative mining. But maybe you can get the idea and apply it to the SuperPoint loss. The safe_radius parameter is the one controlling how close to the positive anchor a negative can be. You should thus start training with a large safe_radius and fine-tune it with a lower one ideally.

@ericzzj1989
Copy link
Author

ericzzj1989 commented Jan 17, 2023

Thanks very much for your help.
According to this #164, I comment the following lines:

dot_product_desc = tf.nn.relu(dot_product_desc)

    dot_product_desc = tf.nn.relu(dot_product_desc)
    dot_product_desc = tf.reshape(tf.nn.l2_normalize(
        tf.reshape(dot_product_desc, [batch_size, Hc, Wc, Hc * Wc]),
        3), [batch_size, Hc, Wc, Hc, Wc])
    dot_product_desc = tf.reshape(tf.nn.l2_normalize(
        tf.reshape(dot_product_desc, [batch_size, Hc * Wc, Hc, Wc]),
        1), [batch_size, Hc, Wc, Hc, Wc])

The train and val descriptor loss are as below with a small amount of data (100 samples).
Train descriptor loss
image
Val descriptor loss
image

These three line code is used for normalization of dot_descriptor and is it necessary for this operation after descriptors dot?
I'm not sure if this is correct and maybe the amount of data (100 samples) is too small?
Could you please give me some advice?

@rpautrat
Copy link
Owner

I am not sure to fully understand your last question, but these lines with the l2 normalization are a trick to make the correspondences between points more discriminative (i.e. that there is at most one correspodence rather than several similar candidates). The original SuperPoint did not have this trick, and the code should also work if you comment it. But I observed empirically better results with it personally.

On the graphs you show, there is a clear overfitting to the small training set, due to the small amount of samples.

@ericzzj1989
Copy link
Author

Thanks very much for your help.
With this l2 normalization trick, the descriptor loss could not converge as the original issue graph shown.
For overfitting, I will try to add more data for training and observe the loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants