Skip to content

Commit

Permalink
*: make aggregator.add check if contribution is valid
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienVig committed Oct 9, 2024
1 parent 85b60a9 commit 5c960cb
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 22 deletions.
7 changes: 4 additions & 3 deletions discojs/src/aggregator.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ AGGREGATORS.forEach(([name, Aggregator]) =>
let promises = List<Promise<WeightsContainer>>()
for (let i = 0; i < 3; i++)
for (let r = 0; r < aggregator.communicationRounds; r++)
promises = promises.push(aggregator.add(`client ${i}`, WeightsContainer.of([i]), r))
promises = promises.push(aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r))
await Promise.all(promises)
await results; // nothing to test

Expand All @@ -58,7 +58,7 @@ AGGREGATORS.forEach(([name, Aggregator]) =>
id,
[agg, WeightsContainer.of([ws])],
]),
),
), 0
)
)
.valueSeq()
Expand Down Expand Up @@ -95,6 +95,7 @@ export function setupNetwork<A extends Aggregator>(
// run all rounds of communication
export async function communicate<A extends Aggregator>(
networkWithContributions: Map<NodeID, [A, WeightsContainer]>,
aggregationRound: number
): Promise<Map<NodeID, WeightsContainer>> {
const communicationsRound =
networkWithContributions.first()?.[0].communicationRounds;
Expand Down Expand Up @@ -126,7 +127,7 @@ export async function communicate<A extends Aggregator>(
agg
.makePayloads(contrib)
.entrySeq()
.forEach(([to, payload]) => network.get(to)?.add(id, payload, r)),
.forEach(([to, payload]) => network.get(to)?.add(id, payload, aggregationRound, r)),
);

contributions = Map(await Promise.all(nextContributions));
Expand Down
11 changes: 9 additions & 2 deletions discojs/src/aggregator/aggregator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,12 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon
* @param contribution The node's contribution
* @returns a promise for the aggregated weights, or undefined if the contribution is invalid
*/
async add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound?: number): Promise<WeightsContainer> {
// Calls the abstract method _add, which is implemented in the subclasses
async add(nodeId: client.NodeID, contribution: WeightsContainer,
aggregationRound: number, communicationRound?: number): Promise<WeightsContainer> {
if (!this.isValidContribution(nodeId, aggregationRound))
throw new Error("Tried adding an invalid contribution. Handle this case before calling add.")

// call the abstract method _add, implemented by subclasses
this._add(nodeId, contribution, communicationRound)
return this.createAggregationPromise()
}
Expand Down Expand Up @@ -124,6 +128,9 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon
/**
* Evaluates whether a given participant contribution can be used in the current aggregation round
* the boolean returned by `this.add` is obtained via `this.isValidContribution`
*
* @param nodeId the node id of the contribution to be added
* @param round the aggregation round of the contribution to be added
*/
isValidContribution(nodeId: client.NodeID, round: number): boolean {
if (!this.nodes.has(nodeId)) {
Expand Down
26 changes: 13 additions & 13 deletions discojs/src/aggregator/mean.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,31 @@ async function WSIntoArrays(ws: WeightsContainer): Promise<number[][]> {

describe("mean aggregator", () => {
it("updates only within round cutoff", async () => {
const aggregator = new MeanAggregator(1, 1, 'relative');
const aggregator = new MeanAggregator(1, 1, 'relative'); // use a round cutoff of 1
aggregator.setNodes(Set.of("client 1"));

// round 0
expect(aggregator.round).to.equal(0)
expect(aggregator.isValidContribution("client 1", 0)).to.be.true;
const client1Round0Promise = await aggregator.add("client 1", WeightsContainer.of([1]));
const client1Round0Promise = await aggregator.add("client 1", WeightsContainer.of([1]), 0);
expect(WeightsContainer.of([1]).equals(client1Round0Promise)).to.be.true
expect(aggregator.round).to.equal(1)

// round 1
aggregator.registerNode("client 2");
expect(aggregator.isValidContribution("client 2", 0)).to.be.true; // round 0 should be within the cutoff
void aggregator.add("client 1", WeightsContainer.of([1]));
const client2Round0Promise = await aggregator.add("client 2", WeightsContainer.of([2]));
void aggregator.add("client 1", WeightsContainer.of([1]), 1);
const client2Round0Promise = await aggregator.add("client 2", WeightsContainer.of([2]), 0);
expect(WeightsContainer.of([1.5]).equals(client2Round0Promise)).to.be.true
expect(aggregator.round).to.equal(2)

// round 2
aggregator.registerNode("client 3");
expect(aggregator.isValidContribution("client 3", 0)).to.be.false; // round 0 is now out of the cutoff
expect(aggregator.isValidContribution("client 3", 1)).to.be.true;
void aggregator.add("client 1", WeightsContainer.of([1]));
void aggregator.add("client 2", WeightsContainer.of([1]));
const client3Round2Promise = await aggregator.add("client 3", WeightsContainer.of([4]));
void aggregator.add("client 1", WeightsContainer.of([1]), 2);
void aggregator.add("client 2", WeightsContainer.of([1]), 2);
const client3Round2Promise = await aggregator.add("client 3", WeightsContainer.of([4]), 1);
expect(WeightsContainer.of([2]).equals(client3Round2Promise)).to.be.true
expect(aggregator.round).to.equal(3)
});
Expand All @@ -51,8 +51,8 @@ describe("mean aggregator", () => {
aggregator.once("aggregation", resolve),
);

const result1 = aggregator.add(id1, WeightsContainer.of([0], [1]));
const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]));
const result1 = aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]), 0);
expect((await result1).equals(await result2)).to.be.true

expect(await WSIntoArrays(await results)).to.deep.equal([[1], [2]]);
Expand All @@ -69,7 +69,7 @@ describe("mean aggregator", () => {
expect(aggregator.round).equals(0)

aggregator.registerNode(id2);
const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]));
const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]), 0);
expect((await result1).equals(await result2)).to.be.true
expect(aggregator.round).equals(1) // round should be one now
});
Expand All @@ -80,7 +80,7 @@ describe("mean aggregator", () => {
aggregator.setNodes(Set.of(id1, id2)); // register two clients

// should aggregate with only one contribution
const result = await aggregator.add(id1, WeightsContainer.of([0], [1]));
const result = await aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
expect(await WSIntoArrays(result)).to.deep.equal([[0], [1]]);
});

Expand All @@ -90,7 +90,7 @@ describe("mean aggregator", () => {
aggregator.setNodes(Set.of(id1, id2)); // register two clients

// should aggregate with only 50% of the contribution (1 contribution)
const result = await aggregator.add(id1, WeightsContainer.of([0], [1]));
const result = await aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
expect(await WSIntoArrays(result)).to.deep.equal([[0], [1]]);
});

Expand All @@ -105,7 +105,7 @@ describe("mean aggregator", () => {
expect(aggregator.round).equals(0)

aggregator.registerNode(id2);
const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]));
const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]), 0);
expect((await result1).equals(await result2)).to.be.true
expect(aggregator.round).equals(1)
});
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/aggregator/secure.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ describe("secure aggregator", () => {
.entrySeq()
.zip(Range(0, 3))
.map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]),
),
), 0
);
const secureResults = await communicate(
Map(
secureNetwork
.entrySeq()
.zip(Range(0, 3))
.map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]),
),
), 0
);

List(await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays)))
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/client/decentralized/decentralized_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ export class DecentralizedClient extends Client {
debug(`[${shortenId(this.ownId)}] received payload from peer ${shortenId(peerId)}` +
` for round (%d, %d)`, message.aggregationRound, message.communicationRound);
// Make sure to not await this promise in order to not miss subsequent messages
void this.aggregator.add(peerId, decoded, message.communicationRound)
void this.aggregator.add(peerId, decoded, message.aggregationRound, message.communicationRound)
.then(() =>
debug(`[${shortenId(this.ownId)}] aggregated the model` +
` for round (%d, %d)`, message.aggregationRound, message.communicationRound)
Expand Down
2 changes: 1 addition & 1 deletion server/src/controllers/federated_controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ export class FederatedController extends TrainingController {
const weights = serialization.weights.decode(payload)

// Send the aggregated weight to the client when enough contributions are received
this.#aggregator.add(clientId, weights)
this.#aggregator.add(clientId, weights, round)
.then(async (weightUpdate) => {
debug("Sending global weights for round %o to client [%s]", this.#aggregator.round, shortId)
const msg: FederatedMessages.ReceiveServerPayload = {
Expand Down

0 comments on commit 5c960cb

Please sign in to comment.