Skip to content

Commit

Permalink
Fix FAR calc bug and added authentication to loop
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenoit26 committed Feb 11, 2025
1 parent 3676a85 commit a7842f3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 37 deletions.
59 changes: 34 additions & 25 deletions projects/online/online/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import time
from pathlib import Path
from queue import Empty
from typing import Iterable, List, Optional, Tuple

import torch
Expand All @@ -11,12 +13,11 @@
from ledger.events import EventSet
from online.utils.buffer import InputBuffer, OutputBuffer
from online.utils.dataloading import data_iterator
from online.utils.gdb import GdbServer, gracedb_factory # , authenticate,
from online.utils.gdb import GdbServer, authenticate, gracedb_factory
from online.utils.pastro import fit_or_load_pastro
from online.utils.pe import run_amplfi, skymap_from_samples
from online.utils.searcher import Searcher
from online.utils.snapshotter import OnlineSnapshotter
from utils.logging import configure_logging
from utils.preprocessing import BatchWhitener

# seconds of data per update
Expand Down Expand Up @@ -221,7 +222,6 @@ def pastro_subprocess(
astro_event_rate: float,
outdir: Path,
):
configure_logging(outdir / "log" / "pastro.log")
gdb = gracedb_factory(server, outdir)

logging.info("Fitting p_astro model or loading from cache")
Expand Down Expand Up @@ -253,7 +253,6 @@ def amplfi_subprocess(
shared_samples: Array,
nside: int = 32,
):
configure_logging(outdir / "log" / "amplfi.log")
gdb = gracedb_factory(server, outdir)

while True:
Expand Down Expand Up @@ -294,27 +293,37 @@ def event_creation_subprocess(
pastro_queue: Queue,
):
gdb = gracedb_factory(server, outdir)
last_auth = time.time()
while True:
event = event_queue.get()
logging.info("Putting event in pastro queue")
pastro_queue.put(event)

# write event information to disk
# and submit it to gracedb
event.write(outdir)
response = gdb.submit(event)

# Get the event's graceid for submitting
# further data products
if isinstance(response, str):
# If the response is a string, then we are using the local gracedb
# client, which just returns the filename
graceid = response
else:
graceid = response.json()["graceid"]
logging.info("Putting graceid in amplfi and pastro queues")
amplfi_queue.put(graceid)
pastro_queue.put(graceid)
try:
event = event_queue.get_nowait()
logging.info("Putting event in pastro queue")
pastro_queue.put(event)

# write event information to disk
# and submit it to gracedb
event.write(outdir)
response = gdb.submit(event)

# Get the event's graceid for submitting
# further data products
if isinstance(response, str):
# If the response is a string, then we are using the
# local gracedb client, which just returns the filename
graceid = response
else:
graceid = response.json()["graceid"]
logging.info("Putting graceid in amplfi and pastro queues")
amplfi_queue.put(graceid)
pastro_queue.put(graceid)
except Empty:
time.sleep(1e-3)
# Re-authenticate every 1000 seconds so that
# the scitoken doesn't expire. Doing it in this
# loop as it's the earliest point of submission
if last_auth - time.time() > 1000:
authenticate()
last_auth = time.time()


def main(
Expand Down Expand Up @@ -429,7 +438,7 @@ def main(
device:
Device to run inference on ("cpu" or "cuda")
"""
# run htgettoken and kinit
# run kinit and htgettoken
# if server != "local":
# logging.info("Authenticating")
# authenticate()
Expand Down
14 changes: 2 additions & 12 deletions projects/online/online/utils/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,9 @@ def __init__(
# calculate the detection statistic threshold
# corresponding to the requested FAR threshold
self.threshold = background.threshold_at_far(far_threshold)
# Speed up FAR calculation by excluding below-threshold events,
# and just record the total number of them
# Speed up FAR calculation by excluding below-threshold events
mask = background.detection_statistic >= self.threshold
self.background = background[mask]
self.total_below_threshold = np.sum(~mask)

def check_refractory(self, value):
time_since_last = time.time() - self.last_detection_time
Expand All @@ -149,16 +147,8 @@ def build_event(self, value: float, t0: float, idx: int):
timestamp = t0 + idx / self.inference_sampling_rate
logging.info("Computing FAR")
far = self.background.far(value)
logging.info("FAR computed")
# Add back in the below-threshold events if
# the FAR is below the minimum FAR
if far != self.background.min_far:
far += (
SECONDS_PER_YEAR
* self.total_below_threshold
/ self.background.Tb
)
far /= SECONDS_PER_YEAR
logging.info("FAR computed")

logging.info(
"Event coalescence time found to be {:0.3f} "
Expand Down

0 comments on commit a7842f3

Please sign in to comment.