-
-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
example workflow code for model implementation (#17)
* upload example code for model implementation * Refactor Model Implement * add export --------- Co-authored-by: Young <[email protected]>
- Loading branch information
Showing
21 changed files
with
2,712 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
|
||
# Preparation | ||
|
||
## Install Pytorch | ||
CPU CUDA will be enough for verify the implementation | ||
|
||
Please install pytorch based on your system. | ||
Here is an example on my system | ||
```bash | ||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu | ||
pip3 install torch_geometric | ||
|
||
``` | ||
|
||
# Tasks | ||
|
||
## Task Extraction | ||
From paper to task. | ||
```bash | ||
python rdagent/app/model_implementation/task_extraction.py | ||
# It may based on rdagent/document_reader/document_reader.py | ||
``` | ||
|
||
## Complete workflow | ||
From paper to implementation | ||
``` bash | ||
# Similar to | ||
# rdagent/app/factor_extraction_and_implementation/factor_extract_and_implement.py | ||
``` | ||
|
||
## Paper benchmark | ||
```bash | ||
python rdagent/app/model_implementation/eval.py | ||
|
||
TODO: | ||
- Is evaluation reasonable | ||
``` | ||
|
||
## Evolving |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from pathlib import Path | ||
|
||
DIRNAME = Path(__file__).absolute().resolve().parent | ||
|
||
from rdagent.model_implementation.benchmark.eval import ModelImpValEval | ||
from rdagent.model_implementation.one_shot import ModelTaskGen | ||
from rdagent.model_implementation.task import ModelImpLoader, ModelTaskLoderJson | ||
|
||
mtl = ModelTaskLoderJson("TODO: A Path to json") | ||
|
||
task_l = mtl.load() | ||
|
||
mtg = ModelTaskGen() | ||
|
||
impl_l = mtg.generate(task_l) | ||
|
||
# TODO: Align it with the benchmark framework after @wenjun's refine the evaluation part. | ||
# Currently, we just handcraft a workflow for fast evaluation. | ||
|
||
mil = ModelImpLoader(DIRNAME.parent.parent / "model_implementation" / "benchmark" / "gt_code") | ||
|
||
mie = ModelImpValEval() | ||
# Evaluation: | ||
eval_l = [] | ||
for impl in impl_l: | ||
gt_impl = mil.load(impl.target_task) | ||
eval_l.append(mie.evaluate(gt_impl, impl)) | ||
|
||
print(eval_l) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# TODO: inherent from the benchmark base class | ||
import torch | ||
from rdagent.model_implementation.task import ModelTaskImpl | ||
|
||
|
||
def get_data_conf(init_val): | ||
# TODO: design this step in the workflow | ||
in_dim = 1000 | ||
in_channels = 128 | ||
exec_config = {"model_eval_param_init": init_val} | ||
node_feature = torch.randn(in_dim, in_channels) | ||
edge_index = torch.randint(0, in_dim, (2, 2000)) | ||
return (node_feature, edge_index), exec_config | ||
|
||
|
||
class ModelImpValEval: | ||
""" | ||
Evaluate the similarity of the model structure by changing the input and observate the output. | ||
Assumption: | ||
- If the model structure is similar, the output will change in similar way when we change the input. | ||
- we try to initialize the model param in similar value. So only the model structure is different. | ||
""" | ||
|
||
def evaluate(self, gt: ModelTaskImpl, gen: ModelTaskImpl): | ||
round_n = 10 | ||
|
||
eval_pairs: list[tuple] = [] | ||
|
||
# run different input value | ||
for _ in range(round_n): | ||
# run different model initial parameters. | ||
for init_val in [-0.2, -0.1, 0.1, 0.2]: | ||
data, exec_config = get_data_conf(init_val) | ||
gt_res = gt.execute(data=data, config=exec_config) | ||
res = gen.execute(data=data, config=exec_config) | ||
eval_pairs.append((res, gt_res)) | ||
|
||
# flat and concat the output | ||
res_batch, gt_res_batch = [], [] | ||
for res, gt_res in eval_pairs: | ||
res_batch.append(res.reshape(-1)) | ||
gt_res_batch.append(gt_res.reshape(-1)) | ||
res_batch = torch.stack(res_batch) | ||
gt_res_batch = torch.stack(gt_res_batch) | ||
|
||
res_batch = res_batch.detach().numpy() | ||
gt_res_batch = gt_res_batch.detach().numpy() | ||
|
||
# pearson correlation of each hidden output | ||
def norm(x): | ||
return (x - x.mean(axis=0)) / x.std(axis=0) | ||
dim_corr = (norm(res_batch) * norm(gt_res_batch)).mean(axis=0) # the correlation of each hidden output | ||
|
||
# aggregate all the correlation | ||
avr_corr = dim_corr.mean() | ||
# FIXME: | ||
# It is too high(e.g. 0.944) . | ||
# Check if it is not a good evaluation!! | ||
# Maybe all the same initial params will results in extreamly high correlation without regard to the model structure. | ||
return avr_corr |
Oops, something went wrong.