Skip to content

Commit

Permalink
Merge pull request #262 from melo-gonzo/experiments-config-refactor
Browse files Browse the repository at this point in the history
Refactor Experiment Configs
  • Loading branch information
laserkelvin authored Jul 25, 2024
2 parents 637ce2e + 1f01427 commit 39dae9b
Show file tree
Hide file tree
Showing 30 changed files with 172 additions and 96 deletions.
16 changes: 13 additions & 3 deletions experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

Experimental workflows may be time consuming, repetitive, and complex to set up. Additionally, pytorch-lightning based cli utilities may not be able to handle specific use cases such as multi-data or multi-task training in matsciml. The experiments module of MatSciML is meant to loosely mirror the functionality of the pytorch lightning cli while allowing more flexibility in setting up complex experiments. Yaml files define the module parameters, and specific arguments may be change via the command line if desired. A single command is used to launch training runs which take out the complexity of writing up new script for each experiment type.

## Config Files
Yaml files dictate how models, datasets, tasks and trainers are set up. A default set of config files is provided under `./experiments/configs`, however when setting up new experiments, it is recommended to create a folder of your own to better track your own experiment designs. Using the cli parameters `-d` `-m`, `-t` and `-e`, you can specify a path to a directory or specific file for datasets, models, trainer, and experiment configuration. The varying yaml files and their contents are explained below. An example of running a simple experimental using the predefined yaml files may look like this:

```python
python experiments/training_script.py -e experiments/experiment_config.yaml -t experiments/configs/trainer.yaml -m ./experiments/models -d ./experiments/datasets/oqmd.yaml --debug
```

Note that a combination of full yaml file paths and directory paths are used. In the case of models configs (`-m`) a full path is specified, meaning all yaml files contained in that directory will accessible from the experiment config. In the case of multidata training, it is only possible to point to a directory of datasets as multiple will need to be configured.

## Experiment Config
The starting point of defining an experiment is the experiment config. This is a yaml file that lays out what model, dataset(s), and task(s) will be used during training. An example config for single task training yaml (`single_task.yaml`) look like this:
```yaml
Expand All @@ -13,11 +22,12 @@ dataset:
```
In general, and experiment may the be launched by running:
`python experiments/training_script.py --experiment_config ./experiments/single_task.yaml`
`python experiments/training_script.py --experiment_config ./experiments/configs/single_task.yaml`


* The `model` field points to a specify `model.yaml` file in `./experiments/models`.
* The `dataset` field is a dictionary specifying which datasets to use, as well as which tasks are associated with the parent dataset.
* The trainer used defaults to the config in `./experiments/configs/trainer.yaml`.
* The `model` field points to a specific `model.yaml` file. Default model configs are in `./experiments/configs/models`.
* The `dataset` field is a dictionary specifying which datasets to use, as well as which tasks are associated with the parent dataset. Default datasets are in `./experiments/configs/datasets`.
* Tasks are referred to by their class name:
```python
ScalarRegressionTask
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
21 changes: 0 additions & 21 deletions experiments/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +0,0 @@
import yaml

from pathlib import Path

yaml_dir = Path(__file__).parent


available_data = {
"generic": {
"experiment": {"batch_size": 32, "num_workers": 16},
"debug": {"batch_size": 4, "num_workers": 0},
},
}


for filename in yaml_dir.rglob("*.yaml"):
file_path = yaml_dir.joinpath(filename)
with open(file_path, "r") as file:
content = yaml.safe_load(file)
file_key = file_path.stem
available_data[file_key] = content
13 changes: 6 additions & 7 deletions experiments/datasets/data_module_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,21 @@

from matsciml.lightning.data_utils import MultiDataModule

from experiments.datasets import available_data
from experiments.models import available_models
from experiments.utils.configurator import configurator
from experiments.utils.utils import instantiate_arg_dict, update_arg_dict


def setup_datamodule(config: dict[str, Any]) -> pl.LightningModule:
model = config["model"]
data_task_dict = config["dataset"]
run_type = config["run_type"]
model = instantiate_arg_dict(deepcopy(available_models[model]))
model = instantiate_arg_dict(deepcopy(configurator.models[model]))
model = update_arg_dict("model", model, config["cli_args"])
datasets = list(data_task_dict.keys())
if len(datasets) == 1:
dset = deepcopy(available_data[datasets[0]])
dset = deepcopy(configurator.datasets[datasets[0]])
dset = update_arg_dict("dataset", dset, config["cli_args"])
dm_kwargs = deepcopy(available_data["generic"]["experiment"])
dm_kwargs = deepcopy(configurator.datasets["generic"]["experiment"])
dm_kwargs.update(dset[run_type])
if run_type == "debug":
dm = MatSciMLDataModule.from_devset(
Expand All @@ -45,9 +44,9 @@ def setup_datamodule(config: dict[str, Any]) -> pl.LightningModule:
else:
dset_list = {"train": [], "val": [], "test": []}
for dataset in datasets:
dset = deepcopy(available_data[dataset])
dset = deepcopy(configurator.datasets[dataset])
dset = update_arg_dict("dataset", dset, config["cli_args"])
dm_kwargs = deepcopy(available_data["generic"]["experiment"])
dm_kwargs = deepcopy(configurator.datasets["generic"]["experiment"])
dset[run_type].pop("normalize_kwargs", None)
dm_kwargs.update(dset[run_type])
dataset_name = dset["dataset"]
Expand Down
10 changes: 10 additions & 0 deletions experiments/datasets/tests/test_data_module_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@
import pytest

import matsciml
import experiments
import matsciml.datasets.transforms # noqa: F401
from experiments.datasets.data_module_config import setup_datamodule
from experiments.utils.configurator import configurator
from pathlib import Path

base_path = Path(experiments.__file__).parent
model_path = base_path.joinpath("configs", "models")
datasets_path = base_path.joinpath("configs", "datasets")
trainer_path = base_path.joinpath("configs", "trainer")
configurator.configure_models(model_path)
configurator.configure_datasets(datasets_path)
configurator.configure_trainer(trainer_path)

single_task = {
"model": "egnn_dgl",
Expand Down
27 changes: 0 additions & 27 deletions experiments/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +0,0 @@
import yaml

from torch.nn import LayerNorm


from pathlib import Path

yaml_dir = Path(__file__).parent
available_models = {
"generic": {
"output_kwargs": {
"norm": LayerNorm(128),
"hidden_dim": 128,
"activation": "SiLU",
"lazy": False,
"input_dim": 128,
},
"lr": 0.0001,
},
}

for filename in yaml_dir.rglob("*.yaml"):
file_path = yaml_dir.joinpath(filename)
with open(file_path, "r") as file:
content = yaml.safe_load(file)
file_key = file_path.stem
available_models[file_key] = content
16 changes: 11 additions & 5 deletions experiments/models/tests/test_model_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@

import pytest

from copy import deepcopy
from experiments.utils.utils import instantiate_arg_dict

from experiments.models import available_models
from experiments.utils.configurator import configurator
import experiments
from pathlib import Path

models = list(available_models.keys())
models.remove("generic")
base_path = Path(experiments.__file__).parent
model_path = base_path.joinpath("configs", "models")
configurator.configure_models(model_path)

models = list(configurator.models.keys())


@pytest.mark.parametrize("model", models)
def test_instantiate_model_dict(model):
model_dict = available_models[model]
instantiate_arg_dict(model_dict)
model_dict = configurator.models[model]
instantiate_arg_dict(deepcopy(model_dict))
11 changes: 5 additions & 6 deletions experiments/task_config/task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,23 @@
from matsciml.common.registry import registry
from matsciml.models.base import MultiTaskLitModule

from experiments.datasets import available_data
from experiments.models import available_models
from experiments.utils.configurator import configurator
from experiments.utils.utils import instantiate_arg_dict, update_arg_dict


def setup_task(config: dict[str, Any]) -> pl.LightningModule:
model = config["model"]
data_task_dict = config["dataset"]
model = instantiate_arg_dict(deepcopy(available_models[model]))
model = instantiate_arg_dict(deepcopy(configurator.models[model]))
model = update_arg_dict("model", model, config["cli_args"])
configured_tasks = []
data_task_list = []
for dataset_name, tasks in data_task_dict.items():
dset_args = deepcopy(available_data[dataset_name])
dset_args = deepcopy(configurator.datasets[dataset_name])
dset_args = update_arg_dict("dataset", dset_args, config["cli_args"])
for task in tasks:
task_class = registry.get_task_class(task["task"])
task_args = deepcopy(available_models["generic"])
task_args = deepcopy(configurator.models["generic"])
task_args.update(model)
task_args.update({"task_keys": task["targets"]})
additonal_task_args = dset_args.get("task_args", None)
Expand All @@ -34,7 +33,7 @@ def setup_task(config: dict[str, Any]) -> pl.LightningModule:
configured_task = task_class(**task_args)
configured_tasks.append(configured_task)
data_task_list.append(
[available_data[dataset_name]["dataset"], configured_task]
[configurator.datasets[dataset_name]["dataset"], configured_task]
)

if len(configured_tasks) > 1:
Expand Down
7 changes: 4 additions & 3 deletions experiments/task_config/tests/test_task_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
)
from experiments.utils.utils import instantiate_arg_dict

from copy import deepcopy

import matsciml
import matsciml.datasets.transforms # noqa: F401


single_task = {
"model": "egnn_dgl",
"dataset": {"oqmd": [{"task": "ScalarRegressionTask", "targets": ["band_gap"]}]},
Expand All @@ -20,7 +21,7 @@
"dataset": {
"s2ef": [
{"task": "ScalarRegressionTask", "targets": ["energy"]},
{"task": "ForceRegressionTask", "targets": ["force"]},
{"task": "ForceRegressionTask", "targets": ["energy", "force"]},
]
}
}
Expand Down Expand Up @@ -64,7 +65,7 @@ def test_build_model() -> dict[str, Any]:
],
}

output = instantiate_arg_dict(input_dict)
output = instantiate_arg_dict(deepcopy(input_dict))
assert isinstance(
output["transforms"][0],
matsciml.datasets.transforms.PeriodicPropertiesTransform,
Expand Down
16 changes: 0 additions & 16 deletions experiments/trainer_config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1 @@
from experiments.trainer_config.trainer_config import setup_trainer # noqa: F401

import yaml
from pathlib import Path


yaml_dir = Path(__file__).parent
trainer_args = {
"generic": {"min_epochs": 15, "max_epochs": 100},
}

for filename in yaml_dir.rglob("*.yaml"):
file_path = yaml_dir.joinpath(filename)
with open(file_path, "r") as file:
content = yaml.safe_load(file)
file_key = file_path.stem
trainer_args.update(content)
35 changes: 31 additions & 4 deletions experiments/training_script.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import yaml
from typing import Any

from pathlib import Path
from experiments.datasets.data_module_config import setup_datamodule
from experiments.task_config.task_config import setup_task
from experiments.trainer_config.trainer_config import setup_trainer
from experiments.trainer_config import trainer_args

from experiments.utils.utils import setup_log_dir, config_help
from experiments.utils.configurator import configurator

from argparse import ArgumentParser

Expand All @@ -17,7 +17,7 @@ def main(config: dict[str, Any]) -> None:

dm = setup_datamodule(config)
task = setup_task(config)
trainer = setup_trainer(config, trainer_args=trainer_args)
trainer = setup_trainer(config, trainer_args=configurator.trainer)
trainer.fit(task, datamodule=dm)


Expand All @@ -30,16 +30,37 @@ def main(config: dict[str, Any]) -> None:
action="store_true",
)
parser.add_argument(
"-d",
"--debug",
help="Uses debug config with devsets and only a few batches per epoch.",
action="store_true",
)
parser.add_argument(
"-e",
"--experiment_config",
type=Path,
help="Experiment config yaml file to use.",
)
parser.add_argument(
"-d",
"--dataset_config",
type=Path,
default=Path(__file__).parent.joinpath("configs", "datasets"),
help="Dataset config folder or yaml file to use.",
)
parser.add_argument(
"-t",
"--trainer_config",
type=Path,
default=Path(__file__).parent.joinpath("configs", "trainer"),
help="Trainer config folder or yaml file to use.",
)
parser.add_argument(
"-m",
"--model_config",
type=Path,
default=Path(__file__).parent.joinpath("configs", "models"),
help="Model config folder or yaml file to use.",
)
parser.add_argument(
"-c",
"--cli_args",
Expand All @@ -48,9 +69,15 @@ def main(config: dict[str, Any]) -> None:
default=None,
)
args = parser.parse_args()

configurator.configure_models(args.model_config)
configurator.configure_datasets(args.dataset_config)
configurator.configure_trainer(args.trainer_config)

if args.options:
config_help()
os._exit(0)

config = yaml.safe_load(open(args.experiment_config))
config["cli_args"] = (
[arg.split(".") for arg in args.cli_args] if args.cli_args else None
Expand Down
Loading

0 comments on commit 39dae9b

Please sign in to comment.