-
Notifications
You must be signed in to change notification settings - Fork 7
/
FL_Base.py
154 lines (120 loc) · 4.19 KB
/
FL_Base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import json
from enum import Enum, auto
import numpy as np
import requests
from keras.models import load_model
from time import localtime, strftime
import tensorflow as tf
class Institution(Enum):
SMC = auto()
NNC = auto()
ETC = auto()
class NumpyEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, np.ndarray):
return o.tolist()
return json.JSONEncoder.default(self, o)
class API():
base_url = "http://FlServer.d6mm7kyzdp.ap-northeast-2.elasticbeanstalk.com"
request_round_api = base_url + "/round"
request_global_weight = base_url + "/weight"
put_global_weight = base_url + "/weight"
request_global_params = base_url + "/params"
request_client_count = base_url + "/client_count"
class FL_Base:
np.random.seed(19)
max_round = 10
global_round = 0
current_round = 0
delay_time = 10 # second
local_epochs = 10
local_batch_size = 100
model = tf.keras.Sequential()
def set_institution(self, institution):
self.institution = institution
def task(self):
"""task sequence
1. request global round
2. compare global round and local round
3. request global weight
4. if the round same & exist global weight > run validation
5. start local training
6. update local weight to aggregation server
7. delay, do next round
"""
pass
def request_global_round(self):
"""
request_global_round
"""
print("> FL_Base: request_global_round")
result = requests.get(API.request_round_api)
self.global_round = result.json()
print("------------------------------")
print("> global round: {}, local round : {}".format(self.global_round, self.current_round))
print("------------------------------")
return self.global_round
def request_global_weight(self):
"""
request_global_weight
"""
print("> FL_Base: request_global_weight")
self.global_weight = None
result = requests.get(API.request_global_weight)
result_data = result.json()
if result_data is not None:
self.global_weight = []
for i in range(len(result_data)):
temp = np.array(result_data[i], dtype=np.float32)
self.global_weight.append(temp)
return self.global_weight
def local_training(self):
"""
local_training
"""
print("> FL_Base: local_training")
def local_evaluate(self):
"""
local_evaluate
"""
print("> FL_Base: local_evaluate")
def update_local_weight(self, local_weight=None):
"""
update_local_weight
"""
print("> FL_Base: update_local_weight")
if local_weight is not None:
local_weight_to_json = json.dumps(local_weight, cls=NumpyEncoder)
result = requests.put(API.put_global_weight, data=local_weight_to_json)
else:
print("> FL_Base: update_local_weight: error: local_weight None")
return result
def save_model(self, file_path_name, ci_train, ci_test):
print("> FL_Base: save_local_model")
current_time = strftime("%y-%m-%d_%I:%M:%S", localtime())
total_path = file_path_name + "_"+ str(ci_train) + "_" + str(ci_test) + "_" + current_time + ".h5"
print("> FL_Base: save_model: " + total_path)
self.model.save_weights(total_path)
def delay_round(self):
"""
delay_round
if self.current_round < self.max_round:
threading.Timer(self.delay_time, self.task).start()
"""
print("> FL_Base: delay_round")
'''
check response object
'''
def request_global_params(self):
print("> FL_Base: request_global_weight")
result = requests.get(API.request_global_params)
result_data = result.json()
return result
'''
check response object
'''
def request_client_count(self):
print("> FL_Base: request_client_count")
result = requests.get(API.request_client_count)
result_data = result.json()
return result_data