-
Notifications
You must be signed in to change notification settings - Fork 0
/
federated_averaging.py
130 lines (98 loc) · 4.4 KB
/
federated_averaging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
import keras
from keras.models import load_model
def federated_averaging(updates, model_path, ckpt_path):
print("Model path: ", model_path)
print("Checkpoint path: ", ckpt_path)
print("Updates array: ", updates)
total_num_batches = 0
# Load model architecture
average_weight_updates = load_model(model_path)
sum_weight_updates = load_model(model_path)
device_weight_updates = load_model(model_path)
# Calculate sum of weight updates from all devices
for device_index in range(len(updates)):
n, weight_updates_path = updates[device_index]
total_num_batches += n
if device_index == 0:
sum_weight_updates.load_weights(weight_updates_path)
else:
# Load device weight updates checkpoint
device_weight_updates.load_weights(weight_updates_path)
# Add weight updates from device to prefix sum
for layer_index in range(len(sum_weight_updates.layers)):
# Old sum of weight updates
old_sum_weight_updates_values = sum_weight_updates.layers[layer_index].get_weights(
)
# Device weight updates
device_weight_updates_values = device_weight_updates.layers[layer_index].get_weights(
)
# Weight updates calculation
sum_weight_updates.layers[layer_index].set_weights(np.asarray(old_sum_weight_updates_values)
+ np.asarray(device_weight_updates_values))
# print("old weights: ", old_layer_weights)
# print("new weights: ", new_layer_weights)
# print("update weights: ", update_weights.layers[i].get_weights())
# Calculate average
for layer_index in range(len(sum_weight_updates.layers)):
# Value of sum of weight updates
sum_weight_updates_values = sum_weight_updates.layers[layer_index].get_weights(
)
# Calculate average and store
average_weight_updates.layers[layer_index].set_weights(
np.asarray(sum_weight_updates_values)/total_num_batches)
# print("old weights: ", old_layer_weights)
# print("new weights: ", new_layer_weights)
# print("update weights: ", update_weights.layers[i].get_weights())
# Add average of weight updates to checkpoint
# Load model and checkpoints
model = load_model(model_path)
model.load_weights(ckpt_path)
for layer_index in range(len(model.layers)):
# Average of weight updates values
average_weight_updates_values = average_weight_updates.layers[layer_index].get_weights(
)
old_model_values = model.layers[layer_index].get_weights()
model.layers[layer_index].set_weights(np.asarray(
old_model_values) + np.asarray(average_weight_updates_values))
# Save updated model checkpoints
model.save_weights(ckpt_path)
print("New checkpoint saved at ", ckpt_path)
def main():
# define Arguments
parser = argparse.ArgumentParser(description='Perform Federated Averaging')
parser.add_argument("--cf", "--ckpt-file-path", required=True, nargs=1)
parser.add_argument("--mf", "--model-file-path", required=True, nargs=1)
parser.add_argument("--u", "--updates", required=True, nargs='*')
# params for federated averaging
model_path = ''
ckpt_path = ''
updates = []
# parse arguments
args = parser.parse_args()
for arg in vars(args):
# print(arg, getattr(args, arg))
if (arg == "mf"):
model_path = getattr(args, arg)[0]
elif (arg == "cf"):
ckpt_path = getattr(args, arg)[0]
else:
update_args = getattr(args, arg)
print(update_args)
print("Len: ", len(update_args))
for i in range(0, len(update_args), 2):
print("index: ", i)
# print(update_args[i])
# print(update_args[i + 1])
n = int(update_args[i])
path = update_args[i + 1]
updates.append((n, path))
# run federated averaging
federated_averaging(updates, model_path, ckpt_path)
if __name__ == "__main__":
main()