From 7d5e901f0ccab413ef4b9916398dd7b3a1726d40 Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 21 Jun 2024 06:30:36 +0000 Subject: [PATCH] Refactor Model Implement --- README.md | 31 + rdagent/app/model_implementation/README.md | 13 + rdagent/app/model_implementation/eval.py | 30 + rdagent/core/implementation.py | 12 +- rdagent/core/task.py | 111 +- .../factor_implementation/evolving/factor.py | 3 + .../model_implementation/benchmark/eval.py | 61 + .../benchmark/gt_code/A-DGN.py | 135 ++ .../benchmark/gt_code/dirgnn.py | 85 ++ .../benchmark/gt_code/gpsconv.py | 196 +++ .../benchmark/gt_code/linkx.py | 178 +++ .../benchmark/gt_code/pmlp.py | 112 ++ .../benchmark/gt_code/visnet.py | 1190 +++++++++++++++++ rdagent/model_implementation/conf.py | 10 + rdagent/model_implementation/gt_code.py | 4 + rdagent/model_implementation/main.py | 6 +- .../model_implementation/one_shot/__init__.py | 42 + .../model_implementation/one_shot/prompt.yaml | 18 + rdagent/model_implementation/task.py | 144 ++ rdagent/utils/__init__.py | 36 + 20 files changed, 2406 insertions(+), 11 deletions(-) create mode 100644 rdagent/app/model_implementation/README.md create mode 100644 rdagent/app/model_implementation/eval.py create mode 100644 rdagent/model_implementation/benchmark/eval.py create mode 100644 rdagent/model_implementation/benchmark/gt_code/A-DGN.py create mode 100644 rdagent/model_implementation/benchmark/gt_code/dirgnn.py create mode 100644 rdagent/model_implementation/benchmark/gt_code/gpsconv.py create mode 100644 rdagent/model_implementation/benchmark/gt_code/linkx.py create mode 100644 rdagent/model_implementation/benchmark/gt_code/pmlp.py create mode 100644 rdagent/model_implementation/benchmark/gt_code/visnet.py create mode 100644 rdagent/model_implementation/conf.py create mode 100644 rdagent/model_implementation/one_shot/__init__.py create mode 100644 rdagent/model_implementation/one_shot/prompt.yaml create mode 100644 rdagent/model_implementation/task.py create mode 100644 rdagent/utils/__init__.py diff --git a/README.md b/README.md index 5cd7cecf..ce446af2 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,39 @@ As the maintainer of this project, please make a few updates: - Understanding the security reporting process in SECURITY.MD - Remove this section from the README +## Configuration: + +You can manually source the `.env` file in your shell before running the Python script: +Most of the workflow are controlled by the environment variables. +```sh +# Source the .env file +source .env +# Run the Python script +python your_script.py +``` + +## Naming convention + +### File naming convention + +| Name | Description | +| -- | -- | +| `conf.py` | The configuration for the module & app & project | + + + + ## Contributing +### Guidance +This project welcomes contributions and suggestions. +You can find issues in the issues list or simply running `grep -r "TODO:"`. + +Making contributions is not a hard thing. Solving an issue(maybe just answering a question raised in issues list ), fixing/issuing a bug, improving the documents and even fixing a typo are important contributions to RDAgent. + + +### Policy + This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. diff --git a/rdagent/app/model_implementation/README.md b/rdagent/app/model_implementation/README.md new file mode 100644 index 00000000..6d17eb2f --- /dev/null +++ b/rdagent/app/model_implementation/README.md @@ -0,0 +1,13 @@ + +# Preparation + +## Install Pytorch +CPU CUDA will be enough for verify the implementation + +Please install pytorch based on your system. +Here is an example on my system +```bash +pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +pip3 install torch_geometric + +``` diff --git a/rdagent/app/model_implementation/eval.py b/rdagent/app/model_implementation/eval.py new file mode 100644 index 00000000..0543789f --- /dev/null +++ b/rdagent/app/model_implementation/eval.py @@ -0,0 +1,30 @@ +from pathlib import Path + +DIRNAME = Path(__file__).absolute().resolve().parent + +from rdagent.model_implementation.benchmark.eval import ModelImpValEval +from rdagent.model_implementation.conf import MODEL_IMPL_SETTINGS +from rdagent.model_implementation.one_shot import ModelTaskGen +from rdagent.model_implementation.task import ModelImpLoader, ModelTaskLoderJson + +mtl = ModelTaskLoderJson("TODO: A Path to json") + +task_l = mtl.load() + +mtg = ModelTaskGen() + +impl_l = mtg.generate(task_l) + +# TODO: Align it with the benchmark framework after @wenjun's refine the evaluation part. +# Currently, we just handcraft a workflow for fast evaluation. + +mil = ModelImpLoader(DIRNAME.parent.parent / "model_implementation" / "benchmark" / "gt_code") + +mie = ModelImpValEval() +# Evaluation: +eval_l = [] +for impl in impl_l: + gt_impl = mil.load(impl.target_task) + eval_l.append(mie.evaluate(gt_impl, impl)) + +print(eval_l) diff --git a/rdagent/core/implementation.py b/rdagent/core/implementation.py index 058f9fab..ada0da65 100644 --- a/rdagent/core/implementation.py +++ b/rdagent/core/implementation.py @@ -1,13 +1,21 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Sequence from rdagent.core.task import ( + BaseTask, TaskImplementation, ) class TaskGenerator(ABC): @abstractmethod - def generate(self, *args, **kwargs) -> List[TaskImplementation]: + def generate(self, task_l: Sequence[BaseTask]) -> Sequence[TaskImplementation]: + """ + Task Generator should take in a sequence of tasks. + + Because the schedule of different tasks is crucial for the final performance + due to it affects the learning process. + + """ raise NotImplementedError("generate method is not implemented.") def collect_feedback(self, feedback_obj_l: List[object]): diff --git a/rdagent/core/task.py b/rdagent/core/task.py index ebae4c67..8e544948 100644 --- a/rdagent/core/task.py +++ b/rdagent/core/task.py @@ -1,27 +1,121 @@ from abc import ABC, abstractmethod -from typing import Tuple +from pathlib import Path +from typing import Generic, Optional, Sequence, Tuple, TypeVar import pandas as pd - """ This file contains the all the data class for rdagent task. """ class BaseTask(ABC): - # 把name放在这里作为主键 + # TODO: 把name放在这里作为主键 + # Please refer to rdagent/model_implementation/task.py for the implementation + # I think the task version applies to the base class. pass +ASpecificTask = TypeVar("ASpecificTask", bound=BaseTask) + -class TaskImplementation(ABC): - def __init__(self, target_task: BaseTask) -> None: +class TaskImplementation(ABC, Generic[ASpecificTask]): + + def __init__(self, target_task: ASpecificTask) -> None: self.target_task = target_task @abstractmethod - def execute(self, *args, **kwargs) -> Tuple[str, pd.DataFrame]: - raise NotImplementedError("__call__ method is not implemented.") + def execute(self, data=None, config: dict = {}) -> object: + """ + The execution of the implementation can be dynamic. + + So we may passin the data and config dynamically. + """ + raise NotImplementedError("execute method is not implemented.") + + @abstractmethod + def execute_desc(self): + """ + return the description how we will execute the code in the folder. + """ + raise NotImplementedError(f"This type of input is not supported") + + # TODO: + # After execution, it should return some results. + # Some evaluators will input the results and output + + +ASpecificTaskImp = TypeVar("ASpecificTaskImp", bound=TaskImplementation) + + +class ImpLoader(ABC, Generic[ASpecificTask, ASpecificTaskImp]): + + @abstractmethod + def load(self, task: ASpecificTask) -> ASpecificTaskImp: + raise NotImplementedError("load method is not implemented.") + + +class FBTaskImplementation(TaskImplementation): + """ + File-based task implementation + + The implemented task will be a folder which contains related elements. + - Data + - Code Implementation + - Output + - After execution, it will generate the final output as file. + + A typical way to run the pipeline of FBTaskImplementation will be + (We didn't add it as a method due to that we may pass arguments into `prepare` or `execute` based on our requirements.) + + .. code-block:: python + + def run_pipline(self, **files: str): + self.prepare() + self.inject_code(**files) + self.execute() + + """ + # TODO: + # FileBasedFactorImplementation should inherient from it. + # Why not directly reuse FileBasedFactorImplementation. + # Because it has too much concerete dependencies. + # e.g. dataframe, factors + + path: Optional[Path] + + @abstractmethod + def prepare(self, *args, **kwargs): + """ + Prepare all the files except the injected code + - Data + - Documentation + - TODO: env? Env is implicitly defined by the document? + + typical usage of `*args, **kwargs`: + Different methods shares the same data. The data are passed by the arguments. + """ + + def inject_code(self, **files: str): + """ + Inject the code into the folder. + { + "model.py": "" + } + """ + for k, v in files.items(): + with open(self.path / k, "w") as f: + f.write(v) + + def get_files(self) -> list[Path]: + """ + Get the environment description. + + To be general, we only return a list of filenames. + How to summarize the environment is the responsibility of the TaskGenerator. + """ + return list(self.path.iterdir()) class TestCase: + def __init__( self, target_task: BaseTask, @@ -32,6 +126,7 @@ def __init__( class TaskLoader: + @abstractmethod - def load(self, *args, **kwargs) -> BaseTask | list[BaseTask]: + def load(self, *args, **kwargs) -> Sequence[BaseTask]: raise NotImplementedError("load method is not implemented.") diff --git a/rdagent/factor_implementation/evolving/factor.py b/rdagent/factor_implementation/evolving/factor.py index 29887af3..6fd0cf7d 100644 --- a/rdagent/factor_implementation/evolving/factor.py +++ b/rdagent/factor_implementation/evolving/factor.py @@ -30,6 +30,8 @@ class FactorImplementTask(BaseTask): + # TODO: generalized the attributes into the BaseTask + # - factor_* -> * def __init__( self, factor_name, @@ -39,6 +41,7 @@ def __init__( variables: dict = {}, resource: str = None, ) -> None: + # TODO: remove the useless factor_formulation_description self.factor_name = factor_name self.factor_description = factor_description self.factor_formulation = factor_formulation diff --git a/rdagent/model_implementation/benchmark/eval.py b/rdagent/model_implementation/benchmark/eval.py new file mode 100644 index 00000000..394b6f7e --- /dev/null +++ b/rdagent/model_implementation/benchmark/eval.py @@ -0,0 +1,61 @@ +# TODO: inherent from the benchmark base class +import torch +from rdagent.model_implementation.task import ModelTaskImpl + + +def get_data_conf(init_val): + # TODO: design this step in the workflow + in_dim = 1000 + in_channels = 128 + exec_config = {"model_eval_param_init": init_val} + node_feature = torch.randn(in_dim, in_channels) + edge_index = torch.randint(0, in_dim, (2, 2000)) + return (node_feature, edge_index), exec_config + + +class ModelImpValEval: + """ + Evaluate the similarity of the model structure by changing the input and observate the output. + + Assumption: + - If the model structure is similar, the output will change in similar way when we change the input. + - we try to initialize the model param in similar value. So only the model structure is different. + """ + + def evaluate(self, gt: ModelTaskImpl, gen: ModelTaskImpl): + round_n = 10 + + eval_pairs: list[tuple] = [] + + # run different input value + for _ in range(round_n): + # run different model initial parameters. + for init_val in [-0.2, -0.1, 0.1, 0.2]: + data, exec_config = get_data_conf(init_val) + gt_res = gt.execute(data=data, config=exec_config) + res = gen.execute(data=data, config=exec_config) + eval_pairs.append((res, gt_res)) + + # flat and concat the output + res_batch, gt_res_batch = [], [] + for res, gt_res in eval_pairs: + res_batch.append(res.reshape(-1)) + gt_res_batch.append(gt_res.reshape(-1)) + res_batch = torch.stack(res_batch) + gt_res_batch = torch.stack(gt_res_batch) + + res_batch = res_batch.detach().numpy() + gt_res_batch = gt_res_batch.detach().numpy() + + # pearson correlation of each hidden output + def norm(x): + return (x - x.mean(axis=0)) / x.std(axis=0) + dim_corr = (norm(res_batch) * norm(gt_res_batch)).mean(axis=0) # the correlation of each hidden output + + # aggregate all the correlation + avr_corr = dim_corr.mean() + # FIXME: + # It is too high(e.g. 0.944) . + # Check if it is not a good evaluation!! + # Maybe all the same initial params will results in extreamly high correlation without regard to the model structure. + return avr_corr diff --git a/rdagent/model_implementation/benchmark/gt_code/A-DGN.py b/rdagent/model_implementation/benchmark/gt_code/A-DGN.py new file mode 100644 index 00000000..4a84af36 --- /dev/null +++ b/rdagent/model_implementation/benchmark/gt_code/A-DGN.py @@ -0,0 +1,135 @@ +import math +from typing import Any, Callable, Dict, Optional, Union + +import torch +from torch import Tensor +from torch.nn import Parameter + +from torch_geometric.nn.conv import GCNConv, MessagePassing +from torch_geometric.nn.inits import zeros +from torch_geometric.nn.resolver import activation_resolver +from torch_geometric.typing import Adj + + +class AntiSymmetricConv(torch.nn.Module): + r"""The anti-symmetric graph convolutional operator from the + `"Anti-Symmetric DGN: a stable architecture for Deep Graph Networks" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{x}_i + \epsilon \cdot \sigma \left( + (\mathbf{W}-\mathbf{W}^T-\gamma \mathbf{I}) \mathbf{x}_i + + \Phi(\mathbf{X}, \mathcal{N}_i) + \mathbf{b}\right), + + where :math:`\Phi(\mathbf{X}, \mathcal{N}_i)` denotes a + :class:`~torch.nn.conv.MessagePassing` layer. + + Args: + in_channels (int): Size of each input sample. + phi (MessagePassing, optional): The message passing module + :math:`\Phi`. If set to :obj:`None`, will use a + :class:`~torch_geometric.nn.conv.GCNConv` layer as default. + (default: :obj:`None`) + num_iters (int, optional): The number of times the anti-symmetric deep + graph network operator is called. (default: :obj:`1`) + epsilon (float, optional): The discretization step size + :math:`\epsilon`. (default: :obj:`0.1`) + gamma (float, optional): The strength of the diffusion :math:`\gamma`. + It regulates the stability of the method. (default: :obj:`0.1`) + act (str, optional): The non-linear activation function :math:`\sigma`, + *e.g.*, :obj:`"tanh"` or :obj:`"relu"`. (default: :class:`"tanh"`) + act_kwargs (Dict[str, Any], optional): Arguments passed to the + respective activation function defined by :obj:`act`. + (default: :obj:`None`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{in})` + """ + + def __init__( + self, + in_channels: int, + phi: Optional[MessagePassing] = None, + num_iters: int = 1, + epsilon: float = 0.1, + gamma: float = 0.1, + act: Union[str, Callable, None] = "tanh", + act_kwargs: Optional[Dict[str, Any]] = None, + bias: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.num_iters = num_iters + self.gamma = gamma + self.epsilon = epsilon + self.act = activation_resolver(act, **(act_kwargs or {})) + + if phi is None: + phi = GCNConv(in_channels, in_channels, bias=False) + + self.W = Parameter(torch.empty(in_channels, in_channels)) + self.register_buffer("eye", torch.eye(in_channels)) + self.phi = phi + + if bias: + self.bias = Parameter(torch.empty(in_channels)) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + torch.nn.init.kaiming_uniform_(self.W, a=math.sqrt(5)) + self.phi.reset_parameters() + zeros(self.bias) + + def forward(self, x: Tensor, edge_index: Adj, *args, **kwargs) -> Tensor: + r"""Runs the forward pass of the module.""" + antisymmetric_W = self.W - self.W.t() - self.gamma * self.eye + + for _ in range(self.num_iters): + h = self.phi(x, edge_index, *args, **kwargs) + h = x @ antisymmetric_W.t() + h + + if self.bias is not None: + h += self.bias + + if self.act is not None: + h = self.act(h) + + x = x + self.epsilon * h + + return x + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"{self.in_channels}, " + f"phi={self.phi}, " + f"num_iters={self.num_iters}, " + f"epsilon={self.epsilon}, " + f"gamma={self.gamma})" + ) + + +model_cls = AntiSymmetricConv + + +if __name__ == "__main__": + node_features = torch.load("node_features.pt") + edge_index = torch.load("edge_index.pt") + + # Model instantiation and forward pass + model = AntiSymmetricConv(in_channels=node_features.size(-1)) + output = model(node_features, edge_index) + + # Save output to a file + torch.save(output, "gt_output.pt") diff --git a/rdagent/model_implementation/benchmark/gt_code/dirgnn.py b/rdagent/model_implementation/benchmark/gt_code/dirgnn.py new file mode 100644 index 00000000..b4430f9c --- /dev/null +++ b/rdagent/model_implementation/benchmark/gt_code/dirgnn.py @@ -0,0 +1,85 @@ +import copy + +import torch +from torch import Tensor + +from torch_geometric.nn.conv import MessagePassing + + +class DirGNNConv(torch.nn.Module): + r"""A generic wrapper for computing graph convolution on directed + graphs as described in the `"Edge Directionality Improves Learning on + Heterophilic Graphs" `_ paper. + :class:`DirGNNConv` will pass messages both from source nodes to target + nodes and from target nodes to source nodes. + + Args: + conv (MessagePassing): The underlying + :class:`~torch_geometric.nn.conv.MessagePassing` layer to use. + alpha (float, optional): The alpha coefficient used to weight the + aggregations of in- and out-edges as part of a convex combination. + (default: :obj:`0.5`) + root_weight (bool, optional): If set to :obj:`True`, the layer will add + transformed root node features to the output. + (default: :obj:`True`) + """ + def __init__( + self, + conv: MessagePassing, + alpha: float = 0.5, + root_weight: bool = True, + ): + super().__init__() + + self.alpha = alpha + self.root_weight = root_weight + + self.conv_in = copy.deepcopy(conv) + self.conv_out = copy.deepcopy(conv) + + if hasattr(conv, 'add_self_loops'): + self.conv_in.add_self_loops = False + self.conv_out.add_self_loops = False + if hasattr(conv, 'root_weight'): + self.conv_in.root_weight = False + self.conv_out.root_weight = False + + if root_weight: + self.lin = torch.nn.Linear(conv.in_channels, conv.out_channels) + else: + self.lin = None + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + self.conv_in.reset_parameters() + self.conv_out.reset_parameters() + if self.lin is not None: + self.lin.reset_parameters() + + def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: + """""" # noqa: D419 + x_in = self.conv_in(x, edge_index) + x_out = self.conv_out(x, edge_index.flip([0])) + + out = self.alpha * x_out + (1 - self.alpha) * x_in + + if self.root_weight: + out = out + self.lin(x) + + return out + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.conv_in}, alpha={self.alpha})' + +if __name__ == "__main__": + node_features = torch.load("node_features.pt") + edge_index = torch.load("edge_index.pt") + + # Model instantiation and forward pass + model = DirGNNConv(MessagePassing()) + output = model(node_features, edge_index) + + # Save output to a file + torch.save(output, "gt_output.pt") \ No newline at end of file diff --git a/rdagent/model_implementation/benchmark/gt_code/gpsconv.py b/rdagent/model_implementation/benchmark/gt_code/gpsconv.py new file mode 100644 index 00000000..c4821cbc --- /dev/null +++ b/rdagent/model_implementation/benchmark/gt_code/gpsconv.py @@ -0,0 +1,196 @@ +import inspect +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Dropout, Linear, Sequential + +from torch_geometric.nn.attention import PerformerAttention +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.inits import reset +from torch_geometric.nn.resolver import ( + activation_resolver, + normalization_resolver, +) +from torch_geometric.typing import Adj +from torch_geometric.utils import to_dense_batch + + +class GPSConv(torch.nn.Module): + r"""The general, powerful, scalable (GPS) graph transformer layer from the + `"Recipe for a General, Powerful, Scalable Graph Transformer" + `_ paper. + + The GPS layer is based on a 3-part recipe: + + 1. Inclusion of positional (PE) and structural encodings (SE) to the input + features (done in a pre-processing step via + :class:`torch_geometric.transforms`). + 2. A local message passing layer (MPNN) that operates on the input graph. + 3. A global attention layer that operates on the entire graph. + + .. note:: + + For an example of using :class:`GPSConv`, see + `examples/graph_gps.py + `_. + + Args: + channels (int): Size of each input sample. + conv (MessagePassing, optional): The local message passing layer. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + dropout (float, optional): Dropout probability of intermediate + embeddings. (default: :obj:`0.`) + act (str or Callable, optional): The non-linear activation function to + use. (default: :obj:`"relu"`) + act_kwargs (Dict[str, Any], optional): Arguments passed to the + respective activation function defined by :obj:`act`. + (default: :obj:`None`) + norm (str or Callable, optional): The normalization function to + use. (default: :obj:`"batch_norm"`) + norm_kwargs (Dict[str, Any], optional): Arguments passed to the + respective normalization function defined by :obj:`norm`. + (default: :obj:`None`) + attn_type (str): Global attention type, :obj:`multihead` or + :obj:`performer`. (default: :obj:`multihead`) + attn_kwargs (Dict[str, Any], optional): Arguments passed to the + attention layer. (default: :obj:`None`) + """ + def __init__( + self, + channels: int, + conv: Optional[MessagePassing], + heads: int = 1, + dropout: float = 0.0, + act: str = 'relu', + act_kwargs: Optional[Dict[str, Any]] = None, + norm: Optional[str] = 'batch_norm', + norm_kwargs: Optional[Dict[str, Any]] = None, + attn_type: str = 'multihead', + attn_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__() + + self.channels = channels + self.conv = conv + self.heads = heads + self.dropout = dropout + self.attn_type = attn_type + + attn_kwargs = attn_kwargs or {} + if attn_type == 'multihead': + self.attn = torch.nn.MultiheadAttention( + channels, + heads, + batch_first=True, + **attn_kwargs, + ) + elif attn_type == 'performer': + self.attn = PerformerAttention( + channels=channels, + heads=heads, + **attn_kwargs, + ) + else: + # TODO: Support BigBird + raise ValueError(f'{attn_type} is not supported') + + self.mlp = Sequential( + Linear(channels, channels * 2), + activation_resolver(act, **(act_kwargs or {})), + Dropout(dropout), + Linear(channels * 2, channels), + Dropout(dropout), + ) + + norm_kwargs = norm_kwargs or {} + self.norm1 = normalization_resolver(norm, channels, **norm_kwargs) + self.norm2 = normalization_resolver(norm, channels, **norm_kwargs) + self.norm3 = normalization_resolver(norm, channels, **norm_kwargs) + + self.norm_with_batch = False + if self.norm1 is not None: + signature = inspect.signature(self.norm1.forward) + self.norm_with_batch = 'batch' in signature.parameters + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + if self.conv is not None: + self.conv.reset_parameters() + self.attn._reset_parameters() + reset(self.mlp) + if self.norm1 is not None: + self.norm1.reset_parameters() + if self.norm2 is not None: + self.norm2.reset_parameters() + if self.norm3 is not None: + self.norm3.reset_parameters() + + def forward( + self, + x: Tensor, + edge_index: Adj, + batch: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tensor: + r"""Runs the forward pass of the module.""" + hs = [] + if self.conv is not None: # Local MPNN. + h = self.conv(x, edge_index, **kwargs) + h = F.dropout(h, p=self.dropout, training=self.training) + h = h + x + if self.norm1 is not None: + if self.norm_with_batch: + h = self.norm1(h, batch=batch) + else: + h = self.norm1(h) + hs.append(h) + + # Global attention transformer-style model. + h, mask = to_dense_batch(x, batch) + + if isinstance(self.attn, torch.nn.MultiheadAttention): + h, _ = self.attn(h, h, h, key_padding_mask=~mask, + need_weights=False) + elif isinstance(self.attn, PerformerAttention): + h = self.attn(h, mask=mask) + + h = h[mask] + h = F.dropout(h, p=self.dropout, training=self.training) + h = h + x # Residual connection. + if self.norm2 is not None: + if self.norm_with_batch: + h = self.norm2(h, batch=batch) + else: + h = self.norm2(h) + hs.append(h) + + out = sum(hs) # Combine local and global outputs. + + out = out + self.mlp(out) + if self.norm3 is not None: + if self.norm_with_batch: + out = self.norm3(out, batch=batch) + else: + out = self.norm3(out) + + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.channels}, ' + f'conv={self.conv}, heads={self.heads}, ' + f'attn_type={self.attn_type})') + +if __name__ == "__main__": + node_features = torch.load("node_features.pt") + edge_index = torch.load("edge_index.pt") + + # Model instantiation and forward pass + model = GPSConv(channels=node_features.size(-1),conv=MessagePassing()) + output = model(node_features, edge_index) + + # Save output to a file + torch.save(output, "gt_output.pt") \ No newline at end of file diff --git a/rdagent/model_implementation/benchmark/gt_code/linkx.py b/rdagent/model_implementation/benchmark/gt_code/linkx.py new file mode 100644 index 00000000..e543adb1 --- /dev/null +++ b/rdagent/model_implementation/benchmark/gt_code/linkx.py @@ -0,0 +1,178 @@ +import math + +import torch +from torch import Tensor +from torch.nn import BatchNorm1d, Parameter + +from torch_geometric.nn import inits +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.models import MLP +from torch_geometric.typing import Adj, OptTensor +from torch_geometric.utils import spmm + + +class SparseLinear(MessagePassing): + def __init__(self, in_channels: int, out_channels: int, bias: bool = True): + super().__init__(aggr='add') + self.in_channels = in_channels + self.out_channels = out_channels + + self.weight = Parameter(torch.empty(in_channels, out_channels)) + if bias: + self.bias = Parameter(torch.empty(out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + inits.kaiming_uniform(self.weight, fan=self.in_channels, + a=math.sqrt(5)) + inits.uniform(self.in_channels, self.bias) + + def forward( + self, + edge_index: Adj, + edge_weight: OptTensor = None, + ) -> Tensor: + # propagate_type: (weight: Tensor, edge_weight: OptTensor) + out = self.propagate(edge_index, weight=self.weight, + edge_weight=edge_weight) + + if self.bias is not None: + out = out + self.bias + + return out + + def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor: + if edge_weight is None: + return weight_j + else: + return edge_weight.view(-1, 1) * weight_j + + def message_and_aggregate(self, adj_t: Adj, weight: Tensor) -> Tensor: + return spmm(adj_t, weight, reduce=self.aggr) + + +class LINKX(torch.nn.Module): + r"""The LINKX model from the `"Large Scale Learning on Non-Homophilous + Graphs: New Benchmarks and Strong Simple Methods" + `_ paper. + + .. math:: + \mathbf{H}_{\mathbf{A}} &= \textrm{MLP}_{\mathbf{A}}(\mathbf{A}) + + \mathbf{H}_{\mathbf{X}} &= \textrm{MLP}_{\mathbf{X}}(\mathbf{X}) + + \mathbf{Y} &= \textrm{MLP}_{f} \left( \sigma \left( \mathbf{W} + [\mathbf{H}_{\mathbf{A}}, \mathbf{H}_{\mathbf{X}}] + + \mathbf{H}_{\mathbf{A}} + \mathbf{H}_{\mathbf{X}} \right) \right) + + .. note:: + + For an example of using LINKX, see `examples/linkx.py `_. + + Args: + num_nodes (int): The number of nodes in the graph. + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + hidden_channels (int): Size of each hidden sample. + out_channels (int): Size of each output sample. + num_layers (int): Number of layers of :math:`\textrm{MLP}_{f}`. + num_edge_layers (int, optional): Number of layers of + :math:`\textrm{MLP}_{\mathbf{A}}`. (default: :obj:`1`) + num_node_layers (int, optional): Number of layers of + :math:`\textrm{MLP}_{\mathbf{X}}`. (default: :obj:`1`) + dropout (float, optional): Dropout probability of each hidden + embedding. (default: :obj:`0.0`) + """ + def __init__( + self, + num_nodes: int, + in_channels: int, + hidden_channels: int, + out_channels: int, + num_layers: int, + num_edge_layers: int = 1, + num_node_layers: int = 1, + dropout: float = 0.0, + ): + super().__init__() + + self.num_nodes = num_nodes + self.in_channels = in_channels + self.out_channels = out_channels + self.num_edge_layers = num_edge_layers + + self.edge_lin = SparseLinear(num_nodes, hidden_channels) + + if self.num_edge_layers > 1: + self.edge_norm = BatchNorm1d(hidden_channels) + channels = [hidden_channels] * num_edge_layers + self.edge_mlp = MLP(channels, dropout=0., act_first=True) + else: + self.edge_norm = None + self.edge_mlp = None + + channels = [in_channels] + [hidden_channels] * num_node_layers + self.node_mlp = MLP(channels, dropout=0., act_first=True) + + self.cat_lin1 = torch.nn.Linear(hidden_channels, hidden_channels) + self.cat_lin2 = torch.nn.Linear(hidden_channels, hidden_channels) + + channels = [hidden_channels] * num_layers + [out_channels] + self.final_mlp = MLP(channels, dropout=dropout, act_first=True) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + self.edge_lin.reset_parameters() + if self.edge_norm is not None: + self.edge_norm.reset_parameters() + if self.edge_mlp is not None: + self.edge_mlp.reset_parameters() + self.node_mlp.reset_parameters() + self.cat_lin1.reset_parameters() + self.cat_lin2.reset_parameters() + self.final_mlp.reset_parameters() + + def forward( + self, + x: OptTensor, + edge_index: Adj, + edge_weight: OptTensor = None, + ) -> Tensor: + """""" # noqa: D419 + out = self.edge_lin(edge_index, edge_weight) + + if self.edge_norm is not None and self.edge_mlp is not None: + out = out.relu_() + out = self.edge_norm(out) + out = self.edge_mlp(out) + + out = out + self.cat_lin1(out) + + if x is not None: + x = self.node_mlp(x) + out = out + x + out = out + self.cat_lin2(x) + + return self.final_mlp(out.relu_()) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, ' + f'in_channels={self.in_channels}, ' + f'out_channels={self.out_channels})') + +if __name__ == "__main__": + node_features = torch.load("node_features.pt") + edge_index = torch.load("edge_index.pt") + + # Model instantiation and forward pass + model = LINKX(num_nodes=node_features.size(0), in_channels=node_features.size(1), hidden_channels=node_features.size(1), out_channels=node_features.size(1), num_layers=1) + output = model(node_features, edge_index) + + # Save output to a file + torch.save(output, "gt_output.pt") \ No newline at end of file diff --git a/rdagent/model_implementation/benchmark/gt_code/pmlp.py b/rdagent/model_implementation/benchmark/gt_code/pmlp.py new file mode 100644 index 00000000..0129dfb6 --- /dev/null +++ b/rdagent/model_implementation/benchmark/gt_code/pmlp.py @@ -0,0 +1,112 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + +from torch_geometric.nn import SimpleConv +from torch_geometric.nn.dense.linear import Linear + + +class PMLP(torch.nn.Module): + r"""The P(ropagational)MLP model from the `"Graph Neural Networks are + Inherently Good Generalizers: Insights by Bridging GNNs and MLPs" + `_ paper. + :class:`PMLP` is identical to a standard MLP during training, but then + adopts a GNN architecture during testing. + + Args: + in_channels (int): Size of each input sample. + hidden_channels (int): Size of each hidden sample. + out_channels (int): Size of each output sample. + num_layers (int): The number of layers. + dropout (float, optional): Dropout probability of each hidden + embedding. (default: :obj:`0.`) + norm (bool, optional): If set to :obj:`False`, will not apply batch + normalization. (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the module + will not learn additive biases. (default: :obj:`True`) + """ + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + num_layers: int, + dropout: float = 0., + norm: bool = True, + bias: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.num_layers = num_layers + self.dropout = dropout + self.bias = bias + + self.lins = torch.nn.ModuleList() + self.lins.append(Linear(in_channels, hidden_channels, self.bias)) + for _ in range(self.num_layers - 2): + lin = Linear(hidden_channels, hidden_channels, self.bias) + self.lins.append(lin) + self.lins.append(Linear(hidden_channels, out_channels, self.bias)) + + self.norm = None + if norm: + self.norm = torch.nn.BatchNorm1d( + hidden_channels, + affine=False, + track_running_stats=False, + ) + + self.conv = SimpleConv(aggr='mean', combine_root='self_loop') + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + for lin in self.lins: + torch.nn.init.xavier_uniform_(lin.weight, gain=1.414) + if self.bias: + torch.nn.init.zeros_(lin.bias) + + def forward( + self, + x: torch.Tensor, + edge_index: Optional[Tensor] = None, + ) -> torch.Tensor: + """""" # noqa: D419 + if not self.training and edge_index is None: + raise ValueError(f"'edge_index' needs to be present during " + f"inference in '{self.__class__.__name__}'") + + for i in range(self.num_layers): + x = x @ self.lins[i].weight.t() + if not self.training: + x = self.conv(x, edge_index) + if self.bias: + x = x + self.lins[i].bias + if i != self.num_layers - 1: + if self.norm is not None: + x = self.norm(x) + x = x.relu() + x = F.dropout(x, p=self.dropout, training=self.training) + + return x + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, num_layers={self.num_layers})') + +if __name__ == "__main__": + node_features = torch.load("node_features.pt") + edge_index = torch.load("edge_index.pt") + + # Model instantiation and forward pass + model = PMLP(in_channels=node_features.size(-1), hidden_channels=node_features.size(-1), node_features.size(-1), num_layers=1) + output = model(node_features, edge_index) + + # Save output to a file + torch.save(output, "gt_output.pt") \ No newline at end of file diff --git a/rdagent/model_implementation/benchmark/gt_code/visnet.py b/rdagent/model_implementation/benchmark/gt_code/visnet.py new file mode 100644 index 00000000..e960cb7b --- /dev/null +++ b/rdagent/model_implementation/benchmark/gt_code/visnet.py @@ -0,0 +1,1190 @@ +import math +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.autograd import grad +from torch.nn import Embedding, LayerNorm, Linear, Parameter + +from torch_geometric.nn import MessagePassing, radius_graph +from torch_geometric.utils import scatter + + +class CosineCutoff(torch.nn.Module): + r"""Appies a cosine cutoff to the input distances. + + .. math:: + \text{cutoffs} = + \begin{cases} + 0.5 * (\cos(\frac{\text{distances} * \pi}{\text{cutoff}}) + 1.0), + & \text{if } \text{distances} < \text{cutoff} \\ + 0, & \text{otherwise} + \end{cases} + + Args: + cutoff (float): A scalar that determines the point at which the cutoff + is applied. + """ + def __init__(self, cutoff: float) -> None: + super().__init__() + self.cutoff = cutoff + + def forward(self, distances: Tensor) -> Tensor: + r"""Applies a cosine cutoff to the input distances. + + Args: + distances (torch.Tensor): A tensor of distances. + + Returns: + cutoffs (torch.Tensor): A tensor where the cosine function + has been applied to the distances, + but any values that exceed the cutoff are set to 0. + """ + cutoffs = 0.5 * ((distances * math.pi / self.cutoff).cos() + 1.0) + cutoffs = cutoffs * (distances < self.cutoff).float() + return cutoffs + + +class ExpNormalSmearing(torch.nn.Module): + r"""Applies exponential normal smearing to the input distances. + + .. math:: + \text{smeared\_dist} = \text{CosineCutoff}(\text{dist}) + * e^{-\beta * (e^{\alpha * (-\text{dist})} - \text{means})^2} + + Args: + cutoff (float, optional): A scalar that determines the point at which + the cutoff is applied. (default: :obj:`5.0`) + num_rbf (int, optional): The number of radial basis functions. + (default: :obj:`128`) + trainable (bool, optional): If set to :obj:`False`, the means and betas + of the RBFs will not be trained. (default: :obj:`True`) + """ + def __init__( + self, + cutoff: float = 5.0, + num_rbf: int = 128, + trainable: bool = True, + ) -> None: + super().__init__() + self.cutoff = cutoff + self.num_rbf = num_rbf + self.trainable = trainable + + self.cutoff_fn = CosineCutoff(cutoff) + self.alpha = 5.0 / cutoff + + means, betas = self._initial_params() + if trainable: + self.register_parameter('means', Parameter(means)) + self.register_parameter('betas', Parameter(betas)) + else: + self.register_buffer('means', means) + self.register_buffer('betas', betas) + + def _initial_params(self) -> Tuple[Tensor, Tensor]: + r"""Initializes the means and betas for the radial basis functions.""" + start_value = torch.exp(torch.tensor(-self.cutoff)) + means = torch.linspace(start_value, 1, self.num_rbf) + betas = torch.tensor([(2 / self.num_rbf * (1 - start_value))**-2] * + self.num_rbf) + return means, betas + + def reset_parameters(self): + r"""Resets the means and betas to their initial values.""" + means, betas = self._initial_params() + self.means.data.copy_(means) + self.betas.data.copy_(betas) + + def forward(self, dist: Tensor) -> Tensor: + r"""Applies the exponential normal smearing to the input distance. + + Args: + dist (torch.Tensor): A tensor of distances. + """ + dist = dist.unsqueeze(-1) + smeared_dist = self.cutoff_fn(dist) * (-self.betas * ( + (self.alpha * (-dist)).exp() - self.means)**2).exp() + return smeared_dist + + +class Sphere(torch.nn.Module): + r"""Computes spherical harmonics of the input data. + + This module computes the spherical harmonics up to a given degree + :obj:`lmax` for the input tensor of 3D vectors. + The vectors are assumed to be given in Cartesian coordinates. + See `here `_ + for mathematical details. + + Args: + lmax (int, optional): The maximum degree of the spherical harmonics. + (default: :obj:`2`) + """ + def __init__(self, lmax: int = 2) -> None: + super().__init__() + self.lmax = lmax + + def forward(self, edge_vec: Tensor) -> Tensor: + r"""Computes the spherical harmonics of the input tensor. + + Args: + edge_vec (torch.Tensor): A tensor of 3D vectors. + """ + return self._spherical_harmonics( + self.lmax, + edge_vec[..., 0], + edge_vec[..., 1], + edge_vec[..., 2], + ) + + @staticmethod + def _spherical_harmonics( + lmax: int, + x: Tensor, + y: Tensor, + z: Tensor, + ) -> Tensor: + r"""Computes the spherical harmonics up to degree :obj:`lmax` of the + input vectors. + + Args: + lmax (int): The maximum degree of the spherical harmonics. + x (torch.Tensor): The x coordinates of the vectors. + y (torch.Tensor): The y coordinates of the vectors. + z (torch.Tensor): The z coordinates of the vectors. + """ + sh_1_0, sh_1_1, sh_1_2 = x, y, z + + if lmax == 1: + return torch.stack([sh_1_0, sh_1_1, sh_1_2], dim=-1) + + sh_2_0 = math.sqrt(3.0) * x * z + sh_2_1 = math.sqrt(3.0) * x * y + y2 = y.pow(2) + x2z2 = x.pow(2) + z.pow(2) + sh_2_2 = y2 - 0.5 * x2z2 + sh_2_3 = math.sqrt(3.0) * y * z + sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2)) + + if lmax == 2: + return torch.stack([ + sh_1_0, + sh_1_1, + sh_1_2, + sh_2_0, + sh_2_1, + sh_2_2, + sh_2_3, + sh_2_4, + ], dim=-1) + + raise ValueError(f"'lmax' needs to be 1 or 2 (got {lmax})") + + +class VecLayerNorm(torch.nn.Module): + r"""Applies layer normalization to the input data. + + This module applies a custom layer normalization to a tensor of vectors. + The normalization can either be :obj:`"max_min"` normalization, or no + normalization. + + Args: + hidden_channels (int): The number of hidden channels in the input. + trainable (bool): If set to :obj:`True`, the normalization weights are + trainable parameters. + norm_type (str, optional): The type of normalization to apply, one of + :obj:`"max_min"` or :obj:`None`. (default: :obj:`"max_min"`) + """ + def __init__( + self, + hidden_channels: int, + trainable: bool, + norm_type: Optional[str] = 'max_min', + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + self.norm_type = norm_type + self.eps = 1e-12 + + weight = torch.ones(self.hidden_channels) + if trainable: + self.register_parameter('weight', Parameter(weight)) + else: + self.register_buffer('weight', weight) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets the normalization weights to their initial values.""" + torch.nn.init.ones_(self.weight) + + def max_min_norm(self, vec: Tensor) -> Tensor: + r"""Applies max-min normalization to the input tensor. + + .. math:: + \text{dist} = ||\text{vec}||_2 + \text{direct} = \frac{\text{vec}}{\text{dist}} + \text{max\_val} = \max(\text{dist}) + \text{min\_val} = \min(\text{dist}) + \text{delta} = \text{max\_val} - \text{min\_val} + \text{dist} = \frac{\text{dist} - \text{min\_val}}{\text{delta}} + \text{normed\_vec} = \max(0, \text{dist}) \cdot \text{direct} + + Args: + vec (torch.Tensor): The input tensor. + """ + dist = torch.norm(vec, dim=1, keepdim=True) + + if (dist == 0).all(): + return torch.zeros_like(vec) + + dist = dist.clamp(min=self.eps) + direct = vec / dist + + max_val, _ = dist.max(dim=-1) + min_val, _ = dist.min(dim=-1) + delta = (max_val - min_val).view(-1) + delta = torch.where(delta == 0, torch.ones_like(delta), delta) + dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1) + + return dist.relu() * direct + + def forward(self, vec: Tensor) -> Tensor: + r"""Applies the layer normalization to the input tensor. + + Args: + vec (torch.Tensor): The input tensor. + """ + if vec.size(1) == 3: + if self.norm_type == 'max_min': + vec = self.max_min_norm(vec) + return vec * self.weight.unsqueeze(0).unsqueeze(0) + elif vec.size(1) == 8: + vec1, vec2 = torch.split(vec, [3, 5], dim=1) + if self.norm_type == 'max_min': + vec1 = self.max_min_norm(vec1) + vec2 = self.max_min_norm(vec2) + vec = torch.cat([vec1, vec2], dim=1) + return vec * self.weight.unsqueeze(0).unsqueeze(0) + + raise ValueError(f"'{self.__class__.__name__}' only support 3 or 8 " + f"channels (got {vec.size(1)})") + + +class Distance(torch.nn.Module): + r"""Computes the pairwise distances between atoms in a molecule. + + This module computes the pairwise distances between atoms in a molecule, + represented by their positions :obj:`pos`. + The distances are computed only between points that are within a certain + cutoff radius. + + Args: + cutoff (float): The cutoff radius beyond + which distances are not computed. + max_num_neighbors (int, optional): The maximum number of neighbors + considered for each point. (default: :obj:`32`) + add_self_loops (bool, optional): If set to :obj:`False`, will not + include self-loops. (default: :obj:`True`) + """ + def __init__( + self, + cutoff: float, + max_num_neighbors: int = 32, + add_self_loops: bool = True, + ) -> None: + super().__init__() + self.cutoff = cutoff + self.max_num_neighbors = max_num_neighbors + self.add_self_loops = add_self_loops + + def forward( + self, + pos: Tensor, + batch: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + r"""Computes the pairwise distances between atoms in the molecule. + + Args: + pos (torch.Tensor): The positions of the atoms in the molecule. + batch (torch.Tensor): A batch vector, which assigns each node to a + specific example. + + Returns: + edge_index (torch.Tensor): The indices of the edges in the graph. + edge_weight (torch.Tensor): The distances between connected nodes. + edge_vec (torch.Tensor): The vector differences between connected + nodes. + """ + edge_index = radius_graph( + pos, + r=self.cutoff, + batch=batch, + loop=self.add_self_loops, + max_num_neighbors=self.max_num_neighbors, + ) + edge_vec = pos[edge_index[0]] - pos[edge_index[1]] + + if self.add_self_loops: + mask = edge_index[0] != edge_index[1] + edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device) + edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) + else: + edge_weight = torch.norm(edge_vec, dim=-1) + + return edge_index, edge_weight, edge_vec + + +class NeighborEmbedding(MessagePassing): + r"""The :class:`NeighborEmbedding` module from the `"Enhancing Geometric + Representations for Molecules with Equivariant Vector-Scalar Interactive + Message Passing" `_ paper. + + Args: + hidden_channels (int): The number of hidden channels in the node + embeddings. + num_rbf (int): The number of radial basis functions. + cutoff (float): The cutoff distance. + max_z (int, optional): The maximum atomic numbers. + (default: :obj:`100`) + """ + def __init__( + self, + hidden_channels: int, + num_rbf: int, + cutoff: float, + max_z: int = 100, + ) -> None: + super().__init__(aggr='add') + self.embedding = Embedding(max_z, hidden_channels) + self.distance_proj = Linear(num_rbf, hidden_channels) + self.combine = Linear(hidden_channels * 2, hidden_channels) + self.cutoff = CosineCutoff(cutoff) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets the parameters of the module.""" + self.embedding.reset_parameters() + torch.nn.init.xavier_uniform_(self.distance_proj.weight) + torch.nn.init.xavier_uniform_(self.combine.weight) + self.distance_proj.bias.data.zero_() + self.combine.bias.data.zero_() + + def forward( + self, + z: Tensor, + x: Tensor, + edge_index: Tensor, + edge_weight: Tensor, + edge_attr: Tensor, + ) -> Tensor: + r"""Computes the neighborhood embedding of the nodes in the graph. + + Args: + z (torch.Tensor): The atomic numbers. + x (torch.Tensor): The node features. + edge_index (torch.Tensor): The indices of the edges. + edge_weight (torch.Tensor): The weights of the edges. + edge_attr (torch.Tensor): The edge features. + + Returns: + x_neighbors (torch.Tensor): The neighborhood embeddings of the + nodes. + """ + mask = edge_index[0] != edge_index[1] + if not mask.all(): + edge_index = edge_index[:, mask] + edge_weight = edge_weight[mask] + edge_attr = edge_attr[mask] + + C = self.cutoff(edge_weight) + W = self.distance_proj(edge_attr) * C.view(-1, 1) + + x_neighbors = self.embedding(z) + x_neighbors = self.propagate(edge_index, x=x_neighbors, W=W) + x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1)) + return x_neighbors + + def message(self, x_j: Tensor, W: Tensor) -> Tensor: + return x_j * W + + +class EdgeEmbedding(torch.nn.Module): + r"""The :class:`EdgeEmbedding` module from the `"Enhancing Geometric + Representations for Molecules with Equivariant Vector-Scalar Interactive + Message Passing" `_ paper. + + Args: + num_rbf (int): The number of radial basis functions. + hidden_channels (int): The number of hidden channels in the node + embeddings. + """ + def __init__(self, num_rbf: int, hidden_channels: int) -> None: + super().__init__() + self.edge_proj = Linear(num_rbf, hidden_channels) + self.reset_parameters() + + def reset_parameters(self): + r"""Resets the parameters of the module.""" + torch.nn.init.xavier_uniform_(self.edge_proj.weight) + self.edge_proj.bias.data.zero_() + + def forward( + self, + edge_index: Tensor, + edge_attr: Tensor, + x: Tensor, + ) -> Tensor: + r"""Computes the edge embeddings of the graph. + + Args: + edge_index (torch.Tensor): The indices of the edges. + edge_attr (torch.Tensor): The edge features. + x (torch.Tensor): The node features. + + Returns: + out_edge_attr (torch.Tensor): The edge embeddings. + """ + x_j = x[edge_index[0]] + x_i = x[edge_index[1]] + return (x_i + x_j) * self.edge_proj(edge_attr) + + +class ViS_MP(MessagePassing): + r"""The message passing module without vertex geometric features of the + equivariant vector-scalar interactive graph neural network (ViSNet) + from the `"Enhancing Geometric Representations for Molecules with + Equivariant Vector-Scalar Interactive Message Passing" + `_ paper. + + Args: + num_heads (int): The number of attention heads. + hidden_channels (int): The number of hidden channels in the node + embeddings. + cutoff (float): The cutoff distance. + vecnorm_type (str, optional): The type of normalization to apply to the + vectors. + trainable_vecnorm (bool): Whether the normalization weights are + trainable. + last_layer (bool, optional): Whether this is the last layer in the + model. (default: :obj:`False`) + """ + def __init__( + self, + num_heads: int, + hidden_channels: int, + cutoff: float, + vecnorm_type: Optional[str], + trainable_vecnorm: bool, + last_layer: bool = False, + ) -> None: + super().__init__(aggr='add', node_dim=0) + + if hidden_channels % num_heads != 0: + raise ValueError( + f"The number of hidden channels (got {hidden_channels}) must " + f"be evenly divisible by the number of attention heads " + f"(got {num_heads})") + + self.num_heads = num_heads + self.hidden_channels = hidden_channels + self.head_dim = hidden_channels // num_heads + self.last_layer = last_layer + + self.layernorm = LayerNorm(hidden_channels) + self.vec_layernorm = VecLayerNorm( + hidden_channels, + trainable=trainable_vecnorm, + norm_type=vecnorm_type, + ) + + self.act = torch.nn.SiLU() + self.attn_activation = torch.nn.SiLU() + + self.cutoff = CosineCutoff(cutoff) + + self.vec_proj = Linear(hidden_channels, hidden_channels * 3, False) + + self.q_proj = Linear(hidden_channels, hidden_channels) + self.k_proj = Linear(hidden_channels, hidden_channels) + self.v_proj = Linear(hidden_channels, hidden_channels) + self.dk_proj = Linear(hidden_channels, hidden_channels) + self.dv_proj = Linear(hidden_channels, hidden_channels) + + self.s_proj = Linear(hidden_channels, hidden_channels * 2) + if not self.last_layer: + self.f_proj = Linear(hidden_channels, hidden_channels) + self.w_src_proj = Linear(hidden_channels, hidden_channels, False) + self.w_trg_proj = Linear(hidden_channels, hidden_channels, False) + + self.o_proj = Linear(hidden_channels, hidden_channels * 3) + + self.reset_parameters() + + @staticmethod + def vector_rejection(vec: Tensor, d_ij: Tensor) -> Tensor: + r"""Computes the component of :obj:`vec` orthogonal to :obj:`d_ij`. + + Args: + vec (torch.Tensor): The input vector. + d_ij (torch.Tensor): The reference vector. + """ + vec_proj = (vec * d_ij.unsqueeze(2)).sum(dim=1, keepdim=True) + return vec - vec_proj * d_ij.unsqueeze(2) + + def reset_parameters(self): + r"""Resets the parameters of the module.""" + self.layernorm.reset_parameters() + self.vec_layernorm.reset_parameters() + torch.nn.init.xavier_uniform_(self.q_proj.weight) + self.q_proj.bias.data.zero_() + torch.nn.init.xavier_uniform_(self.k_proj.weight) + self.k_proj.bias.data.zero_() + torch.nn.init.xavier_uniform_(self.v_proj.weight) + self.v_proj.bias.data.zero_() + torch.nn.init.xavier_uniform_(self.o_proj.weight) + self.o_proj.bias.data.zero_() + torch.nn.init.xavier_uniform_(self.s_proj.weight) + self.s_proj.bias.data.zero_() + + if not self.last_layer: + torch.nn.init.xavier_uniform_(self.f_proj.weight) + self.f_proj.bias.data.zero_() + torch.nn.init.xavier_uniform_(self.w_src_proj.weight) + torch.nn.init.xavier_uniform_(self.w_trg_proj.weight) + + torch.nn.init.xavier_uniform_(self.vec_proj.weight) + torch.nn.init.xavier_uniform_(self.dk_proj.weight) + self.dk_proj.bias.data.zero_() + torch.nn.init.xavier_uniform_(self.dv_proj.weight) + self.dv_proj.bias.data.zero_() + + def forward( + self, + x: Tensor, + vec: Tensor, + edge_index: Tensor, + r_ij: Tensor, + f_ij: Tensor, + d_ij: Tensor, + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + r"""Computes the residual scalar and vector features of the nodes and + scalar featues of the edges. + + Args: + x (torch.Tensor): The scalar features of the nodes. + vec (torch.Tensor):The vector features of the nodes. + edge_index (torch.Tensor): The indices of the edges. + r_ij (torch.Tensor): The distances between connected nodes. + f_ij (torch.Tensor): The scalar features of the edges. + d_ij (torch.Tensor): The unit vectors of the edges + + Returns: + dx (torch.Tensor): The residual scalar features of the nodes. + dvec (torch.Tensor): The residual vector features of the nodes. + df_ij (torch.Tensor, optional): The residual scalar features of the + edges, or :obj:`None` if this is the last layer. + """ + x = self.layernorm(x) + vec = self.vec_layernorm(vec) + + q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim) + k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim) + v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim) + dk = self.act(self.dk_proj(f_ij)) + dk = dk.reshape(-1, self.num_heads, self.head_dim) + dv = self.act(self.dv_proj(f_ij)) + dv = dv.reshape(-1, self.num_heads, self.head_dim) + + vec1, vec2, vec3 = torch.split(self.vec_proj(vec), + self.hidden_channels, dim=-1) + vec_dot = (vec1 * vec2).sum(dim=1) + + x, vec_out = self.propagate(edge_index, q=q, k=k, v=v, dk=dk, dv=dv, + vec=vec, r_ij=r_ij, d_ij=d_ij) + + o1, o2, o3 = torch.split(self.o_proj(x), self.hidden_channels, dim=1) + dx = vec_dot * o2 + o3 + dvec = vec3 * o1.unsqueeze(1) + vec_out + if not self.last_layer: + df_ij = self.edge_updater(edge_index, vec=vec, d_ij=d_ij, + f_ij=f_ij) + return dx, dvec, df_ij + else: + return dx, dvec, None + + def message(self, q_i: Tensor, k_j: Tensor, v_j: Tensor, vec_j: Tensor, + dk: Tensor, dv: Tensor, r_ij: Tensor, + d_ij: Tensor) -> Tuple[Tensor, Tensor]: + + attn = (q_i * k_j * dk).sum(dim=-1) + attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) + + v_j = v_j * dv + v_j = (v_j * attn.unsqueeze(2)).view(-1, self.hidden_channels) + + s1, s2 = torch.split(self.act(self.s_proj(v_j)), self.hidden_channels, + dim=1) + vec_j = vec_j * s1.unsqueeze(1) + s2.unsqueeze(1) * d_ij.unsqueeze(2) + + return v_j, vec_j + + def edge_update(self, vec_i: Tensor, vec_j: Tensor, d_ij: Tensor, + f_ij: Tensor) -> Tensor: + + w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij) + w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij) + w_dot = (w1 * w2).sum(dim=1) + df_ij = self.act(self.f_proj(f_ij)) * w_dot + return df_ij + + def aggregate( + self, + features: Tuple[Tensor, Tensor], + index: Tensor, + ptr: Optional[torch.Tensor], + dim_size: Optional[int], + ) -> Tuple[Tensor, Tensor]: + x, vec = features + x = scatter(x, index, dim=self.node_dim, dim_size=dim_size) + vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) + return x, vec + + +class ViS_MP_Vertex(ViS_MP): + r"""The message passing module with vertex geometric features of the + equivariant vector-scalar interactive graph neural network (ViSNet) + from the `"Enhancing Geometric Representations for Molecules with + Equivariant Vector-Scalar Interactive Message Passing" + `_ paper. + + Args: + num_heads (int): The number of attention heads. + hidden_channels (int): The number of hidden channels in the node + embeddings. + cutoff (float): The cutoff distance. + vecnorm_type (str, optional): The type of normalization to apply to the + vectors. + trainable_vecnorm (bool): Whether the normalization weights are + trainable. + last_layer (bool, optional): Whether this is the last layer in the + model. (default: :obj:`False`) + """ + def __init__( + self, + num_heads: int, + hidden_channels: int, + cutoff: float, + vecnorm_type: Optional[str], + trainable_vecnorm: bool, + last_layer: bool = False, + ) -> None: + super().__init__(num_heads, hidden_channels, cutoff, vecnorm_type, + trainable_vecnorm, last_layer) + + if not self.last_layer: + self.f_proj = Linear(hidden_channels, hidden_channels * 2) + self.t_src_proj = Linear(hidden_channels, hidden_channels, False) + self.t_trg_proj = Linear(hidden_channels, hidden_channels, False) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets the parameters of the module.""" + super().reset_parameters() + + if not self.last_layer: + if hasattr(self, 't_src_proj'): + torch.nn.init.xavier_uniform_(self.t_src_proj.weight) + if hasattr(self, 't_trg_proj'): + torch.nn.init.xavier_uniform_(self.t_trg_proj.weight) + + def edge_update(self, vec_i: Tensor, vec_j: Tensor, d_ij: Tensor, + f_ij: Tensor) -> Tensor: + + w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij) + w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij) + w_dot = (w1 * w2).sum(dim=1) + + t1 = self.vector_rejection(self.t_trg_proj(vec_i), d_ij) + t2 = self.vector_rejection(self.t_src_proj(vec_i), -d_ij) + t_dot = (t1 * t2).sum(dim=1) + + f1, f2 = torch.split(self.act(self.f_proj(f_ij)), self.hidden_channels, + dim=-1) + + return f1 * w_dot + f2 * t_dot + + +class ViSNetBlock(torch.nn.Module): + r"""The representation module of the equivariant vector-scalar + interactive graph neural network (ViSNet) from the `"Enhancing Geometric + Representations for Molecules with Equivariant Vector-Scalar Interactive + Message Passing" `_ paper. + + Args: + lmax (int, optional): The maximum degree of the spherical harmonics. + (default: :obj:`1`) + vecnorm_type (str, optional): The type of normalization to apply to the + vectors. (default: :obj:`None`) + trainable_vecnorm (bool, optional): Whether the normalization weights + are trainable. (default: :obj:`False`) + num_heads (int, optional): The number of attention heads. + (default: :obj:`8`) + num_layers (int, optional): The number of layers in the network. + (default: :obj:`6`) + hidden_channels (int, optional): The number of hidden channels in the + node embeddings. (default: :obj:`128`) + num_rbf (int, optional): The number of radial basis functions. + (default: :obj:`32`) + trainable_rbf (bool, optional): Whether the radial basis function + parameters are trainable. (default: :obj:`False`) + max_z (int, optional): The maximum atomic numbers. + (default: :obj:`100`) + cutoff (float, optional): The cutoff distance. (default: :obj:`5.0`) + max_num_neighbors (int, optional): The maximum number of neighbors + considered for each atom. (default: :obj:`32`) + vertex (bool, optional): Whether to use vertex geometric features. + (default: :obj:`False`) + """ + def __init__( + self, + lmax: int = 1, + vecnorm_type: Optional[str] = None, + trainable_vecnorm: bool = False, + num_heads: int = 8, + num_layers: int = 6, + hidden_channels: int = 128, + num_rbf: int = 32, + trainable_rbf: bool = False, + max_z: int = 100, + cutoff: float = 5.0, + max_num_neighbors: int = 32, + vertex: bool = False, + ) -> None: + super().__init__() + + self.lmax = lmax + self.vecnorm_type = vecnorm_type + self.trainable_vecnorm = trainable_vecnorm + self.num_heads = num_heads + self.num_layers = num_layers + self.hidden_channels = hidden_channels + self.num_rbf = num_rbf + self.trainable_rbf = trainable_rbf + self.max_z = max_z + self.cutoff = cutoff + self.max_num_neighbors = max_num_neighbors + + self.embedding = Embedding(max_z, hidden_channels) + self.distance = Distance(cutoff, max_num_neighbors=max_num_neighbors) + self.sphere = Sphere(lmax=lmax) + self.distance_expansion = ExpNormalSmearing(cutoff, num_rbf, + trainable_rbf) + self.neighbor_embedding = NeighborEmbedding(hidden_channels, num_rbf, + cutoff, max_z) + self.edge_embedding = EdgeEmbedding(num_rbf, hidden_channels) + + self.vis_mp_layers = torch.nn.ModuleList() + vis_mp_kwargs = dict( + num_heads=num_heads, + hidden_channels=hidden_channels, + cutoff=cutoff, + vecnorm_type=vecnorm_type, + trainable_vecnorm=trainable_vecnorm, + ) + vis_mp_class = ViS_MP if not vertex else ViS_MP_Vertex + for _ in range(num_layers - 1): + layer = vis_mp_class(last_layer=False, **vis_mp_kwargs) + self.vis_mp_layers.append(layer) + self.vis_mp_layers.append( + vis_mp_class(last_layer=True, **vis_mp_kwargs)) + + self.out_norm = LayerNorm(hidden_channels) + self.vec_out_norm = VecLayerNorm( + hidden_channels, + trainable=trainable_vecnorm, + norm_type=vecnorm_type, + ) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets the parameters of the module.""" + self.embedding.reset_parameters() + self.distance_expansion.reset_parameters() + self.neighbor_embedding.reset_parameters() + self.edge_embedding.reset_parameters() + for layer in self.vis_mp_layers: + layer.reset_parameters() + self.out_norm.reset_parameters() + self.vec_out_norm.reset_parameters() + + def forward( + self, + z: Tensor, + pos: Tensor, + batch: Tensor, + ) -> Tuple[Tensor, Tensor]: + r"""Computes the scalar and vector features of the nodes. + + Args: + z (torch.Tensor): The atomic numbers. + pos (torch.Tensor): The coordinates of the atoms. + batch (torch.Tensor): A batch vector, which assigns each node to a + specific example. + + Returns: + x (torch.Tensor): The scalar features of the nodes. + vec (torch.Tensor): The vector features of the nodes. + """ + x = self.embedding(z) + edge_index, edge_weight, edge_vec = self.distance(pos, batch) + edge_attr = self.distance_expansion(edge_weight) + mask = edge_index[0] != edge_index[1] + edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], + dim=1).unsqueeze(1) + edge_vec = self.sphere(edge_vec) + x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr) + vec = torch.zeros(x.size(0), ((self.lmax + 1)**2) - 1, x.size(1), + dtype=x.dtype, device=x.device) + edge_attr = self.edge_embedding(edge_index, edge_attr, x) + + for attn in self.vis_mp_layers[:-1]: + dx, dvec, dedge_attr = attn(x, vec, edge_index, edge_weight, + edge_attr, edge_vec) + x = x + dx + vec = vec + dvec + edge_attr = edge_attr + dedge_attr + + dx, dvec, _ = self.vis_mp_layers[-1](x, vec, edge_index, edge_weight, + edge_attr, edge_vec) + x = x + dx + vec = vec + dvec + + x = self.out_norm(x) + vec = self.vec_out_norm(vec) + + return x, vec + + +class GatedEquivariantBlock(torch.nn.Module): + r"""Applies a gated equivariant operation to scalar features and vector + features from the `"Enhancing Geometric Representations for Molecules with + Equivariant Vector-Scalar Interactive Message Passing" + `_ paper. + + Args: + hidden_channels (int): The number of hidden channels in the node + embeddings. + out_channels (int): The number of output channels. + intermediate_channels (int, optional): The number of channels in the + intermediate layer, or :obj:`None` to use the same number as + :obj:`hidden_channels`. (default: :obj:`None`) + scalar_activation (bool, optional): Whether to apply a scalar + activation function to the output node features. + (default: obj:`False`) + """ + def __init__( + self, + hidden_channels: int, + out_channels: int, + intermediate_channels: Optional[int] = None, + scalar_activation: bool = False, + ) -> None: + super().__init__() + self.out_channels = out_channels + + if intermediate_channels is None: + intermediate_channels = hidden_channels + + self.vec1_proj = Linear(hidden_channels, hidden_channels, bias=False) + self.vec2_proj = Linear(hidden_channels, out_channels, bias=False) + + self.update_net = torch.nn.Sequential( + Linear(hidden_channels * 2, intermediate_channels), + torch.nn.SiLU(), + Linear(intermediate_channels, out_channels * 2), + ) + + self.act = torch.nn.SiLU() if scalar_activation else None + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets the parameters of the module.""" + torch.nn.init.xavier_uniform_(self.vec1_proj.weight) + torch.nn.init.xavier_uniform_(self.vec2_proj.weight) + torch.nn.init.xavier_uniform_(self.update_net[0].weight) + self.update_net[0].bias.data.zero_() + torch.nn.init.xavier_uniform_(self.update_net[2].weight) + self.update_net[2].bias.data.zero_() + + def forward(self, x: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]: + r"""Applies a gated equivariant operation to node features and vector + features. + + Args: + x (torch.Tensor): The scalar features of the nodes. + v (torch.Tensor): The vector features of the nodes. + """ + vec1 = torch.norm(self.vec1_proj(v), dim=-2) + vec2 = self.vec2_proj(v) + + x = torch.cat([x, vec1], dim=-1) + x, v = torch.split(self.update_net(x), self.out_channels, dim=-1) + v = v.unsqueeze(1) * vec2 + + if self.act is not None: + x = self.act(x) + + return x, v + + +class EquivariantScalar(torch.nn.Module): + r"""Computes final scalar outputs based on node features and vector + features. + + Args: + hidden_channels (int): The number of hidden channels in the node + embeddings. + """ + def __init__(self, hidden_channels: int) -> None: + super().__init__() + + self.output_network = torch.nn.ModuleList([ + GatedEquivariantBlock( + hidden_channels, + hidden_channels // 2, + scalar_activation=True, + ), + GatedEquivariantBlock( + hidden_channels // 2, + 1, + scalar_activation=False, + ), + ]) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets the parameters of the module.""" + for layer in self.output_network: + layer.reset_parameters() + + def pre_reduce(self, x: Tensor, v: Tensor) -> Tensor: + r"""Computes the final scalar outputs. + + Args: + x (torch.Tensor): The scalar features of the nodes. + v (torch.Tensor): The vector features of the nodes. + + Returns: + out (torch.Tensor): The final scalar outputs of the nodes. + """ + for layer in self.output_network: + x, v = layer(x, v) + + return x + v.sum() * 0 + + +class Atomref(torch.nn.Module): + r"""Adds atom reference values to atomic energies. + + Args: + atomref (torch.Tensor, optional): A tensor of atom reference values, + or :obj:`None` if not provided. (default: :obj:`None`) + max_z (int, optional): The maximum atomic numbers. + (default: :obj:`100`) + """ + def __init__( + self, + atomref: Optional[Tensor] = None, + max_z: int = 100, + ) -> None: + super().__init__() + + if atomref is None: + atomref = torch.zeros(max_z, 1) + else: + atomref = torch.as_tensor(atomref) + + if atomref.ndim == 1: + atomref = atomref.view(-1, 1) + + self.register_buffer('initial_atomref', atomref) + self.atomref = Embedding(len(atomref), 1) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets the parameters of the module.""" + self.atomref.weight.data.copy_(self.initial_atomref) + + def forward(self, x: Tensor, z: Tensor) -> Tensor: + r"""Adds atom reference values to atomic energies. + + Args: + x (torch.Tensor): The atomic energies. + z (torch.Tensor): The atomic numbers. + """ + return x + self.atomref(z) + + +class ViSNet(torch.nn.Module): + r"""A :pytorch:`PyTorch` module that implements the equivariant + vector-scalar interactive graph neural network (ViSNet) from the + `"Enhancing Geometric Representations for Molecules with Equivariant + Vector-Scalar Interactive Message Passing" + `_ paper. + + Args: + lmax (int, optional): The maximum degree of the spherical harmonics. + (default: :obj:`1`) + vecnorm_type (str, optional): The type of normalization to apply to the + vectors. (default: :obj:`None`) + trainable_vecnorm (bool, optional): Whether the normalization weights + are trainable. (default: :obj:`False`) + num_heads (int, optional): The number of attention heads. + (default: :obj:`8`) + num_layers (int, optional): The number of layers in the network. + (default: :obj:`6`) + hidden_channels (int, optional): The number of hidden channels in the + node embeddings. (default: :obj:`128`) + num_rbf (int, optional): The number of radial basis functions. + (default: :obj:`32`) + trainable_rbf (bool, optional): Whether the radial basis function + parameters are trainable. (default: :obj:`False`) + max_z (int, optional): The maximum atomic numbers. + (default: :obj:`100`) + cutoff (float, optional): The cutoff distance. (default: :obj:`5.0`) + max_num_neighbors (int, optional): The maximum number of neighbors + considered for each atom. (default: :obj:`32`) + vertex (bool, optional): Whether to use vertex geometric features. + (default: :obj:`False`) + atomref (torch.Tensor, optional): A tensor of atom reference values, + or :obj:`None` if not provided. (default: :obj:`None`) + reduce_op (str, optional): The type of reduction operation to apply + (:obj:`"sum"`, :obj:`"mean"`). (default: :obj:`"sum"`) + mean (float, optional): The mean of the output distribution. + (default: :obj:`0.0`) + std (float, optional): The standard deviation of the output + distribution. (default: :obj:`1.0`) + derivative (bool, optional): Whether to compute the derivative of the + output with respect to the positions. (default: :obj:`False`) + """ + def __init__( + self, + lmax: int = 1, + vecnorm_type: Optional[str] = None, + trainable_vecnorm: bool = False, + num_heads: int = 8, + num_layers: int = 6, + hidden_channels: int = 128, + num_rbf: int = 32, + trainable_rbf: bool = False, + max_z: int = 100, + cutoff: float = 5.0, + max_num_neighbors: int = 32, + vertex: bool = False, + atomref: Optional[Tensor] = None, + reduce_op: str = "sum", + mean: float = 0.0, + std: float = 1.0, + derivative: bool = False, + ) -> None: + super().__init__() + + self.representation_model = ViSNetBlock( + lmax=lmax, + vecnorm_type=vecnorm_type, + trainable_vecnorm=trainable_vecnorm, + num_heads=num_heads, + num_layers=num_layers, + hidden_channels=hidden_channels, + num_rbf=num_rbf, + trainable_rbf=trainable_rbf, + max_z=max_z, + cutoff=cutoff, + max_num_neighbors=max_num_neighbors, + vertex=vertex, + ) + + self.output_model = EquivariantScalar(hidden_channels=hidden_channels) + self.prior_model = Atomref(atomref=atomref, max_z=max_z) + self.reduce_op = reduce_op + self.derivative = derivative + + self.register_buffer('mean', torch.tensor(mean)) + self.register_buffer('std', torch.tensor(std)) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets the parameters of the module.""" + self.representation_model.reset_parameters() + self.output_model.reset_parameters() + if self.prior_model is not None: + self.prior_model.reset_parameters() + + def forward( + self, + z: Tensor, + pos: Tensor, + batch: Tensor, + ) -> Tuple[Tensor, Optional[Tensor]]: + r"""Computes the energies or properties (forces) for a batch of + molecules. + + Args: + z (torch.Tensor): The atomic numbers. + pos (torch.Tensor): The coordinates of the atoms. + batch (torch.Tensor): A batch vector, + which assigns each node to a specific example. + + Returns: + y (torch.Tensor): The energies or properties for each molecule. + dy (torch.Tensor, optional): The negative derivative of energies. + """ + if self.derivative: + pos.requires_grad_(True) + + x, v = self.representation_model(z, pos, batch) + x = self.output_model.pre_reduce(x, v) + x = x * self.std + + if self.prior_model is not None: + x = self.prior_model(x, z) + + y = scatter(x, batch, dim=0, reduce=self.reduce_op) + y = y + self.mean + + if self.derivative: + grad_outputs = [torch.ones_like(y)] + dy = grad( + [y], + [pos], + grad_outputs=grad_outputs, + create_graph=True, + retain_graph=True, + )[0] + if dy is None: + raise RuntimeError( + "Autograd returned None for the force prediction.") + return y, -dy + + return y, None + +if __name__ == "__main__": + node_features = torch.load("node_features.pt") + edge_index = torch.load("edge_index.pt") + + # Model instantiation and forward pass + model = ViSNet() + output = model(node_features, edge_index) + + # Save output to a file + torch.save(output, "gt_output.pt") \ No newline at end of file diff --git a/rdagent/model_implementation/conf.py b/rdagent/model_implementation/conf.py new file mode 100644 index 00000000..061f5b9d --- /dev/null +++ b/rdagent/model_implementation/conf.py @@ -0,0 +1,10 @@ +from pathlib import Path +from pydantic_settings import BaseSettings + +class ModelImplSettings(BaseSettings): + workspace_path: Path = Path("./git_ignore_folder/model_imp_workspace/") # Added type annotation for work_space + + class Config: + env_prefix = 'MODEL_IMPL_' # Use MODEL_IMPL_ as prefix for environment variables + +MODEL_IMPL_SETTINGS = ModelImplSettings() diff --git a/rdagent/model_implementation/gt_code.py b/rdagent/model_implementation/gt_code.py index 7a5ed0c0..b8329fbf 100644 --- a/rdagent/model_implementation/gt_code.py +++ b/rdagent/model_implementation/gt_code.py @@ -1,3 +1,7 @@ +""" +This is just an exmaple. +It will be replaced wtih a list of ground truth tasks. +""" import math from typing import Any, Callable, Dict, Optional, Union diff --git a/rdagent/model_implementation/main.py b/rdagent/model_implementation/main.py index 8398b2a1..760deef6 100644 --- a/rdagent/model_implementation/main.py +++ b/rdagent/model_implementation/main.py @@ -1,3 +1,7 @@ +""" +This file will be removed in the future and replaced by +- rdagent/app/model_implementation/eval.py +""" from dotenv import load_dotenv from rdagent.oai.llm_utils import APIBackend @@ -77,7 +81,7 @@ average_value_eval.append(value_evaluator(llm_output, gt_output)[1]) print("Shape evaluation: ", average_shape_eval[-1]) - print("Value evaluation: ", average_value_eval[-1]) + print("Value evaluation:super().generate(task_l) ", average_value_eval[-1]) os.system("rm llm_output.pt") os.system("rm gt_output.pt") diff --git a/rdagent/model_implementation/one_shot/__init__.py b/rdagent/model_implementation/one_shot/__init__.py new file mode 100644 index 00000000..357f8c03 --- /dev/null +++ b/rdagent/model_implementation/one_shot/__init__.py @@ -0,0 +1,42 @@ +from typing import Sequence +from rdagent.oai.llm_utils import APIBackend + +from jinja2 import Template +from rdagent.core.implementation import TaskGenerator +from rdagent.core.prompts import Prompts +from rdagent.model_implementation.task import ModelImplTask, ModelTaskImpl + +from pathlib import Path +DIRNAME = Path(__file__).absolute().resolve().parent + + +class ModelTaskGen(TaskGenerator): + + def generate(self, task_l: Sequence[ModelImplTask]) -> Sequence[ModelTaskImpl]: + mti_l = [] + for t in task_l: + mti = ModelTaskImpl(t) + mti.prepare() + pr = Prompts(file_path=DIRNAME / "prompt.yaml") + + user_prompt_tpl = Template(pr["code_implement_user"]) + sys_prompt_tpl = Template(pr["code_implement_sys"]) + + user_prompt = user_prompt_tpl.render( + name=t.name, + description=t.description, + formulation=t.formulation, + variables=t.variables, + execute_desc=mti.execute_desc() + ) + system_prompt = sys_prompt_tpl.render() + + resp = APIBackend().build_messages_and_create_chat_completion( + user_prompt, system_prompt + ) + + # Extract the code part from the response + code = resp.split("```python")[1].split("```")[0] + mti.inject_code(**{"model.py": code}) + mti_l.append(mti) + return mti_l diff --git a/rdagent/model_implementation/one_shot/prompt.yaml b/rdagent/model_implementation/one_shot/prompt.yaml new file mode 100644 index 00000000..0145cf35 --- /dev/null +++ b/rdagent/model_implementation/one_shot/prompt.yaml @@ -0,0 +1,18 @@ + + +code_implement_sys: -| + You are an assistant whose job is to answer user's question." +code_implement_user: -| + With the following given information, write a python code using pytorch and torch_geometric to implement the model. + This model is in the graph learning field, only have one layer. + The input will be node_feature [num_nodes, dim_feature] and edge_index [2, num_edges] (It would be the input of the forward model) + There is not edge attribute or edge weight as input. The model should detect the node_feature and edge_index shape, if there is Linear transformation layer in the model, the input and output shape should be consistent. The in_channels is the dimension of the node features. + Implement the model forward function based on the following information:model formula information. + 1. model name:{{name}} + 2. model description:{{description}} + 3. model formulation:{{formulation}} + 4. model variables:{{variables}}. + You must complete the forward function as far as you can do. + \# Execution + Your implemented code will be exectued in the follow way + {{execute_desc}} diff --git a/rdagent/model_implementation/task.py b/rdagent/model_implementation/task.py new file mode 100644 index 00000000..13efba40 --- /dev/null +++ b/rdagent/model_implementation/task.py @@ -0,0 +1,144 @@ +import torch +from pathlib import Path +import uuid +from typing import Dict, Optional, Sequence +from rdagent.core.exception import CodeFormatException +from rdagent.core.task import BaseTask, FBTaskImplementation, ImpLoader, TaskImplementation, TaskLoader +from rdagent.model_implementation.conf import MODEL_IMPL_SETTINGS +from rdagent.utils import get_module_by_module_path + + +class ModelImplTask(BaseTask): + # TODO: it should change when the BaseTask changes. + name: str + description: str + formulation: str + variables: Dict[str, str] # map the variable name to the variable description + + def __init__(self, name: str, description: str, formulation: str, variables: Dict[str, str], key: Optional[str] = None) -> None: + """ + + Parameters + ---------- + + key : Optional[str] + Key is a string to identify the task. + It will be used to connect to other information(e.g. ground truth). + """ + self.name = name + self.description = description + self.formulation = formulation + self.variables = variables + self.key = key + + +class ModelTaskLoderJson(TaskLoader): + def __init__(self, json_uri: str) -> None: + super().__init__() + # TODO: the json should be loaded from URI. + self.json_uri = json_uri + + def load(self, *argT, **kwargs) -> Sequence[ModelImplTask]: + # TODO: we should load the tasks from json; + + formula_info = { + "name": "Anti-Symmetric Deep Graph Network (A-DGN)", + "description": "A framework for stable and non-dissipative DGN design. It ensures long-range information preservation between nodes and prevents gradient vanishing or explosion during training.", + "formulation": "x_u^{(l)} = x_u^{(l-1)} + \\epsilon \\sigma \\left( W^T x_u^{(l-1)} + \\Phi(X^{(l-1)}, N_u) + b \\right)", + "variables": { + "x_u^{(l)}": "The state of node u at layer l", + "\\epsilon": "The step size in the Euler discretization", + "\\sigma": "A monotonically non-decreasing activation function", + "W": "An anti-symmetric weight matrix", + "X^{(l-1)}": "The node feature matrix at layer l-1", + "N_u": "The set of neighbors of node u", + "b": "A bias vector", + }, + "key": "A-DGN", + } + return [ModelImplTask(**formula_info)] + + +class ModelTaskImpl(FBTaskImplementation): + """ + It is a Pytorch model implementation task; + All the things are placed in a folder. + + Folder + - data source and documents prepared by `prepare` + - Please note that new data may be passed in dynamically in `execute` + - code (file `model.py` ) injected by `inject_code` + - the `model.py` that contains a variable named `model_cls` which indicates the implemented model structure + - `model_cls` is a instance of `torch.nn.Module`; + + + We'll import the model in the implementation in file `model.py` after setting the cwd into the directory + - from model import model_cls + - initialize the model by initializing it `model_cls(input_dim=INPUT_DIM)` + - And then verify the modle. + + """ + def __init__(self, target_task: BaseTask) -> None: + super().__init__(target_task) + self.path = None + + def prepare(self) -> None: + """ + Prepare for the workspace; + """ + unique_id = uuid.uuid4() + self.path = MODEL_IMPL_SETTINGS.workspace_path / f"M{unique_id}" + # start with `M` so that it can be imported via python + self.path.mkdir(parents=True, exist_ok=True) + + def execute(self, data=None, config: dict = {}): + mod = get_module_by_module_path(str(self.path / "model.py")) + try: + model_cls = mod.model_cls + except AttributeError: + raise CodeFormatException("The model_cls is not implemented in the model.py") + # model_init = + + assert isinstance(data, tuple) + node_feature, _ = data + in_channels = node_feature.size(-1) + m = model_cls(in_channels) + + # TODO: initialize all the parameters of `m` to `model_eval_param_init` + model_eval_param_init: float = config["model_eval_param_init"] + + # initialize all parameters of `m` to `model_eval_param_init` + for _, param in m.named_parameters(): + param.data.fill_(model_eval_param_init) + + assert isinstance(data, tuple) + return m(*data) + + def execute_desc(self) -> str: + return """ +The the implemented code will be placed in a file like /model.py + +We'll import the model in the implementation in file `model.py` after setting the cwd into the directory +- from model import model_cls (So you must have a variable named `model_cls` in the file) + - So your implelemented code could follow the following pattern + ```Python + class XXXLayer(torch.nn.Module): + ... + model_cls = XXXLayer + ``` +- initialize the model by initializing it `model_cls(input_dim=INPUT_DIM)` +- And then verify the model by comparing the output tensors by feeding specific input tensor. +""" + +class ModelImpLoader(ImpLoader[ModelImplTask, ModelTaskImpl]): + def __init__(self, path: Path) -> None: + self.path = Path(path) + + def load(self, task: ModelImplTask) -> ModelTaskImpl: + assert task.key is not None + mti = ModelTaskImpl(task) + mti.prepare() + with open(self.path / f"{task.key}.py", "r") as f: + code = f.read() + mti.inject_code(**{"model.py": code}) + return mti diff --git a/rdagent/utils/__init__.py b/rdagent/utils/__init__.py new file mode 100644 index 00000000..58e8a456 --- /dev/null +++ b/rdagent/utils/__init__.py @@ -0,0 +1,36 @@ +""" +This is some common utils functions. +it is not binding to the scenarios or framework (So it is not placed in rdagent.core.utils) +""" +# TODO: merge the common utils in `rdagent.core.utils` into this folder +# TODO: split the utils in this module into different modules in the future. + +import importlib +import re +import sys +from types import ModuleType +from typing import Union + + +def get_module_by_module_path(module_path: Union[str, ModuleType]): + """Load module from path like a/b/c/d.py or a.b.c.d + + :param module_path: + :return: + :raises: ModuleNotFoundError + """ + if module_path is None: + raise ModuleNotFoundError("None is passed in as parameters as module_path") + + if isinstance(module_path, ModuleType): + module = module_path + else: + if module_path.endswith(".py"): + module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_"))) + module_spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(module_spec) + sys.modules[module_name] = module + module_spec.loader.exec_module(module) + else: + module = importlib.import_module(module_path) + return module