diff --git a/discojs/src/aggregator.spec.ts b/discojs/src/aggregator.spec.ts index 92e9f4d47..064a1381c 100644 --- a/discojs/src/aggregator.spec.ts +++ b/discojs/src/aggregator.spec.ts @@ -35,8 +35,10 @@ AGGREGATORS.forEach(([name, Aggregator]) => let promises = List>() for (let i = 0; i < 3; i++) - for (let r = 0; r < aggregator.communicationRounds; r++) - promises = promises.push(aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r)) + for (let r = 0; r < aggregator.communicationRounds; r++){ + promises = promises.push(aggregator.getPromiseForAggregation()) + aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r) + } await Promise.all(promises) await results; // nothing to test diff --git a/discojs/src/aggregator/aggregator.ts b/discojs/src/aggregator/aggregator.ts index 0aa288484..41611a0f4 100644 --- a/discojs/src/aggregator/aggregator.ts +++ b/discojs/src/aggregator/aggregator.ts @@ -17,8 +17,8 @@ export enum AggregationStep { * Main, abstract, aggregator class whose role is to buffer contributions and to produce * a result based off their aggregation, whenever some defined condition is met. * - * Emits an event whenever an aggregation step is performed. - * Users wait for this event to fetch the aggregation result. + * Emits an event whenever an aggregation step is performed with the counrd's aggregated weights. + * Users subscribes to this event to get the aggregation result. */ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsContainer }> { /** @@ -61,19 +61,16 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon this.contributions = Map() this._nodes = Set() + } - // On each aggregation, increment - // updates the aggregator's state to proceed to the next communication round. - // If all communication rounds were performed, proceeds to the next aggregation round - // and empties the collection of stored contributions. - this.on('aggregation', () => { - this._communicationRound++; - if (this.communicationRound === this.communicationRounds) { - this._communicationRound = 0 - this._round++ - this.contributions = Map() - } - }) + /** + * Convenience method to subscribe to the 'aggregation' event. + * Await this promise returns the aggregated weights for the current round. + * + * @returns a promise for the aggregated weights + */ + getPromiseForAggregation(): Promise { + return new Promise((resolve) => this.once('aggregation', resolve)); } /** @@ -81,49 +78,39 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon * The aggregation round is increased whenever a new global model is obtained and local models are updated. * Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation * which requires multiple steps to obtain a global model) - * The contribution will be aggregated during the next aggregation step. + * The contribution is aggregated during the next aggregation step. * * @param nodeId The node's id * @param contribution The node's contribution - * @returns a promise for the aggregated weights, or undefined if the contribution is invalid */ - async add(nodeId: client.NodeID, contribution: WeightsContainer, - aggregationRound: number, communicationRound?: number): Promise { + add(nodeId: client.NodeID, contribution: WeightsContainer, + aggregationRound: number, communicationRound?: number): void { if (!this.isValidContribution(nodeId, aggregationRound)) throw new Error("Tried adding an invalid contribution. Handle this case before calling add.") // call the abstract method _add, implemented by subclasses this._add(nodeId, contribution, communicationRound) - return this.createAggregationPromise() - } - - // Abstract method to be implemented by subclasses - // Handles logging and adding the contribution to the list of the current round's contributions - protected abstract _add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound?: number): void - - /** - * Create a promise which resolves when enough contributions are received and - * local updates are aggregated. - * If the aggregator has enough contribution then we can aggregate the weights - * directly (and emit the 'aggregation' event) - * Otherwise we wait for the 'aggregation' event which will be emitted once - * enough contributions are received - * - * @returns a promise for the aggregated weights - */ - protected createAggregationPromise(): Promise { - // Wait for the aggregation event to be emitted - const ret = new Promise((resolve) => this.once('aggregation', resolve)); - + // If the aggregator has enough contributions then aggregate the weights + // and emit the 'aggregation' event if (this.isFull()) { const aggregatedWeights = this.aggregate() - // Emitting the 'aggregation' communicates the aggregation to other clients and - // takes care of incrementing the round + // On each aggregation, increment the communication round + // If all communication rounds were performed, proceed to the next aggregation round + // and empty the past contributions. + this._communicationRound++; + if (this.communicationRound === this.communicationRounds) { + this._communicationRound = 0 + this._round++; + this.contributions = Map() + } + // Emitting the 'aggregation' communicates the weights to subscribers this.emit('aggregation', aggregatedWeights) } - - return ret } + + // Abstract method to be implemented by subclasses + // Handles logging and adding the contribution to the list of the current round's contributions + protected abstract _add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound?: number): void /** * Evaluates whether a given participant contribution can be used in the current aggregation round diff --git a/discojs/src/aggregator/mean.spec.ts b/discojs/src/aggregator/mean.spec.ts index b179e5dee..dc7575efe 100644 --- a/discojs/src/aggregator/mean.spec.ts +++ b/discojs/src/aggregator/mean.spec.ts @@ -18,26 +18,29 @@ describe("mean aggregator", () => { // round 0 expect(aggregator.round).to.equal(0) expect(aggregator.isValidContribution("client 1", 0)).to.be.true; - const client1Round0Promise = await aggregator.add("client 1", WeightsContainer.of([1]), 0); - expect(WeightsContainer.of([1]).equals(client1Round0Promise)).to.be.true + const client1Round0Promise = aggregator.getPromiseForAggregation(); + aggregator.add("client 1", WeightsContainer.of([1]), 0); + expect(WeightsContainer.of([1]).equals(await client1Round0Promise)).to.be.true expect(aggregator.round).to.equal(1) // round 1 aggregator.registerNode("client 2"); expect(aggregator.isValidContribution("client 2", 0)).to.be.true; // round 0 should be within the cutoff - void aggregator.add("client 1", WeightsContainer.of([1]), 1); - const client2Round0Promise = await aggregator.add("client 2", WeightsContainer.of([2]), 0); - expect(WeightsContainer.of([1.5]).equals(client2Round0Promise)).to.be.true + aggregator.add("client 1", WeightsContainer.of([1]), 1); + const client2Round0Promise = aggregator.getPromiseForAggregation(); + aggregator.add("client 2", WeightsContainer.of([2]), 0); + expect(WeightsContainer.of([1.5]).equals(await client2Round0Promise)).to.be.true expect(aggregator.round).to.equal(2) // round 2 aggregator.registerNode("client 3"); expect(aggregator.isValidContribution("client 3", 0)).to.be.false; // round 0 is now out of the cutoff expect(aggregator.isValidContribution("client 3", 1)).to.be.true; - void aggregator.add("client 1", WeightsContainer.of([1]), 2); - void aggregator.add("client 2", WeightsContainer.of([1]), 2); - const client3Round2Promise = await aggregator.add("client 3", WeightsContainer.of([4]), 1); - expect(WeightsContainer.of([2]).equals(client3Round2Promise)).to.be.true + aggregator.add("client 1", WeightsContainer.of([1]), 2); + aggregator.add("client 2", WeightsContainer.of([1]), 2); + const client3Round2Promise = aggregator.getPromiseForAggregation(); + aggregator.add("client 3", WeightsContainer.of([4]), 1); + expect(WeightsContainer.of([2]).equals(await client3Round2Promise)).to.be.true expect(aggregator.round).to.equal(3) }); @@ -51,8 +54,10 @@ describe("mean aggregator", () => { aggregator.once("aggregation", resolve), ); - const result1 = aggregator.add(id1, WeightsContainer.of([0], [1]), 0); - const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]), 0); + const result1 = aggregator.getPromiseForAggregation(); + aggregator.add(id1, WeightsContainer.of([0], [1]), 0); + const result2 = aggregator.getPromiseForAggregation(); + aggregator.add(id2, WeightsContainer.of([2], [3]), 0); expect((await result1).equals(await result2)).to.be.true expect(await WSIntoArrays(await results)).to.deep.equal([[1], [2]]); @@ -64,12 +69,14 @@ describe("mean aggregator", () => { aggregator.setNodes(Set.of(id1, id2)); - const result1 = aggregator.add(id1, WeightsContainer.of([0], [1]), 0); + const result1 = aggregator.getPromiseForAggregation(); + aggregator.add(id1, WeightsContainer.of([0], [1]), 0); // Make sure that the aggregation isn't triggered expect(aggregator.round).equals(0) aggregator.registerNode(id2); - const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]), 0); + const result2 = aggregator.getPromiseForAggregation(); + aggregator.add(id2, WeightsContainer.of([2], [3]), 0); expect((await result1).equals(await result2)).to.be.true expect(aggregator.round).equals(1) // round should be one now }); @@ -80,8 +87,9 @@ describe("mean aggregator", () => { aggregator.setNodes(Set.of(id1, id2)); // register two clients // should aggregate with only one contribution - const result = await aggregator.add(id1, WeightsContainer.of([0], [1]), 0); - expect(await WSIntoArrays(result)).to.deep.equal([[0], [1]]); + const result = aggregator.getPromiseForAggregation(); + aggregator.add(id1, WeightsContainer.of([0], [1]), 0); + expect(await WSIntoArrays(await result)).to.deep.equal([[0], [1]]); }); it("can wait for an relative number of contributions", async () => { @@ -90,8 +98,9 @@ describe("mean aggregator", () => { aggregator.setNodes(Set.of(id1, id2)); // register two clients // should aggregate with only 50% of the contribution (1 contribution) - const result = await aggregator.add(id1, WeightsContainer.of([0], [1]), 0); - expect(await WSIntoArrays(result)).to.deep.equal([[0], [1]]); + const result = aggregator.getPromiseForAggregation(); + aggregator.add(id1, WeightsContainer.of([0], [1]), 0); + expect(await WSIntoArrays(await result)).to.deep.equal([[0], [1]]); }); it("doesn't aggregate when not enough participants", async () => { @@ -100,12 +109,14 @@ describe("mean aggregator", () => { const [id1, id2] = ["client 1", "client 2"] aggregator.setNodes(Set.of(id1)); - const result1 = aggregator.add(id1, WeightsContainer.of([0], [1]), 0); + const result1 = aggregator.getPromiseForAggregation(); + aggregator.add(id1, WeightsContainer.of([0], [1]), 0); // Make sure that the aggregation isn't triggered expect(aggregator.round).equals(0) aggregator.registerNode(id2); - const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]), 0); + const result2 = aggregator.getPromiseForAggregation(); + aggregator.add(id2, WeightsContainer.of([2], [3]), 0); expect((await result1).equals(await result2)).to.be.true expect(aggregator.round).equals(1) }); diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 473514948..a0e6e70e5 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -131,7 +131,7 @@ export class DecentralizedClient extends Client { this.server.send({ type: type.JoinRound }) // Store the promise for the current round's aggregation result. // We will await for it to resolve at the end of the round when exchanging weight updates. - this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve)) + this.aggregationResult = this.aggregator.getPromiseForAggregation() this.saveAndEmit("local training") return Promise.resolve() } @@ -223,12 +223,11 @@ export class DecentralizedClient extends Client { else { debug(`[${shortenId(this.ownId)}] received payload from peer ${shortenId(peerId)}` + ` for round (%d, %d)`, message.aggregationRound, message.communicationRound); - // Make sure to not await this promise in order to not miss subsequent messages - void this.aggregator.add(peerId, decoded, message.aggregationRound, message.communicationRound) - .then(() => + this.aggregator.once("aggregation", () => debug(`[${shortenId(this.ownId)}] aggregated the model` + ` for round (%d, %d)`, message.aggregationRound, message.communicationRound) ) + this.aggregator.add(peerId, decoded, message.aggregationRound, message.communicationRound) } } catch (e) { if (this.isDisconnected) return @@ -257,7 +256,7 @@ export class DecentralizedClient extends Client { payloads.forEach(async (payload, id) => { // add our own contribution to the aggregator if (id === this.ownId) { - void this.aggregator.add(this.ownId, payload, communicationRound) + this.aggregator.add(this.ownId, payload, this.aggregator.round, communicationRound) return } // Send our payload to each peer @@ -294,7 +293,7 @@ export class DecentralizedClient extends Client { // There is at least one communication round remaining if (communicationRound < this.aggregator.communicationRounds - 1) { // Reuse the aggregation result - this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve)) + this.aggregationResult = this.aggregator.getPromiseForAggregation() } } return await this.aggregationResult diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts index 74eab22b2..aa1e760e9 100644 --- a/server/src/controllers/federated_controller.ts +++ b/server/src/controllers/federated_controller.ts @@ -101,19 +101,20 @@ export class FederatedController extends TrainingController { if (this.#aggregator.isValidContribution(clientId, round)) { const weights = serialization.weights.decode(payload) - // Send the aggregated weight to the client when enough contributions are received + // Create a callback to send the aggregated weight to the client + // when enough contributions are received + this.#aggregator.once('aggregation', async (weightUpdate) => { + debug("Sending global weights for round %o to client [%s]", this.#aggregator.round, shortId) + const msg: FederatedMessages.ReceiveServerPayload = { + type: MessageTypes.ReceiveServerPayload, + round: this.#aggregator.round, // send the current round number after aggregation + payload: await serialization.weights.encode(weightUpdate), + nbOfParticipants: this.connections.size + } + ws.send(msgpack.encode(msg)) + }) + // Add the contribution this.#aggregator.add(clientId, weights, round) - .then(async (weightUpdate) => { - debug("Sending global weights for round %o to client [%s]", this.#aggregator.round, shortId) - const msg: FederatedMessages.ReceiveServerPayload = { - type: MessageTypes.ReceiveServerPayload, - round: this.#aggregator.round, // send the current round number after aggregation - payload: await serialization.weights.encode(weightUpdate), - nbOfParticipants: this.connections.size - } - ws.send(msgpack.encode(msg)) - }) - .catch((e) => debug("while waiting for weights: %o", e)) debug(`Successfully added contribution from client [%s] for round ${round}`, shortId) } else { // If the client sent an invalid or outdated contribution