diff --git a/server/src/controllers/base.ts b/server/src/controllers/base.ts index 98b7321a0..46cc60ec9 100644 --- a/server/src/controllers/base.ts +++ b/server/src/controllers/base.ts @@ -1,7 +1,7 @@ import express from 'express' import type WebSocket from 'ws' -import type { Model, Task, TaskID } from '@epfml/discojs' +import type { Model, Task } from '@epfml/discojs' /** * The Controller abstraction is commonly used in Express @@ -14,12 +14,17 @@ import type { Model, Task, TaskID } from '@epfml/discojs' * and what happens when receiving messages from participants * of a training session. * + * More info on controllers: + * https://developer.mozilla.org/en-US/docs/Learn/Server-side/Express_Nodejs/routes + * */ export abstract class TrainingController { - abstract initTask (task: TaskID, model: Model): void - abstract handle ( - task: Task, + constructor(protected readonly task: Task) { } + + abstract initTask (model: Model): void + + abstract handle( ws: WebSocket, model: Model, req: express.Request, diff --git a/server/src/controllers/decentralized.ts b/server/src/controllers/decentralized.ts index 773b51025..7b02f338e 100644 --- a/server/src/controllers/decentralized.ts +++ b/server/src/controllers/decentralized.ts @@ -4,7 +4,6 @@ import msgpack from 'msgpack-lite' import type WebSocket from 'ws' import { Map, Set } from 'immutable' -import type { Task, TaskID } from '@epfml/discojs' import { client } from '@epfml/discojs' import { TrainingController } from './base.js' @@ -17,9 +16,9 @@ const debug = createDebug("server:routes:decentralized") export class DecentralizedController extends TrainingController { /** - * Map associating task ids to their sets of nodes who have contributed. + * Set of nodes who have contributed. */ - private readyNodes: Map> = Map() + private readyNodes = Set() /** * Map associating node ids to their open WebSocket connections. */ @@ -27,9 +26,9 @@ export class DecentralizedController extends TrainingController { initTask (): void {} - handle (task: Task, ws: WebSocket): void { + handle (ws: WebSocket): void { // TODO @s314cy: add to task definition, to be used as threshold in aggregator - const minimumReadyPeers = task.trainingInformation?.minimumReadyPeers ?? 3 + const minimumReadyPeers = this.task.trainingInformation?.minimumReadyPeers ?? 3 // Peer id of the message sender let peerId = randomUUID() @@ -53,28 +52,19 @@ export class DecentralizedController extends TrainingController { const msg: AssignNodeID = { type: MessageTypes.AssignNodeID, id: peerId, - waitForMoreParticipants: (this.readyNodes.get(task.id)?.size ?? 0) < minimumReadyPeers + waitForMoreParticipants: this.readyNodes.size < minimumReadyPeers } debug("peer ${peerId} joined ${task.id}"); - // Add the new task and its set of nodes - if (!this.readyNodes.has(task.id)) { - this.readyNodes = this.readyNodes.set(task.id, Set()) - } - ws.send(msgpack.encode(msg), { binary: true }) break } // Send by peers at the beginning of each training round to get the list // of active peers for this round. case MessageTypes.PeerIsReady: { - const peers = this.readyNodes.get(task.id)?.add(peerId) - if (peers === undefined) { - throw new Error(`task ${task.id} doesn't exist in ready buffer`) - } - this.readyNodes = this.readyNodes.set(task.id, peers) + const peers = this.readyNodes.add(peerId) if (peers.size >= minimumReadyPeers) { - this.readyNodes = this.readyNodes.set(task.id, Set()) + this.readyNodes = Set() peers .map((id) => { @@ -93,6 +83,8 @@ export class DecentralizedController extends TrainingController { return [conn, encoded] as [WebSocket, Buffer] }).forEach(([conn, encoded]) => { conn.send(encoded) } ) + } else { + this.readyNodes = peers } break } diff --git a/server/src/controllers/federated.ts b/server/src/controllers/federated.ts index ee8c53a96..7373ddd9a 100644 --- a/server/src/controllers/federated.ts +++ b/server/src/controllers/federated.ts @@ -2,14 +2,9 @@ import createDebug from "debug"; import WebSocket from 'ws' import { v4 as randomUUID } from 'uuid' -import { Map } from 'immutable' import msgpack from 'msgpack-lite' -import type { - Task, - TaskID, - WeightsContainer, -} from '@epfml/discojs' +import type { WeightsContainer } from '@epfml/discojs' import { aggregator as aggregators, client, @@ -28,22 +23,20 @@ const debug = createDebug("server:controllers:federated") export class FederatedController extends TrainingController { /** * Aggregators for each hosted task. + By default the server waits for 100% of the nodes to send their contributions before aggregating the updates + * */ - private aggregators = Map() + private aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative') /** * Promises containing the current round's results. To be awaited on when providing clients * with the most recent result. */ - private results = Map>() + private result: Promise | undefined = undefined /** * Map containing the latest global model of each task. The model is already serialized and * can be sent to participants joining mid-training or contributing to previous rounds */ - private latestGlobalModels = Map() - /** - * Mapping between tasks and their current round. - */ - private rounds = Map() + private latestGlobalModel: serialization.weights.Encoded | undefined = undefined /** * Loop creating an aggregation result promise at each round. @@ -52,32 +45,24 @@ export class FederatedController extends TrainingController { * one resolved and awaits until it resolves. The promise is used in createPromiseForWeights. * @param aggregator The aggregation handler */ - private async storeAggregationResult (task: TaskID, aggregator: aggregators.Aggregator): Promise { + private async storeAggregationResult (): Promise { // Create a promise on the future aggregated weights - const resultPromise = new Promise((resolve) => aggregator.once('aggregation', resolve)) // Store the promise such that it is accessible from other methods - this.results = this.results.set(task, resultPromise) + this.result = new Promise((resolve) => this.aggregator.once('aggregation', resolve)) // The promise resolves once the server received enough contributions (through the handle method) // and the aggregator aggregated the weights. - const globalModel = await resultPromise + const globalModel = await this.result const serializedModel = await serialization.weights.encode(globalModel) - this.latestGlobalModels = this.latestGlobalModels.set(task, serializedModel) + this.latestGlobalModel = serializedModel - // Update the server round with the aggregator round - this.rounds = this.rounds.set(task, aggregator.round) // Create a new promise for the next round // TODO weird usage, should be handled inside of aggregator - void this.storeAggregationResult(task, aggregator) + void this.storeAggregationResult() } - initTask(task: TaskID): void { - // The server waits for 100% of the nodes to send their contributions before aggregating the updates - const aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative') - - this.aggregators = this.aggregators.set(task, aggregator) - this.rounds = this.rounds.set(task, 0) - - void this.storeAggregationResult(task, aggregator) + initTask(): void { + // start the perpetual promise loop + void this.storeAggregationResult() } /** @@ -92,13 +77,10 @@ export class FederatedController extends TrainingController { * @param aggregator the server aggregator, in order to access the current round * @param ws the websocket through which send the aggregated weights */ - private createPromiseForWeights ( - task: TaskID, - aggregator: aggregators.Aggregator, - ws: WebSocket): void { - const promisedResult = this.results.get(task) + private createPromiseForWeights (ws: WebSocket): void { + const promisedResult = this.result if (promisedResult === undefined) { - throw new Error(`result promise was not set for task ${task}`) + throw new Error(`result promise was not set`) } // Wait for aggregation result to resolve with timeout, giving the network a time window @@ -109,7 +91,7 @@ export class FederatedController extends TrainingController { ]).then((result) => // Reply with round - 1 because the round number should match the round at which the client sent its weights // After the server aggregated the weights it also incremented the round so the server replies with round - 1 - [result, aggregator.round - 1] as [WeightsContainer, number]) + [result, this.aggregator.round - 1] as [WeightsContainer, number]) .then(async ([result, round]) => [await serialization.weights.encode(result), round] as [serialization.weights.Encoded, number]) .then(([serialized, round]) => { @@ -117,7 +99,7 @@ export class FederatedController extends TrainingController { type: MessageTypes.ReceiveServerPayload, round, payload: serialized, - nbOfParticipants: aggregator.nodes.size + nbOfParticipants: this.aggregator.nodes.size } ws.send(msgpack.encode(msg)) }) @@ -134,14 +116,13 @@ export class FederatedController extends TrainingController { * @param task the task associated with the current websocket (= participant) * @param ws the websocket connection through which the participant and the server communicate */ - handle(task: Task, ws: WebSocket): void { - const aggregator = this.aggregators.get(task.id) - if (aggregator === undefined) - throw new Error(`no aggregator for task ${task.id}`) + handle( ws: WebSocket): void { + if (this.aggregator === undefined) + throw new Error(`no aggregator for task ${this.task.id}`) // Client id of the message sender let clientId = randomUUID() - while (!aggregator.registerNode(clientId)) { + while (!this.aggregator.registerNode(clientId)) { clientId = randomUUID() } @@ -153,9 +134,9 @@ export class FederatedController extends TrainingController { } if (msg.type === MessageTypes.ClientConnected) { - debug(`client ${clientId} joined ${task.id}`) + debug(`client ${clientId} joined ${this.task.id}`) // at least two participants in federated - const waitForMoreParticipants = aggregator.nodes.size < 2 + const waitForMoreParticipants = this.aggregator.nodes.size < 2 const msg: AssignNodeID = { type: MessageTypes.AssignNodeID, id: clientId, @@ -163,35 +144,35 @@ export class FederatedController extends TrainingController { } ws.send(msgpack.encode(msg)) } else if (msg.type === MessageTypes.ClientDisconnected) { - console.info('client', clientId, 'left', task.id) + debug(`client ${clientId} left ${this.task.id}`) - aggregator.removeNode(clientId) + this.aggregator.removeNode(clientId) } else if (msg.type === MessageTypes.SendPayload) { const { payload, round } = msg - if (aggregator.isValidContribution(clientId, round)) { + if (this.aggregator.isValidContribution(clientId, round)) { // We need to create a promise waiting for the global model before adding the contribution to the aggregator // (so that the aggregation and sending the global model to participants // doesn't happen before the promise is created) - this.createPromiseForWeights(task.id, aggregator, ws) + this.createPromiseForWeights(ws) // This is assuming that the federated server's aggregator // always works with a single communication round const weights = serialization.weights.decode(payload) - const addedSuccessfully = aggregator.add(clientId, weights, round) + const addedSuccessfully = this.aggregator.add(clientId, weights, round) if (!addedSuccessfully) throw new Error("Aggregator's isValidContribution returned true but failed to add the contribution") } else { // If the client sent an invalid or outdated contribution // the server answers with the current round and last global model update - debug(`Dropped contribution from client ${clientId} for round ${round}. Sending last global model from round ${aggregator.round - 1}`) - const latestSerializedModel = this.latestGlobalModels.get(task.id) + debug(`Dropped contribution from client ${clientId} for round ${round}` + + `Sending last global model from round ${this.aggregator.round - 1}`) // no latest model at the first round - if (latestSerializedModel === undefined) return + if (this.latestGlobalModel === undefined) return const msg: messages.ReceiveServerPayload = { type: MessageTypes.ReceiveServerPayload, - round: aggregator.round - 1, // send the model from the previous round - payload: latestSerializedModel, - nbOfParticipants: aggregator.nodes.size + round: this.aggregator.round - 1, // send the model from the previous round + payload: this.latestGlobalModel, + nbOfParticipants: this.aggregator.nodes.size } ws.send(msgpack.encode(msg)) } diff --git a/server/src/routes/training.ts b/server/src/routes/training.ts index 70609ed0d..e45823204 100644 --- a/server/src/routes/training.ts +++ b/server/src/routes/training.ts @@ -19,19 +19,15 @@ export class TrainingRouter { private readonly UUIDRegexExp = /^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}$/gi protected readonly description: string - // The controller handles the actual logic of collaborative training - // in its `handle` method - protected readonly controller: TrainingController - - constructor (wsApplier: expressWS.Instance, tasksAndModels: TasksAndModels, scheme: 'federated' | 'decentralized') { + + constructor(private readonly trainingScheme: 'federated' | 'decentralized', + wsApplier: expressWS.Instance, tasksAndModels: TasksAndModels) { this.ownRouter = express.Router() wsApplier.applyTo(this.ownRouter) this.ownRouter.get('/', (_, res) => res.send(this.description + '\n')) - this.description = `Disco ${scheme} server` - this.controller = scheme == 'federated' ? new FederatedController() : new DecentralizedController() - + this.description = `Disco ${this.trainingScheme} server` // delay listener because this (object) isn't fully constructed yet. The lambda function inside process.nextTick is executed after the current operation on the JS stack runs to completion and before the event loop is allowed to continue. /* this.onNewTask is registered as a listener to tasksAndModels, which has 2 consequences: - this.onNewTask is executed on all the default tasks (which are already loaded in tasksAndModels) @@ -47,16 +43,26 @@ export class TrainingRouter { return this.ownRouter } + private initController(task: Task): TrainingController { + return this.trainingScheme == 'federated' ? + new FederatedController(task) + : new DecentralizedController(task) + } + // Register the task and setup the controller to handle // websocket connections private onNewTask (task: Task, model: Model): void { this.tasks.add(task.id) - this.controller.initTask(task.id, model) + // The controller handles the actual logic of collaborative training + // in its `handle` method. Each task has a dedicated controller which + // handles the training logic of this task only + const taskController = this.initController(task) + taskController.initTask(model) // Setup a websocket route which calls the controller's `handle` method this.ownRouter.ws(this.buildRoute(task.id), (ws, req) => { if (this.isValidUrl(req.url)) { - this.controller.handle(task, ws, model, req) + taskController.handle(ws, model, req) } else { ws.terminate() ws.close() diff --git a/server/src/server.ts b/server/src/server.ts index e0e194eea..996024005 100644 --- a/server/src/server.ts +++ b/server/src/server.ts @@ -51,8 +51,8 @@ export class Server { app.use(express.urlencoded({ limit: "50mb", extended: false })); const taskRouter = new TaskRouter(this.#tasksAndModels) - const federatedRouter = new TrainingRouter(wsApplier, this.#tasksAndModels, 'federated') - const decentralizedRouter = new TrainingRouter(wsApplier, this.#tasksAndModels, 'decentralized') + const federatedRouter = new TrainingRouter('federated', wsApplier, this.#tasksAndModels) + const decentralizedRouter = new TrainingRouter('decentralized', wsApplier, this.#tasksAndModels) process.nextTick(() => wsApplier.getWss().on('connection', (ws, req) => {