Skip to content

Commit

Permalink
more comments in the helium application code
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianMct committed Jun 18, 2024
1 parent a5ef077 commit 0c5347c
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 133 deletions.
298 changes: 167 additions & 131 deletions helium/app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import (
"github.com/ChristianMct/helium/node"
)

const DefaultAddress = ":40000"
// cloudAddress is the local address that the cloud node listens on.
const cloudAddress = ":40000"

// defines the command-line flags
var (
nodeId = flag.String("node_id", "", "the id of the node")
nParty = flag.Int("n_party", -1, "the number of parties")
Expand All @@ -35,6 +37,155 @@ var (
expRounds = flag.Int("expRounds", 1, "number of circuit evaluatation rounds to perform")
)

func main() {

flag.Parse()

if *nParty < 2 {
panic("n_party argument should be provided and > 2")
}

if len(*nodeId) == 0 {
panic("node_id argument should be provided")
}

if len(*cloudAddr) == 0 {
panic("cloud_address argument must be provided for session nodes")
}

// sets the threshold to the number of parties if not provided
var threshold int
switch {
case *argThreshold == -1:
threshold = *nParty
case *argThreshold > 0 && *argThreshold <= *nParty:
threshold = *argThreshold
default:
flag.Usage()
panic("threshold argument must be between 1 and N")
}

nid := sessions.NodeID(*nodeId)

// generates a test node list from the command-line arguments
nids, nl, shamirPks, nodeMapping := genNodeLists(*nParty, *cloudAddr)

// generates a config for the node running this program
nc := genConfigForNode(nid, nids, threshold, shamirPks)

// retreives the session parameters from the node config
params, err := bgv.NewParametersFromLiteral(nc.SessionParameters[0].FHEParameters.(bgv.ParametersLiteral))
if err != nil {
panic(err)
}

// the matrix size
m := params.MaxSlots() / 2

// generates a test matrix
a := genTestMatrix(m)

// generates the Helium application (see helium/node/app.go).
// The app declares a circuit "matmul4-dec" that computes the
// encrypted matrix-vector product followed by a collective
// decryption.
app := getApp(params, m)

// creates a context for the session
ctx := sessions.NewBackgroundContext(nc.SessionParameters[0].ID)

// runs Helium as a server or client
var timeSetup, timeCompute time.Duration
var stats map[string]interface{}
start := time.Now()
if nc.ID == "cloud" {

// runs the Helium server. The method returns when the setup phase has completed.
// It returns a channel to send circuit descriptors (evaluation requests) and a channel to
// receive the evaluation outputs.
hsv, cdescs, outs, err := helium.RunHeliumServer(ctx, nc, nl, app, compute.NoInput)
if err != nil {
log.Fatalf("error running helium server: %v", err)
}
timeSetup = time.Since(start)

// One the setup has completed, the collectice public key is available
// and the test matrix can be encrypted with it.
if err := encryptTestMatrix(ctx, a, params, hsv, hsv); err != nil {
log.Fatalf("error encrypting test matrix: %v", err)
}

start = time.Now()
// sends *expRounds evaluation requests to the server for circuit "matmul4-dec".
go func() {
var nSig int
for i := 0; i < *expRounds; i++ {
cdescs <- circuits.Descriptor{
Signature: circuits.Signature{Name: "matmul4-dec"},
CircuitID: sessions.CircuitID(fmt.Sprintf("matmul-%d", nSig)),
NodeMapping: nodeMapping,
Evaluator: "cloud",
}
nSig++
}
close(cdescs)
}()

// the cloud is not supposed to receive any output
out, has := <-outs
if has {
log.Fatalf("unexpected output: %v", out.OperandLabel)
}

hsv.GracefulStop() // waits for the last client to disconnect
timeCompute = time.Since(start)
stats = map[string]interface{}{
"Time": map[string]interface{}{
"Setup": timeSetup,
"Compute": timeCompute,
},
"Net": hsv.GetStats(),
}
} else {

// creates an input provider function for the node (see getInputProvider).
encoder := bgv.NewEncoder(params)
var ip compute.InputProvider = getInputProvider(params, encoder, m)

// runs the Helium client. The method returns a channel to receive the evaluation outputs
// for which the node is the receiver.
hc, outs, err := helium.RunHeliumClient(ctx, nc, nl, app, ip)
if err != nil {
log.Fatalf("error running helium client: %v", err)
}

// checks the results
for out := range outs {
if err = checkResultCorrect(params, *encoder, out, a); err != nil {
log.Printf("error checking result: %v", err)
} else {
log.Printf("got correct result for %s", out.OperandLabel)
}
}

if err := hc.Close(); err != nil {
log.Fatalf("error closing helium client: %v", err)
}
stats = map[string]interface{}{
"net": hc.GetStats(),
}
}

//outputs the stats as JSON on stdout
statsJson, err := json.Marshal(stats)
if err != nil {
log.Fatalf("error marshalling stats: %v", err)
}
fmt.Println("STATS", string(statsJson))
}

// genNodeLists generates a test list of node informations from the experiments parameters.
// In a real scenarios, the node informations would be provided by the user application.
func genNodeLists(nParty int, cloudAddr string) (nids []sessions.NodeID, nl node.List, shamirPks map[sessions.NodeID]mhe.ShamirPublicPoint, nodeMapping map[string]sessions.NodeID) {
nids = make([]sessions.NodeID, nParty)
nl = make(node.List, nParty)
Expand All @@ -54,6 +205,8 @@ func genNodeLists(nParty int, cloudAddr string) (nids []sessions.NodeID, nl node
return
}

// genConfigForNode generates a node.Config for the node with the provided node ID. It also simulates the loading of the secret-key for the node.
// In a real scenario, the secret-key would be loaded from a secure storage.
func genConfigForNode(nid sessions.NodeID, nids []sessions.NodeID, threshold int, shamirPks map[sessions.NodeID]mhe.ShamirPublicPoint) (nc node.Config) {
sessParams := sessions.Parameters{
ID: "test-session",
Expand All @@ -80,7 +233,7 @@ func genConfigForNode(nid sessions.NodeID, nids []sessions.NodeID, threshold int
}

if nid == "cloud" {
nc.Address = DefaultAddress
nc.Address = cloudAddress
nc.SetupConfig.Protocols.MaxAggregation = 32
nc.ComputeConfig.Protocols.MaxAggregation = 32
} else {
Expand All @@ -93,6 +246,8 @@ func genConfigForNode(nid sessions.NodeID, nids []sessions.NodeID, threshold int
return
}

// getApp generates the Helium application for the test.
// The application specifies the setup phase and declares the circuits that can be executed by the nodes.
func getApp(params bgv.Parameters, m int) node.App {
diagGalEl := make(map[int]uint64)
for k := 0; k < m; k++ {
Expand All @@ -110,6 +265,8 @@ func getApp(params bgv.Parameters, m int) node.App {
}
}

// getInputProvider generates an input provider function for the node. The input provider function
// is registered to with the Helium node and is called by Helium to provide the input for the circuit evaluation.
func getInputProvider(params bgv.Parameters, encoder *bgv.Encoder, m int) compute.InputProvider {
return func(ctx context.Context, ci sessions.CircuitID, ol circuits.OperandLabel, s sessions.Session) (any, error) {

Expand All @@ -135,6 +292,7 @@ func getInputProvider(params bgv.Parameters, encoder *bgv.Encoder, m int) comput
}
}

// checkResultCorrect checks if the result of the circuit evaluation is correct by computing the matrix-vector product.
func checkResultCorrect(params bgv.Parameters, encoder bgv.Encoder, out circuits.Output, a *mat.Dense) error {
_, m := a.Dims()

Expand Down Expand Up @@ -164,148 +322,25 @@ func checkResultCorrect(params bgv.Parameters, encoder bgv.Encoder, out circuits
return nil
}

func getTestMatrix(m int) *mat.Dense {
// genTestMatrix generates a test secret matrix of size mxm for the experiment.
func genTestMatrix(m int) *mat.Dense {
a := mat.NewDense(m, m, nil)
a.Apply(func(i, j int, v float64) float64 {
return float64(i) + float64(2*j)
}, a)
return a
}

// main is the entrypoint of the node application.
// Instructions to run: go run main.go node.go -config [nodeconfigfile].
func main() {

flag.Parse()

if *nParty < 2 {
panic("n_party argument should be provided and > 2")
}

if len(*nodeId) == 0 {
panic("node_id argument should be provided")
}

if len(*cloudAddr) == 0 {
panic("cloud_address argument must be provided for session nodes")
}

var threshold int
switch {
case *argThreshold == -1:
threshold = *nParty
case *argThreshold > 0 && *argThreshold <= *nParty:
threshold = *argThreshold
default:
flag.Usage()
panic("threshold argument must be between 1 and N")
}

nid := sessions.NodeID(*nodeId)

nids, nl, shamirPks, nodeMapping := genNodeLists(*nParty, *cloudAddr)

nc := genConfigForNode(nid, nids, threshold, shamirPks)

params, err := bgv.NewParametersFromLiteral(nc.SessionParameters[0].FHEParameters.(bgv.ParametersLiteral))
if err != nil {
panic(err)
}
m := params.MaxSlots() / 2
app := getApp(params, m)
a := getTestMatrix(m)
encoder := bgv.NewEncoder(params) // TODO pass encoder in ip ?

var start time.Time
var ip compute.InputProvider = getInputProvider(params, encoder, m)

sessId := sessions.ID("test-session")
ctx := sessions.NewBackgroundContext(sessId)

start = time.Now()
var timeSetup, timeCompute time.Duration
var stats map[string]interface{}

var nSig int
if nc.ID == "cloud" {

hsv, cdescs, outs, err := helium.RunHeliumServer(ctx, nc, nl, app, compute.NoInput)
if err != nil {
log.Fatalf("error running helium server: %v", err)
}

timeSetup = time.Since(start)

if err := encryptTestMatrix(ctx, a, params, encoder, hsv, hsv); err != nil {
log.Fatalf("error encrypting test matrix: %v", err)
}

go func() {
for i := 0; i < *expRounds; i++ {
cdescs <- circuits.Descriptor{
Signature: circuits.Signature{Name: "matmul4-dec"},
CircuitID: sessions.CircuitID(fmt.Sprintf("matmul-%d", nSig)),
NodeMapping: nodeMapping,
Evaluator: "cloud",
}
nSig++
}
close(cdescs)
}()

out, has := <-outs
if has {
log.Fatalf("unexpected output: %v", out.OperandLabel)
}

hsv.GracefulStop() // waits for the last client to disconnect

timeCompute = time.Since(start) - timeSetup

stats = map[string]interface{}{
"Time": map[string]interface{}{
"Setup": timeSetup,
"Compute": timeCompute,
},
"Net": hsv.GetStats(),
}
} else {
hc, outs, err := helium.RunHeliumClient(ctx, nc, nl, app, ip)
if err != nil {
log.Fatalf("error running helium client: %v", err)
}

for out := range outs {
if err = checkResultCorrect(params, *encoder, out, a); err != nil {
log.Printf("error checking result: %v", err)
} else {
log.Printf("got correct result for %s", out.OperandLabel)
}
}

if err := hc.Close(); err != nil {
log.Fatalf("error closing helium client: %v", err)
}

stats = map[string]interface{}{
"net": hc.GetStats(),
}
}

statsJson, err := json.Marshal(stats)
if err != nil {
log.Fatalf("error marshalling stats: %v", err)
}
fmt.Println("STATS", string(statsJson))
}

func encryptTestMatrix(ctx context.Context, a *mat.Dense, params bgv.Parameters, encoder *bgv.Encoder, pkb circuits.PublicKeyProvider, opp compute.OperandProvider) error {
// encryptTestMatrix encrypts the test matrix with the collective public key of the session, and
// stores the encrypted matrix in the operand provider.
func encryptTestMatrix(ctx context.Context, a *mat.Dense, params bgv.Parameters, pkb circuits.PublicKeyProvider, opp compute.OperandProvider) error {

cpk, err := pkb.GetCollectivePublicKey(ctx)
if err != nil {
return err
}
encryptor := bgv.NewEncryptor(params, cpk)
encoder := bgv.NewEncoder(params)

pta := make(map[int]*rlwe.Plaintext)
cta := make(map[int]*rlwe.Ciphertext)
Expand Down Expand Up @@ -338,6 +373,7 @@ func encryptTestMatrix(ctx context.Context, a *mat.Dense, params bgv.Parameters,
return nil
}

// matmul4dec is a circuit that computes the encrypted matrix-vector product followed by a collective decryption.
func matmul4dec(e circuits.Runtime) error {
params := e.Parameters().(bgv.Parameters)

Expand Down
4 changes: 2 additions & 2 deletions helium/app/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestApp(t *testing.T) {

m := params.MaxSlots() / 2
app := getApp(params, m)
a := getTestMatrix(m)
a := genTestMatrix(m)
encoder := bgv.NewEncoder(params) // TODO pass encoder in ip ?

ctx := sessions.NewBackgroundContext("test-session")
Expand All @@ -65,7 +65,7 @@ func TestApp(t *testing.T) {
return err
}

if err := encryptTestMatrix(ctx, a, params, encoder, cloud, cloud); err != nil {
if err := encryptTestMatrix(ctx, a, params, cloud, cloud); err != nil {
log.Fatalf("error encrypting test matrix: %v", err)
}

Expand Down

0 comments on commit 0c5347c

Please sign in to comment.