Skip to content

Commit

Permalink
fix bug: _param_filter (#642)
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk authored Jul 3, 2023
1 parent fa6e1a5 commit 6dfe8d4
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 14 deletions.
3 changes: 3 additions & 0 deletions federatedscope/attack/trainer/gaussian_attack_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def hook_on_batch_backward_generate_gaussian_noise_gradient(ctx):
ctx.optimizer.zero_grad()
ctx.loss_task.backward()

if ctx.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), ctx.grad_clip)

grad_values = list()
for name, param in ctx.model.named_parameters():
if 'bn' not in name:
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/aggregators/bulyan_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _aggre_with_bulyan(self, models):
Apply MultiKrum to select \theta (\theta <= client_num-
2*self.byzantine_node_num) local models
'''
init_model = self.model.state_dict()
_, init_model = models[0]
global_update = copy.deepcopy(init_model)
models_para = [each_model[1] for each_model in models]
krum_scores = self._calculate_score(models_para)
Expand Down
5 changes: 5 additions & 0 deletions federatedscope/core/aggregators/clients_avg_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def _para_weighted_avg(self, models, recover_fun=None):
for i in range(len(models)):
local_sample_size, local_model = models[i]

if key not in local_model:
continue

if self.cfg.federate.ignore_weight:
weight = 1.0 / len(models)
elif self.cfg.federate.use_ss:
Expand Down Expand Up @@ -126,6 +129,8 @@ def inc(self, content):
if isinstance(content, tuple):
sample_size, model_params = content
for key in self.maintained:
if key not in model_params:
continue
# if model_params[key].device != self.maintained[key].device:
# model_params[key].to(self.maintained[key].device)
self.maintained[key] = (self.cnt * self.maintained[key] +
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/aggregators/median_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def aggregate(self, agg_info):
return updated_model

def _aggre_with_median(self, models):
init_model = self.model.state_dict()
_, init_model = models[0]
global_update = copy.deepcopy(init_model)
for key in init_model:
temp = torch.stack([each_model[1][key] for each_model in models],
Expand Down
16 changes: 11 additions & 5 deletions federatedscope/core/aggregators/normbounding_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,34 @@ def aggregate(self, agg_info):
def _aggre_with_normbounding(self, models):
models_temp = []
for each_model in models:
param = self._flatten_updates(each_model[1])
param, ignore_keys = self._flatten_updates(each_model[1])
if torch.norm(param, p=2) > self.norm_bound:
scaling_rate = self.norm_bound / torch.norm(param, p=2)
scaled_param = scaling_rate * param
models_temp.append(
(each_model[0], self._reconstruct_updates(scaled_param)))
(each_model[0],
self._reconstruct_updates(scaled_param, ignore_keys)))
else:
models_temp.append(each_model)
return self._para_weighted_avg(models_temp)

def _flatten_updates(self, model):
model_update = []
model_update, ignore_keys = [], []
init_model = self.model.state_dict()
for key in init_model:
if key not in model:
ignore_keys.append(key)
continue
model_update.append(model[key].view(-1))
return torch.cat(model_update, dim=0)
return torch.cat(model_update, dim=0), ignore_keys

def _reconstruct_updates(self, flatten_updates):
def _reconstruct_updates(self, flatten_updates, ignore_keys):
start_idx = 0
init_model = self.model.state_dict()
reconstructed_model = copy.deepcopy(init_model)
for key in init_model:
if key in ignore_keys:
continue
reconstructed_model[key] = flatten_updates[
start_idx:start_idx + len(init_model[key].view(-1))].reshape(
init_model[key].shape)
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/aggregators/trimmedmean_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def aggregate(self, agg_info):
return updated_model

def _aggre_with_trimmedmean(self, models):
init_model = self.model.state_dict()
_, init_model = models[0]
global_update = copy.deepcopy(init_model)
excluded_num = int(len(models) * self.excluded_ratio)
for key in init_model:
Expand Down
6 changes: 5 additions & 1 deletion federatedscope/core/auxiliaries/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,8 @@ def get_model(model_config, local_data=None, backend='torch'):


def get_trainable_para_names(model):
return set(dict(list(model.named_parameters())).keys())
grad_params = set()
for name, param in model.named_parameters():
if param.requires_grad:
grad_params.add(name)
return grad_params
1 change: 1 addition & 0 deletions federatedscope/core/trainers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(self, model, cfg, data=None, device=None):

# Setup optimize-related context variable
if self.cfg.backend == 'torch':
# TODO: should we make `self.trainable_para_names` @property?
self.trainable_para_names = get_trainable_para_names(self.model)
# TODO: make `criterion` and `regularizer` @property and cached
# to compare whether changes happen
Expand Down
4 changes: 2 additions & 2 deletions federatedscope/core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,11 @@ def _param_filter(self, state_dict, filter_keywords=None):

trainable_filter = lambda p: True if \
self.cfg.personalization.share_non_trainable_para else \
lambda p: p in self.ctx.trainable_para_names
p in self.ctx.trainable_para_names
keyword_filter = filter_by_specified_keywords
return dict(
filter(
lambda elem: trainable_filter(elem[1]) and keyword_filter(
lambda elem: trainable_filter(elem[0]) and keyword_filter(
elem[0], filter_keywords), state_dict.items()))

def save_model(self, path, cur_round=-1):
Expand Down
5 changes: 3 additions & 2 deletions federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ def _calculate_model_delta(self, init_model, updated_model):

model_deltas = list()
for model_index in range(len(init_model)):
model_delta = copy.deepcopy(init_model[model_index])
model_delta = copy.deepcopy(updated_model[model_index])
for key in init_model[model_index].keys():
if key not in updated_model[model_index].keys():
continue
model_delta[key] = updated_model[model_index][
key] - init_model[model_index][key]
model_deltas.append(model_delta)
Expand Down Expand Up @@ -425,7 +427,6 @@ def callback_funcs_for_model_para(self, message: Message):
else:
shared_model_para = symmetric_uniform_quantization(
shared_model_para, nbits)

self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
Expand Down
13 changes: 13 additions & 0 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,19 @@ def __init__(self,
# set up a trainer for conducting evaluation in server
assert self.models is not None
assert self.data is not None

if self._cfg.backend == 'torch':
import torch.nn as nn
# Set BN track_running_stats to False
for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d):
module.track_running_stats = False
elif self._cfg.backend == 'tensorflow':
# TODO: implement this
pass
else:
raise ValueError(f'Unknown backend named {self._cfg.backend}.')

self.trainer = get_trainer(
model=self.models[0],
data=self.data,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_krum_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_guassian_attack_krum(self):
init_cfg.merge_from_other_cfg(backup_cfg)
self.assertGreater(
test_best_results['client_summarized_weighted_avg']['test_acc'],
0.2)
0.15)
init_cfg.merge_from_other_cfg(backup_cfg)

def test_guassian_attack_multi_krum(self):
Expand Down

0 comments on commit 6dfe8d4

Please sign in to comment.