Skip to content

Commit

Permalink
refactor(examples) Update get/set parameters functions for FlowerTune…
Browse files Browse the repository at this point in the history
… LLM example (#4219)
  • Loading branch information
yan-gao-GY authored Sep 16, 2024
1 parent 13decda commit dd22cd9
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
18 changes: 2 additions & 16 deletions examples/flowertune-llm/flowertune_llm/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import warnings
from typing import Dict, Tuple
from collections import OrderedDict

import torch
from flwr.client import ClientApp, NumPyClient
Expand All @@ -12,7 +11,6 @@
from flwr.common.typing import NDArrays, Scalar
from omegaconf import DictConfig

from peft import get_peft_model_state_dict, set_peft_model_state_dict
from transformers import TrainingArguments
from trl import SFTTrainer

Expand All @@ -24,6 +22,8 @@
from flowertune_llm.models import (
cosine_annealing,
get_model,
set_parameters,
get_parameters,
)

# Avoid warnings
Expand Down Expand Up @@ -96,20 +96,6 @@ def fit(
)


def set_parameters(model, parameters: NDArrays) -> None:
"""Change the parameters of the model using the given ones."""
peft_state_dict_keys = get_peft_model_state_dict(model).keys()
params_dict = zip(peft_state_dict_keys, parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
set_peft_model_state_dict(model, state_dict)


def get_parameters(model) -> NDArrays:
"""Return the parameters of the current net."""
state_dict = get_peft_model_state_dict(model)
return [val.cpu().numpy() for _, val in state_dict.items()]


def client_fn(context: Context) -> FlowerClient:
"""Create a Flower client representing a single organization."""
partition_id = context.node_config["partition-id"]
Expand Down
19 changes: 19 additions & 0 deletions examples/flowertune-llm/flowertune_llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@

import torch
from omegaconf import DictConfig
from collections import OrderedDict
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
set_peft_model_state_dict,
)
from peft.utils import prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

from flwr.common.typing import NDArrays


def cosine_annealing(
current_round: int,
Expand Down Expand Up @@ -56,3 +61,17 @@ def get_model(model_cfg: DictConfig):
)

return get_peft_model(model, peft_config)


def set_parameters(model, parameters: NDArrays) -> None:
"""Change the parameters of the model using the given ones."""
peft_state_dict_keys = get_peft_model_state_dict(model).keys()
params_dict = zip(peft_state_dict_keys, parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
set_peft_model_state_dict(model, state_dict)


def get_parameters(model) -> NDArrays:
"""Return the parameters of the current net."""
state_dict = get_peft_model_state_dict(model)
return [val.cpu().numpy() for _, val in state_dict.items()]
3 changes: 1 addition & 2 deletions examples/flowertune-llm/flowertune_llm/server_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from flwr.server.strategy import FedAvg
from omegaconf import DictConfig

from flowertune_llm.models import get_model
from flowertune_llm.models import get_model, get_parameters, set_parameters
from flowertune_llm.dataset import replace_keys
from flowertune_llm.client_app import get_parameters, set_parameters


# Get function that will be executed by the strategy's evaluate() method
Expand Down
2 changes: 1 addition & 1 deletion examples/flowertune-llm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = "1.0.0"
description = "FlowerTune LLM: Federated LLM Fine-tuning with Flower"
license = "Apache-2.0"
dependencies = [
"flwr[simulation]==1.11.0",
"flwr[simulation]==1.11.1",
"flwr-datasets>=0.3.0",
"trl==0.8.1",
"bitsandbytes==0.43.0",
Expand Down

0 comments on commit dd22cd9

Please sign in to comment.