-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiment.py
189 lines (157 loc) · 7.23 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
---
title: Finetune GPT-2 with LoRA
summary: This is training code with notes for fine-tuning pre-trained GPT-2 model with LoRA.
---
# Finetune [GPT-2](gpt2.html) with [LoRA](index.html)
Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/lora/experiment.ipynb)
"""
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from labml import lab, monit, tracker
from labml.configs import BaseConfigs, option
from labml.utils.download import download_file
from labml_helpers.device import DeviceConfigs
from labml_nn.lora.gpt2 import GPTModel
class Trainer(BaseConfigs):
"""
## Trainer configurations and the training loop
The default configs can and will be over-ridden when we start the experiment
"""
device: torch.device = DeviceConfigs()
# GPT-2 configs
layer_norm_epsilon: float = 1e-05
d_model: int = 768
n_layers: int = 12
n_heads: int = 12
n_positions: int = 1024
vocab_size: int = 50257
# Training configs
epochs: int = 10
batch_size: int = 32
learning_rate: float = 1e-4
context_len: int = 512
# LoRA rank
lora_r: int = 32
# Dataset
text: TensorDataset = "tiny_shakespeare"
# Huggingface tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# [GPT2 model](gpt2.html)
model: GPTModel
# Optimizer
optimizer: torch.optim.Adam
# Cross entropy loss
loss_func = torch.nn.CrossEntropyLoss()
# Dataloader
data_loader: DataLoader
def _load_pretrained_weights(self):
"""
### Load pre-trained [GPT-2 from huggingface](https://huggingface.co/openai-community/gpt2)
"""
# Load the huggingface model and get the parameters
hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
state_dict = hf_model.state_dict()
# Transformer embedding and prediction layer parameter mapping (`hf: ours`)
mapping = {
'transformer.wte.weight': 'token_embedding.weight',
'transformer.wpe.weight': 'position_embedding.weight',
'transformer.ln_f.weight': 'final_norm.weight',
'transformer.ln_f.bias': 'final_norm.bias',
'lm_head.weight': 'lm_head.weight'
}
# Mapping (`hf: ours`) of decoder layers
for i in range(12):
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.attn_norm.weight'
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.attn_norm.bias'
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.ffn_norm.weight'
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.ffn_norm.bias'
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.linear_out.bias'
# Move the parameters based on mapping
new_state_dict = {}
for old_key, new_key in mapping.items():
if old_key in state_dict:
new_state_dict[new_key] = state_dict[old_key]
# GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers
convo_layers = ([f'blocks.{i}.ffn.linear_in.weight' for i in range(12)] +
[f'blocks.{i}.ffn.linear_out.weight' for i in range(12)] +
[f'blocks.{i}.attn.qkv_projection.weight' for i in range(12)] +
[f'blocks.{i}.attn.output_projection.weight' for i in range(12)])
for layer in convo_layers:
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
# Load out model. We use `strict = False` because the state does not have LoRA weights
missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
# make sure that only lora weights are not loaded
assert all('lora' in key for key in missing_keys)
assert not unexpected_keys
def initialize(self):
"""
### Initialize the model, optimizer and dataloader
"""
# Initialize the [GPT2 model](gpt2.html)
self.model = GPTModel(
layer_norm_epsilon=self.layer_norm_epsilon,
d_model=self.d_model,
n_layers=self.n_layers,
n_heads=self.n_heads,
n_positions=self.n_positions,
vocab_size=self.vocab_size,
r=self.lora_r,
)
self.model.to(self.device)
# Load pre-trained model weights
self._load_pretrained_weights()
# Initialize the optimizer
self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)
# Initialize the data loader
self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)
def run(self):
"""
### Training loop
"""
for _ in monit.loop(self.epochs):
# `inputs` has shape `[batch_size, seq_len]`
for (inputs,) in monit.iterate('Train', self.data_loader):
# Move `inputs` to device
inputs = inputs.to(self.device)
# Call the model, with the all but the last token
logits = self.model(inputs[:, :-1])
# Get cross entropy loss
loss = self.loss_func(logits.reshape(-1, logits.shape[-1]), inputs[:, 1:].reshape(-1))
# Make gradients 0
self.optimizer.zero_grad()
# Compute gradients
loss.backward()
# Optimize
self.optimizer.step()
# Log the loss
tracker.save({'loss': loss})
tracker.add_global_step()
#
tracker.new_line()
@option(Trainer.text)
def tiny_shakespeare(c: Trainer):
"""
### Tiny Shakespeare dataset
It will download from the url if not present
"""
path = lab.get_data_path() / 'tiny_shakespeare.txt'
if not path.exists():
download_file("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)
with open(path, 'r', encoding='utf-8') as f:
text = f.read()
tokens = c.tokenizer.encode(text)
num_batches = len(tokens) // (c.batch_size * c.context_len)
tokens = tokens[:num_batches * c.batch_size * c.context_len]
input_ids = torch.tensor(tokens).view(-1, c.context_len)
return TensorDataset(input_ids)