Skip to content

Commit

Permalink
Merge pull request #289 from vallis/pulsardeflate - add partially sha…
Browse files Browse the repository at this point in the history
…red Pulsar objects

Tested, checks OK, merge!
  • Loading branch information
vallis authored Sep 8, 2021
2 parents 9ad4a29 + e92596b commit d7bc0f6
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 27 deletions.
7 changes: 7 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,10 @@ source = enterprise
omit =
enterprise/__init__*
enterprise/signals/__init__*
plugins =
coverage_conditional_plugin

[coverage_conditional_plugin]
rules =
"sys_version_info >= (3, 8)": py-gte-38
"sys_version_info < (3, 8)": py-lt-38
90 changes: 69 additions & 21 deletions enterprise/pulsar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import os
import pickle

import astropy.constants as const
import astropy.units as u
Expand All @@ -14,10 +15,7 @@
import enterprise
from enterprise.signals import utils

try:
import cPickle as pickle
except:
import pickle
from enterprise.pulsar_inflate import PulsarInflater

logger = logging.getLogger(__name__)

Expand All @@ -36,7 +34,6 @@
logger.warning("PINT not installed. Will use libstempo instead.") # pragma: no cover
pint = None


if pint is None and t2 is None:
err_msg = "Must have either PINT or libstempo timing package installed"
raise ImportError(err_msg)
Expand Down Expand Up @@ -141,8 +138,11 @@ def filter_data(self, start_time=None, end_time=None):
dmx_mask = np.sum(self._designmatrix, axis=0) != 0.0
self._designmatrix = self._designmatrix[:, dmx_mask]

for key in self._flags:
self._flags[key] = self._flags[key][mask]
if isinstance(self._flags, np.ndarray):
self._flags = self._flags[mask]
else:
for key in self._flags:
self._flags[key] = self._flags[key][mask]

if self._planetssb is not None:
self._planetssb = self.planetssb[mask, :, :]
Expand Down Expand Up @@ -233,26 +233,37 @@ def dmx(self):
def flags(self):
"""Return a dictionary of tim-file flags."""

return dict((k, v[self._isort]) for k, v in self._flags.items())
flagnames = self._flags.dtype.names if isinstance(self._flags, np.ndarray) else self._flags.keys()

return {flag: self._flags[flag][self._isort] for flag in flagnames}

@property
def backend_flags(self):
"""Return array of backend flags.
Not all TOAs have the same flags for all data sets. In order to
facilitate this we have a ranked ordering system that will look
for flags. The order is `group`, `g`, `sys`, `i`, `f`, `fe`+`be`.
"""

nobs = len(self._toas)
bflags = ["flag"] * nobs
flags = [["group"], ["g"], ["sys"], ["i"], ["f"], ["fe", "be"]]
for ii in range(nobs):
# TODO: make this cleaner
for f in flags:
if np.all([x in self._flags and self._flags[x][ii] != "" for x in f]):
bflags[ii] = "_".join(self._flags[x][ii] for x in f)
break
return np.array(bflags)[self._isort]
# collect flag names
flagnames = self._flags.dtype.names if isinstance(self._flags, np.ndarray) else list(self._flags.keys())

# allocate array with widest dtype
ret = np.zeros(len(self._toas), dtype=max([self._flags[name].dtype for name in flagnames]))

# go through the flags in reverse order of preference
# setting or replacing values for each TOA

if "fe" in flagnames and "be" in flagnames:
ret[:] = [(a + "_" + b if (a and b) else "") for a, b in zip(self._flags["fe"], self._flags["be"])]

for flag in ["f", "i", "sys", "g", "group"]:
if flag in flagnames:
ret[:] = np.where(self._flags[flag] == "", ret, self._flags[flag])

return ret

@property
def theta(self):
Expand Down Expand Up @@ -431,9 +442,14 @@ def __init__(self, t2pulsar, sort=True, drop_t2pulsar=True, planets=True):
spars = [str(p) for p in t2pulsar.pars(which="set")]
self.setpars = [sp for sp in spars if sp not in self.fitpars]

self._flags = {}
flags = {}
for key in t2pulsar.flags():
self._flags[key] = t2pulsar.flagvals(key)
flags[key] = t2pulsar.flagvals(key)

# new-style storage of flags as a numpy record array (previously, psr._flags = flags)
self._flags = np.zeros(len(self._toas), dtype=[(key, val.dtype) for key, val in flags.items()])
for key, val in flags.items():
self._flags[key] = val

self._pdist = self._get_pdist()
self._raj, self._decj = self._get_radec(t2pulsar)
Expand Down Expand Up @@ -524,9 +540,41 @@ def _get_sunssb(self, t2pulsar):
sunssb[:, 3:] = utils.ecl2eq_vec(sunssb[:, 3:])
return sunssb

# infrastructure for sharing Pulsar objects among processes
# (currently Tempo2Pulsar only)
# the Pulsar deflater will copy select numpy arrays to SharedMemory,
# then replace them with pickleable objects that can be inflated
# to numpy arrays with SharedMemory storage

def Pulsar(*args, **kwargs):
_todeflate = ["_designmatrix", "_planetssb", "_sunssb", "_flags"]
_deflated = "pristine"

def deflate(psr): # pragma: py-lt-38
if psr._deflated == "pristine":
for attr in psr._todeflate:
if isinstance(getattr(psr, attr), np.ndarray):
setattr(psr, attr, PulsarInflater(getattr(psr, attr)))

psr._deflated = "deflated"

def inflate(psr): # pragma: py-lt-38
if psr._deflated == "deflated":
for attr in psr._todeflate:
if isinstance(getattr(psr, attr), PulsarInflater):
setattr(psr, attr, getattr(psr, attr).inflate())

psr._deflated = "inflated"

def destroy(psr): # pragma: py-lt-38
if psr._deflated == "deflated":
for attr in psr._todeflate:
if isinstance(getattr(psr, attr), PulsarInflater):
getattr(psr, attr).destroy()

psr._deflated = "destroyed"


def Pulsar(*args, **kwargs):
ephem = kwargs.get("ephem", None)
clk = kwargs.get("clk", None)
bipm_version = kwargs.get("bipm_version", None)
Expand Down
60 changes: 60 additions & 0 deletions enterprise/pulsar_inflate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# pulsar_inflate.py
"""Defines PulsarInflater class: instances copy a numpy array to shared memory,
and (after pickling) will reinflate to a numpy array that refers to the shared
data.
"""

import numpy as np

try:
from multiprocessing import shared_memory, resource_tracker
except:
# shared_memory unavailable in Python < 3.8
pass


class memmap(np.ndarray):
def __del__(self):
if self.base is None and hasattr(self, "shm"):
self.shm.close()


# lifecycle of shared pulsar arrays:
# - begin life as numpy arrays in Pulsar object
# - upon psr.deflate(), replaced by PulsarInflater objects
# - these objects save the array metadata, create a SharedMemory buffer, copy the arrays into it
# - the PulsarInflater objects cannot be used as arrays until re-inflated
# - upon psr.inflate, the PulsarInflater objects are replaced with ndarray views of the SharedMemory buffers
# - the views are special memmap objects that hold a reference to the SharedMemory, and close it on destruction
# - upon psr.destroy, the SharedMemory objects are unlinked and the arrays become unusable
# - standard usage requires 3+ processes:
# - a creator, who calls deflate then pickle
# - one or more users, who unpickle then inflate
# - a destroyer, who unpickles then destroys


class PulsarInflater:
def __init__(self, array):
self.dtype, self.shape, self.nbytes = array.dtype, array.shape, array.nbytes

shm = shared_memory.SharedMemory(create=True, size=array.nbytes)
self.shmname = shm.name

# shm.buf[:array.nbytes] = array.view(dtype='uint8').flatten()

b = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
b[:] = array[:]

resource_tracker.unregister(shm._name, "shared_memory")

def inflate(self):
shm = shared_memory.SharedMemory(self.shmname)

c = np.ndarray(self.shape, dtype=self.dtype, buffer=shm.buf).view(memmap)
c.shm = shm

return c

def destroy(self):
shm = shared_memory.SharedMemory(self.shmname)
shm.unlink()
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ PyYAML>=4.2b1
pytest>=4.0.0
sphinx-rtd-theme>=0.4.0
pytest-cov>=2.7.0
coverage-conditional-plugin>=0.4.0
jupyter>=1.0.0
build==0.3.1.post1
37 changes: 31 additions & 6 deletions tests/test_pulsar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,18 @@
for time slicing, PINT integration and pickling.
"""


import sys
import os
import shutil
import unittest
import pickle
import pytest

import numpy as np

from enterprise.pulsar import Pulsar
from tests.enterprise_test_data import datadir

try:
import cPickle as pickle
except:
import pickle


class TestPulsar(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -133,6 +130,31 @@ def test_to_pickle(self):

assert np.allclose(self.psr.residuals, pkl_psr.residuals, rtol=1e-10)

@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python >= 3.8")
def test_deflate_inflate(self):
psr = Pulsar(datadir + "/B1855+09_NANOGrav_9yv1.gls.par", datadir + "/B1855+09_NANOGrav_9yv1.tim")

dm = psr._designmatrix.copy()

psr.deflate()
psr.to_pickle()

with open("B1855+09.pkl", "rb") as f:
pkl_psr = pickle.load(f)
pkl_psr.inflate()

assert np.allclose(dm, pkl_psr._designmatrix)

del pkl_psr

psr.destroy()

with open("B1855+09.pkl", "rb") as f:
pkl_psr = pickle.load(f)

with self.assertRaises(FileNotFoundError):
pkl_psr.inflate()

def test_wrong_input(self):
"""Test exception when incorrect par(tim) file given."""

Expand Down Expand Up @@ -182,3 +204,6 @@ def test_model(self):

def test_pint_toas(self):
assert hasattr(self.psr, "pint_toas")

def test_deflate_inflate(self):
pass

0 comments on commit d7bc0f6

Please sign in to comment.