-
-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add needed dependency * add extract_model_and_implement pipeline * add merge_file_to_model_dict_to_model_dict * implement `rdagent\app\model_implementation\eval.py` * Running benchmark * refine import --------- Co-authored-by: Young <[email protected]>
- Loading branch information
Showing
14 changed files
with
345 additions
and
46 deletions.
There are no files selected for viewing
16 changes: 16 additions & 0 deletions
16
rdagent/app/model_extraction_and_implementation/model_extraction_and_implementation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# %% | ||
from dotenv import load_dotenv | ||
from rdagent.components.task_implementation.model_implementation.one_shot import ModelTaskGen | ||
from rdagent.components.task_implementation.model_implementation.task_extraction import ModelImplementationTaskLoaderFromPDFfiles | ||
|
||
|
||
|
||
def extract_models_and_implement(report_file_path: str="../test_doc") -> None: | ||
factor_tasks = ModelImplementationTaskLoaderFromPDFfiles().load(report_file_path) | ||
implementation_result = ModelTaskGen().generate(factor_tasks) | ||
return implementation_result | ||
|
||
|
||
import fire | ||
if __name__ == "__main__": | ||
fire.Fire(extract_models_and_implement) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
rdagent/components/task_implementation/model_implementation/benchmark/eval.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
rdagent/components/task_implementation/model_implementation/benchmark/model_dict.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
{ | ||
"PMLP": { | ||
"description": "`PMLP` is identical to a standard MLP during training, but then adopts a GNN architecture (add message passing) during testing.", | ||
"formulation": "\\hat{y}_u = \\psi(\\text{MP}(\\{h^{(l-1)}_v\\}_{v \\in N_u \\cup \\{u\\}}))", | ||
"variables": { | ||
"\\hat{y}_u": "The predicted output for node u", | ||
"\\psi": "A function representing the feed-forward process, consisting of a linear feature transformation followed by a non-linear activation", | ||
"\\text{MP}": "Message Passing operation that aggregates neighbored information", | ||
"h^{(l-1)}_v": "The feature representation of node v at layer (l-1)", | ||
"N_u": "The set of neighbored nodes centered at node u" | ||
}, | ||
"key": "pmlp" | ||
}, | ||
"LINKX": { | ||
"description": "A scalable model for node classification that separately embeds adjacency and node features, combines them with MLPs, and applies simple transformations.", | ||
"formulation": "Y = MLP_f(\\sigma(W[h_A; h_X] + h_A + h_X))", | ||
"variables": { | ||
"Y": "The output predictions", | ||
"\\sigma": "Non-linear activation function", | ||
"W": "Learned weight matrix", | ||
"h_A": "Embedding of the adjacency matrix", | ||
"h_X": "Embedding of the node features", | ||
"MLP_f": "Final multilayer perceptron for prediction" | ||
}, | ||
"key": "linkx" | ||
}, | ||
"GPSConv": { | ||
"description": "A scalable and powerful graph transformer with linear complexity, capable of handling large graphs with state-of-the-art results across diverse benchmarks.", | ||
"formulation": "X^{(l+1)} = \\text{MPNN}^{(l)}(X^{(l)}, A) + \\text{GlobalAttn}^{(l)}(X^{(l)})", | ||
"variables": { | ||
"X^{(l)}": "The node features at layer l", | ||
"A": "The adjacency matrix of the graph", | ||
"X^{(l+1)}": "The updated node features at layer l+1", | ||
"MPNN^{(l)}": "The message-passing neural network function at layer l", | ||
"GlobalAttn^{(l)}": "The global attention function at layer l" | ||
}, | ||
"key": "gpsconv" | ||
}, | ||
"ViSNet": { | ||
"description": "ViSNet is an equivariant geometry-enhanced graph neural network designed for efficient molecular modeling[^1^][1][^2^][2]. It utilizes a Vector-Scalar interactive message passing mechanism to extract and utilize geometric features with low computational costs, achieving state-of-the-art performance on multiple molecular dynamics benchmarks.", | ||
"formulation": "\\text{ViSNet}(G) = \\sum_{u \\in G} f(\\mathbf{h}_u, \\mathbf{e}_u, \\mathbf{v}_u)", | ||
"variables": { | ||
"\\mathbf{h}_u": "Node embedding for atom u", | ||
"\\mathbf{e}_u": "Edge embedding associated with atom u", | ||
"\\mathbf{v}_u": "Direction unit vector for atom u" | ||
}, | ||
"key": "visnet" | ||
}, | ||
"Dir-GNN": { | ||
"description": "A framework for deep learning on directed graphs that extends MPNNs to incorporate edge directionality.", | ||
"formulation": "x^{(k)}_i = COM^{(k)}\\left(x^{(k-1)}_i, m^{(k)}_{i,\\leftarrow}, m^{(k)}_{i,\\rightarrow}\\right)", | ||
"variables": { | ||
"x^{(k)}_i": "The feature representation of node i at layer k", | ||
"m^{(k)}_{i,\\leftarrow}": "The aggregated incoming messages to node i at layer k", | ||
"m^{(k)}_{i,\\rightarrow}": "The aggregated outgoing messages from node i at layer k" | ||
}, | ||
"key": "dirgnn" | ||
}, | ||
"A-DGN": { | ||
"description": "A framework for stable and non-dissipative DGN design, conceived through the lens of ordinary differential equations (ODEs). It ensures long-range information preservation between nodes and prevents gradient vanishing or explosion during training.", | ||
"formulation": "\\frac{\\partial x_u(t)}{\\partial t} = \\sigma(W^T x_u(t) + \\Phi(X(t), N_u) + b)", | ||
"variables": { | ||
"x_u(t)": "The state of node u at time t", | ||
"\\frac{\\partial x_u(t)}{\\partial t}": "The rate of change of the state of node u at time t", | ||
"\\sigma": "A monotonically non-decreasing activation function", | ||
"W": "A weight matrix", | ||
"b": "A bias vector", | ||
"\\Phi(X(t), N_u)": "The aggregation function for the states of the nodes in the neighborhood of u", | ||
"X(t)": "The node feature matrix of the whole graph at time t", | ||
"N_u": "The set of neighboring nodes of u" | ||
}, | ||
"key": "A-DGN" | ||
} | ||
} |
8 changes: 8 additions & 0 deletions
8
rdagent/components/task_implementation/model_implementation/prompts.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
extract_model_formulation_system: |- | ||
offer description of the proposed model in this paper, write a latex formula with variable of the model. the format should be like " "Model Name": { | ||
"description": "", | ||
"formulation": "", | ||
"variables": { | ||
"\\hat{y}_u": "The predicted output for node u", | ||
}" | ||
such format content should be begin with ```json and end with ``` and the content should be in json format. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.