-
Notifications
You must be signed in to change notification settings - Fork 941
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create new example flower-via-docker-compose (#2626)
Co-authored-by: jafermarq <[email protected]>
- Loading branch information
1 parent
b02f263
commit 944afce
Showing
16 changed files
with
1,818 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# ignore __pycache__ directories | ||
__pycache__/ | ||
|
||
# ignore .pyc files | ||
*.pyc | ||
|
||
# ignore .vscode directory | ||
.vscode/ | ||
|
||
# ignore .npz files | ||
*.npz | ||
|
||
# ignore .csv files | ||
*.csv | ||
|
||
# ignore docker-compose.yaml file | ||
docker-compose.yml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Use an official Python runtime as a parent image | ||
FROM python:3.10-slim-buster | ||
|
||
# Set the working directory in the container to /app | ||
WORKDIR /app | ||
|
||
# Copy the requirements file into the container | ||
COPY ./requirements.txt /app/requirements.txt | ||
|
||
# Install gcc and other dependencies | ||
RUN apt-get update && apt-get install -y \ | ||
gcc \ | ||
python3-dev && \ | ||
rm -rf /var/lib/apt/lists/* | ||
|
||
# Install any needed packages specified in requirements.txt | ||
RUN pip install -r requirements.txt | ||
|
||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import os | ||
import argparse | ||
import flwr as fl | ||
import tensorflow as tf | ||
import logging | ||
from helpers.load_data import load_data | ||
import os | ||
from model.model import Model | ||
|
||
logging.basicConfig(level=logging.INFO) # Configure logging | ||
logger = logging.getLogger(__name__) # Create logger for the module | ||
|
||
# Make TensorFlow log less verbose | ||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | ||
|
||
# Parse command line arguments | ||
parser = argparse.ArgumentParser(description="Flower client") | ||
|
||
parser.add_argument( | ||
"--server_address", type=str, default="server:8080", help="Address of the server" | ||
) | ||
parser.add_argument( | ||
"--batch_size", type=int, default=32, help="Batch size for training" | ||
) | ||
parser.add_argument( | ||
"--learning_rate", type=float, default=0.1, help="Learning rate for the optimizer" | ||
) | ||
parser.add_argument("--client_id", type=int, default=1, help="Unique ID for the client") | ||
parser.add_argument( | ||
"--total_clients", type=int, default=2, help="Total number of clients" | ||
) | ||
parser.add_argument( | ||
"--data_percentage", type=float, default=0.5, help="Portion of client data to use" | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
# Create an instance of the model and pass the learning rate as an argument | ||
model = Model(learning_rate=args.learning_rate) | ||
|
||
# Compile the model | ||
model.compile() | ||
|
||
|
||
class Client(fl.client.NumPyClient): | ||
def __init__(self, args): | ||
self.args = args | ||
|
||
logger.info("Preparing data...") | ||
(x_train, y_train), (x_test, y_test) = load_data( | ||
data_sampling_percentage=self.args.data_percentage, | ||
client_id=self.args.client_id, | ||
total_clients=self.args.total_clients, | ||
) | ||
|
||
self.x_train = x_train | ||
self.y_train = y_train | ||
self.x_test = x_test | ||
self.y_test = y_test | ||
|
||
def get_parameters(self, config): | ||
# Return the parameters of the model | ||
return model.get_model().get_weights() | ||
|
||
def fit(self, parameters, config): | ||
# Set the weights of the model | ||
model.get_model().set_weights(parameters) | ||
|
||
# Train the model | ||
history = model.get_model().fit( | ||
self.x_train, self.y_train, batch_size=self.args.batch_size | ||
) | ||
|
||
# Calculate evaluation metric | ||
results = { | ||
"accuracy": float(history.history["accuracy"][-1]), | ||
} | ||
|
||
# Get the parameters after training | ||
parameters_prime = model.get_model().get_weights() | ||
|
||
# Directly return the parameters and the number of examples trained on | ||
return parameters_prime, len(self.x_train), results | ||
|
||
def evaluate(self, parameters, config): | ||
# Set the weights of the model | ||
model.get_model().set_weights(parameters) | ||
|
||
# Evaluate the model and get the loss and accuracy | ||
loss, accuracy = model.get_model().evaluate( | ||
self.x_test, self.y_test, batch_size=self.args.batch_size | ||
) | ||
|
||
# Return the loss, the number of examples evaluated on and the accuracy | ||
return float(loss), len(self.x_test), {"accuracy": float(accuracy)} | ||
|
||
|
||
# Function to Start the Client | ||
def start_fl_client(): | ||
try: | ||
client = Client(args).to_client() | ||
fl.client.start_client(server_address=args.server_address, client=client) | ||
except Exception as e: | ||
logger.error("Error starting FL client: %s", e) | ||
return {"status": "error", "message": str(e)} | ||
|
||
|
||
if __name__ == "__main__": | ||
# Call the function to start the client | ||
start_fl_client() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[security] | ||
allow_embedding = true | ||
admin_user = admin | ||
admin_password = admin | ||
|
||
[dashboards] | ||
default_home_dashboard_path = /etc/grafana/provisioning/dashboards/dashboard_index.json | ||
|
||
[auth.anonymous] | ||
enabled = true | ||
org_name = Main Org. | ||
org_role = Admin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
|
||
global: | ||
scrape_interval: 1s | ||
evaluation_interval: 1s | ||
|
||
rule_files: | ||
scrape_configs: | ||
- job_name: 'cadvisor' | ||
scrape_interval: 1s | ||
metrics_path: '/metrics' | ||
static_configs: | ||
- targets: ['cadvisor:8080'] | ||
labels: | ||
group: 'cadvisor' | ||
- job_name: 'server_metrics' | ||
scrape_interval: 1s | ||
metrics_path: '/metrics' | ||
static_configs: | ||
- targets: ['server:8000'] |
Oops, something went wrong.