-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
107 lines (88 loc) · 3.25 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
so nothing in this file really has anything to do with GPT specifically.
"""
import time
from collections import defaultdict
import torch
from torch.utils.data.dataloader import DataLoader
from mingpt.logger import Logger
from mingpt.utils import CfgNode as CN
from mingpt.utils import try_auto_cast
class Trainer:
@staticmethod
def get_default_config():
C = CN()
C.epochs = 1
# device to train on
C.device = 'auto'
# dataloder parameters
C.num_workers = 4
# optimizer parameters
C.batch_size = 64
C.learning_rate = 3e-4
C.betas = (0.9, 0.95)
C.weight_decay = 0.1 # only applied on matmul weights
C.grad_norm_clip = 1.0
C.compile = False
return C
def __init__(self, config, model, train_dataset):
self.config = config
self.model = model
self.optimizer = None
self.train_dataset = train_dataset
self.callbacks = defaultdict(list)
self.logger = Logger()
# determine the device we'll train on
if config.device == 'auto':
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = config.device
self.model = self.model.to(self.device)
if config.compile:
self.model = torch.compile(self.model)
print("running on device", self.device)
# variables that will be assigned to trainer class later for logging and etc
self.iter_num = 0
self.iter_time = 0.0
self.iter_dt = 0.0
def add_callback(self, onevent: str, callback):
self.callbacks[onevent].append(callback)
def set_callback(self, onevent: str, callback):
self.callbacks[onevent] = [callback]
def trigger_callbacks(self, onevent: str):
for callback in self.callbacks.get(onevent, []):
callback(self)
def run(self):
model, config = self.model, self.config
# setup the optimizer
self.optimizer = model.configure_optimizers(config)
# setup the dataloader
train_loader = DataLoader(
self.train_dataset,
shuffle=True,
pin_memory=True,
drop_last=True,
batch_size=config.batch_size,
num_workers=config.num_workers,
)
model.train()
self.iter_num = 0
self.iter_time = time.time()
for epoch in range(config.epochs):
self.epoch = epoch
for batch in train_loader:
batch = [t.to(self.device) for t in batch]
# forward the model
with try_auto_cast(self.device):
logits, self.loss = model(*batch)
# backprop and update the parameters
model.zero_grad(set_to_none=True)
self.loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
self.optimizer.step()
self.trigger_callbacks('on_batch_end')
self.iter_num += 1
tnow = time.time()
self.iter_dt = tnow - self.iter_time
self.iter_time = tnow