-
Notifications
You must be signed in to change notification settings - Fork 2
/
main_mv_wsl.py
141 lines (127 loc) · 5.18 KB
/
main_mv_wsl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.model_summary.model_summary import ModelSummary
import wandb
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf
from utils import dataset
from config.MyMVWSLConfig import MyMVWSLConfig
from config.MyMVWSLConfig import LogConfig
from config.ModelConfig import JointModelConfig
from config.ModelConfig import MixedPriorModelConfig
from config.ModelConfig import UnimodalModelConfig
from config.ModelConfig import SplitModelConfig
from config.DatasetConfig import PMtranslatedData75Config
from config.DatasetConfig import CelebADataConfig
from config.DatasetConfig import CUBDataConfig
from config.MyMVWSLConfig import EvalConfig
from mv_vaes.mv_joint_vae import MVJointVAE as MVJointVAE
from mv_vaes.mv_split_vae import MVSplitVAE as MVSplitVAE
from mv_vaes.mv_unimodal_vae import MVunimodalVAE as MVunimodalVAE
from mv_vaes.mv_mixedprior_vae import MVMixedPriorVAE as MVMixedPriorVAE
cs = ConfigStore.instance()
# Registering the Config class with the name 'config'.
cs.store(group="log", name="log", node=LogConfig)
cs.store(group="model", name="joint", node=JointModelConfig)
cs.store(group="model", name="mixedprior", node=MixedPriorModelConfig)
cs.store(group="model", name="unimodal", node=UnimodalModelConfig)
cs.store(group="model", name="split", node=SplitModelConfig)
cs.store(group="eval", name="eval", node=EvalConfig)
cs.store(group="dataset", name="PMtranslated75", node=PMtranslatedData75Config)
cs.store(group="dataset", name="CelebA", node=CelebADataConfig)
cs.store(group="dataset", name="cub", node=CUBDataConfig)
cs.store(name="base_config", node=MyMVWSLConfig)
@hydra.main(version_base=None, config_path="config", config_name="config")
def run_experiment(cfg: MyMVWSLConfig):
print(cfg)
if cfg.log.wandb_local_instance:
wandb.login(host=os.getenv("WANDB_LOCAL_URL"))
elif not cfg.log.wandb_offline:
wandb.login(host="https://api.wandb.ai")
pl.seed_everything(cfg.seed, workers=True)
# get data loaders
train_loader, train_dst, val_loader, _ = dataset.get_dataset(cfg)
label_names = train_dst.label_names
# init model
model = None
if cfg.model.name == "joint":
model = MVJointVAE(cfg)
elif cfg.model.name == "mixedprior":
model = MVMixedPriorVAE(cfg)
elif cfg.model.name == "unimodal":
model = MVunimodalVAE(cfg)
elif cfg.model.name == "split":
model = MVSplitVAE(cfg)
assert model is not None
model.assign_label_names(label_names)
summary = ModelSummary(model, max_depth=2)
print(summary)
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
wandb_logger = WandbLogger(
name=cfg.log.wandb_run_name,
config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
project=cfg.log.wandb_project_name,
group=cfg.log.wandb_group,
offline=cfg.log.wandb_offline,
entity=cfg.log.wandb_entity,
save_dir=cfg.log.dir_logs,
)
trainer = pl.Trainer(
max_epochs=cfg.model.epochs,
devices=1,
accelerator="gpu" if cfg.model.device == "cuda" else cfg.model.device,
logger=wandb_logger,
check_val_every_n_epoch=1,
deterministic=True,
)
if cfg.log.debug:
trainer.logger.watch(model, log="all")
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
model.logger.log_metrics({"final_scores/rec_loss": model.final_scores_rec_loss})
model.logger.log_metrics(
{"final_scores/cond_rec_loss": model.final_scores_cond_rec_loss}
)
for m, key in enumerate(model.modality_names):
model.logger.log_metrics(
{
f"final_scores/downstream_lr/aggregated/{key}": model.final_scores_lr_aggregated[
m
]
}
)
model.logger.log_metrics(
{
f"final_scores/downstream_lr/unimodal/{key}": model.final_scores_lr_unimodal[
m
]
}
)
if cfg.dataset.name == "celeba":
for k, l_name in enumerate(label_names):
model.logger.log_metrics(
{
f"final_scores/downstream_lr/aggregated/{key}/{l_name}": model.final_scores_lr_aggregated_alllabels[
m, k
]
}
)
model.logger.log_metrics(
{
f"final_scores/downstream_lr/unimodal/{key}/{l_name}": model.final_scores_lr_unimodal_alllabels[
m, k
]
}
)
for m, key in enumerate(model.modality_names):
for m_tilde, key_tilde in enumerate(model.modality_names):
model.logger.log_metrics(
{
f"final_scores/coherence/{key}_to_{key_tilde}": model.final_scores_coh[
m, m_tilde, :
].mean()
}
)
if __name__ == "__main__":
run_experiment()