Skip to content

Commit

Permalink
Update the registry with configs + example with LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
alheliou authored and elmo55 committed Oct 29, 2024
1 parent 1f41c65 commit 8f5fd79
Show file tree
Hide file tree
Showing 21 changed files with 1,883 additions and 25 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions configs/server.yml → configs/cifar10_cnn/server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ data:
server_adress: "[::]:22222"
num_rounds: 2
client_configs:
- ${root_dir}/configs/client_1.yml
- ${root_dir}/configs/client_2.yml
- ${root_dir}/configs/cifar10_cnn/client_1.yml
- ${root_dir}/configs/cifar10_cnn/client_2.yml
save_on_train_end: true
24 changes: 24 additions & 0 deletions configs/turbofan_lstm/client_1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
cid: 1
pre_train_val: true
fabric:
accelerator: gpu
devices:
- 0
root_dir: ${oc.env:PWD}
model:
name: lstm
config:
n_features: 24
hidden_units: 12
lr: 0.001
data:
name: turbofan
config:
data_path: ${root_dir}/src/ml/data/turbofan/turbofan.txt
engines_train_list: [52,62]
engines_val_list: [64]
engines_test_list: [69]
window: 20
batch_size: 8
num_workers: 0
server_adress: localhost:22222
23 changes: 23 additions & 0 deletions configs/turbofan_lstm/client_2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
cid: 2
fabric:
accelerator: gpu
devices:
- 0
root_dir: ${oc.env:PWD}
model:
name: lstm
config:
n_features: 24
hidden_units: 12
lr: 0.001
data:
name: turbofan
config:
data_path: ${root_dir}/src/ml/data/turbofan/turbofan.txt
engines_train_list: [2]
engines_val_list: [64]
engines_test_list: [69]
window: 20
batch_size: 8
num_workers: 0
server_adress: localhost:22222
22 changes: 22 additions & 0 deletions configs/turbofan_lstm/local_train.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
trainer:
max_epochs: 10
accelerator: gpu
devices:
- 0
root_dir: ${oc.env:PWD}
model:
name: lstm
config:
n_features: 24
hidden_units: 12
lr: 0.001
data:
name: turbofan
config:
data_path: ${root_dir}/src/ml/data/turbofan/turbofan.txt
engines_train_list: [52,62,2]
engines_val_list: [64]
engines_test_list: [69]
window: 20
batch_size: 8
num_workers: 0
33 changes: 33 additions & 0 deletions configs/turbofan_lstm/server.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
fabric:
accelerator: gpu
devices:
- 0
root_dir: ${oc.env:PWD}
logger:
subdir: /experiments/federated/test_1/
strategy:
name: "fabric"
config:
min_fit_clients: 2
model:
name: lstm
config:
n_features: 24
hidden_units: 12
lr: 0.001
data:
name: turbofan
config:
data_path: ${root_dir}/src/ml/data/turbofan/turbofan.txt
engines_train_list: [52,62]
engines_val_list: [64]
engines_test_list: [69]
window: 20
batch_size: 8
num_workers: 0
server_adress: "[::]:22222"
num_rounds: 2
client_configs:
- ${root_dir}/configs/turbofan_lstm/client_1.yml
- ${root_dir}/configs/turbofan_lstm/client_2.yml
save_on_train_end: true
8 changes: 4 additions & 4 deletions docs/how-to.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Available models are located in `ml/models/`. To add a new model, follow the few
- the second one contains the LightningModule based on the classical nn.module.
* if needed, add another directory `ml/models/my-model/all-things-needed/` which would contain all things necessary for your model to work properly: specific losses and metrics, dedicated torch modules, and so on.
* update the file `ml/registry.py`:
- import your LightningModule;
- update `model_registry` by adding a new key linking to your LightingModule.
- import your LightningModule and its config;
- update `model_registry` and `ModelConfig` by adding a new key linking to your LightingModule.

## How to add datasets in Pybiscus

Expand All @@ -25,5 +25,5 @@ Available datasets are located in `ml/data/`. To add a new model, follow the few
- the second one contains the LightningDataModule based on the classical torch.dataset.
* if needed, add another directory `ml/data/my-data/all-things-needed/` which would contain all things necessary for your dataset to work properly, in particular preprocessing.
* update the file `ml/registry.py`:
- import your LightningDataModule;
- update `datamodule_registry` by adding a new key linking to your LightingDataModule.
- import your LightningDataModule and its config;
- update `datamodule_registry` and `DataConfig` by adding a new key linking to your LightingDataModule.
108 changes: 106 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ einops = "^0.6.1"
tensorboard = "^2.14.1"
pydantic = "2.1.1"
trogon = "^0.5.0"
pandas = "^2.2.3"

[tool.poetry.group.dev-dependencies.dependencies]
black = "<25.0"
Expand Down
35 changes: 29 additions & 6 deletions src/flower/client_fabric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import OrderedDict
from collections.abc import Mapping
from typing import Union
from typing_extensions import Annotated

import flwr as fl
import torch
Expand All @@ -10,7 +12,7 @@
from src.console import console
from src.ml.data.cifar10.cifar10_datamodule import ConfigData_Cifar10
from src.ml.loops_fabric import test_loop, train_loop
from src.ml.models.cnn.lit_cnn import ConfigModel_Cifar10
from src.ml.registry import ModelConfig, DataConfig

torch.backends.cudnn.enabled = True

Expand All @@ -33,6 +35,7 @@ class ConfigFabric(BaseModel):
accelerator: str
devices: Union[int, list[int], str] = "auto"

ConfigModel = Annotated[int, lambda x: x > 0]

class ConfigClient(BaseModel):
"""A Pydantic Model to validate the Client configuration given by the user.
Expand Down Expand Up @@ -62,15 +65,35 @@ class ConfigClient(BaseModel):
server_adress: str
root_dir: str
fabric: ConfigFabric
model: ConfigModel_Cifar10
data: ConfigData_Cifar10
model: ModelConfig
data: DataConfig

# Below is used when several models and/or datasets are available.
# model: Union[ConfigModel_Cifar10, ...] = Field(discriminator="name")
# data: Union[ConfigData_Cifar10, ...] = Field(discriminator="name")

model_config = ConfigDict(extra="forbid")

def parse_optimizers(lightning_optimizers):
"""
Parse the output of lightning configure_optimizers
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers
To extract only the optimizers (and not the lr_schedulers)
"""
optimizers = []
if lightning_optimizers:
if isinstance(lightning_optimizers, Mapping):
optimizers.append(lightning_optimizers['optimizer'])
elif isinstance(lightning_optimizers, torch.optim.Optimizer):
optimizers.append(lightning_optimizers)
else:
for optmizers_conf in lightning_optimizers:
if isinstance(optmizers_conf, dict):
optimizers.append(lightning_optimizers)
else:
optimizers.append(optmizers_conf)
return optimizers


class FlowerClient(fl.client.NumPyClient):
"""A Fabric-based, modular Flower Client.
Expand Down Expand Up @@ -116,13 +139,13 @@ def __init__(
self.num_examples = num_examples
self.pre_train_val = pre_train_val

self.optimizer = self.model.configure_optimizers()
self.optimizers = parse_optimizers(self.model.configure_optimizers())

self.fabric = Fabric(**self.conf_fabric)

def initialize(self):
self.fabric.launch()
self.model, self.optimizer = self.fabric.setup(self.model, self.optimizer)
self.model, self.optimizers = self.fabric.setup(self.model, *self.optimizers)
(
self._train_dataloader,
self._validation_dataloader,
Expand Down Expand Up @@ -160,7 +183,7 @@ def fit(self, parameters, config):
self.fabric,
self.model,
self._train_dataloader,
self.optimizer,
self.optimizers, # Alice TODO extend this to multiple optimizers ??
epochs=config["local_epochs"],
)
console.log(f"Training Finished! Loss is {results_train['loss']}")
Expand Down
13 changes: 6 additions & 7 deletions src/flower/server_fabric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Callable, Optional
from typing import Callable, Optional, Union

import flwr as fl
import numpy as np
Expand All @@ -12,9 +12,10 @@
from src.console import console
from src.flower.client_fabric import ConfigFabric
from src.flower.strategies import ConfigFabricStrategy
from src.ml.data.cifar10.cifar10_datamodule import ConfigData_Cifar10
from src.ml.loops_fabric import test_loop
from src.ml.models.cnn.lit_cnn import ConfigModel_Cifar10
from src.ml.registry import ModelConfig, DataConfig




class ConfigStrategy(BaseModel):
Expand Down Expand Up @@ -58,10 +59,8 @@ class ConfigServer(BaseModel):
logger: dict
strategy: ConfigStrategy
fabric: ConfigFabric
model: ConfigModel_Cifar10
data: ConfigData_Cifar10
# model: Union[ConfigModel_Cifar10] = Field(discriminator="name")
# data: Union[ConfigData_Cifar10] = Field(discriminator="name")
model: ModelConfig
data: DataConfig
client_configs: list[str] = Field(default=None)
save_on_train_end: bool = Field(default=False)

Expand Down
Empty file.
Loading

0 comments on commit 8f5fd79

Please sign in to comment.