From c1b40998b07cecb7cbcbb65351e4ba275f0aef9c Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 28 May 2024 19:24:09 +0200 Subject: [PATCH 01/39] add NormalizingFlow --- src/graphnet/models/__init__.py | 1 + src/graphnet/models/easy_model.py | 4 + .../models/graphs/graph_definition.py | 21 +++- src/graphnet/models/graphs/graphs.py | 4 + src/graphnet/models/normalizing_flow.py | 115 ++++++++++++++++++ src/graphnet/models/task/task.py | 102 +++++++++------- src/graphnet/models/utils.py | 9 ++ 7 files changed, 209 insertions(+), 47 deletions(-) create mode 100644 src/graphnet/models/normalizing_flow.py diff --git a/src/graphnet/models/__init__.py b/src/graphnet/models/__init__.py index a2e63befb..a7e0a064b 100644 --- a/src/graphnet/models/__init__.py +++ b/src/graphnet/models/__init__.py @@ -11,3 +11,4 @@ from .model import Model from .standard_model import StandardModel from .standard_averaged_model import StandardAveragedModel +from .normalizing_flow import NormalizingFlow \ No newline at end of file diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index d26d88fa0..b1c51c087 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -292,6 +292,7 @@ def predict( dataloader: DataLoader, gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", + **trainer_kwargs, ) -> List[Tensor]: """Return predictions for `dataloader`.""" self.inference() @@ -305,6 +306,7 @@ def predict( gpus=gpus, distribution_strategy=distribution_strategy, callbacks=callbacks, + **trainer_kwargs, ) predictions_list = inference_trainer.predict(self, dataloader) @@ -325,6 +327,7 @@ def predict_as_dataframe( additional_attributes: Optional[List[str]] = None, gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", + **trainer_kwargs, ) -> pd.DataFrame: """Return predictions for `dataloader` as a DataFrame. @@ -357,6 +360,7 @@ def predict_as_dataframe( dataloader=dataloader, gpus=gpus, distribution_strategy=distribution_strategy, + **trainer_kwargs, ) predictions = ( torch.cat(predictions_torch, dim=1).detach().cpu().numpy() diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index e384425f9..6c9a0a419 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -34,6 +34,7 @@ def __init__( sensor_mask: Optional[List[int]] = None, string_mask: Optional[List[int]] = None, sort_by: str = None, + repeat_labels: bool =False, ): """Construct ´GraphDefinition´. The ´detector´ holds. @@ -62,9 +63,14 @@ def __init__( add_inactive_sensors: If True, inactive sensors will be appended to the graph with padded pulse information. Defaults to False. sensor_mask: A list of sensor id's to be masked from the graph. Any - sensor listed here will be removed from the graph. Defaults to None. - string_mask: A list of string id's to be masked from the graph. Defaults to None. + sensor listed here will be removed from the graph. + Defaults to None. + string_mask: A list of string id's to be masked from the graph. + Defaults to None. sort_by: Name of node feature to sort by. Defaults to None. + repeat_labels: If True, labels will be repeated to match the + the number of rows in the output of the GraphDefinition. + Defaults to False. """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -80,6 +86,7 @@ def __init__( self._sensor_mask = sensor_mask self._string_mask = string_mask self._add_inactive_sensors = add_inactive_sensors + self._repeat_labels = repeat_labels self._resolve_masks() @@ -411,7 +418,10 @@ def _add_truth( for truth_dict in truth_dicts: for key, value in truth_dict.items(): try: - graph[key] = torch.tensor(value) + label = torch.tensor(value) + if self._repeat_labels: + label = label.repeat(graph.x.shape[0],1) + graph[key] = label except TypeError: # Cannot convert `value` to Tensor due to its data type, # e.g. `str`. @@ -448,5 +458,8 @@ def _add_custom_labels( ) -> Data: # Add custom labels to the graph for key, fn in custom_label_functions.items(): - graph[key] = fn(graph) + label = fn(graph) + if self._repeat_labels: + label = label.repeat(graph.x.shape[0],1) + graph[key] = label return graph diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index 0289b943d..6e2ac086d 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -23,6 +23,7 @@ def __init__( seed: Optional[Union[int, Generator]] = None, nb_nearest_neighbours: int = 8, columns: List[int] = [0, 1, 2], + **kwargs ) -> None: """Construct k-nn graph representation. @@ -53,6 +54,7 @@ def __init__( input_feature_names=input_feature_names, perturbation_dict=perturbation_dict, seed=seed, + **kwargs ) @@ -70,6 +72,7 @@ def __init__( dtype: Optional[torch.dtype] = torch.float, perturbation_dict: Optional[Dict[str, float]] = None, seed: Optional[Union[int, Generator]] = None, + **kwargs ) -> None: """Construct isolated nodes graph representation. @@ -94,4 +97,5 @@ def __init__( input_feature_names=input_feature_names, perturbation_dict=perturbation_dict, seed=seed, + **kwargs ) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py new file mode 100644 index 000000000..42266f79c --- /dev/null +++ b/src/graphnet/models/normalizing_flow.py @@ -0,0 +1,115 @@ +"""Standard model class(es).""" + +from typing import Any, Dict, List, Optional, Union, Type +import torch +from torch import Tensor +from torch_geometric.data import Data +from torch.optim import Adam + +from graphnet.models.gnn.gnn import GNN +from .easy_model import EasySyntax +from graphnet.models.task import StandardFlowTask +from graphnet.models.graphs import GraphDefinition +from graphnet.models.utils import get_fields + + +class NormalizingFlow(EasySyntax): + """A Standard way of combining model components in GraphNeT. + + This model is compatible with the vast majority of supervised learning + tasks such as regression, binary and multi-label classification. + + Capable of producing both event-level and pulse-level predictions. + """ + + def __init__( + self, + graph_definition: GraphDefinition, + target_labels: str, + backbone: GNN = None, + condition_on: Union[str, List[str], None] = None, + flow_layers: str = 'gggt', + optimizer_class: Type[torch.optim.Optimizer] = Adam, + optimizer_kwargs: Optional[Dict] = None, + scheduler_class: Optional[type] = None, + scheduler_kwargs: Optional[Dict] = None, + scheduler_config: Optional[Dict] = None, + ) -> None: + """Construct `NormalizingFlow`.""" + + # Handle args + if backbone is not None: + assert isinstance(backbone, GNN) + hidden_size = backbone.nb_outputs + else: + if isinstance(condition_on, str): + condition_on = [condition_on] + hidden_size = len(condition_on) + + # Build Flow Task + task = StandardFlowTask(hidden_size=hidden_size, + flow_layers=flow_layers, + target_labels = target_labels) + + + # Base class constructor + super().__init__( + tasks=task, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + scheduler_class=scheduler_class, + scheduler_kwargs=scheduler_kwargs, + scheduler_config=scheduler_config, + ) + + # Member variable(s) + self._graph_definition = graph_definition + self.backbone = backbone + self._condition_on = condition_on + + def forward( + self, data: Union[Data, List[Data]] + ) -> List[Union[Tensor, Data]]: + """Forward pass, chaining model components.""" + if self.backbone is not None: + x = self._backbone(data) + elif self._condition_on is not None: + x = get_fields(data = data, + fields = self._condition_on) + return self._tasks[0](x, data) + + def _backbone( + self, data: Union[Data, List[Data]] + ) -> List[Union[Tensor, Data]]: + if isinstance(data, Data): + data = [data] + x_list = [] + for d in data: + x = self.backbone(d) + x_list.append(x) + x = torch.cat(x_list, dim=0) + return x + + + def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: + """Perform shared step. + + Applies the forward pass and the following loss calculation, shared + between the training and validation step. + """ + loss = self(batch) + return torch.mean(loss, dim = 0) + + def validate_tasks(self) -> None: + """Verify that self._tasks contain compatible elements.""" + accepted_tasks = (StandardFlowTask) + for task in self._tasks: + assert isinstance(task, accepted_tasks) + + def sample(self, data, n_samples, target_range = [0,1000]): + self._sample = True + self._n_samples = n_samples + self._target_range = target_range + labels, nllh = self(data) + self._sample = False + return labels, nllh diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index cd750f35d..b33636d11 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -4,11 +4,14 @@ from typing import Any, TYPE_CHECKING, List, Tuple, Union from typing import Callable, Optional import numpy as np +from copy import deepcopy import torch from torch import Tensor from torch.nn import Linear from torch_geometric.data import Data +import jammy_flows +from torch.distributions.uniform import Uniform if TYPE_CHECKING: # Avoid cyclic dependency @@ -16,6 +19,7 @@ from graphnet.models import Model from graphnet.utilities.decorators import final +from graphnet.models.utils import get_fields class Task(Model): @@ -39,7 +43,6 @@ def default_prediction_labels(self) -> List[str]: def __init__( self, *, - loss_function: "LossFunction", target_labels: Optional[Union[str, List[str]]] = None, prediction_labels: Optional[Union[str, List[str]]] = None, transform_prediction_and_target: Optional[Callable] = None, @@ -51,7 +54,6 @@ def __init__( """Construct `Task`. Args: - loss_function: Loss function appropriate to the task. target_labels: Name(s) of the quantity/-ies being predicted, used to extract the target tensor(s) from the `Data` object in `.compute_loss(...)`. @@ -101,7 +103,6 @@ def __init__( self._regularisation_loss: Optional[float] = None self._target_labels = target_labels self._prediction_labels = prediction_labels - self._loss_function = loss_function self._inference = False self._loss_weight = loss_weight @@ -229,6 +230,7 @@ class LearnedTask(Task): def __init__( self, hidden_size: int, + loss_function: "LossFunction", **task_kwargs: Any, ): """Construct `LearnedTask`. @@ -237,11 +239,14 @@ def __init__( hidden_size: The number of columns in the output of the last latent layer of `Model` using this Task. Available through `Model.nb_outputs` + loss_function: Loss function appropriate to the task. + """ # Base class constructor super().__init__(**task_kwargs) # Mapping from last hidden layer to required size of input + self._loss_function = loss_function self._affine = Linear(hidden_size, self.nb_inputs) @abstractmethod @@ -384,62 +389,73 @@ class StandardFlowTask(Task): def __init__( self, - target_labels: List[str], + hidden_size: Union[int, None], + flow_layers: str = "gggt", **task_kwargs: Any, ): """Construct `StandardLearnedTask`. Args: target_labels: A list of names for the targets of this Task. + flow_layers: A string indicating the flow layer types. hidden_size: The number of columns in the output of the last latent layer of `Model` using this Task. - Available through `Model.nb_outputs` + Available through `Model.nb_outputs` """ # Base class constructor - super().__init__(target_labels=target_labels, **task_kwargs) + + + # Member variables + self._default_prediction_labels = ["nllh"] + self._hidden_size = hidden_size + super().__init__(**task_kwargs) + self._flow = jammy_flows.pdf(f"e{len(self._target_labels)}", + flow_layers, + conditional_input_dim = hidden_size) + self._initialized = False + + @property + def default_prediction_labels(self) -> List[str]: + """Return default prediction labels.""" + return self._default_prediction_labels def nb_inputs(self) -> int: """Return number of inputs assumed by task.""" - return len(self._target_labels) - - def _forward(self, x: Tensor, jacobian: Tensor) -> Tensor: # type: ignore - # Leave it as is. - return x + return self._hidden_size + + def _forward(self, x: Tensor, y: Tensor) -> Tensor: # type: ignore + if x is not None: + if x.shape[0] != y.shape[0]: + raise AssertionError(f"Targets {self._target_labels} have " + f"{y.shape[0]} rows while conditional " + f"inputs have {x.shape[0]} rows. " + "The number of rows must match.") + log_pdf, _,_ = self._flow(y, conditional_input = x) + else: + log_pdf, _,_ = self._flow(y) + return -log_pdf.reshape(-1,1) @final def forward( - self, x: Union[Tensor, Data], jacobian: Optional[Tensor] - ) -> Union[Tensor, Data]: + self, x: Union[Tensor, Data], data: List[Data]) -> Union[Tensor, Data]: """Forward pass.""" - self._regularisation_loss = 0 # Reset - x = self._forward(x, jacobian) + # Manually cast pdf to correct dtype - is there a better way? + self._flow = self._flow.to(x.dtype) + # Get target values + labels = get_fields(data = data, + fields = self._target_labels).to(x.dtype) + # Set the initial parameters of flow close to truth + # This speeds up training and helps with NaN + if self._initialized is False: + self._flow.init_params(data=deepcopy(labels).cpu()) + self._flow.to(self.device) + self._initialized = True # This is only done once + # Compute nllh + x = self._forward(x, labels) return self._transform_prediction(x) - @final - def compute_loss( - self, prediction: Tensor, jacobian: Tensor, data: Data - ) -> Tensor: - """Compute loss for normalizing flow tasks. - - Args: - prediction: transformed sample in latent distribution space. - jacobian: the jacobian associated with the transformation. - data: the graph object. - - Returns: - the loss associated with the transformation. - """ - if self._loss_weight is not None: - weights = data[self._loss_weight] - else: - weights = None - loss = ( - self._loss_function( - prediction=prediction, - jacobian=jacobian, - weights=weights, - target=None, - ) - + self._regularisation_loss - ) - return loss + def sample(self, x, data, n_samples, target_range): + self.inference() + with torch.no_grad(): + labels = Uniform(target_range[0], target_range[1]).sample((n_samples, 1)) + return labels, self._forward(y= labels, x = x.repeat(n_samples,1)) diff --git a/src/graphnet/models/utils.py b/src/graphnet/models/utils.py index d05e8223f..73a4f56f3 100644 --- a/src/graphnet/models/utils.py +++ b/src/graphnet/models/utils.py @@ -7,6 +7,7 @@ from torch import Tensor, LongTensor from torch_geometric.utils import homophily +from torch_geometric.data import Data def calculate_xyzt_homophily( @@ -103,3 +104,11 @@ def array_to_sequence( mask = torch.ne(x[:, :, 1], excluding_value) x[~mask] = padding_value return x, mask, seq_length + +def get_fields(data: List[Data], fields: List[str]) -> Tensor: + labels = [] + if not isinstance(data, list): + data = [data] + for label in list(fields): + labels.append(torch.cat([d[label].reshape(-1,1) for d in data], dim=0)) + return torch.cat(labels, dim = 1) \ No newline at end of file From 1e14f543984dbd3f86550bf380a08c5870c52e20 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 28 May 2024 19:52:31 +0200 Subject: [PATCH 02/39] check --- .pre-commit-config.yaml | 4 ++++ src/graphnet/models/task/task.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4794b3745..fd6bae19e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,16 +10,20 @@ repos: rev: 4.0.1 hooks: - id: flake8 + language_version: python3 - repo: https://github.com/pycqa/docformatter rev: v1.5.0 hooks: - id: docformatter + language_version: python3 - repo: https://github.com/pycqa/pydocstyle rev: 6.1.1 hooks: - id: pydocstyle + language_version: python3 - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.982 hooks: - id: mypy args: [--follow-imports=silent, --disallow-untyped-defs, --disallow-incomplete-defs, --disallow-untyped-calls] + language_version: python3 \ No newline at end of file diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index b33636d11..df6d81948 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -454,7 +454,7 @@ def forward( x = self._forward(x, labels) return self._transform_prediction(x) - def sample(self, x, data, n_samples, target_range): + def sample(self, x, data: int, n_samples, target_range): self.inference() with torch.no_grad(): labels = Uniform(target_range[0], target_range[1]).sample((n_samples, 1)) From 0c135b72302ee65505846050ebe8c3e50e0b0e97 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 28 May 2024 19:54:02 +0200 Subject: [PATCH 03/39] hooks --- src/graphnet/models/task/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index df6d81948..d69412daa 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -454,7 +454,7 @@ def forward( x = self._forward(x, labels) return self._transform_prediction(x) - def sample(self, x, data: int, n_samples, target_range): + def sample(self, x, data: float, n_samples, target_range): self.inference() with torch.no_grad(): labels = Uniform(target_range[0], target_range[1]).sample((n_samples, 1)) From ce1223d853306855eb4a739979255646171614d0 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 28 May 2024 19:57:31 +0200 Subject: [PATCH 04/39] hooks --- src/graphnet/models/task/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index d69412daa..df6d81948 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -454,7 +454,7 @@ def forward( x = self._forward(x, labels) return self._transform_prediction(x) - def sample(self, x, data: float, n_samples, target_range): + def sample(self, x, data: int, n_samples, target_range): self.inference() with torch.no_grad(): labels = Uniform(target_range[0], target_range[1]).sample((n_samples, 1)) From 88863769bc1a45a71c4917e329a9aa4a0b2b3ca7 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 28 May 2024 21:37:01 +0200 Subject: [PATCH 05/39] hooks --- src/graphnet/models/task/task.py | 48 +++++++++++++++----------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index df6d81948..bb1842191 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -240,7 +240,6 @@ def __init__( the last latent layer of `Model` using this Task. Available through `Model.nb_outputs` loss_function: Loss function appropriate to the task. - """ # Base class constructor super().__init__(**task_kwargs) @@ -400,18 +399,19 @@ def __init__( flow_layers: A string indicating the flow layer types. hidden_size: The number of columns in the output of the last latent layer of `Model` using this Task. - Available through `Model.nb_outputs` + Available through `Model.nb_outputs` """ # Base class constructor - - + # Member variables self._default_prediction_labels = ["nllh"] self._hidden_size = hidden_size super().__init__(**task_kwargs) - self._flow = jammy_flows.pdf(f"e{len(self._target_labels)}", - flow_layers, - conditional_input_dim = hidden_size) + self._flow = jammy_flows.pdf( + f"e{len(self._target_labels)}", + flow_layers, + conditional_input_dim=hidden_size, + ) self._initialized = False @property @@ -419,43 +419,39 @@ def default_prediction_labels(self) -> List[str]: """Return default prediction labels.""" return self._default_prediction_labels - def nb_inputs(self) -> int: - """Return number of inputs assumed by task.""" + def nb_inputs(self) -> Union[int, None]: # type: ignore + """Return number of conditional inputs assumed by task.""" return self._hidden_size def _forward(self, x: Tensor, y: Tensor) -> Tensor: # type: ignore if x is not None: if x.shape[0] != y.shape[0]: - raise AssertionError(f"Targets {self._target_labels} have " - f"{y.shape[0]} rows while conditional " - f"inputs have {x.shape[0]} rows. " - "The number of rows must match.") - log_pdf, _,_ = self._flow(y, conditional_input = x) + raise AssertionError( + f"Targets {self._target_labels} have " + f"{y.shape[0]} rows while conditional " + f"inputs have {x.shape[0]} rows. " + "The number of rows must match." + ) + log_pdf, _, _ = self._flow(y, conditional_input=x) else: - log_pdf, _,_ = self._flow(y) - return -log_pdf.reshape(-1,1) + log_pdf, _, _ = self._flow(y) + return -log_pdf.reshape(-1, 1) @final def forward( - self, x: Union[Tensor, Data], data: List[Data]) -> Union[Tensor, Data]: + self, x: Union[Tensor, Data], data: List[Data] + ) -> Union[Tensor, Data]: """Forward pass.""" # Manually cast pdf to correct dtype - is there a better way? self._flow = self._flow.to(x.dtype) # Get target values - labels = get_fields(data = data, - fields = self._target_labels).to(x.dtype) + labels = get_fields(data=data, fields=self._target_labels).to(x.dtype) # Set the initial parameters of flow close to truth # This speeds up training and helps with NaN if self._initialized is False: self._flow.init_params(data=deepcopy(labels).cpu()) self._flow.to(self.device) - self._initialized = True # This is only done once + self._initialized = True # This is only done once # Compute nllh x = self._forward(x, labels) return self._transform_prediction(x) - - def sample(self, x, data: int, n_samples, target_range): - self.inference() - with torch.no_grad(): - labels = Uniform(target_range[0], target_range[1]).sample((n_samples, 1)) - return labels, self._forward(y= labels, x = x.repeat(n_samples,1)) From 9d8b5608e873865f3a0fe5e6eb877d2c3126ed53 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 28 May 2024 21:39:11 +0200 Subject: [PATCH 06/39] black --- src/graphnet/models/easy_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index b1c51c087..d3ed4f419 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -16,7 +16,6 @@ from pytorch_lightning.loggers import Logger as LightningLogger from graphnet.training.callbacks import ProgressBar -from graphnet.models.graphs import GraphDefinition from graphnet.models.model import Model from graphnet.models.task import StandardLearnedTask @@ -292,7 +291,7 @@ def predict( dataloader: DataLoader, gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", - **trainer_kwargs, + **trainer_kwargs: Any, ) -> List[Tensor]: """Return predictions for `dataloader`.""" self.inference() @@ -327,7 +326,7 @@ def predict_as_dataframe( additional_attributes: Optional[List[str]] = None, gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", - **trainer_kwargs, + **trainer_kwargs: Any, ) -> pd.DataFrame: """Return predictions for `dataloader` as a DataFrame. From dbb02c4042fa2ca28d126c766690a6f79cc7dec0 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 28 May 2024 21:45:38 +0200 Subject: [PATCH 07/39] black --- src/graphnet/models/normalizing_flow.py | 34 ++++++++++--------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py index 42266f79c..f84caa881 100644 --- a/src/graphnet/models/normalizing_flow.py +++ b/src/graphnet/models/normalizing_flow.py @@ -28,7 +28,7 @@ def __init__( target_labels: str, backbone: GNN = None, condition_on: Union[str, List[str], None] = None, - flow_layers: str = 'gggt', + flow_layers: str = "gggt", optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_kwargs: Optional[Dict] = None, scheduler_class: Optional[type] = None, @@ -36,21 +36,23 @@ def __init__( scheduler_config: Optional[Dict] = None, ) -> None: """Construct `NormalizingFlow`.""" - # Handle args if backbone is not None: assert isinstance(backbone, GNN) hidden_size = backbone.nb_outputs - else: + elif condition_on is not None: if isinstance(condition_on, str): condition_on = [condition_on] hidden_size = len(condition_on) + else: + hidden_size = None # Build Flow Task - task = StandardFlowTask(hidden_size=hidden_size, - flow_layers=flow_layers, - target_labels = target_labels) - + task = StandardFlowTask( + hidden_size=hidden_size, + flow_layers=flow_layers, + target_labels=target_labels, + ) # Base class constructor super().__init__( @@ -74,13 +76,14 @@ def forward( if self.backbone is not None: x = self._backbone(data) elif self._condition_on is not None: - x = get_fields(data = data, - fields = self._condition_on) + assert isinstance(self._condition_on, list) + x = get_fields(data=data, fields=self._condition_on) return self._tasks[0](x, data) def _backbone( self, data: Union[Data, List[Data]] ) -> List[Union[Tensor, Data]]: + assert self.backbone is not None if isinstance(data, Data): data = [data] x_list = [] @@ -89,7 +92,6 @@ def _backbone( x_list.append(x) x = torch.cat(x_list, dim=0) return x - def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: """Perform shared step. @@ -98,18 +100,10 @@ def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: between the training and validation step. """ loss = self(batch) - return torch.mean(loss, dim = 0) + return torch.mean(loss, dim=0) def validate_tasks(self) -> None: """Verify that self._tasks contain compatible elements.""" - accepted_tasks = (StandardFlowTask) + accepted_tasks = StandardFlowTask for task in self._tasks: assert isinstance(task, accepted_tasks) - - def sample(self, data, n_samples, target_range = [0,1000]): - self._sample = True - self._n_samples = n_samples - self._target_range = target_range - labels, nllh = self(data) - self._sample = False - return labels, nllh From 9b90af40a9726714b2c0b14a171bb01b504e1275 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 28 May 2024 21:47:12 +0200 Subject: [PATCH 08/39] black --- src/graphnet/models/utils.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/graphnet/models/utils.py b/src/graphnet/models/utils.py index 73a4f56f3..11b73d06f 100644 --- a/src/graphnet/models/utils.py +++ b/src/graphnet/models/utils.py @@ -1,6 +1,6 @@ """Utility functions for `graphnet.models`.""" -from typing import List, Tuple, Any +from typing import List, Tuple, Any, Union from torch_geometric.nn import knn_graph from torch_geometric.data import Batch import torch @@ -105,10 +105,14 @@ def array_to_sequence( x[~mask] = padding_value return x, mask, seq_length -def get_fields(data: List[Data], fields: List[str]) -> Tensor: - labels = [] - if not isinstance(data, list): - data = [data] - for label in list(fields): - labels.append(torch.cat([d[label].reshape(-1,1) for d in data], dim=0)) - return torch.cat(labels, dim = 1) \ No newline at end of file + +def get_fields(data: Union[Data, List[Data]], fields: List[str]) -> Tensor: + """Extract named fields in Data object.""" + labels = [] + if not isinstance(data, list): + data = [data] + for label in list(fields): + labels.append( + torch.cat([d[label].reshape(-1, 1) for d in data], dim=0) + ) + return torch.cat(labels, dim=1) From 382651ffb2348b2757beda77ea4405a02d088555 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 28 May 2024 21:48:28 +0200 Subject: [PATCH 09/39] black --- src/graphnet/models/graphs/graph_definition.py | 11 ++++++----- src/graphnet/models/graphs/graphs.py | 10 +++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 6c9a0a419..0338225b8 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -34,7 +34,7 @@ def __init__( sensor_mask: Optional[List[int]] = None, string_mask: Optional[List[int]] = None, sort_by: str = None, - repeat_labels: bool =False, + repeat_labels: bool = False, ): """Construct ´GraphDefinition´. The ´detector´ holds. @@ -63,9 +63,9 @@ def __init__( add_inactive_sensors: If True, inactive sensors will be appended to the graph with padded pulse information. Defaults to False. sensor_mask: A list of sensor id's to be masked from the graph. Any - sensor listed here will be removed from the graph. + sensor listed here will be removed from the graph. Defaults to None. - string_mask: A list of string id's to be masked from the graph. + string_mask: A list of string id's to be masked from the graph. Defaults to None. sort_by: Name of node feature to sort by. Defaults to None. repeat_labels: If True, labels will be repeated to match the @@ -415,12 +415,13 @@ def _add_truth( """ # Write attributes, either target labels, truth info or original # features. + for truth_dict in truth_dicts: for key, value in truth_dict.items(): try: label = torch.tensor(value) if self._repeat_labels: - label = label.repeat(graph.x.shape[0],1) + label = label.repeat(graph.x.shape[0], 1) graph[key] = label except TypeError: # Cannot convert `value` to Tensor due to its data type, @@ -460,6 +461,6 @@ def _add_custom_labels( for key, fn in custom_label_functions.items(): label = fn(graph) if self._repeat_labels: - label = label.repeat(graph.x.shape[0],1) + label = label.repeat(graph.x.shape[0], 1) graph[key] = label return graph diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index 6e2ac086d..525675ca7 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -1,6 +1,6 @@ """A module containing different graph representations in GraphNeT.""" -from typing import List, Optional, Dict, Union +from typing import List, Optional, Dict, Union, Any import torch from numpy.random import Generator @@ -23,7 +23,7 @@ def __init__( seed: Optional[Union[int, Generator]] = None, nb_nearest_neighbours: int = 8, columns: List[int] = [0, 1, 2], - **kwargs + **kwargs: Any, ) -> None: """Construct k-nn graph representation. @@ -54,7 +54,7 @@ def __init__( input_feature_names=input_feature_names, perturbation_dict=perturbation_dict, seed=seed, - **kwargs + **kwargs, ) @@ -72,7 +72,7 @@ def __init__( dtype: Optional[torch.dtype] = torch.float, perturbation_dict: Optional[Dict[str, float]] = None, seed: Optional[Union[int, Generator]] = None, - **kwargs + **kwargs: Any, ) -> None: """Construct isolated nodes graph representation. @@ -97,5 +97,5 @@ def __init__( input_feature_names=input_feature_names, perturbation_dict=perturbation_dict, seed=seed, - **kwargs + **kwargs, ) From e299aac4db008c066d35d1a6673e47a2c85c20f4 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 09:02:19 +0200 Subject: [PATCH 10/39] polish dtype assignment --- src/graphnet/models/normalizing_flow.py | 31 +++++++++++++------------ src/graphnet/models/task/task.py | 7 +++--- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py index f84caa881..2d351bd5a 100644 --- a/src/graphnet/models/normalizing_flow.py +++ b/src/graphnet/models/normalizing_flow.py @@ -69,30 +69,31 @@ def __init__( self.backbone = backbone self._condition_on = condition_on - def forward( - self, data: Union[Data, List[Data]] - ) -> List[Union[Tensor, Data]]: + def forward(self, data: Union[Data, List[Data]]) -> Tensor: """Forward pass, chaining model components.""" - if self.backbone is not None: - x = self._backbone(data) - elif self._condition_on is not None: - assert isinstance(self._condition_on, list) - x = get_fields(data=data, fields=self._condition_on) - return self._tasks[0](x, data) - - def _backbone( - self, data: Union[Data, List[Data]] - ) -> List[Union[Tensor, Data]]: - assert self.backbone is not None if isinstance(data, Data): data = [data] x_list = [] for d in data: - x = self.backbone(d) + if self.backbone is not None: + x = self._backbone(d) + elif self._condition_on is not None: + assert isinstance(self._condition_on, list) + x = get_fields(data=d, fields=self._condition_on) + else: + # Unconditional flow + x = None + x = self._tasks[0](x, d) x_list.append(x) x = torch.cat(x_list, dim=0) return x + def _backbone( + self, data: Union[Data, List[Data]] + ) -> List[Union[Tensor, Data]]: + assert self.backbone is not None + return self.backbone(data) + def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: """Perform shared step. diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index bb1842191..99514b11d 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -423,7 +423,7 @@ def nb_inputs(self) -> Union[int, None]: # type: ignore """Return number of conditional inputs assumed by task.""" return self._hidden_size - def _forward(self, x: Tensor, y: Tensor) -> Tensor: # type: ignore + def _forward(self, x: Optional[Tensor], y: Tensor) -> Tensor: # type: ignore if x is not None: if x.shape[0] != y.shape[0]: raise AssertionError( @@ -443,9 +443,10 @@ def forward( ) -> Union[Tensor, Data]: """Forward pass.""" # Manually cast pdf to correct dtype - is there a better way? - self._flow = self._flow.to(x.dtype) + self._flow = self._flow.to(self.dtype) # Get target values - labels = get_fields(data=data, fields=self._target_labels).to(x.dtype) + labels = get_fields(data=data, fields=self._target_labels) + labels = labels.to(self.dtype) # Set the initial parameters of flow close to truth # This speeds up training and helps with NaN if self._initialized is False: From 9c0ad64980f5034cb2d9fa3ff857502b43c90ca8 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 09:24:36 +0200 Subject: [PATCH 11/39] add warning --- src/graphnet/models/normalizing_flow.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py index 2d351bd5a..a52a54e66 100644 --- a/src/graphnet/models/normalizing_flow.py +++ b/src/graphnet/models/normalizing_flow.py @@ -36,6 +36,16 @@ def __init__( scheduler_config: Optional[Dict] = None, ) -> None: """Construct `NormalizingFlow`.""" + # Checks + if (backbone is not None) & (condition_on is not None): + # If user wants to condition on both + raise ValueError( + f"{self.__class__.__name__} got values for both " + "`backbone` and `condition_on`, but can only" + "condition on one of those. Please specify just " + "one of these arguments." + ) + # Handle args if backbone is not None: assert isinstance(backbone, GNN) From f53bc1dc0c0f2b97d31f72827ed6a2a756478df5 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 09:46:48 +0200 Subject: [PATCH 12/39] add check for flow package --- src/graphnet/models/__init__.py | 7 ++++--- src/graphnet/models/task/task.py | 6 ++++-- src/graphnet/utilities/imports.py | 14 ++++++++++++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/graphnet/models/__init__.py b/src/graphnet/models/__init__.py index a7e0a064b..12d4cbcc5 100644 --- a/src/graphnet/models/__init__.py +++ b/src/graphnet/models/__init__.py @@ -6,9 +6,10 @@ existing, purpose-built components and chain them together to form a complete GNN """ - - +from graphnet.utilities.imports import has_jammy_flows_package from .model import Model from .standard_model import StandardModel from .standard_averaged_model import StandardAveragedModel -from .normalizing_flow import NormalizingFlow \ No newline at end of file + +if has_jammy_flows_package(): + from .normalizing_flow import NormalizingFlow diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index 99514b11d..441484838 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -10,8 +10,6 @@ from torch import Tensor from torch.nn import Linear from torch_geometric.data import Data -import jammy_flows -from torch.distributions.uniform import Uniform if TYPE_CHECKING: # Avoid cyclic dependency @@ -20,6 +18,10 @@ from graphnet.models import Model from graphnet.utilities.decorators import final from graphnet.models.utils import get_fields +from graphnet.utilities.imports import has_jammy_flows_package + +if has_jammy_flows_package(): + import jammy_flows class Task(Model): diff --git a/src/graphnet/utilities/imports.py b/src/graphnet/utilities/imports.py index a490f413c..ae59d3b98 100644 --- a/src/graphnet/utilities/imports.py +++ b/src/graphnet/utilities/imports.py @@ -33,6 +33,20 @@ def has_torch_package() -> bool: return False +def has_jammy_flows_package() -> bool: + """Check if the `jammy_flows` package is available.""" + try: + import jammmy_flows # pyright: reportMissingImports=false + + return True + except ImportError: + Logger(log_folder=None).warning_once( + "`jammy_flows` not available. Normalizing Flow functionality is " + "missing." + ) + return False + + def requires_icecube(test_function: Callable) -> Callable: """Decorate `test_function` for use only if `icecube` module is present.""" From a0afcc3e547b6910190196b649cd541d15447832 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 09:57:24 +0200 Subject: [PATCH 13/39] expand docstrings --- src/graphnet/models/normalizing_flow.py | 10 +++++----- src/graphnet/models/task/task.py | 17 +++++++++++------ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py index a52a54e66..e0f11be3e 100644 --- a/src/graphnet/models/normalizing_flow.py +++ b/src/graphnet/models/normalizing_flow.py @@ -14,12 +14,12 @@ class NormalizingFlow(EasySyntax): - """A Standard way of combining model components in GraphNeT. + """A model for building (conditional) normalizing flows in GraphNeT. - This model is compatible with the vast majority of supervised learning - tasks such as regression, binary and multi-label classification. - - Capable of producing both event-level and pulse-level predictions. + This model relies on `jammy_flows` for building and evaluating + normalizing flows. + https://thoglu.github.io/jammy_flows/usage/introduction.html + for details. """ def __init__( diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index 441484838..45f6f8a15 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -386,7 +386,11 @@ def _forward(self, x: Union[Tensor, Data]) -> Tensor: # type: ignore class StandardFlowTask(Task): - """A `Task` for `NormalizingFlow`s in GraphNeT.""" + """A `Task` for `NormalizingFlow`s in GraphNeT. + + This Task requires the support package`jammy_flows` for constructing and + evaluating normalizing flows. + """ def __init__( self, @@ -394,14 +398,15 @@ def __init__( flow_layers: str = "gggt", **task_kwargs: Any, ): - """Construct `StandardLearnedTask`. + """Construct `StandardFlowTask`. Args: target_labels: A list of names for the targets of this Task. - flow_layers: A string indicating the flow layer types. - hidden_size: The number of columns in the output of - the last latent layer of `Model` using this Task. - Available through `Model.nb_outputs` + flow_layers: A string indicating the flow layer types. See + https://thoglu.github.io/jammy_flows/usage/introduction.html + for details. + hidden_size: The number of columns on which the normalizing flow + is conditioned on. May be `None`, indicating non-conditional flow. """ # Base class constructor From 845293d029b677d8c0998b4aa6f3e2979b7cf29a Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 10:04:43 +0200 Subject: [PATCH 14/39] update workflow to install jammy_flows --- .github/actions/install/action.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/install/action.yml b/.github/actions/install/action.yml index b2d6d2896..19e23be01 100644 --- a/.github/actions/install/action.yml +++ b/.github/actions/install/action.yml @@ -38,4 +38,5 @@ runs: run: | echo requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }} pip install -r requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }} + pip install git+https://github.com/thoglu/jammy_flows.git shell: bash From a71765cc30c919c74fa937cd748db76e618ac4bd Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 10:30:53 +0200 Subject: [PATCH 15/39] add example --- .../04_training/07_train_normalizing_flow.py | 225 ++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 examples/04_training/07_train_normalizing_flow.py diff --git a/examples/04_training/07_train_normalizing_flow.py b/examples/04_training/07_train_normalizing_flow.py new file mode 100644 index 000000000..e6de97920 --- /dev/null +++ b/examples/04_training/07_train_normalizing_flow.py @@ -0,0 +1,225 @@ +"""Example of training a conditional NormalizingFlow.""" + +import os +from typing import Any, Dict, List, Optional + +from pytorch_lightning.loggers import WandbLogger +import torch +from torch.optim.adam import Adam + +from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR +from graphnet.data.constants import FEATURES, TRUTH +from graphnet.models import NormalizingFlow +from graphnet.models.detector.prometheus import Prometheus +from graphnet.models.gnn import DynEdge +from graphnet.models.graphs import KNNGraph +from graphnet.models.task.task import StandardFlowTask +from graphnet.training.callbacks import PiecewiseLinearLR +from graphnet.training.utils import make_train_validation_dataloader +from graphnet.utilities.argparse import ArgumentParser +from graphnet.utilities.logging import Logger + +# Constants +features = FEATURES.PROMETHEUS +truth = TRUTH.PROMETHEUS + + +def main( + path: str, + pulsemap: str, + target: str, + truth_table: str, + gpus: Optional[List[int]], + max_epochs: int, + early_stopping_patience: int, + batch_size: int, + num_workers: int, + wandb: bool = False, +) -> None: + """Run example.""" + # Construct Logger + logger = Logger() + + # Initialise Weights & Biases (W&B) run + if wandb: + # Make sure W&B output directory exists + wandb_dir = "./wandb/" + os.makedirs(wandb_dir, exist_ok=True) + wandb_logger = WandbLogger( + project="example-script", + entity="graphnet-team", + save_dir=wandb_dir, + log_model=True, + ) + + logger.info(f"features: {features}") + logger.info(f"truth: {truth}") + + # Configuration + config: Dict[str, Any] = { + "path": path, + "pulsemap": pulsemap, + "batch_size": batch_size, + "num_workers": num_workers, + "target": target, + "early_stopping_patience": early_stopping_patience, + "fit": { + "gpus": gpus, + "max_epochs": max_epochs, + }, + } + + archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs") + run_name = "dynedge_{}_example".format(config["target"]) + if wandb: + # Log configuration to W&B + wandb_logger.experiment.config.update(config) + + # Define graph representation + graph_definition = KNNGraph(detector=Prometheus()) + + ( + training_dataloader, + validation_dataloader, + ) = make_train_validation_dataloader( + db=config["path"], + graph_definition=graph_definition, + pulsemaps=config["pulsemap"], + features=features, + truth=truth, + batch_size=config["batch_size"], + num_workers=config["num_workers"], + truth_table=truth_table, + selection=None, + ) + + # Building model + + backbone = DynEdge( + nb_inputs=graph_definition.nb_outputs, + global_pooling_schemes=["min", "max", "mean", "sum"], + ) + + model = NormalizingFlow( + graph_definition=graph_definition, + backbone=backbone, + optimizer_class=Adam, + target_labels=config["target"], + optimizer_kwargs={"lr": 1e-03, "eps": 1e-03}, + scheduler_class=PiecewiseLinearLR, + scheduler_kwargs={ + "milestones": [ + 0, + len(training_dataloader) / 2, + len(training_dataloader) * config["fit"]["max_epochs"], + ], + "factors": [1e-2, 1, 1e-02], + }, + scheduler_config={ + "interval": "step", + }, + ) + + # Training model + model.fit( + training_dataloader, + validation_dataloader, + early_stopping_patience=config["early_stopping_patience"], + logger=wandb_logger if wandb else None, + **config["fit"], + ) + + # Get predictions + additional_attributes = model.target_labels + assert isinstance(additional_attributes, list) # mypy + + results = model.predict_as_dataframe( + validation_dataloader, + additional_attributes=additional_attributes + ["event_no"], + gpus=config["fit"]["gpus"], + ) + + # Save predictions and model to file + db_name = path.split("/")[-1].split(".")[0] + path = os.path.join(archive, db_name, run_name) + logger.info(f"Writing results to {path}") + os.makedirs(path, exist_ok=True) + + # Save results as .csv + results.to_csv(f"{path}/results.csv") + + # Save full model (including weights) to .pth file - not version safe + # Note: Models saved as .pth files in one version of graphnet + # may not be compatible with a different version of graphnet. + model.save(f"{path}/model.pth") + + # Save model config and state dict - Version safe save method. + # This method of saving models is the safest way. + model.save_state_dict(f"{path}/state_dict.pth") + model.save_config(f"{path}/model_config.yml") + + +if __name__ == "__main__": + + # Parse command-line arguments + parser = ArgumentParser( + description=""" +Train conditional NormalizingFlow without the use of config files. +""" + ) + + parser.add_argument( + "--path", + help="Path to dataset file (default: %(default)s)", + default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", + ) + + parser.add_argument( + "--pulsemap", + help="Name of pulsemap to use (default: %(default)s)", + default="total", + ) + + parser.add_argument( + "--target", + help=( + "Name of feature to use as regression target (default: " + "%(default)s)" + ), + default="total_energy", + ) + + parser.add_argument( + "--truth-table", + help="Name of truth table to be used (default: %(default)s)", + default="mc_truth", + ) + + parser.with_standard_arguments( + "gpus", + ("max-epochs", 1), + "early-stopping-patience", + ("batch-size", 16), + "num-workers", + ) + + parser.add_argument( + "--wandb", + action="store_true", + help="If True, Weights & Biases are used to track the experiment.", + ) + + args, unknown = parser.parse_known_args() + + main( + args.path, + args.pulsemap, + args.target, + args.truth_table, + args.gpus, + args.max_epochs, + args.early_stopping_patience, + args.batch_size, + args.num_workers, + args.wandb, + ) From eb159328d789eac623cead20e5eef5be6be36de3 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 10:33:15 +0200 Subject: [PATCH 16/39] check in example --- examples/04_training/07_train_normalizing_flow.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/04_training/07_train_normalizing_flow.py b/examples/04_training/07_train_normalizing_flow.py index e6de97920..187e30dfc 100644 --- a/examples/04_training/07_train_normalizing_flow.py +++ b/examples/04_training/07_train_normalizing_flow.py @@ -13,11 +13,21 @@ from graphnet.models.detector.prometheus import Prometheus from graphnet.models.gnn import DynEdge from graphnet.models.graphs import KNNGraph -from graphnet.models.task.task import StandardFlowTask from graphnet.training.callbacks import PiecewiseLinearLR from graphnet.training.utils import make_train_validation_dataloader from graphnet.utilities.argparse import ArgumentParser from graphnet.utilities.logging import Logger +from graphnet.utilities.imports import has_jammy_flows_package + +# Make sure the jammy flows is installed +try: + assert has_jammy_flows_package +except AssertionError: + raise AssertionError( + "This example requires the package`jammy_flow` " + " to be installed. It appears that the package is " + " not installed. Please install the package." + ) # Constants features = FEATURES.PROMETHEUS From 4150f144de8f145f1adfbb2cc61a24ac3617a845 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 11:36:08 +0200 Subject: [PATCH 17/39] update example --- examples/04_training/07_train_normalizing_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/04_training/07_train_normalizing_flow.py b/examples/04_training/07_train_normalizing_flow.py index 187e30dfc..c94c2821f 100644 --- a/examples/04_training/07_train_normalizing_flow.py +++ b/examples/04_training/07_train_normalizing_flow.py @@ -21,7 +21,7 @@ # Make sure the jammy flows is installed try: - assert has_jammy_flows_package + assert has_jammy_flows_package() except AssertionError: raise AssertionError( "This example requires the package`jammy_flow` " From 8116c29ae4cffe463f714c8cc5719ea7e4c43cc1 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 11:53:15 +0200 Subject: [PATCH 18/39] update example --- examples/04_training/07_train_normalizing_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/04_training/07_train_normalizing_flow.py b/examples/04_training/07_train_normalizing_flow.py index c94c2821f..1de4b349b 100644 --- a/examples/04_training/07_train_normalizing_flow.py +++ b/examples/04_training/07_train_normalizing_flow.py @@ -9,7 +9,6 @@ from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR from graphnet.data.constants import FEATURES, TRUTH -from graphnet.models import NormalizingFlow from graphnet.models.detector.prometheus import Prometheus from graphnet.models.gnn import DynEdge from graphnet.models.graphs import KNNGraph @@ -22,6 +21,7 @@ # Make sure the jammy flows is installed try: assert has_jammy_flows_package() + from graphnet.models import NormalizingFlow except AssertionError: raise AssertionError( "This example requires the package`jammy_flow` " From d51f02c108a2994e1d5db40c62f93791279c7f11 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 12:00:41 +0200 Subject: [PATCH 19/39] actions --- .github/actions/install/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/install/action.yml b/.github/actions/install/action.yml index 19e23be01..2941789d4 100644 --- a/.github/actions/install/action.yml +++ b/.github/actions/install/action.yml @@ -38,5 +38,5 @@ runs: run: | echo requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }} pip install -r requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }} - pip install git+https://github.com/thoglu/jammy_flows.git + pip install git+https://github.com/thoglu/jammy_flows.git ${{ env.PIP_FLAGS }} shell: bash From 210ef2847498a80a1153f919a239af0bac01dad0 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 12:02:07 +0200 Subject: [PATCH 20/39] update icetray action --- .github/workflows/build.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8f2762e77..a17bf4b8f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -63,6 +63,14 @@ jobs: uses: ./.github/actions/install with: editable: true + - name: Print packages in pip + run: | + pip show torch + pip show torch-geometric + pip show torch-cluster + pip show torch-sparse + pip show torch-scatter + pip show jammy_flows - name: Run unit tests and generate coverage report run: | coverage run --source=graphnet -m pytest tests/ --ignore=tests/examples/04_training --ignore=tests/utilities @@ -109,6 +117,7 @@ jobs: pip show torch-cluster pip show torch-sparse pip show torch-scatter + pip show jammy_flows - name: Run unit tests and generate coverage report run: | set -o pipefail # To propagate exit code from pytest From c32ffd107fcfe8bfaf2b1251d7cacd358944181a Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 12:07:49 +0200 Subject: [PATCH 21/39] update install action --- .github/actions/install/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/install/action.yml b/.github/actions/install/action.yml index 2941789d4..19e23be01 100644 --- a/.github/actions/install/action.yml +++ b/.github/actions/install/action.yml @@ -38,5 +38,5 @@ runs: run: | echo requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }} pip install -r requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }} - pip install git+https://github.com/thoglu/jammy_flows.git ${{ env.PIP_FLAGS }} + pip install git+https://github.com/thoglu/jammy_flows.git shell: bash From 59870dd5e4eb1525291c713f7fe97d423733a354 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 12:17:02 +0200 Subject: [PATCH 22/39] fix `has_jammy_flows_package` --- src/graphnet/utilities/imports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/utilities/imports.py b/src/graphnet/utilities/imports.py index ae59d3b98..1c143280a 100644 --- a/src/graphnet/utilities/imports.py +++ b/src/graphnet/utilities/imports.py @@ -36,7 +36,7 @@ def has_torch_package() -> bool: def has_jammy_flows_package() -> bool: """Check if the `jammy_flows` package is available.""" try: - import jammmy_flows # pyright: reportMissingImports=false + import jammy_flows # pyright: reportMissingImports=false return True except ImportError: From b953ff4dece1dac067190683a79cbda0a5b3cf5a Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 17:16:12 +0200 Subject: [PATCH 23/39] polish --- src/graphnet/models/normalizing_flow.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py index e0f11be3e..3f61ea294 100644 --- a/src/graphnet/models/normalizing_flow.py +++ b/src/graphnet/models/normalizing_flow.py @@ -96,7 +96,7 @@ def forward(self, data: Union[Data, List[Data]]) -> Tensor: x = self._tasks[0](x, d) x_list.append(x) x = torch.cat(x_list, dim=0) - return x + return [x] def _backbone( self, data: Union[Data, List[Data]] @@ -111,6 +111,9 @@ def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: between the training and validation step. """ loss = self(batch) + if isinstance(loss, list): + assert len(loss) == 1 + loss = loss[0] return torch.mean(loss, dim=0) def validate_tasks(self) -> None: From 5ab298a8579fa99fc019753866cda9810177a880 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 17:30:04 +0200 Subject: [PATCH 24/39] add doc string --- src/graphnet/models/normalizing_flow.py | 32 ++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py index 3f61ea294..528001ac8 100644 --- a/src/graphnet/models/normalizing_flow.py +++ b/src/graphnet/models/normalizing_flow.py @@ -35,7 +35,37 @@ def __init__( scheduler_kwargs: Optional[Dict] = None, scheduler_config: Optional[Dict] = None, ) -> None: - """Construct `NormalizingFlow`.""" + """Build NormalizingFlow to learn (conditional) normalizing flows. + + NormalizingFlow is able to build, train and evaluate a wide suite of + normalizing flows. Instead of optimizing a loss function, flows + minimize a learned pdf of your data, providing you with a posterior + distribution for every example instead of point-like predictions. + + `NormalizingFlow` can be conditioned on existing fields in the + DataRepresentation or latent representations from `Models`. + + Args: + graph_definition: The `GraphDefinition` to train the model on. + target_labels: Name of target(s) to learn the pdf of. + backbone: Architecture used to produce latent representations of + the input data on which the pdf will be conditioned. + Defaults to None. + condition_on: List of fields in Data objects to condition the + pdf on. Defaults to None. + flow_layers: A string defining the flow layers. + See https://thoglu.github.io/jammy_flows/usage/introduction.html + for details. Defaults to "gggt". + optimizer_class: Optimizer to use. Defaults to Adam. + optimizer_kwargs: Optimzier arguments. Defaults to None. + scheduler_class: Learning rate scheduler to use. Defaults to None. + scheduler_kwargs: Arguments to learning rate scheduler. + Defaults to None. + scheduler_config: Defaults to None. + + Raises: + ValueError: if both `backbone` and `condition_on` is specified. + """ # Checks if (backbone is not None) & (condition_on is not None): # If user wants to condition on both From 3bc33a3718106bfcba66896bf9c5f79dc708fc84 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 17:53:43 +0200 Subject: [PATCH 25/39] update docstring --- src/graphnet/models/normalizing_flow.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py index 528001ac8..59e6e3961 100644 --- a/src/graphnet/models/normalizing_flow.py +++ b/src/graphnet/models/normalizing_flow.py @@ -45,6 +45,9 @@ def __init__( `NormalizingFlow` can be conditioned on existing fields in the DataRepresentation or latent representations from `Models`. + NormalizingFlow is built upon https://github.com/thoglu/jammy_flows, + and we refer to their documentation for details on the flows. + Args: graph_definition: The `GraphDefinition` to train the model on. target_labels: Name of target(s) to learn the pdf of. From b74d9b2658c7889eeeed3b443686551047f44690 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 29 May 2024 17:56:44 +0200 Subject: [PATCH 26/39] update installation instruction --- docs/source/installation/quick-start.html | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/installation/quick-start.html b/docs/source/installation/quick-start.html index aff34659e..e80fd5b8d 100644 --- a/docs/source/installation/quick-start.html +++ b/docs/source/installation/quick-start.html @@ -107,20 +107,20 @@ } if (os == "linux" && cuda != "cpu" && torch != "no_torch"){ - $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]`); + $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`); } else if (os == "linux" && cuda == "cpu" && torch != "no_torch"){ - $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]`); + $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`); } else if (os == "linux" && cuda == "cpu" && torch == "no_torch"){ - $("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[develop]`); + $("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`); } if (os == "macos" && cuda == "cpu" && torch != "no_torch"){ - $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[torch,develop]`); + $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`); } if (os == "macos" && cuda == "cpu" && torch == "no_torch"){ - $("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[develop]`); + $("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`); } } From 49576cb323aeeb240fef931f5c3d488060818569 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Sat, 13 Jul 2024 10:09:03 +0200 Subject: [PATCH 27/39] add normalization --- src/graphnet/models/normalizing_flow.py | 2 ++ src/graphnet/models/task/task.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py index 59e6e3961..d62cf7c42 100644 --- a/src/graphnet/models/normalizing_flow.py +++ b/src/graphnet/models/normalizing_flow.py @@ -111,6 +111,7 @@ def __init__( self._graph_definition = graph_definition self.backbone = backbone self._condition_on = condition_on + self._norm = torch.nn.BatchNorm1d(hidden_size) def forward(self, data: Union[Data, List[Data]]) -> Tensor: """Forward pass, chaining model components.""" @@ -120,6 +121,7 @@ def forward(self, data: Union[Data, List[Data]]) -> Tensor: for d in data: if self.backbone is not None: x = self._backbone(d) + x = self._norm(x) elif self._condition_on is not None: assert isinstance(self._condition_on, list) x = get_fields(data=d, fields=self._condition_on) diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index 45f6f8a15..a1c3c52ed 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -396,6 +396,7 @@ def __init__( self, hidden_size: Union[int, None], flow_layers: str = "gggt", + target_norm: float = 1000.0, **task_kwargs: Any, ): """Construct `StandardFlowTask`. @@ -405,6 +406,8 @@ def __init__( flow_layers: A string indicating the flow layer types. See https://thoglu.github.io/jammy_flows/usage/introduction.html for details. + target_norm: A normalization constant used to divide the target + values. Value is applied to all targets. Defaults to 1000. hidden_size: The number of columns on which the normalizing flow is conditioned on. May be `None`, indicating non-conditional flow. """ @@ -420,6 +423,7 @@ def __init__( conditional_input_dim=hidden_size, ) self._initialized = False + self._norm = target_norm @property def default_prediction_labels(self) -> List[str]: @@ -431,6 +435,7 @@ def nb_inputs(self) -> Union[int, None]: # type: ignore return self._hidden_size def _forward(self, x: Optional[Tensor], y: Tensor) -> Tensor: # type: ignore + y = y / self._norm if x is not None: if x.shape[0] != y.shape[0]: raise AssertionError( From 53eefb46c1de0432b72d985a3a46545e3fb65c7f Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 6 Aug 2024 09:52:26 +0200 Subject: [PATCH 28/39] increase batch size to avoid single event batch --- examples/04_training/07_train_normalizing_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/04_training/07_train_normalizing_flow.py b/examples/04_training/07_train_normalizing_flow.py index 1de4b349b..baa3eec85 100644 --- a/examples/04_training/07_train_normalizing_flow.py +++ b/examples/04_training/07_train_normalizing_flow.py @@ -209,7 +209,7 @@ def main( "gpus", ("max-epochs", 1), "early-stopping-patience", - ("batch-size", 16), + ("batch-size", 50), "num-workers", ) From dfee76d90be66a02675f35016025b8dbb2130f4e Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Tue, 6 Aug 2024 09:53:48 +0200 Subject: [PATCH 29/39] revert change --- examples/04_training/07_train_normalizing_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/04_training/07_train_normalizing_flow.py b/examples/04_training/07_train_normalizing_flow.py index baa3eec85..1de4b349b 100644 --- a/examples/04_training/07_train_normalizing_flow.py +++ b/examples/04_training/07_train_normalizing_flow.py @@ -209,7 +209,7 @@ def main( "gpus", ("max-epochs", 1), "early-stopping-patience", - ("batch-size", 50), + ("batch-size", 16), "num-workers", ) From dd416599a417d035477b5d0c8a8759d3b0f813f8 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Tue, 6 Aug 2024 09:55:34 +0200 Subject: [PATCH 30/39] increase batch size --- examples/04_training/07_train_normalizing_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/04_training/07_train_normalizing_flow.py b/examples/04_training/07_train_normalizing_flow.py index 1de4b349b..baa3eec85 100644 --- a/examples/04_training/07_train_normalizing_flow.py +++ b/examples/04_training/07_train_normalizing_flow.py @@ -209,7 +209,7 @@ def main( "gpus", ("max-epochs", 1), "early-stopping-patience", - ("batch-size", 16), + ("batch-size", 50), "num-workers", ) From bd44bd6bfbe19d9c16b399daceb73cbc545222ea Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Thu, 15 Aug 2024 21:03:40 -0400 Subject: [PATCH 31/39] Relabel "batches" in ParquetDataset, add samplers * Renamed the batch variables in ParquetDataset to chunk variables * Implemented RandomChunkSampler and LenMatchBatchSampler w/ modifications --- src/graphnet/data/dataset/__init__.py | 1 + .../data/dataset/parquet/parquet_dataset.py | 28 ++- src/graphnet/data/dataset/samplers.py | 232 ++++++++++++++++++ 3 files changed, 250 insertions(+), 11 deletions(-) create mode 100644 src/graphnet/data/dataset/samplers.py diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py index f6eafee94..eeb3123d9 100644 --- a/src/graphnet/data/dataset/__init__.py +++ b/src/graphnet/data/dataset/__init__.py @@ -5,6 +5,7 @@ if has_torch_package(): import torch.multiprocessing from .dataset import EnsembleDataset, Dataset, ColumnMissingException + from .samplers import RandomChunkSampler, LenMatchBatchSampler from .parquet.parquet_dataset import ParquetDataset from .sqlite.sqlite_dataset import SQLiteDataset diff --git a/src/graphnet/data/dataset/parquet/parquet_dataset.py b/src/graphnet/data/dataset/parquet/parquet_dataset.py index 3561c591a..2df6ed16e 100644 --- a/src/graphnet/data/dataset/parquet/parquet_dataset.py +++ b/src/graphnet/data/dataset/parquet/parquet_dataset.py @@ -5,6 +5,7 @@ List, Optional, Union, + Any, ) import numpy as np @@ -92,7 +93,7 @@ def __init__( `"10000 random events ~ event_no % 5 > 0"` or `"20% random events ~ event_no % 5 > 0"`). graph_definition: Method that defines the graph representation. - cache_size: Number of batches to cache in memory. + cache_size: Number of files to cache in memory. Must be at least 1. Defaults to 1. labels: Dictionary of labels to be added to the dataset. """ @@ -123,8 +124,8 @@ def __init__( self._path: str = self._path # Member Variables self._cache_size = cache_size - self._batch_sizes = self._calculate_sizes() - self._batch_cumsum = np.cumsum(self._batch_sizes) + self._chunk_sizes = self._calculate_sizes() + self._chunk_cumsum = np.cumsum(self._chunk_sizes) self._file_cache = self._initialize_file_cache( truth_table=truth_table, node_truth_table=node_truth_table, @@ -179,9 +180,14 @@ def _get_event_index(self, sequential_index: int) -> int: ) return event_index + @property + def chunk_sizes(self) -> List[int]: + """Return a list of the chunk sizes.""" + return self._chunk_sizes + def __len__(self) -> int: """Return length of dataset, i.e. number of training examples.""" - return sum(self._batch_sizes) + return sum(self._chunk_sizes) def _get_all_indices(self) -> List[int]: """Return a list of all unique values in `self._index_column`.""" @@ -189,22 +195,22 @@ def _get_all_indices(self) -> List[int]: return np.arange(0, len(files), 1) def _calculate_sizes(self) -> List[int]: - """Calculate the number of events in each batch.""" + """Calculate the number of events in each chunk.""" sizes = [] - for batch_id in self._indices: + for chunk_id in self._indices: path = os.path.join( self._path, self._truth_table, - f"{self.truth_table}_{batch_id}.parquet", + f"{self.truth_table}_{chunk_id}.parquet", ) sizes.append(len(pol.read_parquet(path))) return sizes def _get_row_idx(self, sequential_index: int) -> int: """Return the row index corresponding to a `sequential_index`.""" - file_idx = bisect_right(self._batch_cumsum, sequential_index) + file_idx = bisect_right(self._chunk_cumsum, sequential_index) if file_idx > 0: - idx = int(sequential_index - self._batch_cumsum[file_idx - 1]) + idx = int(sequential_index - self._chunk_cumsum[file_idx - 1]) else: idx = sequential_index return idx @@ -241,9 +247,9 @@ def query_table( # type: ignore columns = [columns] if sequential_index is None: - file_idx = np.arange(0, len(self._batch_cumsum), 1) + file_idx = np.arange(0, len(self._chunk_cumsum), 1) else: - file_idx = [bisect_right(self._batch_cumsum, sequential_index)] + file_idx = [bisect_right(self._chunk_cumsum, sequential_index)] file_indices = [self._indices[idx] for idx in file_idx] diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py new file mode 100644 index 000000000..0d102898b --- /dev/null +++ b/src/graphnet/data/dataset/samplers.py @@ -0,0 +1,232 @@ +"""`Sampler` and `BatchSampler` objects for `graphnet`.""" +from typing import ( + Any, + List, + Optional, + Tuple, + Iterator, + Sequence, +) + +from collections import defaultdict +from multiprocessing import Pool, cpu_count, get_context + +import numpy as np +import torch +from torch.utils.data import Sampler, BatchSampler + + +class RandomChunkSampler(Sampler[int]): + """A `Sampler` that randomly selects chunks. + + MIT License + + Copyright (c) 2023 DrHB + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + + def __init__( + self, + data_source: Sequence[Any], + chunks: List[int], + num_samples: Optional[int] = None, + generator: Optional[torch.Generator] = None, + ) -> None: + """Construct `RandomChunkSampler`.""" + # chunks - a list of chunk sizes + self._data_source = data_source + self._num_samples = num_samples + self._chunks = chunks + + # Create a random number generator if one was not provided + if generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + self._generator = torch.Generator() + self._generator.manual_seed(seed) + else: + self._generator = generator + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError( + "num_samples should be a positive integer " + "value, but got num_samples={}".format(self.num_samples) + ) + + @property + def data_source(self) -> Sequence[Any]: + """Return the data source.""" + return self._data_source + + @property + def num_samples(self) -> int: + """Return the number of samples in the data source.""" + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __len__(self) -> int: + """Return the number of sampled.""" + return self.num_samples + + @property + def chunks(self) -> List[int]: + """Return the list of chunks.""" + return self._chunks + + def __iter__(self) -> Iterator[List[int]]: + """Return a list of indices from a randomly sampled chunk.""" + cumsum = np.cumsum(self.chunks) + chunk_list = torch.randperm( + len(self.chunks), generator=self.generator + ).tolist() + + # sample indexes chunk by chunk + yield_samples = 0 + for i in chunk_list: + chunk_len = self.chunks[i] + offset = cumsum[i - 1] if i > 0 else 0 + samples = ( + offset + torch.randperm(chunk_len, generator=self.generator) + ).tolist() + if len(samples) <= self.num_samples - yield_samples: + yield_samples += len(samples) + else: + samples = samples[: self.num_samples - yield_samples] + yield_samples = self.num_samples + yield from samples + + +def gather_buckets( + params: Tuple[List[int], Sequence[Any], int], +) -> Tuple[List[List[int]], List[List[int]]]: + """Gather buckets of events. + + The function that will be used to gather buckets of events by the + `LenMatchBatchSampler`. When using multiprocessing, each worker will call + this function. + + Args: + params: A tuple containg the list of indices to process, + the data_source (typically a `Dataset`), and the batch size. + + Returns: + batches: A list containing batches. + remaining_batches: Incomplete batches. + """ + indices, data_source, batch_size = params + buckets = defaultdict(list) + batches = [] + + for idx in indices: + s = data_source[idx] + L = max(1, s.num_nodes // 16) + buckets[L].append(idx) + if len(buckets[L]) == batch_size: + batches.append(list(buckets[L])) + buckets[L] = [] + + # Include any remaining items in partially filled buckets + remaining_batches = [b for b in buckets.values() if b] + return batches, remaining_batches + + +class LenMatchBatchSampler(BatchSampler): + """A `BatchSampler` that batches similar length events. + + MIT License + + Copyright (c) 2023 DrHB + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + + def __init__( + self, + sampler: Sampler, + batch_size: int, + drop_last: Optional[bool] = False, + ) -> None: + """Construct `LenMatchBatchSampler`.""" + super().__init__( + sampler=sampler, batch_size=batch_size, drop_last=drop_last + ) + + def __iter__(self) -> Iterator[List[int]]: + """Return length-matched batches.""" + indices = list(self.sampler) + data_source = self.sampler.data_source + + n_workers = min(cpu_count(), 6) + chunk_size = len(indices) // n_workers + + # Split indices into nearly equal-sized chunks + chunks = [ + indices[i * chunk_size : (i + 1) * chunk_size] + for i in range(n_workers) + ] + if len(indices) % n_workers != 0: + chunks.append(indices[n_workers * chunk_size :]) + + yielded = 0 + with get_context("spawn").Pool(processes=n_workers) as pool: + results = pool.map( + gather_buckets, + [(chunk, data_source, self.batch_size) for chunk in chunks], + ) + + merged_batches = [] + remaining_indices = [] + for batches, remaining in results: + merged_batches.extend(batches) + remaining_indices.extend(remaining) + + for batch in merged_batches: + yield batch + yielded += 1 + + # Process any remaining indices + leftover = [idx for batch in remaining_indices for idx in batch] + batch = [] + for idx in leftover: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + yielded += 1 + batch = [] + + if len(batch) > 0 and not self.drop_last: + yield batch + yielded += 1 From 0702c474e8202f0119f905bdaa694a11c00912ab Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Thu, 15 Aug 2024 22:18:37 -0400 Subject: [PATCH 32/39] Implementation of Samplers and BatchSamplers into GraphNeTDataModule --- src/graphnet/data/datamodule.py | 40 ++++++++++++++++++- src/graphnet/data/dataset/samplers.py | 55 +++++++++++++++++---------- 2 files changed, 74 insertions(+), 21 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 33f31c5fe..7ab2bbe06 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -195,6 +195,13 @@ def setup(self, stage: str) -> None: if self._val_selection is not None: self._val_dataset = self._create_dataset(self._val_selection) + # if self._len_match_batch: # TODO: the same for val -PW + # batch_size = self._train_batch_sampler_kwargs["batch_size"] + # self._train_random_chunk_sampler = RandomChunkSampler(self._train_dataset, + # chunks=self._train_dataset.chunk_sizes) + # self._train_len_match_batch_sampler = LenMatchBatchSampler(self._train_random_chunk_sampler, + # batch_size=batch_size, + # drop_last=True) return @property @@ -273,6 +280,38 @@ def _create_dataloader( "Unknown dataset encountered during dataloader creation." ) + if "sampler" in dataloader_args.keys(): + # If there were no kwargs provided, set it to empty dict + if "sampler_kwargs" not in dataloader_args.keys(): + dataloader_args["sampler_kwargs"] = {} + dataloader_args["sampler"] = dataloader_args["sampler"]( + dataset, **dataloader_args["sampler_kwargs"] + ) + del dataloader_args["sampler_kwargs"] + + if "batch_sampler" in dataloader_args.keys(): + if "sampler" not in dataloader_args.keys(): + raise KeyError( + "When specifying a `batch_sampler`, you must also provide `sampler`." + ) + # If there were no kwargs provided, set it to empty dict + if "batch_sampler_kwargs" not in dataloader_args.keys(): + dataloader_args["batch_sampler_kwargs"] = {} + + batch_sampler = dataloader_args["batch_sampler"]( + dataloader_args["sampler"], + **dataloader_args["batch_sampler_kwargs"], + ) + dataloader_args["batch_sampler"] = batch_sampler + # Remove extra keys + for key in [ + "batch_sampler_kwargs", + "drop_last", + "sampler", + "shuffle", + ]: + dataloader_args.pop(key, None) + if dataloader_args is None: raise AttributeError("Dataloader arguments not provided.") @@ -479,7 +518,6 @@ def _infer_selections_on_single_dataset( .sample(frac=1, replace=False, random_state=self._rng) .values.tolist() ) # shuffled list - return self._split_selection(all_events) def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index 0d102898b..7a264088a 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -14,6 +14,7 @@ import numpy as np import torch from torch.utils.data import Sampler, BatchSampler +from graphnet.data.dataset import Dataset class RandomChunkSampler(Sampler[int]): @@ -40,20 +41,21 @@ class RandomChunkSampler(Sampler[int]): LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + _____________________ + + Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py """ def __init__( self, - data_source: Sequence[Any], - chunks: List[int], + data_source: Dataset, num_samples: Optional[int] = None, generator: Optional[torch.Generator] = None, ) -> None: """Construct `RandomChunkSampler`.""" - # chunks - a list of chunk sizes self._data_source = data_source self._num_samples = num_samples - self._chunks = chunks + self._chunks = data_source.chunk_sizes # Create a random number generator if one was not provided if generator is None: @@ -94,7 +96,7 @@ def __iter__(self) -> Iterator[List[int]]: """Return a list of indices from a randomly sampled chunk.""" cumsum = np.cumsum(self.chunks) chunk_list = torch.randperm( - len(self.chunks), generator=self.generator + len(self.chunks), generator=self._generator ).tolist() # sample indexes chunk by chunk @@ -103,7 +105,7 @@ def __iter__(self) -> Iterator[List[int]]: chunk_len = self.chunks[i] offset = cumsum[i - 1] if i > 0 else 0 samples = ( - offset + torch.randperm(chunk_len, generator=self.generator) + offset + torch.randperm(chunk_len, generator=self._generator) ).tolist() if len(samples) <= self.num_samples - yield_samples: yield_samples += len(samples) @@ -114,29 +116,33 @@ def __iter__(self) -> Iterator[List[int]]: def gather_buckets( - params: Tuple[List[int], Sequence[Any], int], + params: Tuple[List[int], Sequence[Any], int, int], ) -> Tuple[List[List[int]], List[List[int]]]: """Gather buckets of events. - The function that will be used to gather buckets of events by the + The function that will be used to gather batches of events for the `LenMatchBatchSampler`. When using multiprocessing, each worker will call - this function. + this function. Given indices, this function will group events based on + their length. If the length of event is N, then it will go into the + (N // bucket_width) bucket. This returns completed batches and a + list of incomplete batches that did not fill to batch_size at the end. Args: params: A tuple containg the list of indices to process, - the data_source (typically a `Dataset`), and the batch size. + the data_source (typically a `Dataset`), the batch size, and the + bucket width. Returns: batches: A list containing batches. remaining_batches: Incomplete batches. """ - indices, data_source, batch_size = params + indices, data_source, batch_size, bucket_width = params buckets = defaultdict(list) batches = [] for idx in indices: s = data_source[idx] - L = max(1, s.num_nodes // 16) + L = max(1, s.num_nodes // bucket_width) buckets[L].append(idx) if len(buckets[L]) == batch_size: batches.append(list(buckets[L])) @@ -171,40 +177,49 @@ class LenMatchBatchSampler(BatchSampler): LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + _____________________ + + Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py """ def __init__( self, sampler: Sampler, - batch_size: int, + batch_size: int = 1, + num_workers: int = 1, + bucket_width: int = 16, drop_last: Optional[bool] = False, ) -> None: """Construct `LenMatchBatchSampler`.""" super().__init__( sampler=sampler, batch_size=batch_size, drop_last=drop_last ) + self._bucket_width = bucket_width + self._num_workers = num_workers def __iter__(self) -> Iterator[List[int]]: """Return length-matched batches.""" indices = list(self.sampler) data_source = self.sampler.data_source - n_workers = min(cpu_count(), 6) - chunk_size = len(indices) // n_workers + chunk_size = len(indices) // self._num_workers # Split indices into nearly equal-sized chunks chunks = [ indices[i * chunk_size : (i + 1) * chunk_size] - for i in range(n_workers) + for i in range(self._num_workers) ] - if len(indices) % n_workers != 0: - chunks.append(indices[n_workers * chunk_size :]) + if len(indices) % self._num_workers != 0: + chunks.append(indices[self._num_workers * chunk_size :]) yielded = 0 - with get_context("spawn").Pool(processes=n_workers) as pool: + with get_context("spawn").Pool(processes=self._num_workers) as pool: results = pool.map( gather_buckets, - [(chunk, data_source, self.batch_size) for chunk in chunks], + [ + (chunk, data_source, self.batch_size, self._bucket_width) + for chunk in chunks + ], ) merged_batches = [] From a570f7b93fd24fd7cbe1627275f365969d056875 Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Thu, 15 Aug 2024 22:20:48 -0400 Subject: [PATCH 33/39] Remove old comment block --- src/graphnet/data/datamodule.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 7ab2bbe06..ae3737ebd 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -195,13 +195,6 @@ def setup(self, stage: str) -> None: if self._val_selection is not None: self._val_dataset = self._create_dataset(self._val_selection) - # if self._len_match_batch: # TODO: the same for val -PW - # batch_size = self._train_batch_sampler_kwargs["batch_size"] - # self._train_random_chunk_sampler = RandomChunkSampler(self._train_dataset, - # chunks=self._train_dataset.chunk_sizes) - # self._train_len_match_batch_sampler = LenMatchBatchSampler(self._train_random_chunk_sampler, - # batch_size=batch_size, - # drop_last=True) return @property From 3132ac04854aae99fbe2ab47098b7fb73ec0895e Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Thu, 15 Aug 2024 22:31:57 -0400 Subject: [PATCH 34/39] Add multiprocessing_context for LenMatchBatchSampler --- src/graphnet/data/dataset/samplers.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index 7a264088a..a9938a968 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -188,14 +188,33 @@ def __init__( batch_size: int = 1, num_workers: int = 1, bucket_width: int = 16, + multiprocessing_context: str = "spawn", drop_last: Optional[bool] = False, ) -> None: - """Construct `LenMatchBatchSampler`.""" + """Construct `LenMatchBatchSampler`. + + This `BatchSampler` groups data with similar lengths to be more efficient + in operations like masking for MultiHeadAttention. Since batch samplers + run on the main process and can result in a CPU bottleneck, `num_workers` + can be specified to use multiprocessing for creating the batches. The + `bucket_width` argument specifies how wide the bins are for grouping batches. + For example, with `bucket_width=16`, data with length [1, 16] and grouped into + a bucket and data with length [17, 32] in another. + + Args: + sampler: A `Sampler` object that selects/draws data in some way. + batch_size: Batch size. + num_workers: Number of workers to spawn to create batches. + bucket_width: Size of length buckets for grouping data. + multiprocessing_context: Start method for multiprocessing. + drop_last: (Optional) Drop the last incomplete batch. + """ super().__init__( sampler=sampler, batch_size=batch_size, drop_last=drop_last ) - self._bucket_width = bucket_width self._num_workers = num_workers + self._bucket_width = bucket_width + self._multiprocessing_context = multiprocessing_context def __iter__(self) -> Iterator[List[int]]: """Return length-matched batches.""" @@ -213,7 +232,9 @@ def __iter__(self) -> Iterator[List[int]]: chunks.append(indices[self._num_workers * chunk_size :]) yielded = 0 - with get_context("spawn").Pool(processes=self._num_workers) as pool: + with get_context(self._multiprocessing_context).Pool( + processes=self._num_workers + ) as pool: results = pool.map( gather_buckets, [ From c261fc922042635cec3aa9ddd6e94a95370e58ac Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Fri, 16 Aug 2024 00:12:07 -0400 Subject: [PATCH 35/39] Improved LenMatchBatchSampler --- src/graphnet/data/dataset/samplers.py | 53 +++++++++++++-------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index a9938a968..45619b41b 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -198,8 +198,8 @@ def __init__( run on the main process and can result in a CPU bottleneck, `num_workers` can be specified to use multiprocessing for creating the batches. The `bucket_width` argument specifies how wide the bins are for grouping batches. - For example, with `bucket_width=16`, data with length [1, 16] and grouped into - a bucket and data with length [17, 32] in another. + For example, with `bucket_width=16`, data with length [1, 16] are grouped into + a bucket, data with length [17, 32] into another, etc. Args: sampler: A `Sampler` object that selects/draws data in some way. @@ -212,6 +212,10 @@ def __init__( super().__init__( sampler=sampler, batch_size=batch_size, drop_last=drop_last ) + assert ( + num_workers >= 1 + ), "Need at least one worker to use LenMatchBatchSampler!" + self._num_workers = num_workers self._bucket_width = bucket_width self._multiprocessing_context = multiprocessing_context @@ -221,43 +225,38 @@ def __iter__(self) -> Iterator[List[int]]: indices = list(self.sampler) data_source = self.sampler.data_source - chunk_size = len(indices) // self._num_workers + segments_size = len(indices) // self._num_workers - # Split indices into nearly equal-sized chunks - chunks = [ - indices[i * chunk_size : (i + 1) * chunk_size] + # Split indices into nearly equal-sized segments amonst the workers + segments = [ + indices[i * segments_size : (i + 1) * segments_size] for i in range(self._num_workers) ] + + # Collect the leftovers into another segment if len(indices) % self._num_workers != 0: - chunks.append(indices[self._num_workers * chunk_size :]) + segments.append(indices[self._num_workers * segments_size :]) yielded = 0 - with get_context(self._multiprocessing_context).Pool( - processes=self._num_workers - ) as pool: - results = pool.map( + remaining_indices = [] + with get_context("spawn").Pool(processes=self._num_workers) as pool: + for result in pool.imap_unordered( gather_buckets, [ - (chunk, data_source, self.batch_size, self._bucket_width) - for chunk in chunks + (segment, data_source, self.batch_size, self._bucket_width) + for segment in segments ], - ) - - merged_batches = [] - remaining_indices = [] - for batches, remaining in results: - merged_batches.extend(batches) - remaining_indices.extend(remaining) - - for batch in merged_batches: - yield batch - yielded += 1 + ): + batches, leftovers = result + for batch in batches: + yield batch + yielded += 1 + remaining_indices.extend(leftovers) # Process any remaining indices - leftover = [idx for batch in remaining_indices for idx in batch] batch = [] - for idx in leftover: - batch.append(idx) + for incomplete_batch in remaining_indices: + batch.extend(incomplete_batch) if len(batch) == self.batch_size: yield batch yielded += 1 From 9b6e5688e0e06856d465567a1b99cad459e4f1e9 Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Thu, 22 Aug 2024 08:17:36 -0400 Subject: [PATCH 36/39] Minor changes, more settings for samplers --- src/graphnet/data/dataset/__init__.py | 5 +- src/graphnet/data/dataset/samplers.py | 133 +++++++++++++++++--------- 2 files changed, 92 insertions(+), 46 deletions(-) diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py index eeb3123d9..ed1c55ef5 100644 --- a/src/graphnet/data/dataset/__init__.py +++ b/src/graphnet/data/dataset/__init__.py @@ -5,7 +5,10 @@ if has_torch_package(): import torch.multiprocessing from .dataset import EnsembleDataset, Dataset, ColumnMissingException - from .samplers import RandomChunkSampler, LenMatchBatchSampler + from .samplers import ( + RandomChunkSampler, + LenMatchBatchSampler, + ) from .parquet.parquet_dataset import ParquetDataset from .sqlite.sqlite_dataset import SQLiteDataset diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index 45619b41b..ae8f728fb 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -15,6 +15,7 @@ import torch from torch.utils.data import Sampler, BatchSampler from graphnet.data.dataset import Dataset +from graphnet.utilities.logging import Logger class RandomChunkSampler(Sampler[int]): @@ -115,10 +116,10 @@ def __iter__(self) -> Iterator[List[int]]: yield from samples -def gather_buckets( - params: Tuple[List[int], Sequence[Any], int, int], +def gather_len_matched_buckets( + params: Tuple[range, Sequence[Any], int, int], ) -> Tuple[List[List[int]], List[List[int]]]: - """Gather buckets of events. + """Gather length-matched buckets of events. The function that will be used to gather batches of events for the `LenMatchBatchSampler`. When using multiprocessing, each worker will call @@ -153,7 +154,7 @@ def gather_buckets( return batches, remaining_batches -class LenMatchBatchSampler(BatchSampler): +class LenMatchBatchSampler(BatchSampler, Logger): """A `BatchSampler` that batches similar length events. MIT License @@ -188,6 +189,7 @@ def __init__( batch_size: int = 1, num_workers: int = 1, bucket_width: int = 16, + chunks_per_segment: int = 4, multiprocessing_context: str = "spawn", drop_last: Optional[bool] = False, ) -> None: @@ -206,62 +208,103 @@ def __init__( batch_size: Batch size. num_workers: Number of workers to spawn to create batches. bucket_width: Size of length buckets for grouping data. + chunks_per_segment: Number of chunks to group together for processing. multiprocessing_context: Start method for multiprocessing. drop_last: (Optional) Drop the last incomplete batch. """ + Logger.__init__(self) super().__init__( sampler=sampler, batch_size=batch_size, drop_last=drop_last ) - assert ( - num_workers >= 1 - ), "Need at least one worker to use LenMatchBatchSampler!" + assert num_workers >= 0, "`num_workers` must be >= 0!" self._num_workers = num_workers self._bucket_width = bucket_width + self._chunks_per_segment = chunks_per_segment self._multiprocessing_context = multiprocessing_context + self.info( + f"Setting up batch sampler with {self._num_workers} workers." + ) + def __iter__(self) -> Iterator[List[int]]: """Return length-matched batches.""" indices = list(self.sampler) data_source = self.sampler.data_source - segments_size = len(indices) // self._num_workers - - # Split indices into nearly equal-sized segments amonst the workers - segments = [ - indices[i * segments_size : (i + 1) * segments_size] - for i in range(self._num_workers) - ] - - # Collect the leftovers into another segment - if len(indices) % self._num_workers != 0: - segments.append(indices[self._num_workers * segments_size :]) - - yielded = 0 - remaining_indices = [] - with get_context("spawn").Pool(processes=self._num_workers) as pool: - for result in pool.imap_unordered( - gather_buckets, - [ - (segment, data_source, self.batch_size, self._bucket_width) - for segment in segments - ], - ): - batches, leftovers = result - for batch in batches: - yield batch - yielded += 1 - remaining_indices.extend(leftovers) - - # Process any remaining indices - batch = [] - for incomplete_batch in remaining_indices: - batch.extend(incomplete_batch) - if len(batch) == self.batch_size: + if self._num_workers > 0: + + n_chunks = len(self.sampler.chunks) + n_segments = n_chunks // self._chunks_per_segment + + # Split indices into nearly equal-sized segments amongst the workers + segments = [ + range( + sum(self.sampler.chunks[: i * self._chunks_per_segment]), + sum( + self.sampler.chunks[ + : (i + 1) * self._chunks_per_segment + ] + ), + ) + for i in range(n_segments) + ] + segments.extend( + [range(segments[-1][-1], len(indices) - 1)] + ) # Make a segment w/ the leftover indices + + remaining_indices = [] + with get_context(self._multiprocessing_context).Pool( + processes=self._num_workers + ) as pool: + results = pool.imap_unordered( + gather_len_matched_buckets, + [ + ( + segments[i], + data_source, + self.batch_size, + self._bucket_width, + ) + for i in range(n_segments) + ], + ) + for result in results: + batches, leftovers = result + for batch in batches: + yield batch + remaining_indices.extend(leftovers) + + # Process any remaining indices + batch = [] + for incomplete_batch in remaining_indices: + batch.extend(incomplete_batch) + if len(batch) >= self.batch_size: + yield batch[: self.batch_size] + batch = batch[self.batch_size :] + + if len(batch) > 0 and not self.drop_last: yield batch - yielded += 1 - batch = [] + else: # n_workers = 0, no multiprocessing + buckets = defaultdict(list) + + for idx in self.sampler: + s = self.sampler.data_source[idx] + L = max(1, s.num_nodes // self._bucket_width) + buckets[L].append(idx) + if len(buckets[L]) == self.batch_size: + batch = list(buckets[L]) + yield batch + buckets[L] = [] - if len(batch) > 0 and not self.drop_last: - yield batch - yielded += 1 + batch = [] + leftover = [idx for bucket in buckets for idx in bucket] + + for idx in leftover: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + + if len(batch) > 0 and not self.drop_last: + yield batch From d6e7bed3af1734b7356f3ca5f8a3fed8d3eed4d2 Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Mon, 2 Sep 2024 11:09:22 -0400 Subject: [PATCH 37/39] Fix docstrings --- src/graphnet/data/datamodule.py | 3 +- src/graphnet/data/dataset/samplers.py | 98 +++++++++++---------------- 2 files changed, 42 insertions(+), 59 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index ae3737ebd..802a64a7d 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -285,7 +285,8 @@ def _create_dataloader( if "batch_sampler" in dataloader_args.keys(): if "sampler" not in dataloader_args.keys(): raise KeyError( - "When specifying a `batch_sampler`, you must also provide `sampler`." + "When specifying a `batch_sampler`," + "you must also provide `sampler`." ) # If there were no kwargs provided, set it to empty dict if "batch_sampler_kwargs" not in dataloader_args.keys(): diff --git a/src/graphnet/data/dataset/samplers.py b/src/graphnet/data/dataset/samplers.py index ae8f728fb..c43455447 100644 --- a/src/graphnet/data/dataset/samplers.py +++ b/src/graphnet/data/dataset/samplers.py @@ -1,4 +1,29 @@ -"""`Sampler` and `BatchSampler` objects for `graphnet`.""" +"""`Sampler` and `BatchSampler` objects for `graphnet`. + +MIT License + +Copyright (c) 2023 DrHB + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +_____________________ +""" + from typing import ( Any, List, @@ -21,30 +46,8 @@ class RandomChunkSampler(Sampler[int]): """A `Sampler` that randomly selects chunks. - MIT License - - Copyright (c) 2023 DrHB - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - _____________________ - - Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py + Original implementation: + https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py """ def __init__( @@ -157,30 +160,8 @@ def gather_len_matched_buckets( class LenMatchBatchSampler(BatchSampler, Logger): """A `BatchSampler` that batches similar length events. - MIT License - - Copyright (c) 2023 DrHB - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - _____________________ - - Original implementation: https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py + Original implementation: + https://github.com/DrHB/icecube-2nd-place/blob/main/src/dataset.py """ def __init__( @@ -195,20 +176,21 @@ def __init__( ) -> None: """Construct `LenMatchBatchSampler`. - This `BatchSampler` groups data with similar lengths to be more efficient - in operations like masking for MultiHeadAttention. Since batch samplers - run on the main process and can result in a CPU bottleneck, `num_workers` - can be specified to use multiprocessing for creating the batches. The - `bucket_width` argument specifies how wide the bins are for grouping batches. - For example, with `bucket_width=16`, data with length [1, 16] are grouped into - a bucket, data with length [17, 32] into another, etc. + This `BatchSampler` groups data with similar lengths to be more + efficient in operations like masking for MultiHeadAttention. Since + batch samplers run on the main process and can result in a CPU + bottleneck, `num_workers` can be specified to use multiprocessing for + creating the batches. The `bucket_width` argument specifies how wide + the bins are for grouping batches. For example, with `bucket_width=16`, + data with length [1, 16] are grouped into a bucket, data with length + [17, 32] into another, etc. Args: sampler: A `Sampler` object that selects/draws data in some way. batch_size: Batch size. num_workers: Number of workers to spawn to create batches. bucket_width: Size of length buckets for grouping data. - chunks_per_segment: Number of chunks to group together for processing. + chunks_per_segment: Number of chunks to group together. multiprocessing_context: Start method for multiprocessing. drop_last: (Optional) Drop the last incomplete batch. """ @@ -237,7 +219,7 @@ def __iter__(self) -> Iterator[List[int]]: n_chunks = len(self.sampler.chunks) n_segments = n_chunks // self._chunks_per_segment - # Split indices into nearly equal-sized segments amongst the workers + # Split indices into nearly equal-sized segments amongst workers segments = [ range( sum(self.sampler.chunks[: i * self._chunks_per_segment]), From 5276d4d33124e7623fd56ee5bc8bc40f5396108e Mon Sep 17 00:00:00 2001 From: niklasmei Date: Sat, 14 Sep 2024 16:20:49 +0200 Subject: [PATCH 38/39] changed the import path of SQLiteDataset to the currently correct location. At the same time did the same change for ParquetDataset, issue will follow --- docs/source/datasets/datasets.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/datasets/datasets.rst b/docs/source/datasets/datasets.rst index 8716d6113..ae8f78ee2 100644 --- a/docs/source/datasets/datasets.rst +++ b/docs/source/datasets/datasets.rst @@ -176,7 +176,7 @@ After that, you can construct your :code:`Dataset` from a SQLite database with j .. code-block:: python - from graphnet.data.sqlite import SQLiteDataset + from graphnet.data.dataset.sqlite.sqlite_dataset import SQLiteDataset from graphnet.models.detector.prometheus import Prometheus from graphnet.models.graphs import KNNGraph from graphnet.models.graphs.nodes import NodesAsPulses @@ -203,7 +203,7 @@ Or similarly for Parquet files: .. code-block:: python - from graphnet.data.parquet import ParquetDataset + from graphnet.data.dataset.parquet.parquet_dataset import ParquetDataset from graphnet.models.detector.prometheus import Prometheus from graphnet.models.graphs import KNNGraph from graphnet.models.graphs.nodes import NodesAsPulses From 9aa693624065150901d2cc1d8b7d6e73e53e6a3f Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Mon, 16 Sep 2024 15:54:05 +0200 Subject: [PATCH 39/39] only initialize if training --- src/graphnet/models/task/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index a1c3c52ed..0b9101107 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -461,7 +461,7 @@ def forward( labels = labels.to(self.dtype) # Set the initial parameters of flow close to truth # This speeds up training and helps with NaN - if self._initialized is False: + if (self._initialized is False) & (self.training): self._flow.init_params(data=deepcopy(labels).cpu()) self._flow.to(self.device) self._initialized = True # This is only done once