Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncated Gamma #969

Open
quattro opened this issue Mar 22, 2021 · 7 comments
Open

Truncated Gamma #969

quattro opened this issue Mar 22, 2021 · 7 comments
Labels
enhancement New feature or request

Comments

@quattro
Copy link
Contributor

quattro commented Mar 22, 2021

It would be great to have a truncated gamma distribution implemented in NumPyro in order to cover lower-bounded variances (or upper bounded precisions) in probabilistic programs.

I'm happy to code this up and issue a pull-req after some testing.

A nice example would be in Mendelian Randomization from summary statistics (ie two-stage least squares). Residual variance greater than 1 is consistent with heterogeneity across studies, but inferred variance should be bounded below by 1.

@fehiepsi fehiepsi added the enhancement New feature or request label Mar 23, 2021
@fehiepsi
Copy link
Member

@quattro I didn't use Truncated Gamma previously but I think this would be nice to have. FYI, we have TruncatedDistribution API with base distributions can be Cauchy, Laplace, Logistic, Normal, or StudentT. Probably it would be cleaner to follow that setting and have LeftTruncatedGamma, RightTruncatedGamma, TwoSidedTruncatedGamma implemented, then dispatching TruncatedGamma to the corresponding class (based on low=None or high=None).

@quattro
Copy link
Contributor Author

quattro commented Mar 29, 2021

Great, thanks for the advice @fehiepsi. I'll likely be busy with some other things for a little while but plan on coming back to this soon.

@quattro
Copy link
Contributor Author

quattro commented Apr 7, 2021

Hi @fehiepsi , I made some progress on this, but stopped at the sampling implementation. I think there are two paths forward, and would appreciate your thoughts on the matter:

  1. Use a uniform -> invCDF(gamma) sampling based approach
  2. Use rejection sampling

In terms of implementation:

  1. Requires gammaincinv in order to compute the quantiles of the gamma distribution, which is currently not supported by JAX (as you pointed out here: [FR] Support for scipy.special.gammaincinv jax-ml/jax#5350). It looks like TFP has gammaincinv implemented (along with gradients), but admittedly, I don't have the bandwidth to initiate a PR and port it over to JAX.
  2. There are a few papers describing sampling approaches to either the Left/Right truncated Gamma, or the more general case using latent variables, that might be worthwhile investigating.

I'm hoping to move forward with this soon, as it looks like there is other interest for a truncated gamma distribution, as mentioned in this thread: jax-ml/jax#552 .

@fehiepsi
Copy link
Member

fehiepsi commented Apr 7, 2021

@quattro I think the simplest solution is to make a wrapper (with try/except import like contrib.tfp) for gammaincinv in numpyro.distributions.util and use it in the sample method. That would unblock you. What do you think?

@quattro
Copy link
Contributor Author

quattro commented Apr 7, 2021

@fehiepsi that works for now. If JAX ever natively supports gammaincinv at a later date, I can come back to this to minimize core external dependencies.

@quattro
Copy link
Contributor Author

quattro commented Oct 7, 2021

Just an update. I've had this implemented for a while, but was not able to get a few tests to pass. I can initiate the PR to highlight which tests specifically, and perhaps get feedback. My guess is numerical precision being an issue due to the inverse CDF transform, but not 100% clear atm.

@fehiepsi
Copy link
Member

fehiepsi commented Oct 8, 2021

Sure, @quattro! We can discuss the numerical issue in detail in the PR. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants