Skip to content

Commit

Permalink
Integrate zero-knowledge proof check for horizontal FL (#13)
Browse files Browse the repository at this point in the history
* Refine the structure

* Add gitignore and update descriptions

* Add zkp step notes and clean code

* Add zkp client and server initialization

* Add zkp check sketch

* Add arguments for zkp check library

* Add interaction logic between client and server

* Add SignPubKey and SignPrvKey messages

* Add weights flatten/unflatten and fix training process

* Fix server message index issue

* Add random str broadcast and model dimension calculation

* Clean log and add function comments

* Fix the client weight norm bound issue

* Update README.md commands
  • Loading branch information
lemonviv authored May 13, 2024
1 parent a3f46f3 commit dab5887
Show file tree
Hide file tree
Showing 10 changed files with 847 additions and 355 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,5 @@ data/dataset/breast_cancer_data/client0/log_reg/
data/dataset/breast_cancer_data/client1/log_reg/
data/dataset/breast_cancer_data/client2/log_reg/

python/data/
src/executor/python/data/bank/
src/executor/python/data/mnist/
Empty file.
11 changes: 6 additions & 5 deletions src/executor/python/hfl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,24 @@ Download the bank dataset and split it into 3 partitions.
# 3. run the following command which:
# (1) splits the dataset into N subsets
# (2) splits each subsets into train set and test set (8:2)
python -m bank N
# (3) puts the split sub-datasets to /FALCON_PATH/src/executor/python/data/bank/
python3 -m bank N
```

## Run the example

Run the server first (set the number of epochs to 3)

```bash
python -m src.server -m 3 --num_clients 3
python3 -m src.server_zkp -m 3 --num_clients 3
```

Then, start 3 clients in different terminal

```bash
python -m src.client --model mlp --data bank -m 3 -i 0 -d non-iid
python -m src.client --model mlp --data bank -m 3 -i 1 -d non-iid
python -m src.client --model mlp --data bank -m 3 -i 2 -d non-iid
python3 -m src.client_zkp --model mlp --data bank --data_dir /opt/falcon/src/executor/python/data/bank -m 3 --num_clients 3 -i 0
python3 -m src.client_zkp --model mlp --data bank --data_dir /opt/falcon/src/executor/python/data/bank -m 3 --num_clients 3 -i 1
python3 -m src.client_zkp --model mlp --data bank --data_dir /opt/falcon/src/executor/python/data/bank -m 3 --num_clients 3 -i 2
```

Finally, the server and clients finish the FL training.
47 changes: 2 additions & 45 deletions src/executor/python/hfl/src/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import socket

from .proto import interface_pb2 as proto
from .utils import parseargs, get_dataset
from .utils import parseargs
from .dataset.data_loader import get_dataset

import numpy as np

Expand Down Expand Up @@ -95,50 +96,6 @@ def push(self) -> None:
utils.send_message(self.sock, message)


# Calculate accuracy
def accuracy(pred, target):
# y is network output to be compared with ground truth (int)
y = np.argmax(pred, axis=1)
a = y == target
correct = np.array(a, "int").sum()
return correct


# Data partition according to the rank
def partition(global_rank, world_size, train_x, train_y, val_x, val_y):
# Partition training data
data_per_rank = train_x.shape[0] // world_size
idx_start = global_rank * data_per_rank
idx_end = (global_rank + 1) * data_per_rank
train_x = train_x[idx_start:idx_end]
train_y = train_y[idx_start:idx_end]

# Partition evaluation data
data_per_rank = val_x.shape[0] // world_size
idx_start = global_rank * data_per_rank
idx_end = (global_rank + 1) * data_per_rank
val_x = val_x[idx_start:idx_end]
val_y = val_y[idx_start:idx_end]
return train_x, train_y, val_x, val_y



def get_data(data, data_dist="iid", device_id=None):
if data == "bank":
train_x, train_y, val_x, val_y, num_classes = bank.load(device_id)
else:
raise NotImplementedError
return train_x, train_y, val_x, val_y, num_classes


def get_model(model, num_channels=None, num_classes=None, data_size=None):
if model == "mlp":
model = MLP_Bank(dim_in=data_size, dim_hidden=64, dim_out=num_classes)
else:
raise NotImplementedError
return model


def run(
args,
global_rank,
Expand Down
291 changes: 291 additions & 0 deletions src/executor/python/hfl/src/client_zkp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
#!/usr/bin/env python3

import socket
import sys

from .proto import interface_pb2 as proto
from .utils import parseargs, check_defense_type, send_string, receive_string, flatten_model_weights, flattened_weight_size
from .dataset.data_loader import get_dataset

import numpy as np
import copy
import os
import base64
from tqdm import tqdm

from . import utils
from .dataset import bank
from .models import MLP_Bank, CNNMnist

from .update import LocalUpdate, test_inference

import risefl_interface


class Client:
"""Client sends and receives protobuf messages.
Create and start the server, then use pull and push to communicate with the server.
Attributes:
global_rank (int): The rank in training process.
host (str): Host address of the server.
port (str): Port of the server.
sock (socket.socket): Socket of the client.
weights (Dict[Any]): Weights stored locally.
"""

def __init__(
self,
global_rank: int = 0,
host: str = "127.0.0.1",
port: str = 1234,
) -> None:
"""Class init method
Args:
global_rank (int, optional): The rank in training process. Defaults to 0.
host (str, optional): Host ip address. Defaults to '127.0.0.1'.
port (str, optional): Port. Defaults to 1234.
"""
self.host = host
self.port = port
self.global_rank = global_rank

self.sock = socket.socket()

self.weights = {}

self.zkp_client = None
self.check_param = None
self.random_bytes_str = None

def __start_connection(self) -> None:
"""Start the network connection to server."""
self.sock.connect((self.host, self.port))

def __start_rank_pairing(self) -> None:
"""Sending global rank to server"""
utils.send_int(self.sock, self.global_rank)

def init_zkp_client(self, args, global_model) -> None:
"""Initialize zkp params etc."""
defense_type = check_defense_type(args.check_type)

sign_pub_keys_vec = risefl_interface.VecSignPubKeys(args.num_clients + 1)
print("****** [client_zkp.init_zkp_client] args.num_clients + 1 = ", args.num_clients + 1)
for j in range(args.num_clients + 1):
# print("j = ", j)
recv_pub_key_j_str = receive_string(self.sock)
sign_pub_keys_vec[j] = risefl_interface.convert_string_to_sign_pub_key(recv_pub_key_j_str)
# print("recv " + str(j) + " sign_pub_key success")
# print("recv_pub_key finished")
recv_prv_key_str = receive_string(self.sock)
# print("recv_prv_key_str = ", recv_prv_key_str)
sign_prv_keys_vec_i = risefl_interface.convert_string_to_sign_prv_key(recv_prv_key_str)

print("****** [client_zkp.init_zkp_client] init sign_keys success")

dim = int(flattened_weight_size(global_model))
print("****** [client_zkp.init_zkp_client] dim = ", dim)

# the client id in the zkp library starts from 1, so the index needs to be increased by 1
self.zkp_client = risefl_interface.ClientInterface(
args.num_clients, args.max_malicious_clients, dim,
args.num_blinds_per_group_element, args.weight_bits, args.random_normal_bit_shifter,
args.num_norm_bound_samples, args.inner_prod_bound_bits, args.max_bound_sq_bits,
defense_type, self.global_rank + 1,
sign_pub_keys_vec, sign_prv_keys_vec_i)

print(f"****** [client_zkp.init_zkp_client] self.global_rank + 1 = {self.global_rank + 1}")
print("****** [client_zkp.init_zkp_client] create zkp_client success")

# initialize the check parameter
self.check_param = risefl_interface.CheckParamFloat(defense_type)
self.check_param.l2_param.bound = args.norm_bound

# a random string used to generate independent group elements, to be used by both the server and clients
# random_bytes = os.urandom(64)
# random_bytes_str = base64.b64encode(random_bytes).decode('ascii')
# random_bytes_str = "r0sdTz/eXbBDsPpB9QiB4P+ejll9juZdbYa4Xt+OZbFlV/n7FUcTMas64getSoWMoV5hE+UmiR6W554xa4SPnQ=="
# print("random_bytes_str = " + random_bytes_str)
self.random_bytes_str = receive_string(self.sock)
self.zkp_client.initialize_from_seed(self.random_bytes_str)
print("****** [client_zkp.init_zkp_client] received random_bytes_str from server")
print("****** [client_zkp.init_zkp_client] init zkp_client success")

def start(self) -> None:
"""Start the client.
This method will first connect to the server. Then global rank is sent to the server.
"""
self.__start_connection()
self.__start_rank_pairing()

print(f"[Client {self.global_rank}] Connect to {self.host}:{self.port}")

def close(self) -> None:
"""Close the server."""
self.sock.close()

def pull(self) -> None:
"""Client pull weights from server.
Namely server push weights from clients.
"""
message = proto.WeightsExchange()
message = utils.receive_message(self.sock, message)
for k, v in message.weights.items():
self.weights[k] = utils.deserialize_tensor(v)

def push(self) -> None:
"""Client push weights to server.
Namely server pull weights from clients.
"""
message = proto.WeightsExchange()
message.op_type = proto.GATHER
for k, v in self.weights.items():
message.weights[k] = utils.serialize_tensor(v)
utils.send_message(self.sock, message)


def run(args, device_id):
# Connect to server
client = Client(global_rank=device_id)
client.start()

# if args.gpu_id:
# torch.cuda.set_device(args.gpu_id)
device = 'cuda' if args.gpu else 'cpu'

# load dataset and user groups
train_dataset, test_dataset = get_dataset(args.data, args.data_dir, client_id=client.global_rank)

# BUILD MODEL
if args.model == 'cnn':
# Convolutional neural netork
if args.data == 'mnist':
global_model = CNNMnist(num_channels=args.num_channels, num_classes=args.num_classes)

elif args.model == 'mlp':
# Multi-layer perceptron
len_in = args.num_features
global_model = MLP_Bank(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
else:
exit('****** [client_zkp.run] Error: unrecognized model')

# Set the model to train and send it to device.
global_model.to(device)
global_model.train()
print(global_model)

# copy weights
global_weights = global_model.state_dict()
client.weights = global_weights

client.init_zkp_client(args, global_model)

# Training
train_loss, train_accuracy = [], []

for epoch in tqdm(range(args.max_epoch)):
# local_weights, local_losses = [], []
print(f'\n | Global Training Round : {epoch} |\n')

global_model.load_state_dict(client.weights)

print("****** [client_zkp.run] client start local update...")
local_model = LocalUpdate(args=args, dataset=train_dataset)
w, loss = local_model.update_weights(
model=copy.deepcopy(global_model), global_round=epoch)
print("****** [client_zkp.run] client finish local update...")

client.weights = copy.deepcopy(w)
print(f'****** [client_zkp.run] Local training loss : {loss}')

# step 1 client sends message to the server
# flatten weights to 1D array
print(f"****** [client_zkp.run] client.weights.type: {type(client.weights)}")
# print("client.weights: ", client.weights)
flatten_weights = flatten_model_weights(w)
# clip the weight by norm_bound, otherwise, will get error when the server aggregates
flatten_weights *= args.norm_bound / np.linalg.norm(flatten_weights)
flatten_weights = flatten_weights.tolist()
print(f"****** [client_zkp.run] flatten_weights.type: {type(flatten_weights)}")
print(f"****** [client_zkp.run] flatten_weights.length: {len(flatten_weights)}")
# print("flatten_weights: ", flatten_weights)

# for testing the correctness of summation when dim = 2
# weight_updates_collection = [[0, 0], [0.09574025869369507, -0.0437011756002903],
# [-0.012869355268776417, 0.0022518674377352],
# [-0.07237587869167328, 0.12259631603956223]]
# weight_updates_collection = np.random.rand(4, 31)
# weight_updates_collection = weight_updates_collection.tolist()
# print(f"weight_update_collection.type: {type(weight_updates_collection[1])}")
# print(weight_updates_collection[client.global_rank + 1])
# print(f"flatten_weights: {flatten_weights}")
converted_weights = risefl_interface.VecFloat(flatten_weights)
# converted_weights = risefl_interface.VecFloat(weight_updates_collection[client.global_rank + 1])
client_send_str1 = client.zkp_client.send_1(client.check_param, converted_weights)
# send this string to the server
send_string(client.sock, client_send_str1)

print("****** [client_zkp.run] client_sends_str1 finished")

# step 2 receive message from the server and sends message back to the server
server_sent_2_str = receive_string(client.sock)
# bytes_sent_2 = sent_2.encode()
client_send_str2 = client.zkp_client.receive_and_send_2(server_sent_2_str)
send_string(client.sock, client_send_str2)

print("****** [client_zkp.run] client_sends_str2 finished")

# step 3 receive message from the server and sends message back to the server
server_sent_3_str = receive_string(client.sock)
client_send_str3 = client.zkp_client.receive_and_send_3(server_sent_3_str)
send_string(client.sock, client_send_str3)

print("****** [client_zkp.run] client_sends_str3 finished")

# step 4 receive message from the server and sends message back to the server
server_sent_4_str = receive_string(client.sock)
client_send_str4 = client.zkp_client.receive_and_send_4(server_sent_4_str)
send_string(client.sock, client_send_str4)

print("****** [client_zkp.run] client_sends_str4 finished")

# step 5 receive message from the server and sends message back to the server
server_sent_5_str = receive_string(client.sock)
client_send_str5 = client.zkp_client.receive_and_send_5(server_sent_5_str)
send_string(client.sock, client_send_str5)

print("****** [client_zkp.run] client_sends_str5 finished")

# pull the latest round of model weights
client.pull()
print(f"****** [client_zkp.run] client finishes pulling weights from server")

local_model = LocalUpdate(args=args, dataset=train_dataset)
acc, loss = local_model.inference(model=global_model)
train_accuracy.append(acc)
train_loss.append(loss)
print("|---- [client_zkp.run] Train Accuracy: {:.2f}%".format(100*acc))
print("|---- [client_zkp.run] Train Loss: {:.2f}".format(loss))

print(f"****** [client_zkp.run] Train Accuracy: {train_accuracy}")
print(f"****** [client_zkp.run] Train Loss: {train_loss}")

# Test inference after completion of training
test_acc, test_loss = test_inference(args, global_model, test_dataset)

print(f' \n Results after {args.max_epoch} global rounds of training:')
print("|---- [client_zkp.run] Test Accuracy: {:.2f}%".format(100*test_acc))
print("|---- [client_zkp.run] Test Loss: {:.2f}".format(test_loss))

client.close()


if __name__ == "__main__":
args = parseargs()
run(args, args.device_id)
Loading

0 comments on commit dab5887

Please sign in to comment.