-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtrain.py
221 lines (179 loc) · 8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the CC-BY-NC 4.0 license found in the
# LICENSE file in the root directory of this source tree.
import os
import hydra
import numpy as np
import torch
import torch.utils.data as data
from hydra.core.hydra_config import HydraConfig
from lightning.fabric import seed_everything
from omegaconf import DictConfig, OmegaConf
import wandb
from tactile_ssl.trainer import Trainer # noqa: E402
from tactile_ssl.utils import get_local_rank
from tactile_ssl.utils.logging import get_pylogger, print_config_tree # noqa: E402
logger = get_pylogger(__name__)
OmegaConf.register_new_resolver("int_multiply", lambda a, b: int(a * b))
def init_wandb(cfg: DictConfig):
wandb.init(
project=cfg.project,
entity=cfg.entity,
dir=cfg.save_dir,
id=f"{cfg.id}_{get_local_rank()}",
group=cfg.group,
tags=cfg.tags,
notes=cfg.notes,
)
return wandb
def get_dataloaders_magnetic_based(cfg: DictConfig):
data_cfg = cfg.data
if data_cfg.sensor == "tdex":
dataset = hydra.utils.instantiate(data_cfg.dataset)
train_dset_size = int(len(dataset) * cfg.data.train_val_split)
train_dset, val_dset = data.random_split(
dataset, [train_dset_size, len(dataset) - train_dset_size]
)
elif data_cfg.sensor == "reskin":
train_dataset_list = data_cfg.train_dataset_list
val_dataset_list = data_cfg.val_dataset_list
train_datasets, val_datasets = [], []
for dataset_name in train_dataset_list:
data_path = os.path.join(data_cfg.dataset.data_path, dataset_name)
train_datasets.append(
hydra.utils.instantiate(data_cfg.dataset, data_path=data_path)
)
for dataset_name in val_dataset_list:
data_path = os.path.join(data_cfg.dataset.data_path, dataset_name)
val_datasets.append(
hydra.utils.instantiate(data_cfg.dataset, data_path=data_path)
)
train_dset = data.ConcatDataset(train_datasets)
val_dset = data.ConcatDataset(val_datasets)
return train_dset, val_dset
def get_dataloaders_vision_based(cfg: DictConfig):
data_cfg = cfg.data
n_sensors = len(cfg.data.sensor)
train_dset = []
val_dset = []
for i in range(n_sensors):
sensor_cfg = data_cfg.sensor[i]
if sensor_cfg.type == "digit" or sensor_cfg.type == "gelsight_mini":
list_datasets = sensor_cfg.dataset.config.list_datasets
train_dset_ids = sensor_cfg.dataset.config.dataset_ids_train
val_dset_ids = sensor_cfg.dataset.config.dataset_ids_val
for obj in list_datasets:
for d_id in train_dset_ids:
dataset_name = obj + "/dataset_" + str(d_id)
dataset = hydra.utils.instantiate(
sensor_cfg.dataset,
sensor_type=sensor_cfg.type,
dataset_name=dataset_name,
)
train_dset.append(dataset)
for d_id in val_dset_ids:
dataset_name = obj + "/dataset_" + str(d_id)
dataset = hydra.utils.instantiate(
sensor_cfg.dataset,
sensor_type=sensor_cfg.type,
dataset_name=dataset_name,
)
val_dset.append(dataset)
elif sensor_cfg.type == "gelsight":
list_datasets = sensor_cfg.dataset.config.list_datasets
path_dataset = sensor_cfg.dataset.config.path_dataset
all_datasets = []
for obj in list_datasets:
files_list = os.listdir(os.path.join(path_dataset, obj))
for file in files_list:
dataset_name = obj + "/" + file.split(".")[0]
dataset = hydra.utils.instantiate(
sensor_cfg.dataset,
sensor_type=sensor_cfg.type,
dataset_name=dataset_name,
)
all_datasets.append(dataset)
all_datasets = sorted(all_datasets, key=lambda x: len(x), reverse=True)
train_dset_size = int(
len(all_datasets) * sensor_cfg.dataset.config.train_val_split
)
train_dset = train_dset + all_datasets[:train_dset_size]
val_dset = val_dset + all_datasets[train_dset_size:]
if isinstance(train_dset, list):
train_dset = data.ConcatDataset(train_dset)
val_dset = data.ConcatDataset(val_dset)
return train_dset, val_dset
def get_dataloaders(cfg: DictConfig):
train_dset, val_dset = get_dataloaders_vision_based(cfg)
train_dataloader = data.DataLoader(train_dset, **cfg.data.train_dataloader)
val_dataloader = data.DataLoader(val_dset, **cfg.data.val_dataloader)
return train_dataloader, val_dataloader
def attempt_resume(cfg: DictConfig):
ckpt_path = None
if os.environ.get('SLURM_RESTART_COUNT') is not None:
slurm_job_id = os.environ.get('SLURM_JOB_ID')
requeue_count = os.environ.get('SLURM_RESTART_COUNT')
logger.info(f"requeue count for job {slurm_job_id}: {requeue_count}")
cfg.resume_id = slurm_job_id
if os.path.exists(f"{cfg.paths.output_dir}/config.yaml") and cfg.resume_id:
job_id = HydraConfig.get().job.id
logger.info(f"Attempting to resume experiment with {cfg.resume_id}")
if not os.path.exists(f"{cfg.paths.output_dir}/checkpoints/"):
logger.warning(
f"Unable to resume: No checkpoints found for experiment with id {job_id}"
)
return False, cfg
if not os.path.exists(f"{cfg.paths.output_dir}/wandb/"):
logger.warning(
f"Unable to resume: No wandb logs found for experiment with id {job_id}"
)
return False, cfg
if not os.path.exists(f"{cfg.paths.output_dir}/config.yaml"):
logger.warning(
"Could not find a config.yaml file in the resume directory. Using the current config."
)
return False, cfg
cfg = OmegaConf.load(f"{cfg.paths.output_dir}/config.yaml")
ckpt_path = f"{cfg.paths.output_dir}/checkpoints/"
OmegaConf.update(cfg, "ckpt_path", ckpt_path, force_add=True)
experiment_name = cfg.experiment_name
cfg.wandb.id = f"{job_id}_{experiment_name}"
logger.info(
f"Resuming experiment {job_id} with wandb_id: {cfg.wandb.id} from latest checkpoint at {cfg.ckpt_path}"
)
return True, cfg
return False, cfg
def train(cfg: DictConfig):
resume_state, cfg = attempt_resume(cfg)
logger.info(f"Resume state: {resume_state}, {cfg.ckpt_path}")
logger.info("Instantiating wandb ...")
wandb = init_wandb(cfg.wandb)
if ~resume_state:
wandb.config.update(OmegaConf.to_container(cfg, resolve=True))
OmegaConf.save(cfg, f"{cfg.paths.output_dir}/config.yaml")
print_config_tree(cfg, resolve=True, save_to_file=True)
if cfg.get("seed"):
seed_everything(cfg.seed, workers=True)
n_sensors = len(cfg.data.sensor) if OmegaConf.is_list(cfg.data.sensor) else 1
sensors_type = (
[cfg.data.sensor[i].type for i in range(n_sensors)]
if n_sensors > 1
else cfg.data.sensor
)
logger.info(f"Instantiating dataset & dataloaders for <{sensors_type}>")
train_dataloader, val_dataloader = get_dataloaders(cfg)
trainer = Trainer(wandb_logger=wandb, **cfg.trainer)
logger.info(f"Instantiating model <{cfg.model._target_}>")
model = hydra.utils.instantiate(cfg.model)
trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=cfg.ckpt_path)
wandb.finish()
@hydra.main(version_base="1.3", config_path="config", config_name="default.yaml")
def main(cfg: DictConfig):
"""
Main function to train the model
"""
train(cfg)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
main()