-
Notifications
You must be signed in to change notification settings - Fork 5
/
trainer.py
119 lines (103 loc) · 4.85 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from transformers import Trainer
from transformers.trainer import (
is_sagemaker_mp_enabled,
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
logger,
)
import os
from peft import get_peft_model_state_dict
class MultiModalTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
return model(
image=inputs["images"],
input_ids=inputs["input_ids"],
labels=inputs["labels"],
).loss
def save_model(self, output_dir=None, _internal_call=False):
from transformers.trainer import TRAINING_ARGS_NAME
# Ensure output_dir is not None
if output_dir is None:
output_dir = self.args.output_dir
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Save training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
# Access the original model
model = self.model.module if hasattr(self.model, 'module') else self.model
# Save LLM parameters
saved_params_LLM = get_peft_model_state_dict(model.LLM)
torch.save(saved_params_LLM, os.path.join(output_dir, "adapter_model.bin"))
# Save other parameters
saved_params_other = model.feature_proj.state_dict()
torch.save(saved_params_other, os.path.join(output_dir, "other_params.bin"))
# Save configuration
config = model.LLM.peft_config
selected_adapters = list(config.keys())
config[selected_adapters[0]].save_pretrained(output_dir, auto_mapping_dict=None)
def create_optimizer(self):
if is_sagemaker_mp_enabled():
return super().create_optimizer()
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
if self.args.feature_proj_lr is not None:
projector_parameters = [name for name, _ in opt_model.named_parameters() if "feature_proj" in name]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.feature_proj_lr,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.feature_proj_lr,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return self.optimizer
def create_optimizer_and_scheduler(self, num_training_steps: int):
super().create_optimizer_and_scheduler(num_training_steps)
if self.args.local_rank != -1:
self.model = torch.nn.parallel.DistributedDataParallel(
self.model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=True
)