-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.js
157 lines (132 loc) · 5.25 KB
/
main.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
/* eslint-disable no-console */
'use strict'
const PeerId = require('peer-id')
const { Multiaddr } = require('multiaddr')
const Libp2p = require('libp2p')
const TCP = require('libp2p-tcp')
const Mplex = require('libp2p-mplex')
const {NOISE} = require('libp2p-noise')
const Gossipsub = require('libp2p-gossipsub')
const nodeAddressPort = require('./data/nodes/address.json')
const currentNodeId = process.argv.slice(2)[0]
const tf = require('@tensorflow/tfjs-node')
const data = require('./data')(currentNodeId)
//const data = require('./data')(1)
const model = require('./model')
const utils = require('./utils')
const COMMUNICATION_ROUND = 10
const createNode = async (peerAddress, peerIdFromJson) => {
const node = await Libp2p.create({
peerId: peerIdFromJson,
addresses: {
listen: [peerAddress]
},
modules: {
transport: [TCP],
streamMuxer: [Mplex],
connEncryption: [NOISE],
pubsub: Gossipsub
}
})
await node.start()
return node
}
; (async () => {
await data.loadData()
const { images: trainImages, labels: trainLabels } = data.getTrainData()
const topic = 'mnist'
let currentRound = 0
const nodesArray = await Promise.all([
PeerId.createFromJSON(require('./data/nodes/peer-id-node-1')),
PeerId.createFromJSON(require('./data/nodes/peer-id-node-2')),
PeerId.createFromJSON(require('./data/nodes/peer-id-node-3')),
PeerId.createFromJSON(require('./data/nodes/peer-id-node-4')),
PeerId.createFromJSON(require('./data/nodes/peer-id-node-5')),
PeerId.createFromJSON(require('./data/nodes/peer-id-node-6')),
PeerId.createFromJSON(require('./data/nodes/peer-id-node-7')),
PeerId.createFromJSON(require('./data/nodes/peer-id-node-8')),
PeerId.createFromJSON(require('./data/nodes/peer-id-node-9')),
PeerId.createFromJSON(require('./data/nodes/peer-id-node-10'))
])
const nodes = {}
nodesArray.forEach((v, i) => nodes[i] = v)
const nodeAddress = nodes[(parseInt(currentNodeId) - 1).toString()]
delete nodes[(parseInt(currentNodeId) - 1).toString()]
const node = await createNode(nodeAddressPort[currentNodeId], nodeAddress)
console.log("Current Node Address: ")
node.multiaddrs.forEach((ma) => {
console.log(ma.toString())
})
while (Object.keys(nodes).length > 0) {
await utils.sleep(2000)
for (const [key, value] of Object.entries(nodes)) {
try {
await node.dial(new Multiaddr(`${nodeAddressPort[(parseInt(key) + 1).toString()]}/p2p/${value.toB58String()}`))
console.log("Connected !")
delete nodes[key]
} catch (AggregateError) {
console.log("Connection Error")
}
}
}
console.log("All Nodes Connected !")
const globalModel = await utils.modelToDict(model)
const recievedNode = []
node.pubsub.on(topic, (msg) => {
const receivedModel = JSON.parse(msg.data)
console.log("Model arrived from: ", receivedModel["nodeId"])
if (receivedModel["roundIndex"] == currentRound) {
recievedNode.push(receivedModel["nodeId"])
delete receivedModel["roundIndex"]
delete receivedModel["nodeId"]
if (Object.keys(recievedNode).length == 1) {
for (const layerName of Object.keys(receivedModel)) {
globalModel[layerName] = receivedModel[layerName]
}
}
else {
for (const layerName of Object.keys(receivedModel)) {
// weight
utils.dictSum(globalModel[layerName]["data"][0], receivedModel[layerName]["data"][0])
//bias
utils.dictSum(globalModel[layerName]["data"][1], receivedModel[layerName]["data"][1])
}
}
}
})
await node.pubsub.subscribe(topic)
while (COMMUNICATION_ROUND > currentRound) {
console.log("Current Round:", currentRound + 1)
await model.fit(trainImages, trainLabels, {
epochs: 10,
batchSize: 10
})
console.log("Train End!")
const trainedModelToDict = await utils.modelToDict(model)
trainedModelToDict["roundIndex"] = currentRound
trainedModelToDict["nodeId"] = currentNodeId
const serializedArray = JSON.stringify(trainedModelToDict)
await node.pubsub.publish(topic, serializedArray)
console.log("Model Published!!")
while (Object.keys(recievedNode).length < 9) {await utils.delay(1 * 100)}
for (const layerName of Object.keys(globalModel)) {
// weight
utils.dictSum(globalModel[layerName]["data"][0], trainedModelToDict[layerName]["data"][0])
utils.dictDivide(globalModel[layerName]["data"][0], 10)
// bias
utils.dictSum(globalModel[layerName]["data"][1], trainedModelToDict[layerName]["data"][1])
utils.dictDivide(globalModel[layerName]["data"][1], 10)
}
utils.dictToModel(model, globalModel)
recievedNode.length = 0
currentRound += 1
}
if(currentNodeId == "1"){
const { images: testImages, labels: testLabels } = data.getTestData()
const evalOutput = model.evaluate(testImages, testLabels)
console.log(
`\nEvaluation result:\n` +
` Loss = ${evalOutput[0].dataSync()[0].toFixed(3)}; ` +
`Accuracy = ${evalOutput[1].dataSync()[0].toFixed(3)}`)
}
})()