-
Notifications
You must be signed in to change notification settings - Fork 400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Variance inconsistency in HeteroskedasticSingleTaskGP #933
Comments
Hi @mklpr. There's a known bug with the noise model of the HeteroscedasticSingleTaskGP, see #861. That may explain why you're running into this issue. I haven't had time to look closely into methods you're trying, but if I am not mistaken, the issue in #861 is that the noise model is trained over log transformed input, which never gets untransformed. So, what you do with the 4th option may be the correct way around the bug. In addition to this, the bug may also lead to issues during the hyper-parameter training, so I'd not recommend using the packaged HeteroscedasticSingleTaskGP model right now. There are some fixes proposed in #861 that you could implement locally to get around the issue. |
hi @saitcakmak , thanks the helpful comment, let's ignore this known bug for the time now, there are some other questions,
# test code mainly from issue #861
import torch
from botorch import fit_gpytorch_model
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
from gpytorch import ExactMarginalLogLikelihood
torch.manual_seed(1)
t_X = torch.rand(10, 2)
t_Y_var = torch.ones(10, 1) * 10
model = HeteroskedasticSingleTaskGP(
train_X=t_X,
train_Y=torch.randn(10, 1),
train_Yvar=t_Y_var,
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_model(mll)
noise_model = model.likelihood.noise_covar.noise_model
noise_mean_predict1 = noise_model.posterior(t_X, observation_noise=False).mean
noise_mean_predict2 = noise_model(t_X).mean.exp()
print(noise_mean_predict1)
print(noise_mean_predict2)
# print output
tensor([[1.6492],
[2.1872],
[0.5681],
[2.0095],
[3.5275],
[3.9483],
[2.1735],
[3.1278],
[3.8957],
[0.9924]], grad_fn=<ExpBackward>)
tensor([0.2114, 0.2991, 0.0394, 0.2630, 0.7618, 0.7673, 0.3464, 0.5275, 0.8492,
0.0753], grad_fn=<ExpBackward>)
|
Hi @mklpr.
The reasoning being that
From a theoretical perspective, the reason for fitting them jointly is that the solution you'll obtain by solving
I don't know much about why one is used vs the other. I think you have the right intuition (I have the same intuition), and eventually which one works better would depend on the particular problem instance. If you don't want to use the log transform (and avoid the bug around it), you can subclass the |
hi @saitcakmak . I tested in practical iterative fitting is sufficient to use, so implement an IterativeHeteroskedasticSingleTaskGP model for convience, refer to model source code, gaussian process regression demo and bayesian optimization demo , hope it can provide some help to who want to use heteroskedastic model. |
The main difference is that
The main reason is there is no guarantee that a GP fit on non-negative data will produce non-negative predictions (as you seem to have found out yourself). Using a log transform is one straightforward way of dealing with this.
We had a PR for this a long time ago that never got wrapped up: #250. @jelena-markovic worked on updating this and we have an internal version of that. Sill needs some more work but we could probably put that out as an PR if that would be helpful. |
Closing this since the bug is being tracked in #861 |
hi,
in HeteroskedasticSingleTaskGP, where using different ways to compute posterior with noise, i got different results and can't explain or understand it myself, so seek for helps here.
i use four ways to compute posterior with noise,
method 1 and method 2 has the same results, but method 3 and method 4 different from all others, in my knowledge total posterior variance equals noise variance from noise_model plus variance from GP kernel, and verify it in SingleTaskGP, so what's wrong in HeteroskedasticSingleTaskGP? is it comes from the log transfrom and how mll_heter.likelihood(model_heter.posterior(scan_x, observation_noise=False), scan_x) process it internally? thanks.
test code
Refer to https://colab.research.google.com/drive/1dOUHQzl3aQ8hz6QUtwRrXlQBGqZadQgG#scrollTo=D0A4Cf0W_QkZ
system info
The text was updated successfully, but these errors were encountered: