@@ -11,7 +11,7 @@ def solve(
11
11
num_samples : int = 10000 ,
12
12
dt : float = 1.0 ,
13
13
burnin : int = 0 ,
14
- seed : int = 0 ,
14
+ key : Array = None ,
15
15
) -> Array :
16
16
"""
17
17
Obtain the solution of the linear system
@@ -27,12 +27,13 @@ def solve(
27
27
- num_samples: float, number of samples to be collected.
28
28
- dt: float, time step.
29
29
- burnin: burn-in, steps before which samples are not collected.
30
- - seed: random seed
30
+ - key: JAX random key
31
31
32
32
Returns:
33
33
- approximate solution, x, of the linear system.
34
34
"""
35
- key = jax .random .PRNGKey (seed )
35
+ if key is None :
36
+ key = jax .random .PRNGKey (0 )
36
37
ts = jnp .arange (burnin , burnin + num_samples ) * dt
37
38
x0 = jnp .zeros_like (b )
38
39
samples = sample_identity_diffusion (key , ts , x0 , A , jnp .linalg .solve (A , b ))
@@ -44,7 +45,7 @@ def inv(
44
45
num_samples : int = 10000 ,
45
46
dt : float = 1.0 ,
46
47
burnin : int = 0 ,
47
- seed : int = 0 ,
48
+ key : Array = None ,
48
49
) -> Array :
49
50
"""
50
51
Obtain the inverse of a matrix A by
@@ -56,12 +57,13 @@ def inv(
56
57
- num_samples: float, number of samples to be collected.
57
58
- dt: float, time step.
58
59
- burnin: burn-in, steps before which samples are not collected.
59
- - seed: random seed
60
+ - key: JAX random key
60
61
61
62
Returns:
62
63
- approximate inverse of A.
63
64
"""
64
- key = jax .random .PRNGKey (seed )
65
+ if key is None :
66
+ key = jax .random .PRNGKey (0 )
65
67
ts = jnp .arange (burnin , burnin + num_samples ) * dt
66
68
b = jnp .zeros (A .shape [0 ])
67
69
x0 = jnp .zeros_like (b )
@@ -74,7 +76,7 @@ def expnegm(
74
76
num_samples : int = 10000 ,
75
77
dt : float = 1.0 ,
76
78
burnin : int = 0 ,
77
- seed : int = 0 ,
79
+ key : Array = None ,
78
80
alpha : float = 0.0 ,
79
81
) -> Array :
80
82
"""
@@ -87,18 +89,19 @@ def expnegm(
87
89
- num_samples: float, number of samples to be collected.
88
90
- dt: float, time step.
89
91
- burnin: burn-in, steps before which samples are not collected.
90
- - seed: random seed
92
+ - key: JAX random key
91
93
- alpha: float, regularization parameter to ensure diffusion matrix
92
94
is symmetric positive definite.
93
95
94
96
Returns:
95
97
- approximate negative matrix exponential, exp(-A).
96
98
"""
99
+ if key is None :
100
+ key = jax .random .PRNGKey (0 )
101
+
97
102
A_shifted = (A + alpha * jnp .eye (A .shape [0 ])) / dt
98
103
B = A_shifted + A_shifted .T
99
104
100
- key = jax .random .PRNGKey (seed )
101
-
102
105
ts = jnp .arange (burnin , burnin + num_samples ) * dt
103
106
b = jnp .zeros (A .shape [0 ])
104
107
x0 = jnp .zeros_like (b )
@@ -111,7 +114,7 @@ def expm(
111
114
num_samples : int = 10000 ,
112
115
dt : float = 1.0 ,
113
116
burnin : int = 0 ,
114
- seed : int = 0 ,
117
+ key : Array = None ,
115
118
alpha : float = 1.0 ,
116
119
) -> Array :
117
120
"""
@@ -124,14 +127,14 @@ def expm(
124
127
- num_samples: float, number of samples to be collected.
125
128
- dt: float, time step.
126
129
- burnin: burn-in, steps before which samples are not collected.
127
- - seed: random seed
130
+ - key: JAX random key
128
131
- alpha: float, regularization parameter to ensure diffusion matrix
129
132
is symmetric positive definite.
130
133
131
134
Returns:
132
135
- approximate matrix exponential, exp(A).
133
136
"""
134
- return expnegm (- A , num_samples , dt , burnin , seed , alpha )
137
+ return expnegm (- A , num_samples , dt , burnin , key , alpha )
135
138
136
139
137
140
def autocovariance (samples : Array ) -> Array :
0 commit comments