diff --git a/examples/flowertune-llm/flowertune_llm/client_app.py b/examples/flowertune-llm/flowertune_llm/client_app.py index 992b0f1a3e09..b61a733b29cf 100644 --- a/examples/flowertune-llm/flowertune_llm/client_app.py +++ b/examples/flowertune-llm/flowertune_llm/client_app.py @@ -3,7 +3,6 @@ import os import warnings from typing import Dict, Tuple -from collections import OrderedDict import torch from flwr.client import ClientApp, NumPyClient @@ -12,7 +11,6 @@ from flwr.common.typing import NDArrays, Scalar from omegaconf import DictConfig -from peft import get_peft_model_state_dict, set_peft_model_state_dict from transformers import TrainingArguments from trl import SFTTrainer @@ -24,6 +22,8 @@ from flowertune_llm.models import ( cosine_annealing, get_model, + set_parameters, + get_parameters, ) # Avoid warnings @@ -96,20 +96,6 @@ def fit( ) -def set_parameters(model, parameters: NDArrays) -> None: - """Change the parameters of the model using the given ones.""" - peft_state_dict_keys = get_peft_model_state_dict(model).keys() - params_dict = zip(peft_state_dict_keys, parameters) - state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) - set_peft_model_state_dict(model, state_dict) - - -def get_parameters(model) -> NDArrays: - """Return the parameters of the current net.""" - state_dict = get_peft_model_state_dict(model) - return [val.cpu().numpy() for _, val in state_dict.items()] - - def client_fn(context: Context) -> FlowerClient: """Create a Flower client representing a single organization.""" partition_id = context.node_config["partition-id"] diff --git a/examples/flowertune-llm/flowertune_llm/models.py b/examples/flowertune-llm/flowertune_llm/models.py index 7d0c8f391687..e1609caeb2fc 100644 --- a/examples/flowertune-llm/flowertune_llm/models.py +++ b/examples/flowertune-llm/flowertune_llm/models.py @@ -2,13 +2,18 @@ import torch from omegaconf import DictConfig +from collections import OrderedDict from peft import ( LoraConfig, get_peft_model, + get_peft_model_state_dict, + set_peft_model_state_dict, ) from peft.utils import prepare_model_for_kbit_training from transformers import AutoModelForCausalLM, BitsAndBytesConfig +from flwr.common.typing import NDArrays + def cosine_annealing( current_round: int, @@ -56,3 +61,17 @@ def get_model(model_cfg: DictConfig): ) return get_peft_model(model, peft_config) + + +def set_parameters(model, parameters: NDArrays) -> None: + """Change the parameters of the model using the given ones.""" + peft_state_dict_keys = get_peft_model_state_dict(model).keys() + params_dict = zip(peft_state_dict_keys, parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + set_peft_model_state_dict(model, state_dict) + + +def get_parameters(model) -> NDArrays: + """Return the parameters of the current net.""" + state_dict = get_peft_model_state_dict(model) + return [val.cpu().numpy() for _, val in state_dict.items()] diff --git a/examples/flowertune-llm/flowertune_llm/server_app.py b/examples/flowertune-llm/flowertune_llm/server_app.py index 309166cc30a3..ff0da90c8b9b 100644 --- a/examples/flowertune-llm/flowertune_llm/server_app.py +++ b/examples/flowertune-llm/flowertune_llm/server_app.py @@ -9,9 +9,8 @@ from flwr.server.strategy import FedAvg from omegaconf import DictConfig -from flowertune_llm.models import get_model +from flowertune_llm.models import get_model, get_parameters, set_parameters from flowertune_llm.dataset import replace_keys -from flowertune_llm.client_app import get_parameters, set_parameters # Get function that will be executed by the strategy's evaluate() method diff --git a/examples/flowertune-llm/pyproject.toml b/examples/flowertune-llm/pyproject.toml index 20aa7267d9d5..5c057de2ea70 100644 --- a/examples/flowertune-llm/pyproject.toml +++ b/examples/flowertune-llm/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.0" description = "FlowerTune LLM: Federated LLM Fine-tuning with Flower" license = "Apache-2.0" dependencies = [ - "flwr[simulation]==1.11.0", + "flwr[simulation]==1.11.1", "flwr-datasets>=0.3.0", "trl==0.8.1", "bitsandbytes==0.43.0",