-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
58 lines (47 loc) · 1.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import logging
import config
from utils import utils
import os
from dataloader.data import get_dataset
from torch.utils.data import DataLoader
from approaches.train import Appr
from approaches.noncl import Appr as Appr_noncl
import torch
from torch.utils.data import ConcatDataset
logger = logging.getLogger(__name__)
args = config.parse_args()
args = utils.prepare_sequence_train(args)
## set seed
random_seed = args.seed # or any of your favorite number
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
dataset = get_dataset(args)
model = utils.lookfor_model(args)
if 'full' in args.baseline:
train_loader = DataLoader(ConcatDataset([dataset[t]['train'] for t in range(args.task+1)]), batch_size=args.batch_size, shuffle=True, num_workers=8)
else:
train_loader = DataLoader(dataset[args.task]['train'], batch_size=args.batch_size, shuffle=True, num_workers=8)
test_loaders = []
for eval_t in range(args.ntasks):
test_dataset = dataset[eval_t]['test']
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
test_loaders.append(test_dataloader)
replay_loader = None
if dataset[args.task]['replay'] is not None:
replay_loader = DataLoader(dataset[args.task]['replay'], batch_size=args.replay_batch_size, shuffle=True, num_workers=8)
if 'full' in args.baseline:
appr = Appr_noncl(args)
else:
appr = Appr(args)
appr.train(model, train_loader, test_loaders, replay_loader)