-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
69 lines (58 loc) · 2.08 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from torch.optim import AdamW
from utils.compute_para import *
from utils.train_type import *
from utils.lr_scheduler import *
from dataclasses import dataclass, field
from transformers import (
AutoProcessor,
Trainer,
TrainingArguments,
HfArgumentParser,
)
from data import LlavaDataset, TrainLLavaModelCollator
logger = logging.getLogger(__name__)
@dataclass
class Arguments:
model_name_or_path: str = field(default="mllm_chinese")
train_type: str = field(
default="freeze_vision_and_llm",
metadata={"help": "Training types: 'use_lora', 'freeze_vision', 'freeze_vision_and_llm'"}
)
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
# 训练过程
def train():
parser = HfArgumentParser((Arguments, TrainingArguments))
args, training_args = parser.parse_args_into_dataclasses()
print(f"Parsed arguments: {args}")
model, processor = load_model_and_processor(args)
data_collator = TrainLLavaModelCollator(processor, -100)
train_dataset = LlavaDataset(args.data_path)
print(model)
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.1)
total_steps = compute_total_steps(train_dataset, training_args)
# lr_scheduler = custom_lr_scheduler(optimizer, total_steps)
lr_scheduler = cos_lr_scheduler(optimizer, total_steps)
# 创建LossLoggerCallback实例
loss_logger = LossLoggerCallback()
# 创建Trainer对象
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
optimizers=(optimizer, lr_scheduler),
eval_dataset=None,
data_collator=data_collator,
callbacks=[loss_logger],
)
# 开始训练
trainer.train()
trainer.save_model(output_dir=training_args.output_dir)
# 绘制损失曲线
plot_loss_curve(loss_logger.losses,output_dir=training_args.output_dir)
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
train()