Skip to content

Commit

Permalink
Undo using IAF
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenoit26 committed Feb 26, 2025
1 parent 2643572 commit fba485f
Showing 1 changed file with 13 additions and 32 deletions.
45 changes: 13 additions & 32 deletions projects/online/online/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
from typing import Iterable, List, Optional, Tuple

import torch
from amplfi.train.architectures.embeddings import ResNet
from amplfi.train.architectures.flows import InverseAutoregressiveFlow
from amplfi.train.architectures.flows.base import FlowArchitecture
from architectures import Architecture
from ml4gw.nn.norm import GroupNorm1DGetter
from ml4gw.transforms import ChannelWiseScaler, SpectralDensity, Whiten
from torch.multiprocessing import Array, Process, Queue

from ledger.events import EventSet
from online.utils.buffer import InputBuffer, OutputBuffer
from online.utils.dataloading import data_iterator
from online.utils.gdb import GdbServer, gracedb_factory
from online.utils.gdb import GdbServer, authenticate, gracedb_factory
from online.utils.ngdd import data_iterator as ngdd_data_iterator
from online.utils.pastro import fit_or_load_pastro
from online.utils.pe import run_amplfi, skymap_from_samples
Expand Down Expand Up @@ -293,7 +290,7 @@ def event_creation_subprocess(
pastro_queue: Queue,
):
gdb = gracedb_factory(server, outdir)
# last_auth = time.time()
last_auth = time.time()
while True:
try:
event = event_queue.get_nowait()
Expand All @@ -320,9 +317,9 @@ def event_creation_subprocess(
# Re-authenticate every 1000 seconds so that
# the scitoken doesn't expire. Doing it in this
# loop as it's the earliest point of submission
# if last_auth - time.time() > 1000:
# authenticate()
# last_auth = time.time()
if last_auth - time.time() > 1000:
authenticate()
last_auth = time.time()


def main(
Expand Down Expand Up @@ -442,10 +439,10 @@ def main(
Device to run inference on ("cpu" or "cuda")
"""
# run kinit and htgettoken
# if server != "local":
# logging.info("Authenticating")
# authenticate()
# logging.info("Authentication complete")
if server != "local":
logging.info("Authenticating")
authenticate()
logging.info("Authentication complete")

fftlength = fftlength or kernel_length + fduration
data = torch.randn(samples_per_event * len(inference_params))
Expand Down Expand Up @@ -527,27 +524,11 @@ def main(
aframe = aframe.to(device)

logging.info(f"Loading AMPLFI from weights at path {amplfi_weights}")
# amplfi, scaler = load_amplfi(
# amplfi_architecture, amplfi_weights, len(inference_params)
# )
# amplfi = amplfi.to(device)
# scaler = scaler.to(device)
embedding_net = ResNet(
num_ifos=len(ifos),
layers=[5, 3, 3],
norm_layer=GroupNorm1DGetter(8),
context_dim=8,
amplfi, scaler = load_amplfi(
amplfi_architecture, amplfi_weights, len(inference_params)
)
amplfi = InverseAutoregressiveFlow(
num_params=len(inference_params),
hidden_features=150,
num_transforms=80,
num_blocks=6,
# context_dim=8,
embedding_net=embedding_net,
).to(device)
scaler = ChannelWiseScaler(len(inference_params)).to(device)
scaler.fit(torch.randn((8, 2048)))
amplfi = amplfi.to(device)
scaler = scaler.to(device)

spectral_density = SpectralDensity(
sample_rate=sample_rate,
Expand Down

0 comments on commit fba485f

Please sign in to comment.