@@ -12,6 +12,7 @@ def solve(
12
12
dt : float = 1.0 ,
13
13
burnin : int = 0 ,
14
14
key : Array = None ,
15
+ associative_scan : bool = True ,
15
16
) -> Array :
16
17
"""
17
18
Obtain the solution of the linear system
@@ -29,6 +30,7 @@ def solve(
29
30
burnin: Time-step index corresponding to the end of the burn-in period.
30
31
Samples before this step are not collected.
31
32
key: JAX random key
33
+ associative_scan: If True, uses jax.lax.associative_scan.
32
34
33
35
Returns:
34
36
Approximate solution, x, of the linear system.
@@ -37,7 +39,9 @@ def solve(
37
39
key = random .PRNGKey (0 )
38
40
ts = jnp .arange (burnin , burnin + num_samples ) * dt
39
41
x0 = jnp .zeros_like (b )
40
- samples = sample_identity_diffusion (key , ts , x0 , A , jnp .linalg .solve (A , b ))
42
+ samples = sample_identity_diffusion (
43
+ key , ts , x0 , A , jnp .linalg .solve (A , b ), associative_scan
44
+ )
41
45
return jnp .mean (samples , axis = 0 )
42
46
43
47
@@ -47,6 +51,7 @@ def inv(
47
51
dt : float = 1.0 ,
48
52
burnin : int = 0 ,
49
53
key : Array = None ,
54
+ associative_scan : bool = True ,
50
55
) -> Array :
51
56
"""
52
57
Obtain the inverse of a matrix A by
@@ -60,6 +65,7 @@ def inv(
60
65
burnin: Time-step index corresponding to the end of the burn-in period.
61
66
Samples before this step are not collected.
62
67
key: JAX random key
68
+ associative_scan: If True, uses jax.lax.associative_scan.
63
69
64
70
Returns:
65
71
Approximate inverse of A.
@@ -69,7 +75,7 @@ def inv(
69
75
ts = jnp .arange (burnin , burnin + num_samples ) * dt
70
76
b = jnp .zeros (A .shape [0 ])
71
77
x0 = jnp .zeros_like (b )
72
- samples = sample (key , ts , x0 , A , b , 2 * jnp .eye (A .shape [0 ]))
78
+ samples = sample (key , ts , x0 , A , b , 2 * jnp .eye (A .shape [0 ]), associative_scan )
73
79
return jnp .cov (samples .T )
74
80
75
81
@@ -80,6 +86,7 @@ def expnegm(
80
86
burnin : int = 0 ,
81
87
key : Array = None ,
82
88
alpha : float = 0.0 ,
89
+ associative_scan : bool = True ,
83
90
) -> Array :
84
91
"""
85
92
Obtain the negative exponential of a matrix A by
@@ -95,6 +102,7 @@ def expnegm(
95
102
key: JAX random key
96
103
alpha: Regularization parameter to ensure diffusion matrix
97
104
is symmetric positive definite.
105
+ associative_scan: If True, uses jax.lax.associative_scan.
98
106
99
107
Returns:
100
108
Approximate negative matrix exponential, exp(-A).
@@ -108,7 +116,7 @@ def expnegm(
108
116
ts = jnp .arange (burnin , burnin + num_samples ) * dt
109
117
b = jnp .zeros (A .shape [0 ])
110
118
x0 = jnp .zeros_like (b )
111
- samples = sample (key , ts , x0 , A_shifted , b , B )
119
+ samples = sample (key , ts , x0 , A_shifted , b , B , associative_scan )
112
120
return autocovariance (samples ) * jnp .exp (alpha )
113
121
114
122
@@ -119,6 +127,7 @@ def expm(
119
127
burnin : int = 0 ,
120
128
key : Array = None ,
121
129
alpha : float = 1.0 ,
130
+ associative_scan : bool = True ,
122
131
) -> Array :
123
132
"""
124
133
Obtain the exponential of a matrix A by
@@ -134,11 +143,12 @@ def expm(
134
143
key: JAX random key
135
144
alpha: Regularization parameter to ensure diffusion matrix
136
145
is symmetric positive definite.
146
+ associative_scan: If True, uses jax.lax.associative_scan.
137
147
138
148
Returns:
139
149
Approximate matrix exponential, exp(A).
140
150
"""
141
- return expnegm (- A , num_samples , dt , burnin , key , alpha )
151
+ return expnegm (- A , num_samples , dt , burnin , key , alpha , associative_scan )
142
152
143
153
144
154
def autocovariance (samples : Array ) -> Array :
0 commit comments