diff --git a/examples/04_training/01_train_dynedge.py b/examples/04_training/01_train_dynedge.py index 3e31fc126..b8a8e82b7 100644 --- a/examples/04_training/01_train_dynedge.py +++ b/examples/04_training/01_train_dynedge.py @@ -82,13 +82,16 @@ def main( wandb_logger.experiment.config.update(config) # Define graph/data representation, here the KNNGraph is used. - # The KNNGraph is a graph representation which uses the KNNEdges edge definition with 8 neighbours as default. - # The graph representation is defined by the detector, in this case the Prometheus detector. + # The KNNGraph is a graph representation, which uses the + # KNNEdges edge definition with 8 neighbours as default. + # The graph representation is defined by the detector, + # in this case the Prometheus detector. # The standard node definition is used, which is NodesAsPulses. graph_definition = KNNGraph(detector=Prometheus()) # Use GraphNetDataModule to load in data and create dataloaders - # The input here depends on the dataset being used, in this case the Prometheus dataset. + # The input here depends on the dataset being used, + # in this case the Prometheus dataset. dm = GraphNeTDataModule( dataset_reference=config["dataset_reference"], dataset_args={ @@ -114,16 +117,18 @@ def main( # Building model - # Define architecture of the backbone, in this example we use the DynEdge architecture - # described in detail in the Jinst paper: https://iopscience.iop.org/article/10.1088/1748-0221/17/11/P11003 + # Define architecture of the backbone, in this example + # the DynEdge architecture is used. + # https://iopscience.iop.org/article/10.1088/1748-0221/17/11/P11003 backbone = DynEdge( nb_inputs=graph_definition.nb_outputs, global_pooling_schemes=["min", "max", "mean", "sum"], ) # Define the task. - # In this case we are performing energy reconstruction, with a LogCoshLoss function. - # The target and prediction are transformed using the log10 function. When infering - # the prediction is transformed back to the original scale using 10^x. + # Here an energy reconstruction, with a LogCoshLoss function. + # The target and prediction are transformed using the log10 function. + # When infering the prediction is transformed back to the + # original scale using 10^x. task = EnergyReconstruction( hidden_size=backbone.nb_outputs, target_labels=config["target"], @@ -131,8 +136,9 @@ def main( transform_prediction_and_target=lambda x: torch.log10(x), transform_inference=lambda x: torch.pow(10, x), ) - # Define the full model, which includes the backbone, task(s) along with typical - # machine learning options such as learning rate optimizer and scheduler. + # Define the full model, which includes the backbone, task(s), + # along with typical machine learning options such as + # learning rate optimizers and schedulers. model = StandardModel( graph_definition=graph_definition, backbone=backbone,