Skip to content

Commit

Permalink
add example
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed May 29, 2024
1 parent 845293d commit a71765c
Showing 1 changed file with 225 additions and 0 deletions.
225 changes: 225 additions & 0 deletions examples/04_training/07_train_normalizing_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""Example of training a conditional NormalizingFlow."""

import os
from typing import Any, Dict, List, Optional

from pytorch_lightning.loggers import WandbLogger
import torch
from torch.optim.adam import Adam

from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR
from graphnet.data.constants import FEATURES, TRUTH
from graphnet.models import NormalizingFlow
from graphnet.models.detector.prometheus import Prometheus
from graphnet.models.gnn import DynEdge
from graphnet.models.graphs import KNNGraph
from graphnet.models.task.task import StandardFlowTask
from graphnet.training.callbacks import PiecewiseLinearLR
from graphnet.training.utils import make_train_validation_dataloader
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.logging import Logger

# Constants
features = FEATURES.PROMETHEUS
truth = TRUTH.PROMETHEUS


def main(
path: str,
pulsemap: str,
target: str,
truth_table: str,
gpus: Optional[List[int]],
max_epochs: int,
early_stopping_patience: int,
batch_size: int,
num_workers: int,
wandb: bool = False,
) -> None:
"""Run example."""
# Construct Logger
logger = Logger()

# Initialise Weights & Biases (W&B) run
if wandb:
# Make sure W&B output directory exists
wandb_dir = "./wandb/"
os.makedirs(wandb_dir, exist_ok=True)
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=wandb_dir,
log_model=True,
)

logger.info(f"features: {features}")
logger.info(f"truth: {truth}")

# Configuration
config: Dict[str, Any] = {
"path": path,
"pulsemap": pulsemap,
"batch_size": batch_size,
"num_workers": num_workers,
"target": target,
"early_stopping_patience": early_stopping_patience,
"fit": {
"gpus": gpus,
"max_epochs": max_epochs,
},
}

archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs")
run_name = "dynedge_{}_example".format(config["target"])
if wandb:
# Log configuration to W&B
wandb_logger.experiment.config.update(config)

# Define graph representation
graph_definition = KNNGraph(detector=Prometheus())

(
training_dataloader,
validation_dataloader,
) = make_train_validation_dataloader(
db=config["path"],
graph_definition=graph_definition,
pulsemaps=config["pulsemap"],
features=features,
truth=truth,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
truth_table=truth_table,
selection=None,
)

# Building model

backbone = DynEdge(
nb_inputs=graph_definition.nb_outputs,
global_pooling_schemes=["min", "max", "mean", "sum"],
)

model = NormalizingFlow(
graph_definition=graph_definition,
backbone=backbone,
optimizer_class=Adam,
target_labels=config["target"],
optimizer_kwargs={"lr": 1e-03, "eps": 1e-03},
scheduler_class=PiecewiseLinearLR,
scheduler_kwargs={
"milestones": [
0,
len(training_dataloader) / 2,
len(training_dataloader) * config["fit"]["max_epochs"],
],
"factors": [1e-2, 1, 1e-02],
},
scheduler_config={
"interval": "step",
},
)

# Training model
model.fit(
training_dataloader,
validation_dataloader,
early_stopping_patience=config["early_stopping_patience"],
logger=wandb_logger if wandb else None,
**config["fit"],
)

# Get predictions
additional_attributes = model.target_labels
assert isinstance(additional_attributes, list) # mypy

results = model.predict_as_dataframe(
validation_dataloader,
additional_attributes=additional_attributes + ["event_no"],
gpus=config["fit"]["gpus"],
)

# Save predictions and model to file
db_name = path.split("/")[-1].split(".")[0]
path = os.path.join(archive, db_name, run_name)
logger.info(f"Writing results to {path}")
os.makedirs(path, exist_ok=True)

# Save results as .csv
results.to_csv(f"{path}/results.csv")

# Save full model (including weights) to .pth file - not version safe
# Note: Models saved as .pth files in one version of graphnet
# may not be compatible with a different version of graphnet.
model.save(f"{path}/model.pth")

# Save model config and state dict - Version safe save method.
# This method of saving models is the safest way.
model.save_state_dict(f"{path}/state_dict.pth")
model.save_config(f"{path}/model_config.yml")


if __name__ == "__main__":

# Parse command-line arguments
parser = ArgumentParser(
description="""
Train conditional NormalizingFlow without the use of config files.
"""
)

parser.add_argument(
"--path",
help="Path to dataset file (default: %(default)s)",
default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db",
)

parser.add_argument(
"--pulsemap",
help="Name of pulsemap to use (default: %(default)s)",
default="total",
)

parser.add_argument(
"--target",
help=(
"Name of feature to use as regression target (default: "
"%(default)s)"
),
default="total_energy",
)

parser.add_argument(
"--truth-table",
help="Name of truth table to be used (default: %(default)s)",
default="mc_truth",
)

parser.with_standard_arguments(
"gpus",
("max-epochs", 1),
"early-stopping-patience",
("batch-size", 16),
"num-workers",
)

parser.add_argument(
"--wandb",
action="store_true",
help="If True, Weights & Biases are used to track the experiment.",
)

args, unknown = parser.parse_known_args()

main(
args.path,
args.pulsemap,
args.target,
args.truth_table,
args.gpus,
args.max_epochs,
args.early_stopping_patience,
args.batch_size,
args.num_workers,
args.wandb,
)

0 comments on commit a71765c

Please sign in to comment.