-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_jgrp2o.py
64 lines (55 loc) · 2.77 KB
/
train_jgrp2o.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import glob
import os
import tensorflow as tf
import src.estimation.configuration as configs
from src.estimation.evaluation import evaluate
from src.estimation.train import train
def get_configs(dataset_name: str):
if dataset_name == 'bighand':
train_conf = configs.TrainBighandConfig()
test_conf = configs.TestBighandConfig()
elif dataset_name == 'msra':
train_conf = configs.TrainMsraConfig()
test_conf = configs.TestMsraConfig()
else:
raise ValueError(F"Invalid dataset: {dataset_name}")
return train_conf, test_conf
parser = argparse.ArgumentParser()
parser.add_argument('dataset', type=str, action='store',
help='the dataset to be used for training (allowed options: msra, bighand)')
parser.add_argument('--evaluate', type=str, action='store', default=None,
help='the dataset to be used for evaluation (allowed options: msra, bighand)')
parser.add_argument('--model', type=str, action='store', default=None,
help='the weights to load the model from (default: none)')
parser.add_argument('--features', type=int, action='store', default=196,
help='the number of features (channels) throughout the network (default: 196)')
parser.add_argument('--batch-size', type=int, action='store', default=64,
help='the number of samples in a batch')
parser.add_argument('--learning-rate', type=float, action='store', default=0.0001,
help='learning rate')
parser.add_argument('--learning-decay-rate', type=float, action='store', default=0.93,
help='a decay of learning rate after each epoch')
parser.add_argument('--ignore-otsus-threshold', type=float, action='store', default=0.01,
help='a theshold for ignoring Otsus thresholding method')
args = parser.parse_args()
train_cfg, test_cfg = get_configs(args.dataset)
train_cfg.learning_rate = args.learning_rate
train_cfg.learning_decay_rate = args.learning_decay_rate
train_cfg.batch_size = args.batch_size
test_cfg.batch_size = args.batch_size
train_cfg.ignore_threshold_otsus = args.ignore_otsus_threshold
test_cfg.ignore_threshold_otsus = args.ignore_otsus_threshold
log_dir, model_filepath = train(args.dataset, args.model, train_cfg, model_features=args.features)
if args.evaluate is not None:
if model_filepath is not None and os.path.isfile(model_filepath):
path = model_filepath
else:
ckpts_pattern = os.path.join(str(log_dir), 'train_ckpts/*')
ckpts = glob.glob(ckpts_pattern)
path = max(ckpts, key=os.path.getctime)
if path is not None:
thresholds, mje = evaluate(args.evaluate, path, args.features)
tf.print("MJE:", mje)
else:
raise ValueError("No checkpoints available")