-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
37 lines (30 loc) · 894 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from hbconfig import Config
from acgan import ACGAN
from gan import GAN
from one_shot_aug import OneShotAug
class Model():
TRAIN_MODE = "train"
EVALUATE_MODE = "evaluate"
PREDICT_MODE = "predict"
def __init__(self, mode):
self.mode = mode
def model_builder(self,seed):
# load model
global model
models = [ACGAN, GAN, OneShotAug]
for current_model in models:
if current_model.model_name == Config.model.name:
model = current_model(seed)
break
criterion = model.build_criterion()
if self.mode == self.TRAIN_MODE:
optimizers = model.build_optimizers(model.meta_net)
return model.train_fn(criterion, optimizers)
elif self.mode == self.EVALUATE_MODE:
return model.evaluate_model()
elif self.mode == self.PREDICT_MODE:
return model.predict(criterion)
else:
raise ValueError(f"unknown mode: {self.mode}")
def build_metric(self):
pass