Skip to content

Commit 20a6c9f

Browse files
committed
FEATURE: allow prior arguments to be functions
1 parent c65f9f1 commit 20a6c9f

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

bilby/core/prior/base.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,10 @@ def _parse_argument_string(cls, val):
408408
The string is interpreted as a call to instantiate another prior
409409
class, Bilby will attempt to recursively construct that prior,
410410
e.g., Uniform(minimum=0, maximum=1), my.custom.PriorClass(**kwargs).
411+
- Else If the string contains a ".":
412+
It is treated as a path to a Python function and imported, e.g.,
413+
"some_module.some_function" returns
414+
:code:`import some_module; return some_module.some_function`
411415
- Else:
412416
Try to evaluate the string using `eval`. Only built-in functions
413417
and numpy methods can be used, e.g., np.pi / 2, 1.57.
@@ -448,10 +452,17 @@ def _parse_argument_string(cls, val):
448452
try:
449453
val = eval(val, dict(), dict(np=np, inf=np.inf, pi=np.pi))
450454
except NameError:
451-
raise TypeError(
452-
"Cannot evaluate prior, "
453-
"failed to parse argument {}".format(val)
454-
)
455+
if "." in val:
456+
module = '.'.join(val.split('.')[:-1])
457+
func = val.split('.')[-1]
458+
new_val = getattr(import_module(module), func, val)
459+
if val == new_val:
460+
raise TypeError(
461+
"Cannot evaluate prior, "
462+
f"failed to parse argument {val}"
463+
)
464+
else:
465+
val = new_val
455466
return val
456467

457468

test/core/prior/dict_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,17 @@ def test_load_prior_with_parentheses(self):
464464
prior = bilby.core.prior.PriorDict(filename)
465465
self.assertTrue(isinstance(prior["logA"], bilby.core.prior.Uniform))
466466

467+
def test_load_prior_with_function(self):
468+
filename = os.path.join(
469+
os.path.dirname(os.path.realpath(__file__)),
470+
"prior_files/prior_with_function.prior",
471+
)
472+
prior = bilby.core.prior.ConditionalPriorDict(filename)
473+
self.assertTrue("mass_1" in prior)
474+
self.assertTrue("mass_2" in prior)
475+
samples = prior.sample(10000)
476+
self.assertTrue(all(samples["mass_1"] > samples["mass_2"]))
477+
467478

468479
class TestCreateDefaultPrior(unittest.TestCase):
469480
def test_none_behaviour(self):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
mass_1 = Uniform(name='mass_1', minimum=5, maximum=100, unit='$M_{\odot}$', boundary=None)
2+
mass_2 = ConditionalUniform(name="mass_1", minimum=5, maximum=100, condition_func="bilby.gw.prior.secondary_mass_condition_function")

0 commit comments

Comments
 (0)