Skip to content

Commit

Permalink
Merge pull request #25 from deepskies/issue/modules
Browse files Browse the repository at this point in the history
Issue/modules
  • Loading branch information
beckynevin authored Oct 11, 2023
2 parents 9b1030e + 3f1e39c commit 8e4d2c5
Show file tree
Hide file tree
Showing 13 changed files with 2,841 additions and 5,970 deletions.
140 changes: 46 additions & 94 deletions notebooks/SBI.ipynb

Large diffs are not rendered by default.

461 changes: 461 additions & 0 deletions notebooks/SBI_hierarchical_csv.ipynb

Large diffs are not rendered by default.

99 changes: 39 additions & 60 deletions notebooks/error_propagation_demonstration.ipynb

Large diffs are not rendered by default.

4,291 changes: 366 additions & 3,925 deletions notebooks/pendulum_error_one_moment_in_time_DeepEnsemble.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

2,356 changes: 532 additions & 1,824 deletions notebooks/pendulum_one_time_hierarchical.ipynb

Large diffs are not rendered by default.

379 changes: 378 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ seaborn = "^0.12.2"
torch = "^2.0.1"
sbi = "^0.21.0"
pytest-cov = "^4.1.0"
deepbench = "^0.2.2"


[build-system]
Expand Down
26 changes: 26 additions & 0 deletions src/scripts/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np

# how should the error propagate?
# its all partial derivatives
def calc_error_prop(true_L, true_theta, true_a, dthing, time, wrt = 'theta_0'):
if wrt == 'theta_0':
dx_dthing = true_L * np.cos(true_theta * np.cos(np.sqrt(true_a / true_L) * time)) * \
np.cos(np.sqrt(true_a / true_L) * time) * dthing
if wrt == 'L':
dx_dthing = (0.5 * true_theta * time * np.sqrt(true_a / true_L) * np.sin(time * np.sqrt(true_a / true_L)) * \
np.cos(true_theta * np.cos(time * np.sqrt(true_a / true_L))) + \
np.sin(true_theta * np.cos(time * np.sqrt(true_a / true_L)))) * dthing
if wrt == 'a_g':
dx_dthing = (- 0.5 * np.sqrt(true_L / true_a) * true_theta * time * \
np.sin(np.sqrt(true_a / true_L) * time) * \
np.cos(true_theta * np.cos(np.sqrt(true_a / true_L) * time))) * dthing
if wrt == 'all':
dx_dthing = true_L * np.cos(true_theta * np.cos(np.sqrt(true_a / true_L) * time)) * \
np.cos(np.sqrt(true_a / true_L) * time) * dthing[1] + \
(0.5 * true_theta * time * np.sqrt(true_a / true_L) * np.sin(time * np.sqrt(true_a / true_L)) * \
np.cos(true_theta * np.cos(time * np.sqrt(true_a / true_L))) + \
np.sin(true_theta * np.cos(time * np.sqrt(true_a / true_L)))) * dthing[0] + \
(- 0.5 * np.sqrt(true_L / true_a) * true_theta * time * \
np.sin(np.sqrt(true_a / true_L) * time) * \
np.cos(true_theta * np.cos(np.sqrt(true_a / true_L) * time))) * dthing[2]
return abs(dx_dthing)
153 changes: 153 additions & 0 deletions src/scripts/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import numpyro
import numpyro.distributions as dist
import numpy as np
import jax
import jax.numpy as jnp # yes i know this is confusing
import torch.nn as nn

# tensorflow sucks
# build a similar thing in pytorch


class de_no_var(nn.Module):
def __init__(self):
super().__init__()
drop_percent = 0.1
self.ln_1 = nn.Linear(3, 100)
self.act1 = nn.ReLU()
self.drop1 = nn.Dropout(drop_percent)
self.ln_2 = nn.Linear(100, 100)
self.act2 = nn.ReLU()
self.drop2 = nn.Dropout(drop_percent)
self.ln_3 = nn.Linear(100, 100)
self.act3 = nn.ReLU()
self.drop3 = nn.Dropout(drop_percent)
self.ln_4 = nn.Linear(100,1) # needs to be 2 if using the GaussianNLLoss

def forward(self, x):
x = self.drop1(self.act1(self.ln_1(x)))
x = self.drop2(self.act2(self.ln_2(x)))
x = self.drop3(self.act3(self.ln_3(x)))
x = self.ln_4(x)
return x


class de_var(nn.Module):
def __init__(self):
super().__init__()
drop_percent = 0.1
self.ln_1 = nn.Linear(3, 100)
self.act1 = nn.ReLU()
self.drop1 = nn.Dropout(drop_percent)
self.ln_2 = nn.Linear(100, 100)
self.act2 = nn.ReLU()
self.drop2 = nn.Dropout(drop_percent)
self.ln_3 = nn.Linear(100, 100)
self.act3 = nn.ReLU()
self.drop3 = nn.Dropout(drop_percent)
self.ln_4 = nn.Linear(100,2) # needs to be 2 if using the GaussianNLLoss

def forward(self, x):
x = self.drop1(self.act1(self.ln_1(x)))
x = self.drop2(self.act2(self.ln_2(x)))
x = self.drop3(self.act3(self.ln_3(x)))
x = self.ln_4(x)
return x

## in numpyro, you must specify number of sampling chains you will use upfront

# words of wisdom from Tian Li and crew:
# on gpu, don't use conda, use pip install
# HMC after SBI to look at degeneracies between params
# different guides (some are slower but better at showing degeneracies)

## define the platform and number of cores (one chain per core)
numpyro.set_platform('cpu')
core_num = 4
numpyro.set_host_device_count(core_num)

def hierarchical_model(planet_code,
pendulum_code,
times,
exponential,
pos_obs=None):
"""
"""
## inputs to a numpyro model are rows from a dataframe:
## planet code - array of embedded numbers representing which planet {0...1}
## pendulum code - array of embedded numbers representing which pendulum {0...7}
## times - moments in time (s)
## pos_obs - this is optional, set to None but used to compare the model with data
## (when data, xpos, is defined)

## numpyro models function by drawing parameters from samples
## first, we define the global parameters, mean and sigma of a normal from
## which the individual a_g values of each planet will be drawn


#μ_a_g = numpyro.sample("μ_a_g", dist.LogUniform(5.0,15.0))
μ_a_g = numpyro.sample("μ_a_g", dist.TruncatedNormal(12.5, 5, low=0.01))
# scale parameters should be log uniform so that they don't go negative
# and so that they're not uniform
# 1 / x in linear space
σ_a_g = numpyro.sample("σ_a_g", dist.TruncatedNormal(0.1, 0.01, low=0.01))
n_planets = len(np.unique(planet_code))
n_pendulums = len(np.unique(pendulum_code))

## plates are a numpyro primitive or context manager for handing conditionally independence
## for instance, we wish to model a_g for each planet independently
with numpyro.plate("planet_i", n_planets):
a_g = numpyro.sample("a_g", dist.TruncatedNormal(μ_a_g, σ_a_g,
low=0.01))
# helps because a_gs are being pulled from same normal dist
# removes dependency of a_g on sigma_a_g on a prior level
# removing one covariance from model, model is easier
# to sample from

## we also wish to model L and theta for each pendulum independently
## here we draw from an uniform distribution
with numpyro.plate("pend_i", n_pendulums):
L = numpyro.sample("L", dist.TruncatedNormal(5, 2, low=0.01))
theta = numpyro.sample("theta", dist.TruncatedNormal(jnp.pi/100,
jnp.pi/500,
low=0.00001))

## σ is the error on the position measurement for each moment in time
## we also model this
## eventually, we should also model the error on each parameter independently?
## draw from an exponential distribution parameterized by a rate parameter
## the mean of an exponential distribution is 1/r where r is the rate parameter
## exponential distributions are never negative. This is good for error.
σ = numpyro.sample("σ", dist.Exponential(exponential))

## the moments in time are not independent, so we do not place the following in a plate
## instead, the brackets segment the model by pendulum and by planet,
## telling us how to conduct the inference
modelx = L[pendulum_code] * jnp.sin(theta[pendulum_code] * jnp.cos(jnp.sqrt(a_g[planet_code] / L[pendulum_code]) * times))
## don't forget to use jnp instead of np so jax knows what to do
## A BIG QUESTION I STILL HAVE IS WHAT IS THE LIKELIHOOD? IS IT JUST SAMPLED FROM?
## again, for each pendulum we compare the observed to the modeled position:
with numpyro.plate("data", len(pendulum_code)):
pos = numpyro.sample("obs", dist.Normal(modelx, σ), obs=pos_obs)


def unpooled_model(planet_code,
pendulum_code,
times,
exponential,
pos_obs=None):
n_planets = len(np.unique(planet_code))
n_pendulums = len(np.unique(pendulum_code))
with numpyro.plate("planet_i", n_planets):
a_g = numpyro.sample("a_g", dist.TruncatedNormal(12.5, 5,
low=0, high=25))
with numpyro.plate("pend_i", n_pendulums):
L = numpyro.sample("L", dist.TruncatedNormal(5, 2, low = 0.01))
theta = numpyro.sample("theta", dist.TruncatedNormal(jnp.pi/100,
jnp.pi/500,
low=0.00001))
σ = numpyro.sample("σ", dist.Exponential(exponential))
modelx = L[pendulum_code] * jnp.sin(theta[pendulum_code] *
jnp.cos(jnp.sqrt(a_g[planet_code] / L[pendulum_code]) * times))
with numpyro.plate("data", len(pendulum_code)):
pos = numpyro.sample("obs", dist.Normal(modelx, σ), obs=pos_obs)
Loading

0 comments on commit 8e4d2c5

Please sign in to comment.