-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
45 lines (38 loc) · 1.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import List
import jax.numpy as np
from jax import vmap
from functools import partial
import numpy as onp
def get_hk(k): # normalizing factor for basis function
_hk = (2. * k + onp.sin(2 * k))/(4. * k)
_hk = _hk.at[onp.isnan(_hk)].set(1.)
return onp.sqrt(onp.prod(_hk))
def get_ck(trajectory, basis): # fourier coefficients for time-averaged trajectory
ck = np.mean(vmap(basis.fk_vmap)(trajectory), axis=0)
ck = ck / basis.hk_list
return ck
def get_phik(vals, basis): # fourier coefficients for information map
_phi, _x = vals
phik = np.dot(_phi, vmap(basis.fk_vmap)(_x))
phik = phik/phik[0]
phik = phik/basis.hk_list
return phik
def recon_from_fourier(basis_coef):
pass
class BasisFunc(object):
def __init__(self, n_basis) -> None:
kmesh = np.meshgrid(
*[np.arange(0,n_max, step=1) for n_max in n_basis]
)
self.n = len(n_basis)
self.k_list = np.stack([
_k.ravel() for _k in kmesh
]).T * np.pi
self.hk_list = np.array([
get_hk(_k) for _k in self.k_list
])
self._fk = lambda k, x: np.prod(np.cos(x*k))
self.fk_kvmap = vmap(self._fk, in_axes=(0, None))
self.fk_xvmap = vmap(self._fk, in_axes=(None, 0))
# self.fk_vmap = partial(self.fk_kvmap, np.array([0.1,0.2]))
self.fk_vmap = partial(self.fk_kvmap, self.k_list)