Skip to content

Commit

Permalink
Merge pull request #29 from blab/switching-to-jax.Array
Browse files Browse the repository at this point in the history
Updating Array types
  • Loading branch information
marlinfiggins authored Sep 26, 2023
2 parents 24f8bdb + 8959de2 commit 27e3eec
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion evofr/models/piantham_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def compute_frequency_piantham(ga, q0, gen_rev, T):
Returns
-------
Simulated frequencies as DeviceArray.
Simulated frequencies as Array.
"""
_ga = jnp.append(ga, 1.0)
max_age = gen_rev.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions evofr/plotting/plot_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def prep_posterior_for_plot(
samples:
Dictionary with keys being site or variable names.
Values are DeviceArrays containing posterior samples
Values are Arrays containing posterior samples
with shape (sample_number, site_shape).
ps:
Expand Down Expand Up @@ -65,7 +65,7 @@ def plot_posterior_time(
Median values.
quants:
Quantiles to be plotted. Organized as a list of CIs as DeviceArrays.
Quantiles to be plotted. Organized as a list of CIs as Arrays.
alphas:
Transparency for each quantile.
Expand Down
2 changes: 1 addition & 1 deletion evofr/posterior/posterior_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
----------
samples:
optional dictionary with keys corresponding to variable names.
Values will be DeviceArrays containing posterior samples.
Values will be Arrays containing posterior samples.
data:
optional DataSpec instance containing underlying data from analysis
Expand Down
7 changes: 4 additions & 3 deletions evofr/posterior/posterior_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial
from typing import Dict, List, Optional
import jax
import jax.numpy as jnp
import json
import numpy as np
Expand All @@ -16,7 +17,7 @@ def get_quantile(samples: Dict, p, site):
----------
samples:
Dictionary with keys being site or variable names.
Values are DeviceArrays with shape (sample_number, site_shape).
Values are Arrays with shape (sample_number, site_shape).
p:
Percent credible interval to return.
Expand All @@ -26,7 +27,7 @@ def get_quantile(samples: Dict, p, site):
Returns
-------
DeviceArray of shape (site_shape).
Array of shape (site_shape).
"""
q = jnp.array([0.5 * (1 - p), 0.5 * (1 + p)])
return jnp.quantile(samples[site], q=q, axis=0)
Expand Down Expand Up @@ -156,7 +157,7 @@ def default(self, obj):
return round(float(obj), 3)
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, jnp.DeviceArray):
if isinstance(obj, jax.Array):
return self.default(np.array(obj))
if isinstance(obj, pd.Timestamp):
return obj.strftime("%Y-%m-%d")
Expand Down

0 comments on commit 27e3eec

Please sign in to comment.