Skip to content

Commit

Permalink
Create one training controller per task rather than one for all tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienVig committed Aug 12, 2024
1 parent ef39975 commit 5a9ce14
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 88 deletions.
13 changes: 9 additions & 4 deletions server/src/controllers/base.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import express from 'express'
import type WebSocket from 'ws'

import type { Model, Task, TaskID } from '@epfml/discojs'
import type { Model, Task } from '@epfml/discojs'

/**
* The Controller abstraction is commonly used in Express
Expand All @@ -14,12 +14,17 @@ import type { Model, Task, TaskID } from '@epfml/discojs'
* and what happens when receiving messages from participants
* of a training session.
*
* More info on controllers:
* https://developer.mozilla.org/en-US/docs/Learn/Server-side/Express_Nodejs/routes
*
*/
export abstract class TrainingController {
abstract initTask (task: TaskID, model: Model): void

abstract handle (
task: Task,
constructor(protected readonly task: Task) { }

abstract initTask (model: Model): void

abstract handle(
ws: WebSocket,
model: Model,
req: express.Request,
Expand Down
26 changes: 9 additions & 17 deletions server/src/controllers/decentralized.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import msgpack from 'msgpack-lite'
import type WebSocket from 'ws'
import { Map, Set } from 'immutable'

import type { Task, TaskID } from '@epfml/discojs'
import { client } from '@epfml/discojs'

import { TrainingController } from './base.js'
Expand All @@ -17,19 +16,19 @@ const debug = createDebug("server:routes:decentralized")

export class DecentralizedController extends TrainingController {
/**
* Map associating task ids to their sets of nodes who have contributed.
* Set of nodes who have contributed.
*/
private readyNodes: Map<TaskID, Set<client.NodeID>> = Map()
private readyNodes = Set<client.NodeID>()
/**
* Map associating node ids to their open WebSocket connections.
*/
private connections: Map<client.NodeID, WebSocket> = Map()

initTask (): void {}

handle (task: Task, ws: WebSocket): void {
handle (ws: WebSocket): void {
// TODO @s314cy: add to task definition, to be used as threshold in aggregator
const minimumReadyPeers = task.trainingInformation?.minimumReadyPeers ?? 3
const minimumReadyPeers = this.task.trainingInformation?.minimumReadyPeers ?? 3

// Peer id of the message sender
let peerId = randomUUID()
Expand All @@ -53,28 +52,19 @@ export class DecentralizedController extends TrainingController {
const msg: AssignNodeID = {
type: MessageTypes.AssignNodeID,
id: peerId,
waitForMoreParticipants: (this.readyNodes.get(task.id)?.size ?? 0) < minimumReadyPeers
waitForMoreParticipants: this.readyNodes.size < minimumReadyPeers
}
debug("peer ${peerId} joined ${task.id}");

// Add the new task and its set of nodes
if (!this.readyNodes.has(task.id)) {
this.readyNodes = this.readyNodes.set(task.id, Set())
}

ws.send(msgpack.encode(msg), { binary: true })
break
}
// Send by peers at the beginning of each training round to get the list
// of active peers for this round.
case MessageTypes.PeerIsReady: {
const peers = this.readyNodes.get(task.id)?.add(peerId)
if (peers === undefined) {
throw new Error(`task ${task.id} doesn't exist in ready buffer`)
}
this.readyNodes = this.readyNodes.set(task.id, peers)
const peers = this.readyNodes.add(peerId)
if (peers.size >= minimumReadyPeers) {
this.readyNodes = this.readyNodes.set(task.id, Set())
this.readyNodes = Set()

peers
.map((id) => {
Expand All @@ -93,6 +83,8 @@ export class DecentralizedController extends TrainingController {
return [conn, encoded] as [WebSocket, Buffer]
}).forEach(([conn, encoded]) => { conn.send(encoded) }
)
} else {
this.readyNodes = peers
}
break
}
Expand Down
91 changes: 36 additions & 55 deletions server/src/controllers/federated.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@
import createDebug from "debug";
import WebSocket from 'ws'
import { v4 as randomUUID } from 'uuid'
import { Map } from 'immutable'
import msgpack from 'msgpack-lite'

import type {
Task,
TaskID,
WeightsContainer,
} from '@epfml/discojs'
import type { WeightsContainer } from '@epfml/discojs'
import {
aggregator as aggregators,
client,
Expand All @@ -28,22 +23,20 @@ const debug = createDebug("server:controllers:federated")
export class FederatedController extends TrainingController {
/**
* Aggregators for each hosted task.
By default the server waits for 100% of the nodes to send their contributions before aggregating the updates
*
*/
private aggregators = Map<TaskID, aggregators.Aggregator>()
private aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative')
/**
* Promises containing the current round's results. To be awaited on when providing clients
* with the most recent result.
*/
private results = Map<TaskID, Promise<WeightsContainer>>()
private result: Promise<WeightsContainer> | undefined = undefined
/**
* Map containing the latest global model of each task. The model is already serialized and
* can be sent to participants joining mid-training or contributing to previous rounds
*/
private latestGlobalModels = Map<TaskID, serialization.weights.Encoded>()
/**
* Mapping between tasks and their current round.
*/
private rounds = Map<TaskID, number>()
private latestGlobalModel: serialization.weights.Encoded | undefined = undefined

/**
* Loop creating an aggregation result promise at each round.
Expand All @@ -52,32 +45,24 @@ export class FederatedController extends TrainingController {
* one resolved and awaits until it resolves. The promise is used in createPromiseForWeights.
* @param aggregator The aggregation handler
*/
private async storeAggregationResult (task: TaskID, aggregator: aggregators.Aggregator): Promise<void> {
private async storeAggregationResult (): Promise<void> {
// Create a promise on the future aggregated weights
const resultPromise = new Promise<WeightsContainer>((resolve) => aggregator.once('aggregation', resolve))
// Store the promise such that it is accessible from other methods
this.results = this.results.set(task, resultPromise)
this.result = new Promise<WeightsContainer>((resolve) => this.aggregator.once('aggregation', resolve))
// The promise resolves once the server received enough contributions (through the handle method)
// and the aggregator aggregated the weights.
const globalModel = await resultPromise
const globalModel = await this.result
const serializedModel = await serialization.weights.encode(globalModel)
this.latestGlobalModels = this.latestGlobalModels.set(task, serializedModel)
this.latestGlobalModel = serializedModel

// Update the server round with the aggregator round
this.rounds = this.rounds.set(task, aggregator.round)
// Create a new promise for the next round
// TODO weird usage, should be handled inside of aggregator
void this.storeAggregationResult(task, aggregator)
void this.storeAggregationResult()
}

initTask(task: TaskID): void {
// The server waits for 100% of the nodes to send their contributions before aggregating the updates
const aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative')

this.aggregators = this.aggregators.set(task, aggregator)
this.rounds = this.rounds.set(task, 0)

void this.storeAggregationResult(task, aggregator)
initTask(): void {
// start the perpetual promise loop
void this.storeAggregationResult()
}

/**
Expand All @@ -92,13 +77,10 @@ export class FederatedController extends TrainingController {
* @param aggregator the server aggregator, in order to access the current round
* @param ws the websocket through which send the aggregated weights
*/
private createPromiseForWeights (
task: TaskID,
aggregator: aggregators.Aggregator,
ws: WebSocket): void {
const promisedResult = this.results.get(task)
private createPromiseForWeights (ws: WebSocket): void {
const promisedResult = this.result
if (promisedResult === undefined) {
throw new Error(`result promise was not set for task ${task}`)
throw new Error(`result promise was not set`)
}

// Wait for aggregation result to resolve with timeout, giving the network a time window
Expand All @@ -109,15 +91,15 @@ export class FederatedController extends TrainingController {
]).then((result) =>
// Reply with round - 1 because the round number should match the round at which the client sent its weights
// After the server aggregated the weights it also incremented the round so the server replies with round - 1
[result, aggregator.round - 1] as [WeightsContainer, number])
[result, this.aggregator.round - 1] as [WeightsContainer, number])
.then(async ([result, round]) =>
[await serialization.weights.encode(result), round] as [serialization.weights.Encoded, number])
.then(([serialized, round]) => {
const msg: messages.ReceiveServerPayload = {
type: MessageTypes.ReceiveServerPayload,
round,
payload: serialized,
nbOfParticipants: aggregator.nodes.size
nbOfParticipants: this.aggregator.nodes.size
}
ws.send(msgpack.encode(msg))
})
Expand All @@ -134,14 +116,13 @@ export class FederatedController extends TrainingController {
* @param task the task associated with the current websocket (= participant)
* @param ws the websocket connection through which the participant and the server communicate
*/
handle(task: Task, ws: WebSocket): void {
const aggregator = this.aggregators.get(task.id)
if (aggregator === undefined)
throw new Error(`no aggregator for task ${task.id}`)
handle( ws: WebSocket): void {
if (this.aggregator === undefined)
throw new Error(`no aggregator for task ${this.task.id}`)

// Client id of the message sender
let clientId = randomUUID()
while (!aggregator.registerNode(clientId)) {
while (!this.aggregator.registerNode(clientId)) {
clientId = randomUUID()
}

Expand All @@ -153,45 +134,45 @@ export class FederatedController extends TrainingController {
}

if (msg.type === MessageTypes.ClientConnected) {
debug(`client ${clientId} joined ${task.id}`)
debug(`client ${clientId} joined ${this.task.id}`)
// at least two participants in federated
const waitForMoreParticipants = aggregator.nodes.size < 2
const waitForMoreParticipants = this.aggregator.nodes.size < 2
const msg: AssignNodeID = {
type: MessageTypes.AssignNodeID,
id: clientId,
waitForMoreParticipants
}
ws.send(msgpack.encode(msg))
} else if (msg.type === MessageTypes.ClientDisconnected) {
console.info('client', clientId, 'left', task.id)
debug(`client ${clientId} left ${this.task.id}`)

aggregator.removeNode(clientId)
this.aggregator.removeNode(clientId)
} else if (msg.type === MessageTypes.SendPayload) {

const { payload, round } = msg
if (aggregator.isValidContribution(clientId, round)) {
if (this.aggregator.isValidContribution(clientId, round)) {
// We need to create a promise waiting for the global model before adding the contribution to the aggregator
// (so that the aggregation and sending the global model to participants
// doesn't happen before the promise is created)
this.createPromiseForWeights(task.id, aggregator, ws)
this.createPromiseForWeights(ws)
// This is assuming that the federated server's aggregator
// always works with a single communication round
const weights = serialization.weights.decode(payload)
const addedSuccessfully = aggregator.add(clientId, weights, round)
const addedSuccessfully = this.aggregator.add(clientId, weights, round)
if (!addedSuccessfully) throw new Error("Aggregator's isValidContribution returned true but failed to add the contribution")
} else {
// If the client sent an invalid or outdated contribution
// the server answers with the current round and last global model update
debug(`Dropped contribution from client ${clientId} for round ${round}. Sending last global model from round ${aggregator.round - 1}`)
const latestSerializedModel = this.latestGlobalModels.get(task.id)
debug(`Dropped contribution from client ${clientId} for round ${round}` +
`Sending last global model from round ${this.aggregator.round - 1}`)
// no latest model at the first round
if (latestSerializedModel === undefined) return
if (this.latestGlobalModel === undefined) return

const msg: messages.ReceiveServerPayload = {
type: MessageTypes.ReceiveServerPayload,
round: aggregator.round - 1, // send the model from the previous round
payload: latestSerializedModel,
nbOfParticipants: aggregator.nodes.size
round: this.aggregator.round - 1, // send the model from the previous round
payload: this.latestGlobalModel,
nbOfParticipants: this.aggregator.nodes.size
}
ws.send(msgpack.encode(msg))
}
Expand Down
26 changes: 16 additions & 10 deletions server/src/routes/training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,15 @@ export class TrainingRouter {
private readonly UUIDRegexExp = /^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}$/gi

protected readonly description: string
// The controller handles the actual logic of collaborative training
// in its `handle` method
protected readonly controller: TrainingController

constructor (wsApplier: expressWS.Instance, tasksAndModels: TasksAndModels, scheme: 'federated' | 'decentralized') {

constructor(private readonly trainingScheme: 'federated' | 'decentralized',
wsApplier: expressWS.Instance, tasksAndModels: TasksAndModels) {
this.ownRouter = express.Router()
wsApplier.applyTo(this.ownRouter)

this.ownRouter.get('/', (_, res) => res.send(this.description + '\n'))

this.description = `Disco ${scheme} server`
this.controller = scheme == 'federated' ? new FederatedController() : new DecentralizedController()

this.description = `Disco ${this.trainingScheme} server`
// delay listener because this (object) isn't fully constructed yet. The lambda function inside process.nextTick is executed after the current operation on the JS stack runs to completion and before the event loop is allowed to continue.
/* this.onNewTask is registered as a listener to tasksAndModels, which has 2 consequences:
- this.onNewTask is executed on all the default tasks (which are already loaded in tasksAndModels)
Expand All @@ -47,16 +43,26 @@ export class TrainingRouter {
return this.ownRouter
}

private initController(task: Task): TrainingController {
return this.trainingScheme == 'federated' ?
new FederatedController(task)
: new DecentralizedController(task)
}

// Register the task and setup the controller to handle
// websocket connections
private onNewTask (task: Task, model: Model): void {
this.tasks.add(task.id)
this.controller.initTask(task.id, model)
// The controller handles the actual logic of collaborative training
// in its `handle` method. Each task has a dedicated controller which
// handles the training logic of this task only
const taskController = this.initController(task)
taskController.initTask(model)

// Setup a websocket route which calls the controller's `handle` method
this.ownRouter.ws(this.buildRoute(task.id), (ws, req) => {
if (this.isValidUrl(req.url)) {
this.controller.handle(task, ws, model, req)
taskController.handle(ws, model, req)
} else {
ws.terminate()
ws.close()
Expand Down
4 changes: 2 additions & 2 deletions server/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ export class Server {
app.use(express.urlencoded({ limit: "50mb", extended: false }));

const taskRouter = new TaskRouter(this.#tasksAndModels)
const federatedRouter = new TrainingRouter(wsApplier, this.#tasksAndModels, 'federated')
const decentralizedRouter = new TrainingRouter(wsApplier, this.#tasksAndModels, 'decentralized')
const federatedRouter = new TrainingRouter('federated', wsApplier, this.#tasksAndModels)
const decentralizedRouter = new TrainingRouter('decentralized', wsApplier, this.#tasksAndModels)

process.nextTick(() =>
wsApplier.getWss().on('connection', (ws, req) => {
Expand Down

0 comments on commit 5a9ce14

Please sign in to comment.