Skip to content

Commit

Permalink
snapshot works via pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterKraus committed Feb 25, 2024
1 parent 7353656 commit e7f90e7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 19 deletions.
13 changes: 7 additions & 6 deletions src/tomato/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def tomato_job() -> None:
logfile = jobpath / f"job-{jobid}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s:%(levelname)-8s:%(processName)s:%(message)s",
format="%(asctime)s - %(levelname)8s - %(name)-30s - %(message)s",
handlers=[logging.FileHandler(logfile, mode="a"), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,21 +84,22 @@ def tomato_job() -> None:
else:
logger.warning(f"could not contact tomato-daemon in {TIMEOUT/1000} s")

logger.info("handing off to 'job_worker'")
logger.info("==============================")
ret = job_worker(context, args.port, payload["method"], pip, jobpath)
logger.info("==============================")

output = tomato["output"]
prefix = f"results.{jobid}" if output["prefix"] is None else output["prefix"]
outpath = Path(output["path"])
snappath = outpath / f"snapshot.{jobid}.nc"
logger.debug(f"output folder is {outpath}")
if outpath.exists():
assert outpath.is_dir()
else:
logger.debug("path does not exist, creating")
os.makedirs(outpath)

logger.info("handing off to 'job_worker'")
logger.info("==============================")
ret = job_worker(context, args.port, payload, pip, jobpath, snappath)
logger.info("==============================")

merge_netcdfs(jobpath, outpath / f"{prefix}.nc")

if ret is None:
Expand Down
43 changes: 31 additions & 12 deletions src/tomato/drivers/jobfuncs.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
import time
import os
import logging
import zmq
from pathlib import Path
import xarray as xr
import pickle

from multiprocessing import Process
from tomato.models import Component, Device, Driver

logger = logging.getLogger(f"{__name__}")


def merge_netcdfs(jobpath: Path, outpath: Path):
fns = [fn for fn in os.listdir(jobpath) if fn.endswith(".nc")]
datasets = [xr.load_dataset(jobpath / fn, engine="h5netcdf") for fn in fns]
logger = logging.getLogger(f"{__name__}.merge_netcdf")
logger.debug("opening datasets")
datasets = []
for fn in jobpath.glob("*.pkl"):
with pickle.load(fn.open("rb")) as ds:
datasets.append(ds)
logger.debug("merging datasets")
if len(datasets) > 0:
ds = xr.concat(datasets, dim="uts")
ds.to_netcdf(outpath, engine="h5netcdf")


def data_to_netcdf(ds: xr.Dataset, path: Path):
def data_to_pickle(ds: xr.Dataset, path: Path):
logger = logging.getLogger(f"{__name__}.data_to_pickle")
logger.debug("checking existing")
if path.exists():
oldds = xr.load_dataset(path, engine="h5netcdf")
ds = xr.concat([oldds, ds], dim="uts")
ds.to_netcdf(path, engine="h5netcdf")
with pickle.load(path.open("rb")) as oldds:
ds = xr.concat([oldds, ds], dim="uts")
logger.debug("dumping pickle")
with path.open("wb") as out:
pickle.dump(ds, out, protocol=5)


def job_process(
Expand All @@ -42,7 +53,7 @@ def job_process(

kwargs = dict(address=component.address, channel=component.channel)

datapath = jobpath / f"{component.role}.nc"
datapath = jobpath / f"{component.role}.pkl"
logger.debug("distributing tasks:")
for task in tasks:
logger.debug(f"{task=}")
Expand All @@ -64,7 +75,7 @@ def job_process(
req.send_pyobj(dict(cmd="task_data", params={**kwargs}))
ret = req.recv_pyobj()
if ret.success:
data_to_netcdf(ret.data, datapath)
data_to_pickle(ret.data, datapath)
t0 += device.pollrate
req.send_pyobj(dict(cmd="task_status", params={**kwargs}))
ret = req.recv_pyobj()
Expand All @@ -75,15 +86,16 @@ def job_process(
req.send_pyobj(dict(cmd="task_data", params={**kwargs}))
ret = req.recv_pyobj()
if ret.success:
data_to_netcdf(ret.data, datapath)
data_to_pickle(ret.data, datapath)


def job_worker(
context: zmq.Context,
port: int,
method: dict,
payload: dict,
pipname: str,
jobpath: Path,
snappath: Path,
) -> None:
sender = f"{__name__}.job_worker"
logger = logging.getLogger(sender)
Expand All @@ -99,7 +111,7 @@ def job_worker(

# collate steps by role
plan = {}
for step in method:
for step in payload["method"]:
if step["device"] not in plan:
plan[step["device"]] = []
task = {k: v for k, v in step.items()}
Expand All @@ -124,7 +136,14 @@ def job_worker(
processes[role].start()

# wait until threads join or we're killed
snapshot = payload["tomato"].get("snapshot", None)
t0 = time.perf_counter()
while True:
tN = time.perf_counter()
if snapshot is not None and tN - t0 > snapshot["frequency"]:
logger.debug("creating snapshot")
merge_netcdfs(jobpath, snappath)
t0 += snapshot["frequency"]
joined = [proc.is_alive() is False for proc in processes.values()]
if all(joined):
break
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ def stop_tomato_daemon(port: int = 12345):
pass
gone, alive = psutil.wait_procs(procs, timeout=5)
print(f"{gone=}")
print(f"{alive=}")
print(f"{alive=}")

0 comments on commit e7f90e7

Please sign in to comment.