-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaverage_ckpt.py
33 lines (27 loc) · 1.38 KB
/
average_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import argparse
import torch
def average_checkpoints(checkpoint_paths):
averaged_ckpt = torch.load(checkpoint_paths[-1], map_location=torch.device('cpu'))
param_sum_dict = {}
for key, value in averaged_ckpt['state_dict'].items():
param_sum_dict[key] = value.clone()
num_checkpoints = len(checkpoint_paths)
for ckpt_path in checkpoint_paths[:-1]:
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
for key, value in checkpoint['state_dict'].items():
param_sum_dict[key] += value
for key in param_sum_dict.keys():
param_sum_dict[key] = param_sum_dict[key] / num_checkpoints
averaged_ckpt['state_dict'] = param_sum_dict
return averaged_ckpt
def parse_arguments():
parser = argparse.ArgumentParser(description="Averages the weights of multiple transformer model checkpoints.")
parser.add_argument('--checkpoint_paths', nargs='+', required=True,
help='List of paths to the checkpoints to be averaged. Example: --checkpoint_paths path1 path2 path3')
parser.add_argument('--output_path', type=str, required=True,)
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
averaged_state_dict = average_checkpoints(args.checkpoint_paths)
torch.save(averaged_state_dict, args.output_path)
print(f"Averaged checkpoint saved to {args.output_path}")