Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Definition #558

Merged
merged 60 commits into from
Aug 14, 2023
Merged

Conversation

RasmusOrsoe
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe commented Jul 18, 2023

This PR addresses the ongoing discussion in #462 and #521 (which has been a roadblock for a while) by changing Model such that it now consists of the modules

Model = [GraphDefinition, GNN, Task]

Where GraphDefinition is a single, problem/model dependent class that contains all the code responsible for data representations.

  • Add GraphDefinition
  • Implement our default k-nn graph as importable graph definition
  • Redefine Model to depend on GraphDefinition
  • Refactor Detector
  • Update example scripts
  • Update config files
  • Consider default values
  • Delete redundant modules; GraphBuilder etc.
  • Update getting_started.md
  • Refactor Deployment modules

TLDR: Model, Dataset and GraphNeTI3Module now depends on GraphDefinition, which allows us to easily represent data as sequences, images, or whatever your heart desires. This change is breaking; older config files and pickled models are not compatible with these changes, but state_dicts are.

Conceptually, GraphDefinition contains all the code that alters the raw data from Dataset before it's passed to GNN. It's a single, swapable module that can be passed to Dataset and deployment modules. GraphDefinition consists of multiple submodules, and the data flow is GraphDefinition = [Detector, NodeDefinition, EdgeDefinition] and can be seen here. The definition exists at a point in the dataflow where events are unbatched, meaning that the construction of data representations can be done on CPU and in parallel, before it's batched and sent to the GPU. That means that the sequence creation included in #521 becomes much simpler and likely faster @Aske-Rosted, and should also be useful for the transformer exploration by @MoustHolmes.

The modules are defined as

NodeDefinition :
A generic class that defines what a node represents. Problem-specific versions can be implemented by overwriting the abstract method

def _construct_nodes(self, x: torch.tensor) -> Data:
        """Construct nodes from raw node features ´x´.

        Args:
            x: standardized node features with shape ´[num_pulses, d]´,
            where ´d´ is the number of node features.

        Returns:
            graph: graph without edges.
        """

_construct_nodes is the playground we've been missing for a while; it gives us the freedom to fully define exactly how we want the data to be structured for our Models. Here, one can use nodes to represent DOMs (by using Coarsening or some other method), create images for CNNs, define sequences or other forms of data representations. Our standard of representing pulses as nodes is just

class NodesAsPulses(NodeDefinition):
    """Represent each measured pulse of Cherenkov Radiation as a node."""

    def _construct_nodes(self, x: torch.tensor) -> Data:
        return Data(x=x)

EdgeDefinition:
A generic class that defines how edges are drawn between nodes in the graph. This is essentially a refactor of our GraphBuilder. One can create problem-specific implementations by overwriting the abstract method

def _construct_edges(self, graph: Data) -> Data:
        """Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´.

        Args:
            graph: graph without edges

        Returns:
            graph: graph with edges assigned.
        """

Detector:
Virtually unchanged from it's known form. In charge of standardizing data and is now able to work on a subset of the feature space that it is defined on. I cleaned the class up a little bit. In the future, it will hold detector-specific geometry tables as mentioned in #462.

Our usual k-nn graph with nodes representing pulses can then be created like so:

from graphnet.models.graphs import GraphDefinition
from graphnet.models.graphs.nodes import NodesAsPulses
from graphnet.models.graphs.edges import KNNEdges
from graphnet.models.detector.prometheus import Prometheus

graph_definition = GraphDefinition(node_definiton = NodesAsPulses(nb_nearest_neighbours=8),
                                   edge_definiton = KNNEdges(),
                                   detector = Prometheus(),
                                     )
                               

Alternatively, you can also just import this graph definition directly, as it's included in the PR:

from graphnet.models.graphs import KNNGraph

graph_definition = KNNGraph(
        detector=Prometheus(),
        node_definition=NodesAsPulses(),
        nb_nearest_neighbours=8,
        node_feature_names=features,
    )

It is the problem-specific implementation of a graph definition that defines the number of input parameters to our GNNs, available through graph_definition.nb_outputs. When we instantiate a Model, the syntax is now:

 gnn = DynEdge(
        nb_inputs=graph_definition.nb_outputs,
        global_pooling_schemes=["min", "max", "mean", "sum"],
    )
    task = ..
model = StandardModel(
        graph_definition=graph_definition,
        gnn=gnn,
        tasks=[task],
        ...)

Other things to note:

  1. Dataset is now simpler, as graph-altering code has been moved to GraphDefinition
  2. Changes are compatible with the Config system.
  3. Some folder structure has been changed in graphnet.data to avoid circular imports.

@RasmusOrsoe RasmusOrsoe marked this pull request as draft July 18, 2023 13:39
@RasmusOrsoe
Copy link
Collaborator Author

I have addressed the initial comments from @AMHermansen, refactored the last few detectors and introduced GraphDefinition in the deployment code. In relation to this, I've created #560 which contains a list of to-do items for the deployment modules which fall outside the scope of this PR.

What's left is to update getting_started.md and considerations on default values. I'll wait with this last step until @Aske-Rosted has had a chance to give this a look after the ICRC conference.

features: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
graph_definition:
arguments:
columns: [0, 1, 2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the naming of this "columns" arguments (which I understand as the columns included in the "distance" calculation of your edge definition) becomes very vague. At first glance it hard to tell what this feature does. Something like "edge_defining_columns" would be more descriptive I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the name of that argument is not unambiguous, but I also fail to find better alternatives that are not very long. The name of this argument is the same as what we used to call it in the KNNGraphBuilder (see here) and I do think that the doc string for this argument (see here) is pretty clear. So even though it's a bit challenging to understand what that argument does when one reads the config file, I think we should keep it as-is and refer users to the docs instead (which is the intended usage anyway). Is that OK with you? @Aske-Rosted

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely fine.


torch.multiprocessing.set_sharing_strategy("file_system")

del has_torch_package
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for deleting the has_torch_package here but keeping it in src/graphnet/data/dataset/sqlite/__init__.py?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't know. I did not add this line; the difference you point out here is present in the main branch currently also. I have added the del statement to the sqlite part now too.

@@ -133,7 +133,7 @@ def _construct_model(
fn_kwargs={"trust": trust},
)

# Construct model based on arguments
# Construct model based on
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this change intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nopes. Fixed.

@Aske-Rosted
Copy link
Collaborator

I have just finished looking through everything, I think it looks great! I also believe that this will add the necessary flexibility for a lot of the things I've been trying to implement, where I previously had to go make changes in the dataset class e.g. #521 . I am looking forward to trying it out.

Copy link
Collaborator

@AMHermansen AMHermansen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @RasmusOrsoe I've noticed some more minor things.

) -> Data:
for idx, feature in enumerate(node_feature_names):
try:
node_features[:, idx] = self.feature_map()[feature]( # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if you're supposed to call self.feature_map, since it is classified as a property?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey. Thanks for pointing this out. In fact, the property decorator in detector.py is poorly placed - it is actually completely redundant. I've removed it.

# Assume all features in Detector is used.
node_feature_names = list(self._detector.feature_map().keys()) # type: ignore
self._node_feature_names = node_feature_names
if dtype is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this code correctly you're just getting the same behavior as if you changed the default value of dtype in the constructor to torch.float instead of None. But in a slightly more more complicated way :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right. Fixed :-)

@RasmusOrsoe
Copy link
Collaborator Author

Thank you very much for your comments. I believe I have now addressed all of them, and I have updated the GETTING_STARTED.MD accordingly. @Aske-Rosted @AMHermansen.

Copy link
Collaborator

@AMHermansen AMHermansen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@RasmusOrsoe RasmusOrsoe marked this pull request as ready for review August 14, 2023 07:42
@RasmusOrsoe RasmusOrsoe merged commit 268a17d into graphnet-team:main Aug 14, 2023
11 checks passed
RasmusOrsoe added a commit to RasmusOrsoe/graphnet that referenced this pull request Oct 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants