Skip to content

Commit

Permalink
MNT: change bilby MCMC to use glasflow instead of nflows
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will authored and ColmTalbot committed Feb 26, 2024
1 parent e008142 commit 164ca22
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 32 deletions.
12 changes: 6 additions & 6 deletions bilby/bilby_mcmc/flows.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import torch
from nflows.distributions.normal import StandardNormal
from nflows.flows.base import Flow
from nflows.nn import nets as nets
from nflows.transforms import (
from glasflow.nflows.distributions.normal import StandardNormal
from glasflow.nflows.flows.base import Flow
from glasflow.nflows.nn import nets as nets
from glasflow.nflows.transforms import (
CompositeTransform,
MaskedAffineAutoregressiveTransform,
RandomPermutation,
)
from nflows.transforms.coupling import (
from glasflow.nflows.transforms.coupling import (
AdditiveCouplingTransform,
AffineCouplingTransform,
)
from nflows.transforms.normalization import BatchNorm
from glasflow.nflows.transforms.normalization import BatchNorm
from torch.nn import functional as F

# Turn off parallelism
Expand Down
4 changes: 2 additions & 2 deletions bilby/bilby_mcmc/proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,10 +754,10 @@ def propose(self, chain):

@staticmethod
def check_dependencies(warn=True):
if importlib.util.find_spec("nflows") is None:
if importlib.util.find_spec("glasflow") is None:
if warn:
logger.warning(
"Unable to utilise NormalizingFlowProposal as nflows is not installed"
"Unable to utilise NormalizingFlowProposal as glasflow is not installed"
)
return False
else:
Expand Down
3 changes: 2 additions & 1 deletion containers/env-template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies:
- dill
- black
- pytest-cov
- pytest-requires
- arviz
- parameterized
- scikit-image
Expand Down Expand Up @@ -65,8 +66,8 @@ dependencies:
- jupyter
- nbconvert
- twine
- glasflow
- pip:
- autodoc
- ipykernel
- build
- nflows
2 changes: 1 addition & 1 deletion mcmc_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
scikit-learn
nflows
glasflow
1 change: 1 addition & 0 deletions optional_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
celerite
george
plotly
pytest-requires
41 changes: 19 additions & 22 deletions test/bilby_mcmc/test_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from bilby.bilby_mcmc import proposals
from bilby.bilby_mcmc.utils import LOGLKEY, LOGPKEY
import numpy as np
import pytest


class GivenProposal(proposals.BaseProposal):
Expand Down Expand Up @@ -164,36 +165,32 @@ def test_GMM_proposal(self):
else:
print("Unable to test GMM as sklearn is not installed")

@pytest.mark.requires("glasflow")
def test_NF_proposal(self):
priors = self.create_priors()
chain = self.create_chain(10000)
if proposals.NormalizingFlowProposal.check_dependencies():
prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
prop.steps_since_refit = 9999
start = time.time()
p, w = prop(chain)
dt = time.time() - start
print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
self.assertTrue(prop.trained)
self.proposal_check(prop)
else:
print("nflows not installed, unable to test NormalizingFlowProposal")
prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
prop.steps_since_refit = 9999
start = time.time()
p, w = prop(chain)
dt = time.time() - start
print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
self.assertTrue(prop.trained)
self.proposal_check(prop)

@pytest.mark.requires("glasflow")
def test_NF_proposal_15D(self):
ndim = 15
priors = self.create_priors(ndim)
chain = self.create_chain(10000, ndim=ndim)
if proposals.NormalizingFlowProposal.check_dependencies():
prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
prop.steps_since_refit = 9999
start = time.time()
p, w = prop(chain)
dt = time.time() - start
print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
self.assertTrue(prop.trained)
self.proposal_check(prop, ndim=ndim)
else:
print("nflows not installed, unable to test NormalizingFlowProposal")
prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
prop.steps_since_refit = 9999
start = time.time()
p, w = prop(chain)
dt = time.time() - start
print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
self.assertTrue(prop.trained)
self.proposal_check(prop, ndim=ndim)


if __name__ == "__main__":
Expand Down

0 comments on commit 164ca22

Please sign in to comment.