diff --git a/projects/online/online/main.py b/projects/online/online/main.py index debf6659..a7929bd9 100644 --- a/projects/online/online/main.py +++ b/projects/online/online/main.py @@ -5,18 +5,15 @@ from typing import Iterable, List, Optional, Tuple import torch -from amplfi.train.architectures.embeddings import ResNet -from amplfi.train.architectures.flows import InverseAutoregressiveFlow from amplfi.train.architectures.flows.base import FlowArchitecture from architectures import Architecture -from ml4gw.nn.norm import GroupNorm1DGetter from ml4gw.transforms import ChannelWiseScaler, SpectralDensity, Whiten from torch.multiprocessing import Array, Process, Queue from ledger.events import EventSet from online.utils.buffer import InputBuffer, OutputBuffer from online.utils.dataloading import data_iterator -from online.utils.gdb import GdbServer, gracedb_factory +from online.utils.gdb import GdbServer, authenticate, gracedb_factory from online.utils.ngdd import data_iterator as ngdd_data_iterator from online.utils.pastro import fit_or_load_pastro from online.utils.pe import run_amplfi, skymap_from_samples @@ -293,7 +290,7 @@ def event_creation_subprocess( pastro_queue: Queue, ): gdb = gracedb_factory(server, outdir) - # last_auth = time.time() + last_auth = time.time() while True: try: event = event_queue.get_nowait() @@ -320,9 +317,9 @@ def event_creation_subprocess( # Re-authenticate every 1000 seconds so that # the scitoken doesn't expire. Doing it in this # loop as it's the earliest point of submission - # if last_auth - time.time() > 1000: - # authenticate() - # last_auth = time.time() + if last_auth - time.time() > 1000: + authenticate() + last_auth = time.time() def main( @@ -442,10 +439,10 @@ def main( Device to run inference on ("cpu" or "cuda") """ # run kinit and htgettoken - # if server != "local": - # logging.info("Authenticating") - # authenticate() - # logging.info("Authentication complete") + if server != "local": + logging.info("Authenticating") + authenticate() + logging.info("Authentication complete") fftlength = fftlength or kernel_length + fduration data = torch.randn(samples_per_event * len(inference_params)) @@ -527,27 +524,11 @@ def main( aframe = aframe.to(device) logging.info(f"Loading AMPLFI from weights at path {amplfi_weights}") - # amplfi, scaler = load_amplfi( - # amplfi_architecture, amplfi_weights, len(inference_params) - # ) - # amplfi = amplfi.to(device) - # scaler = scaler.to(device) - embedding_net = ResNet( - num_ifos=len(ifos), - layers=[5, 3, 3], - norm_layer=GroupNorm1DGetter(8), - context_dim=8, + amplfi, scaler = load_amplfi( + amplfi_architecture, amplfi_weights, len(inference_params) ) - amplfi = InverseAutoregressiveFlow( - num_params=len(inference_params), - hidden_features=150, - num_transforms=80, - num_blocks=6, - # context_dim=8, - embedding_net=embedding_net, - ).to(device) - scaler = ChannelWiseScaler(len(inference_params)).to(device) - scaler.fit(torch.randn((8, 2048))) + amplfi = amplfi.to(device) + scaler = scaler.to(device) spectral_density = SpectralDensity( sample_rate=sample_rate,