Skip to content

Commit

Permalink
code climate fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Oct 23, 2024
1 parent 1993eae commit 3cde04d
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions examples/04_training/01_train_dynedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,16 @@ def main(
wandb_logger.experiment.config.update(config)

# Define graph/data representation, here the KNNGraph is used.
# The KNNGraph is a graph representation which uses the KNNEdges edge definition with 8 neighbours as default.
# The graph representation is defined by the detector, in this case the Prometheus detector.
# The KNNGraph is a graph representation, which uses the
# KNNEdges edge definition with 8 neighbours as default.
# The graph representation is defined by the detector,
# in this case the Prometheus detector.
# The standard node definition is used, which is NodesAsPulses.
graph_definition = KNNGraph(detector=Prometheus())

# Use GraphNetDataModule to load in data and create dataloaders
# The input here depends on the dataset being used, in this case the Prometheus dataset.
# The input here depends on the dataset being used,
# in this case the Prometheus dataset.
dm = GraphNeTDataModule(
dataset_reference=config["dataset_reference"],
dataset_args={
Expand All @@ -114,25 +117,28 @@ def main(

# Building model

# Define architecture of the backbone, in this example we use the DynEdge architecture
# described in detail in the Jinst paper: https://iopscience.iop.org/article/10.1088/1748-0221/17/11/P11003
# Define architecture of the backbone, in this example
# the DynEdge architecture is used.
# https://iopscience.iop.org/article/10.1088/1748-0221/17/11/P11003
backbone = DynEdge(
nb_inputs=graph_definition.nb_outputs,
global_pooling_schemes=["min", "max", "mean", "sum"],
)
# Define the task.
# In this case we are performing energy reconstruction, with a LogCoshLoss function.
# The target and prediction are transformed using the log10 function. When infering
# the prediction is transformed back to the original scale using 10^x.
# Here an energy reconstruction, with a LogCoshLoss function.
# The target and prediction are transformed using the log10 function.
# When infering the prediction is transformed back to the
# original scale using 10^x.
task = EnergyReconstruction(
hidden_size=backbone.nb_outputs,
target_labels=config["target"],
loss_function=LogCoshLoss(),
transform_prediction_and_target=lambda x: torch.log10(x),
transform_inference=lambda x: torch.pow(10, x),
)
# Define the full model, which includes the backbone, task(s) along with typical
# machine learning options such as learning rate optimizer and scheduler.
# Define the full model, which includes the backbone, task(s),
# along with typical machine learning options such as
# learning rate optimizers and schedulers.
model = StandardModel(
graph_definition=graph_definition,
backbone=backbone,
Expand Down

0 comments on commit 3cde04d

Please sign in to comment.