-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlightning_model.py
86 lines (65 loc) · 2.33 KB
/
lightning_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from math import sqrt
import numpy as np
from PIL import Image
import pytorch_lightning as pl
import torch
import torch.nn as nn
from models import *
import dataloader
from collections import OrderedDict
class Model(pl.LightningModule):
def __init__(self, hparams):
super().__init__(hparams)
self.save_hyperparameters(hparams)
self.model = U_Net(img_ch=3, output_ch=3)
'''
TODO:
1. model __getattr__ 로 소환
2. transfer learning 구현
3. transfer learning 모델과 안맞을시 assertion
'''
if hparams.transfer_learning:
try:
dicts = torch.load(hparams.transfer_learning)
self.model.load_state_dict(dicts)
except:
print("Load error you may loaded another state dict of model. try another state dict")
def load_dict(self, target_dict):
TODO: load checkpoint as specified in hparameter.yaml file.
def forward(self, x):
'''
INPUT:
x -> [B, W, H, C]
out -> [B, W, H, C]
'''
out = self.model(x)
return out
def common_step(self, x, y):
output = self(x)
'''
activation, thersholding here.
'''
loss = self.mse(output, y)
return loss, output
def training_step(self, batch):
x, y = batch # coordinate [B,2], rgb [B,3]
loss, _ = self.common_step(x, y)
self.log('loss', loss)
return loss
def validation_step(self, batch):
x, y = batch
loss, output = self.common_step(x, y, is_train=False)
self.log('val_loss', loss)
self.logger.log_image(output, x*2 - 1, self.current_epoch)
return {'loss': loss, 'output': output}
def test_step(self, batch):
x, y = batch
loss, output = self.common_step(x, y, is_train=False)
self.log('test_loss', loss)
return {'test_loss': loss, 'output': output}
def train_dataloader(self):
return dataloader.create_coin_dataloader(self.hparams, 0)
def val_dataloader(self):
return dataloader.create_coin_dataloader(self.hparams, 1)
def test_dataloader(self):
return dataloader.create_coin_dataloader(self.hparams, 2)