Skip to content

Commit

Permalink
Fix main.py logic
Browse files Browse the repository at this point in the history
  • Loading branch information
edogab33 committed Dec 21, 2023
1 parent 8d22f2e commit 56e2bb8
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 111 deletions.
21 changes: 18 additions & 3 deletions baselines/flanders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,29 @@ Ensure that the environment is properly set up, then run:
python -m flanders.main
```

This will run the entire batch of experiments with all the settings specified in `conf/base.yaml`. You can change them if you want to run single experiments.
To execute a single experiment with the default values in `conf/base.yaml`.

Note that also CIFAR-10 and California Housing are implemented.
To run custom experiments, you can override the defaults values like that:

```bash
python -m flanders.main dataset=income server.attack_fn=lie server.num_malicious=1
```

Finally, to run multiple custom experiments:

```bash
python -m flanders.main --multirun dataset=income,mnist server.attack_fn=gaussian,lie,fang,minmax server.num_malicious=0,1,2,3,4,5
```


## Expected Results

By running the above command, it will generate the results in `results/all_results.csv`. To generate the plots, use the notebook in `plotting/plots.ipynb`.
By running;
```bash
python -m flanders.main --multirun dataset=income,mnist server.attack_fn=gaussian,lie,fang,minmax server.num_malicious=0,1,2,3,4,5,6,7,8,9
```

It will generate the results in `results/all_results.csv`. To generate the plots, use the notebook in `plotting/plots.ipynb`.

Expected maximum accuracy achieved across different number of malicious clients and different attacks:
![](_static/max_acc.jpg)
Expand Down
25 changes: 3 additions & 22 deletions baselines/flanders/flanders/conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@ hydra:
run:
dir: .

dataset:
name:
- income
#- house
- mnist
#- cifar
dataset: mnist

strategy:
_target_: flanders.strategy.Flanders
Expand All @@ -28,25 +23,11 @@ server:
_target_: flanders.server.EnhancedServer
num_rounds: 50
pool_size: 10
num_malicious:
- 0
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
warmup_rounds: 30
sampling: 500
history_dir: clients_params
magnitude: 100
attack_fn:
- gaussian_attack
- lie_attack
- fang_attack
- minmax_attack
threshold: 1e-05
attack_fn: gaussian
num_malicious: 0
omniscent: True
166 changes: 80 additions & 86 deletions baselines/flanders/flanders/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def main(cfg: DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))

attacks = {
"gaussian_attack": gaussian_attack,
"lie_attack": lie_attack,
"fang_attack": fang_attack,
"minmax_attack": minmax_attack
"gaussian": gaussian_attack,
"lie": lie_attack,
"fang": fang_attack,
"minmax": minmax_attack
}

clients = {
Expand All @@ -73,88 +73,82 @@ def main(cfg: DictConfig) -> None:
"income": (IncomeClient, income_evaluate)
}

for dataset_name in cfg.dataset.name:
for attack_fn in cfg.server.attack_fn:
for num_malicious in cfg.server.num_malicious:
# the experiment with num_malicious = 0 should be done only one time
if num_malicious == 0 and attack_fn != "gaussian_attack":
continue

if attack_fn == "fang_attack" and num_malicious == 1:
continue

# Delete old client_params and clients_predicted_params
if os.path.exists(cfg.server.history_dir):
shutil.rmtree(cfg.server.history_dir)

# 2. Prepare your dataset
sampling = cfg.server.sampling
if dataset_name == "cifar":
train_path, testset = get_cifar_10()
fed_dir = do_fl_partitioning(
train_path, pool_size=cfg.server.pool_size, alpha=10000, num_classes=10, val_ratio=0.5, seed=1234
)
elif dataset_name == "income":
sampling = 0
X_train, X_test, y_train, y_test = get_partitioned_income("flanders/datasets_files/adult.csv", cfg.server.pool_size)
elif dataset_name == "house":
sampling = 0
X_train, X_test, y_train, y_test = get_partitioned_house("flanders/datasets_files/houses_preprocessed.csv", cfg.server.pool_size)


# 3. Define your clients
def client_fn(cid: str, pool_size: int = 10, dataset_name: str = dataset_name):
client = clients[dataset_name][0]
cid_idx = int(cid)
if dataset_name == "cifar":
return client(cid, fed_dir)
elif dataset_name == "mnist":
return client(cid, pool_size)
elif dataset_name == "income":
return client(cid, X_train[cid_idx], y_train[cid_idx], X_test[cid_idx], y_test[cid_idx])
elif dataset_name == "house":
return client(cid, X_train[cid_idx], y_train[cid_idx], X_test[cid_idx], y_test[cid_idx])
else:
raise ValueError("Dataset not supported")

# 4. Define your strategy
strategy = instantiate(
cfg.strategy,
evaluate_fn = clients[dataset_name][1],
on_fit_config_fn=fit_config,
fraction_fit=1,
fraction_evaluate=0,
min_fit_clients=cfg.server.pool_size,
min_evaluate_clients=0,
warmup_rounds=cfg.server.warmup_rounds,
to_keep=cfg.strategy.to_keep,
min_available_clients=cfg.server.pool_size,
window=cfg.server.warmup_rounds,
distance_function=l2_norm,
maxiter=cfg.strategy.maxiter,
)

# 5. Start Simulation
fl.simulation.start_simulation(
client_fn=client_fn,
client_resources={"num_cpus": 10},
num_clients=cfg.server.pool_size,
server=EnhancedServer(
warmup_rounds=cfg.server.warmup_rounds,
num_malicious=num_malicious,
attack_fn=attacks[attack_fn],
magnitude=cfg.server.magnitude,
client_manager=SimpleClientManager(),
strategy=strategy,
sampling=cfg.server.sampling,
history_dir=cfg.server.history_dir,
dataset_name=dataset_name,
threshold=cfg.server.threshold,
omniscent=cfg.server.omniscent,
),
config=fl.server.ServerConfig(num_rounds=cfg.server.num_rounds),
strategy=strategy
)
# Delete old client_params and clients_predicted_params
if os.path.exists(cfg.server.history_dir):
shutil.rmtree(cfg.server.history_dir)

dataset_name = cfg.dataset
attack_fn = cfg.server.attack_fn
num_malicious = cfg.server.num_malicious

# 2. Prepare your dataset
sampling = cfg.server.sampling
if dataset_name == "cifar":
train_path, testset = get_cifar_10()
fed_dir = do_fl_partitioning(
train_path, pool_size=cfg.server.pool_size, alpha=10000, num_classes=10, val_ratio=0.5, seed=1234
)
elif dataset_name == "income":
sampling = 0
X_train, X_test, y_train, y_test = get_partitioned_income("flanders/datasets_files/adult.csv", cfg.server.pool_size)
elif dataset_name == "house":
sampling = 0
X_train, X_test, y_train, y_test = get_partitioned_house("flanders/datasets_files/houses_preprocessed.csv", cfg.server.pool_size)


# 3. Define your clients
def client_fn(cid: str, pool_size: int = 10, dataset_name: str = dataset_name):
client = clients[dataset_name][0]
cid_idx = int(cid)
if dataset_name == "cifar":
return client(cid, fed_dir)
elif dataset_name == "mnist":
return client(cid, pool_size)
elif dataset_name == "income":
return client(cid, X_train[cid_idx], y_train[cid_idx], X_test[cid_idx], y_test[cid_idx])
elif dataset_name == "house":
return client(cid, X_train[cid_idx], y_train[cid_idx], X_test[cid_idx], y_test[cid_idx])
else:
raise ValueError("Dataset not supported")

# 4. Define your strategy
strategy = instantiate(
cfg.strategy,
evaluate_fn = clients[dataset_name][1],
on_fit_config_fn=fit_config,
fraction_fit=1,
fraction_evaluate=0,
min_fit_clients=cfg.server.pool_size,
min_evaluate_clients=0,
warmup_rounds=cfg.server.warmup_rounds,
to_keep=cfg.strategy.to_keep,
min_available_clients=cfg.server.pool_size,
window=cfg.server.warmup_rounds,
distance_function=l2_norm,
maxiter=cfg.strategy.maxiter,
)

# 5. Start Simulation
fl.simulation.start_simulation(
client_fn=client_fn,
client_resources={"num_cpus": 10},
num_clients=cfg.server.pool_size,
server=EnhancedServer(
warmup_rounds=cfg.server.warmup_rounds,
num_malicious=num_malicious,
attack_fn=attacks[attack_fn],
magnitude=cfg.server.magnitude,
client_manager=SimpleClientManager(),
strategy=strategy,
sampling=cfg.server.sampling,
history_dir=cfg.server.history_dir,
dataset_name=dataset_name,
threshold=cfg.server.threshold,
omniscent=cfg.server.omniscent,
),
config=fl.server.ServerConfig(num_rounds=cfg.server.num_rounds),
strategy=strategy
)

def fit_config(server_round: int) -> Dict[str, Scalar]:
"""Return a configuration with static batch size and (local) epochs."""
Expand Down

0 comments on commit 56e2bb8

Please sign in to comment.