-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodels_jax.py
162 lines (128 loc) · 5.4 KB
/
models_jax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# -------------------
# Models: JAX implementation
# -------------------
'''
This script contains functions to:
- compute distance / dissimilarity metrics from inputs
- evaluate the correlation models for a fixed set of parameters
- probabilistic correlation models that specify the
prior parameter distributions and the probabilistic structure
For a detailed explanation see notebook:
introductory_example_estimation.ipynb
The models are built in numpyro,
which in turn makes use of JAX
to enable high-performance MCMC computations.
Therefore it is necessary to use jax.numpy instead
of numpy functions.
For an application of the estimated models,
numpy implementations of the models are sufficient
(see models_numpy.py).
'''
import numpy as np
import numpyro as numpyro
import jax as jax
# Change JAX default from float32 to float64
jax.config.update("jax_enable_x64", True)
# Import some functions directly for convenience
import jax.numpy as jnp
import numpyro.distributions as dist
# -------------------
# Distance metrics
# -------------------
'''
See models_numpy.py for documentation.
'''
def getEucDistanceFromPolar(X):
sq_dist = (jnp.square(jnp.reshape(X[:,0],[-1,1])) +
jnp.square(jnp.reshape(X[:,0],[1,-1])) -
2 * (jnp.reshape(X[:,0],[-1,1]) * jnp.reshape(X[:,0],[1,-1])) *
jnp.cos(jnp.abs(jnp.reshape(X[:,1],[1,-1]) - jnp.reshape(X[:,1],[-1,1]) )) )
sq_dist = jnp.clip(sq_dist, 0, np.inf)
dist_mat = jnp.sqrt(sq_dist)
dist_mat = dist_mat.at[jnp.diag_indices_from(dist_mat)].set(0.0)
return dist_mat
def getAngDistanceFromPolar(X):
cos_angle = jnp.cos( jnp.abs(jnp.reshape(X[:,1],[1,-1]) -
jnp.reshape(X[:,1],[-1,1]) ))
dist_mat = jnp.arccos(jnp.clip(cos_angle, -1, 1))
return dist_mat
def getSoilDissimilarity(X):
sq_dist = jnp.square(jnp.reshape(X[:,2], [-1,1]) - jnp.reshape(X[:,2], [1,-1]))
sq_dist = jnp.clip(sq_dist, 0, np.inf)
dist_mat = jnp.sqrt(sq_dist)
return dist_mat
# -------------------
# Deterministic correlation functions
# -------------------
'''
See models_numpy.py for documentation.
'''
def rhoE(X, LEt, gammaE, nugget=1e-6):
distE = getEucDistanceFromPolar(X)
K = jnp.exp(-1.0 * jnp.multiply(jnp.power(distE, gammaE), 1.0/LEt))
K = K.at[jnp.diag_indices(X.shape[0])].add(nugget)
return K
def rhoEA(X, LEt, gammaE, LA, nugget=1e-6):
distE = getEucDistanceFromPolar(X)
KE = jnp.exp(-1.0 * jnp.multiply(jnp.power(distE, gammaE), 1.0/LEt))
distA = getAngDistanceFromPolar(X)
distAdeg = jnp.multiply(distA, 180.0/np.pi)
KA = ((1 + jnp.multiply(distAdeg, 1.0/LA)) *
jnp.power(1 - jnp.multiply(distAdeg, 1.0/180), 180/LA))
K = KE * KA
K = K.at[jnp.diag_indices(X.shape[0])].add(nugget)
return K
def rhoEAS(X, LEt, gammaE, LA, LS, w, nugget=1e-6):
distE = getEucDistanceFromPolar(X)
KE = jnp.exp(-1.0 * jnp.multiply(jnp.power(distE, gammaE), 1.0/LEt))
distA = getAngDistanceFromPolar(X)
distAdeg = jnp.multiply(distA, 180.0/np.pi)
KA = ((1 + jnp.multiply(distAdeg, 1.0/LA)) *
jnp.power(1 - jnp.multiply(distAdeg, 1.0/180), 180/LA))
distS = getSoilDissimilarity(X)
KS = jnp.exp(-1.0 * jnp.multiply(distS, 1.0/LS))
K = KE * (w * KA + (1-w) * KS)
K = K.at[jnp.diag_indices(X.shape[0])].add(nugget)
return K
# -------------------
# Probabilistic correlation models
# -------------------
def modelE(X, eqids, z):
# Define Prior Distributions
LE = numpyro.sample("LE", dist.InverseGamma(concentration=2, rate=30))
gamma2 = numpyro.sample("gamma2", dist.Beta(2,2))
# Compute transformed parameters
gammaE = numpyro.deterministic("gammaE", 2.0*gamma2)
LEt = numpyro.deterministic("LEt", jnp.power(LE, gammaE))
# Specify observational model for each event
z = [numpyro.sample("z_{}".format(eqid),
dist.MultivariateNormal(0, rhoE(X[i], LEt, gammaE)),
obs = z[i]) for i,eqid in enumerate(eqids)]
def modelEA(X, eqids, z):
# Define Prior Distributions
LE = numpyro.sample("LE", dist.InverseGamma(concentration=2, rate=30))
gamma2 = numpyro.sample("gamma2", dist.Beta(2,2))
LAt = numpyro.sample("LAt", dist.Gamma(2, 0.25))
# Compute transformed parameters
gammaE = numpyro.deterministic("gammaE", 2.0*gamma2)
LEt = numpyro.deterministic("LEt", jnp.power(LE, gammaE))
LA = numpyro.deterministic("LA", 180/(4.0 + LAt))
# Specify observational model for each event
z = [numpyro.sample("z_{}".format(eqid),
dist.MultivariateNormal(0, rhoEA(X[i], LEt, gammaE, LA)),
obs = z[i]) for i,eqid in enumerate(eqids)]
def modelEAS(X, eqids, z):
# Define Prior Distributions
LE = numpyro.sample("LE", dist.InverseGamma(concentration=2, rate=30))
gamma2 = numpyro.sample("gamma2", dist.Beta(2,2))
LAt = numpyro.sample("LAt", dist.Gamma(2, 0.25))
LS = numpyro.sample("LS", dist.InverseGamma(2, 100))
w = numpyro.sample("w", dist.Beta(2,2))
# Compute transformed parameters
gammaE = numpyro.deterministic("gammaE", 2.0*gamma2)
LEt = numpyro.deterministic("LEt", jnp.power(LE, gammaE))
LA = numpyro.deterministic("LA", 180/(4.0 + LAt))
# Specify observational model for each event
z = [numpyro.sample("z_{}".format(eqid),
dist.MultivariateNormal(0, rhoEAS(X[i], LEt, gammaE, LA, LS, w)),
obs = z[i]) for i,eqid in enumerate(eqids)]