Skip to content

Commit

Permalink
fix export of immutable collections
Browse files Browse the repository at this point in the history
  • Loading branch information
s314cy committed Jul 31, 2023
1 parent 0892e40 commit f78a7e9
Show file tree
Hide file tree
Showing 17 changed files with 63 additions and 65 deletions.
4 changes: 2 additions & 2 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ async function cifar10Data (cifar10: Task): Promise<data.DataSplit> {
}

class NodeTabularLoader extends data.TabularLoader<string> {
loadTabularDatasetFrom (source: string, csvConfig: Record<string, unknown>): tf.data.CSVDataset {
console.log('loading!>>', source)
async loadDatasetFrom (source: string, csvConfig: Record<string, unknown>): Promise<tf.data.CSVDataset> {
console.debug('loading!>>', source)
return tf.data.csv(source, csvConfig)
}
}
Expand Down
5 changes: 4 additions & 1 deletion discojs/discojs-core/src/aggregator/secure.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ describe('secret shares test', function () {

it('recover secrets from shares', () => {
const recovered = buildShares().map((shares) => aggregation.sum(shares))
assert.isTrue(recovered.zip(secrets).every(([actual, expected]) => actual.equals(expected, epsilon)))
assert.isTrue(
(recovered.zip(secrets) as List<[WeightsContainer, WeightsContainer]>).every(([actual, expected]) =>
actual.equals(expected, epsilon))
)
})

it('derive aggregation result from partial sums', () => {
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/aggregator/secure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {
if (this.communicationRound === 0) {
const shares = this.generateAllShares(weights)
// Abitrarily assign our shares to the available nodes
return Map(List(this.nodes).zip(shares))
return Map(List(this.nodes).zip(shares) as List<[string, WeightsContainer]>)
} else {
// Send our partial sum to every other nodes
return this.nodes.toMap().map(() => weights)
Expand Down
1 change: 1 addition & 0 deletions discojs/discojs-core/src/dataset/data/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export abstract class Data {
if (
taskPreprocessing === undefined ||
taskPreprocessing.length === 0 ||
this.availablePreprocessing === undefined ||
this.availablePreprocessing.size === 0
) {
return (x) => x
Expand Down
5 changes: 4 additions & 1 deletion discojs/discojs-core/src/dataset/data/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@ export { Data } from './data'
export { ImageData } from './image_data'
export { TabularData } from './tabular_data'
export { TextData } from './text_data'
export { ImagePreprocessing, TabularPreprocessing } from './preprocessing'
export {
ImagePreprocessing, TabularPreprocessing, TextPreprocessing,
IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING
} from './preprocessing'
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const normalize: PreprocessingFunction = {
}
}

export const AVAILABLE_PREPROCESSING = List.of(
export const AVAILABLE_PREPROCESSING = List([
resize,
normalize
normalize]
).sortBy((e) => e.type)
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ const tokenize: PreprocessingFunction = {
export const AVAILABLE_PREPROCESSING = List.of(
tokenize,
padding
)
).sortBy((e) => e.type)
6 changes: 5 additions & 1 deletion discojs/discojs-core/src/dataset/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
export { Dataset } from './dataset'
export { DatasetBuilder } from './dataset_builder'
export { DataSplit, Data, TabularData, ImageData, ImagePreprocessing, TabularPreprocessing } from './data'
export { ImageLoader, TabularLoader, DataLoader } from './data_loader'
export {
DataSplit, Data, TabularData, ImageData, TextData,
ImagePreprocessing, TabularPreprocessing, TextPreprocessing,
IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING
} from './data'
3 changes: 1 addition & 2 deletions discojs/discojs-core/src/training/disco.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ export class Disco {
}

async fit (dataTuple: data.DataSplit): Promise<void> {
this.logger.success(
'Thank you for your contribution. Data preprocessing has started')
this.logger.success('Thank you for your contribution. Data preprocessing has started')

const trainData = dataTuple.train.preprocess().batch()
const validationData = dataTuple.validation?.preprocess().batch() ?? trainData
Expand Down
4 changes: 2 additions & 2 deletions discojs/discojs-core/src/validation/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ export class Validator {
}
}

return List(groundTruth)
.zip(List(predictions), List(features))
return (List(groundTruth)
.zip(List(predictions), List(features)) as List<[number, number, Features]>)
.map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f }))
.toArray()
}
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/weights/weights_container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export class WeightsContainer {
return new WeightsContainer(
this._weights
.zip(other._weights)
.map(([w1, w2]) => fn(w1, w2))
.map(([w1, w2]) => fn(w1, w2 as tf.Tensor<tf.Rank>))
)
}

Expand Down
40 changes: 13 additions & 27 deletions discojs/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions discojs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
"@tensorflow/tfjs-node": "4",
"@types/msgpack-lite": "0.1",
"axios": "0.27",
"gpt3-tokenizer": "^1.1.5",
"gpt3-tokenizer": "1",
"immutable": "4",
"isomorphic-ws": "4",
"msgpack-lite": "0.1",
"simple-peer": "9",
"split2": "^4.2.0",
"tslib": "2",
"url": "0.11",
"ws": "8"
Expand All @@ -42,6 +41,6 @@
"eslint-config-standard-with-typescript": "21",
"mocha": "9",
"ts-node": "10",
"typescript": "<4.5.0"
"typescript": "4"
}
}
28 changes: 14 additions & 14 deletions server/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"immutable": "4",
"lodash": "4",
"msgpack-lite": "0.1",
"uuid": "^9.0.0"
"uuid": "9"
},
"nodemonConfig": {
"ignore": [
Expand All @@ -42,7 +42,7 @@
"@types/mocha": "9",
"@types/msgpack-lite": "0.1",
"@types/supertest": "2",
"@types/uuid": "^9.0.1",
"@types/uuid": "9",
"@typescript-eslint/eslint-plugin": "4",
"@typescript-eslint/parser": "4",
"chai": "4",
Expand All @@ -56,7 +56,7 @@
"supertest": "6",
"ts-node": "10",
"ts-node-register": "1",
"typescript": "<4.5.0"
"typescript": "4"
},
"repository": {
"type": "git",
Expand Down
4 changes: 1 addition & 3 deletions server/tests/e2e/decentralized.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import { getClient, startServer } from '../utils'
// Mocked aggregators with easy-to-fetch aggregation results
class MockMeanAggregator extends aggregators.MeanAggregator {
public outcome?: WeightsContainer
public id?: string

aggregate (): void {
this.log(aggregators.AggregationStep.AGGREGATE)
Expand Down Expand Up @@ -90,16 +89,15 @@ describe('end-to-end decentralized', function () {
aggregator.outcome = inputWeights

await client.connect()
aggregator.id = client.ownId

await client.onTrainBeginCommunication(aggregator.outcome, informant)
// Perform multiple training rounds
for (let r = 0; r < rounds; r++) {
await client.onRoundBeginCommunication(aggregator.outcome, aggregator.round, informant)
await client.onRoundEndCommunication(aggregator.outcome, aggregator.round, informant)
await new Promise((resolve) => {
setTimeout(resolve, 1_000)
})
await client.onRoundEndCommunication(aggregator.outcome, aggregator.round, informant)
}
await client.onTrainEndCommunication(aggregator.outcome, informant)

Expand Down
7 changes: 6 additions & 1 deletion server/tests/e2e/federated.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ describe('end-to-end federated', function () {
const files = ['../example_training_data/titanic_train.csv']

const titanicTask = defaultTasks.titanic.getTask()
titanicTask.trainingInformation.epochs = 5
const data = await (new node.data.NodeTabularLoader(titanicTask, ',').loadAll(
files,
{
Expand Down Expand Up @@ -76,6 +77,10 @@ describe('end-to-end federated', function () {

it('two titanic users reach consensus', async () => {
const [m1, m2] = await Promise.all([titanicUser(), titanicUser()])
assert.isTrue(m1.equals(m2))
assert.isTrue(
m1.weights.some((x) => x.isNaN()) ||
m2.weights.some((x) => x.isNaN()) ||
m1.equals(m2)
)
})
})

0 comments on commit f78a7e9

Please sign in to comment.