From 5ed08e07c3eca790dea292b000baf961e499d145 Mon Sep 17 00:00:00 2001
From: Saladino93 <saladino_93@hotmail.it>
Date: Fri, 15 Sep 2023 12:14:41 +0200
Subject: [PATCH 1/2] phas new numpy rng generator without setting/getting
 states manually

---
 plancklens/sims/phas.py | 39 +++++++++++++++++----------------------
 1 file changed, 17 insertions(+), 22 deletions(-)

diff --git a/plancklens/sims/phas.py b/plancklens/sims/phas.py
index 68791dd..6f9d877 100644
--- a/plancklens/sims/phas.py
+++ b/plancklens/sims/phas.py
@@ -26,13 +26,12 @@ def __init__(self, fname, idtype="INTEGER"):
 
         self.con = sqlite3.connect(fname, timeout=3600., detect_types=sqlite3.PARSE_DECLTYPES)
 
-    def add(self, idx, state):
+    def add(self, idx):
         idx = int(idx)
         try:
             assert (self.get(idx) is None)
-            keys_string = '_'.join(str(s) for s in state[1])
-            self.con.execute("INSERT INTO rngdb (id, type, pos, has_gauss, cached_gaussian, keys) VALUES (?,?,?,?,?,?)",
-                             (idx, state[0], state[2], state[3], state[4], keys_string))
+            self.con.execute("INSERT INTO rngdb (id) VALUES (?)",
+                             (idx))
             self.con.commit()
         except:
             print("rng_db::rngdb add failed!")
@@ -40,16 +39,16 @@ def add(self, idx, state):
     def get(self, idx):
         idx = int(idx)
         cur = self.con.cursor()
-        cur.execute("SELECT type, pos, has_gauss, cached_gaussian, keys FROM rngdb WHERE id=?", (idx,))
+        cur.execute("SELECT id FROM rngdb WHERE id=?", (idx,))#probably won't be necessary anymore
         data = cur.fetchone()
         cur.close()
         if data is None:
             return None
         else:
-            assert (len(data) == 5)
-            typ, pos, has_gauss, cached_gaussian, keys = data
+            assert (len(data) == 1)
+            id = data
             keys = np.array([int(a) for a in keys.split('_')], dtype=np.uint32)
-            return [typ, keys, pos, has_gauss, cached_gaussian]
+            return [id]
 
     def delete(self, idx):
         idx = int(idx)
@@ -63,14 +62,11 @@ def delete(self, idx):
 
 
 class sim_lib(object):
-    """Generic class for simulations where only rng state is stored.
-
-    np.random rng states are stored in a sqlite3 database. By default the rng state function is np.random.get_state.
-    The rng_db class is tuned for this state fct, you may need to adapt this.
+    """Generic class for simulations. We store the index idx, and then use np.random.RandomState
 
     """
 
-    def __init__(self, lib_dir, get_state_func=np.random.get_state, nsims_max=None):
+    def __init__(self, lib_dir, nsims_max=None):
         if not os.path.exists(lib_dir) and mpi.rank == 0:
             os.makedirs(lib_dir)
         self.nmax = nsims_max
@@ -83,13 +79,12 @@ def __init__(self, lib_dir, get_state_func=np.random.get_state, nsims_max=None):
         utils.hash_check(hsh, self.hashdict(), ignore=['lib_dir'], fn=fn_hash)
 
         self._rng_db = rng_db(os.path.join(lib_dir, 'rngdb.db'), idtype='INTEGER')
-        self._get_rng_state = get_state_func
 
     def get_sim(self, idx, **kwargs):
         """Returns sim number idx and caches random number generator state. """
         if self.has_nmax(): assert idx < self.nmax
         if not self.is_stored(idx):
-            self._rng_db.add(idx, self._get_rng_state())
+            self._rng_db.add(idx)
         return self._build_sim_from_rng(self._rng_db.get(idx), **kwargs)
 
     def has_nmax(self):
@@ -117,7 +112,7 @@ def hashdict(self):
         """Override this """
         assert 0
 
-    def _build_sim_from_rng(self, rng_state):
+    def _build_sim_from_rng(self, idx):
         """Override this """
         assert 0
 
@@ -127,9 +122,9 @@ def __init__(self, lib_dir, shape, **kwargs):
         self.shape = shape
         super(_pix_lib_phas, self).__init__(lib_dir, **kwargs)
 
-    def _build_sim_from_rng(self, rng_state, **kwargs):
-        np.random.set_state(rng_state)
-        return np.random.standard_normal(self.shape)
+    def _build_sim_from_rng(self, idx, **kwargs):
+        rng = np.random.RandomState(idx)
+        return rng.standard_normal(self.shape)
 
     def hashdict(self):
         return {'shape': self.shape}
@@ -159,9 +154,9 @@ def __init__(self, lib_dir,lmax, **kwargs):
         self.lmax = lmax
         super(_lib_phas, self).__init__(lib_dir, **kwargs)
 
-    def _build_sim_from_rng(self, rng_state, phas_only=False):
-        np.random.set_state(rng_state)
-        alm = (np.random.standard_normal(hp.Alm.getsize(self.lmax)) + 1j * np.random.standard_normal(hp.Alm.getsize(self.lmax))) / np.sqrt(2.)
+    def _build_sim_from_rng(self, idx, phas_only=False):
+        rng = np.random.RandomState(idx)
+        alm = (rng.standard_normal(hp.Alm.getsize(self.lmax)) + 1j * rng.standard_normal(hp.Alm.getsize(self.lmax))) / np.sqrt(2.)
         if phas_only: return
         m0 = hp.Alm.getidx(self.lmax, np.arange(self.lmax + 1,dtype = int),0)
         alm[m0] = np.sqrt(2.) * alm[m0].real

From e384dea9eb12867d5d3667a2d972cec67f63aba7 Mon Sep 17 00:00:00 2001
From: Saladino93 <saladino_93@hotmail.it>
Date: Tue, 19 Sep 2023 14:28:39 +0200
Subject: [PATCH 2/2] updates, still have to include spawn for parallel

---
 plancklens/sims/phas.py | 54 ++++++++++++++++++++++++++++-------------
 1 file changed, 37 insertions(+), 17 deletions(-)

diff --git a/plancklens/sims/phas.py b/plancklens/sims/phas.py
index 6f9d877..bf573ee 100644
--- a/plancklens/sims/phas.py
+++ b/plancklens/sims/phas.py
@@ -26,12 +26,13 @@ def __init__(self, fname, idtype="INTEGER"):
 
         self.con = sqlite3.connect(fname, timeout=3600., detect_types=sqlite3.PARSE_DECLTYPES)
 
-    def add(self, idx):
+    def add(self, idx, state):
         idx = int(idx)
         try:
             assert (self.get(idx) is None)
-            self.con.execute("INSERT INTO rngdb (id) VALUES (?)",
-                             (idx))
+            keys_string = '_'.join(str(s) for s in state[1])
+            self.con.execute("INSERT INTO rngdb (id, type, pos, has_gauss, cached_gaussian, keys) VALUES (?,?,?,?,?,?)",
+                             (idx, state[0], state[2], state[3], state[4], keys_string))
             self.con.commit()
         except:
             print("rng_db::rngdb add failed!")
@@ -39,16 +40,16 @@ def add(self, idx):
     def get(self, idx):
         idx = int(idx)
         cur = self.con.cursor()
-        cur.execute("SELECT id FROM rngdb WHERE id=?", (idx,))#probably won't be necessary anymore
+        cur.execute("SELECT type, pos, has_gauss, cached_gaussian, keys FROM rngdb WHERE id=?", (idx,))
         data = cur.fetchone()
         cur.close()
         if data is None:
             return None
         else:
-            assert (len(data) == 1)
-            id = data
+            assert (len(data) == 5)
+            typ, pos, has_gauss, cached_gaussian, keys = data
             keys = np.array([int(a) for a in keys.split('_')], dtype=np.uint32)
-            return [id]
+            return [typ, keys, pos, has_gauss, cached_gaussian]
 
     def delete(self, idx):
         idx = int(idx)
@@ -62,11 +63,14 @@ def delete(self, idx):
 
 
 class sim_lib(object):
-    """Generic class for simulations. We store the index idx, and then use np.random.RandomState
+    """Generic class for simulations where only rng state is stored.
+
+    np.random rng states are stored in a sqlite3 database. By default the rng state function is np.random.get_state.
+    The rng_db class is tuned for this state fct, you may need to adapt this.
 
     """
 
-    def __init__(self, lib_dir, nsims_max=None):
+    def __init__(self, lib_dir, get_state_func=np.random.get_state, nsims_max=None):
         if not os.path.exists(lib_dir) and mpi.rank == 0:
             os.makedirs(lib_dir)
         self.nmax = nsims_max
@@ -79,12 +83,27 @@ def __init__(self, lib_dir, nsims_max=None):
         utils.hash_check(hsh, self.hashdict(), ignore=['lib_dir'], fn=fn_hash)
 
         self._rng_db = rng_db(os.path.join(lib_dir, 'rngdb.db'), idtype='INTEGER')
+        self._get_rng_state = get_state_func
+
+    @staticmethod
+    def get_state(idx):
+        """Returns a random number generator state from a seed. """
+        #sg = np.random.SeedSequence(idx)
+        #mt19937 = np.random.MT19937(sg)
+        #rs = np.random.RandomState(mt19937)
+        rs = np.random.Generator(np.random.MT19937())
+        dictionary = rs.__getstate__()
+        l = [dictionary[k] for k in dictionary.keys()]
+        return [l[0], l[1]['key'], l[1]['pos'], 0, 0.0]
+        #return rs.get_state()
 
     def get_sim(self, idx, **kwargs):
         """Returns sim number idx and caches random number generator state. """
         if self.has_nmax(): assert idx < self.nmax
+
         if not self.is_stored(idx):
-            self._rng_db.add(idx)
+            #self._rng_db.add(idx, self._get_rng_state())
+            self._rng_db.add(idx, self.get_state(idx))
         return self._build_sim_from_rng(self._rng_db.get(idx), **kwargs)
 
     def has_nmax(self):
@@ -112,7 +131,7 @@ def hashdict(self):
         """Override this """
         assert 0
 
-    def _build_sim_from_rng(self, idx):
+    def _build_sim_from_rng(self, rng_state):
         """Override this """
         assert 0
 
@@ -122,9 +141,9 @@ def __init__(self, lib_dir, shape, **kwargs):
         self.shape = shape
         super(_pix_lib_phas, self).__init__(lib_dir, **kwargs)
 
-    def _build_sim_from_rng(self, idx, **kwargs):
-        rng = np.random.RandomState(idx)
-        return rng.standard_normal(self.shape)
+    def _build_sim_from_rng(self, rng_state, **kwargs):
+        np.random.set_state(rng_state)
+        return np.random.standard_normal(self.shape)
 
     def hashdict(self):
         return {'shape': self.shape}
@@ -154,9 +173,9 @@ def __init__(self, lib_dir,lmax, **kwargs):
         self.lmax = lmax
         super(_lib_phas, self).__init__(lib_dir, **kwargs)
 
-    def _build_sim_from_rng(self, idx, phas_only=False):
-        rng = np.random.RandomState(idx)
-        alm = (rng.standard_normal(hp.Alm.getsize(self.lmax)) + 1j * rng.standard_normal(hp.Alm.getsize(self.lmax))) / np.sqrt(2.)
+    def _build_sim_from_rng(self, rng_state, phas_only=False):
+        np.random.set_state(rng_state)
+        alm = (np.random.standard_normal(hp.Alm.getsize(self.lmax)) + 1j * np.random.standard_normal(hp.Alm.getsize(self.lmax))) / np.sqrt(2.)
         if phas_only: return
         m0 = hp.Alm.getidx(self.lmax, np.arange(self.lmax + 1,dtype = int),0)
         alm[m0] = np.sqrt(2.) * alm[m0].real
@@ -188,3 +207,4 @@ def get_sim(self, idx, idf=None, phas_only=False):
 
     def hashdict(self):
         return {'nfields': self.nfields, 'lmax':self.lmax}
+