You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fitting the model terminates with NaNs some of the time and we haven't been able to locate the issue. For example, if we fit on the same dataset 8 times in separate jobs, it finishes successfully 5 times but the other 3 times it gives NaNs at the 47th or 48th iteration of SLDS. We are initializing the parameters from the ARHMM fits before. If it helps, we are using latent_dim=10 (10 pca components explain 90% of the variance) and nlags=3.
We have verified it's not an issue of single precision. I have also checked various sizes (150k frames or 2mil frames) and different samples for the same size of the dataset to rule out the possibility of having not enough data but we still get NaNs about 30-35% of the time. This is keypoint-moseq v0.1.5 and jax-moseq v0.0.3. We also tested with newer releases of jax-moseq with parallel LGSMM support but that resulted in giving NaNs early on during training (around 20th iteration of SLDS) almost every time on any size of the dataset so we decided to revert.
Have you encountered a similar issue before? What could be a good way to assess the quality of the fits at each SLDS iteration or, in general, debug this issue?
I'm sorry to hear that. Probably the easiest thing would be to send me a minimal example of a dataset that produces NaNs, along with the notebook you used for fitting it. Any file sharing method is good. You can use my email ([email protected])
Fitting the model terminates with NaNs some of the time and we haven't been able to locate the issue. For example, if we fit on the same dataset 8 times in separate jobs, it finishes successfully 5 times but the other 3 times it gives NaNs at the 47th or 48th iteration of SLDS. We are initializing the parameters from the ARHMM fits before. If it helps, we are using
latent_dim=10
(10 pca components explain 90% of the variance) andnlags=3
.We have verified it's not an issue of single precision. I have also checked various sizes (150k frames or 2mil frames) and different samples for the same size of the dataset to rule out the possibility of having not enough data but we still get NaNs about 30-35% of the time. This is
keypoint-moseq
v0.1.5 andjax-moseq
v0.0.3. We also tested with newer releases of jax-moseq with parallel LGSMM support but that resulted in giving NaNs early on during training (around 20th iteration of SLDS) almost every time on any size of the dataset so we decided to revert.Have you encountered a similar issue before? What could be a good way to assess the quality of the fits at each SLDS iteration or, in general, debug this issue?
cc @r-shruthi11
The text was updated successfully, but these errors were encountered: