Skip to content

Commit 889cda7

Browse files
committed
Added demo
1 parent 917a5fa commit 889cda7

File tree

5 files changed

+9403
-0
lines changed

5 files changed

+9403
-0
lines changed

docs/demo/compiled_model.js

+372
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
// This class stores information about model weights
2+
// this is useful for extracting trained model weights
3+
// or importing pre-trained model weights
4+
class LayerWeights {
5+
_layer;
6+
_weights;
7+
_bias;
8+
constructor(layer) {
9+
this._layer = layer;
10+
}
11+
_Float32ArrayFromBuffer(buffer, offset, byte_size) {
12+
return new Float32Array(buffer, offset, byte_size / Float32Array.BYTES_PER_ELEMENT);
13+
}
14+
_Float32ArrayFromArray(array) {
15+
return new Float32Array(array);
16+
}
17+
ImportWeightsFromBuffer(buffer, offset, byte_size) {
18+
this._weights = this._Float32ArrayFromBuffer(buffer, offset, byte_size);
19+
}
20+
ImportWeightsFromArray(array) {
21+
this._weights = this._Float32ArrayFromArray(array);
22+
}
23+
ImportBiasFromBuffer(buffer, offset, byte_size) {
24+
this._bias = this._Float32ArrayFromBuffer(buffer, offset, byte_size);
25+
}
26+
ImportBiasFromArray(array) {
27+
this._bias = this._Float32ArrayFromArray(array);
28+
}
29+
CopyWeights(layer_weights) {
30+
if(this._layer === layer_weights._layer) {
31+
if(this._weights.length === layer_weights._weights.length && this._bias.length === layer_weights._bias.length) {
32+
this._weights.set(layer_weights._weights);
33+
this._bias.set(layer_weights._bias);
34+
} else {
35+
console.error("Failed to copy weights/bias: arrays size are different (Weights: %d and %d), (Bias: %d and %d)",
36+
this._weights.length, layer_weights._weights.length, this._bias.length, layer_weights._bias.length);
37+
return false;
38+
}
39+
} else {
40+
console.error("Failed to copy weights: layer id are different (%d != %d)", this._layer, layer_weights._layer);
41+
return false;
42+
}
43+
return true;
44+
}
45+
ToJson() {
46+
return {
47+
layer: this._layer,
48+
weights: Array.from(this._weights),
49+
bias: Array.from(this._bias)
50+
}
51+
}
52+
}
53+
54+
// This class is a wrapper for the generated Wasm model
55+
// It also contains the import functions from JS to Wasm
56+
// Most functions simply calls the Wasm exported functions
57+
// but some might process the arguments in order to pass
58+
// them correctly to Wasm functions
59+
class CompiledModel {
60+
_wasm = null;
61+
_imports = {};
62+
63+
constructor() {
64+
this._imports = this._InitImports();
65+
}
66+
67+
// Set the wasm instance
68+
SetWasm(wasm) {
69+
this._wasm = wasm;
70+
}
71+
72+
// Get exports from Wasm to JS
73+
Exports() {
74+
if (this._wasm == null) {
75+
console.error("Wasm instance was not set");
76+
return null;
77+
}
78+
return this._wasm.instance.exports;
79+
}
80+
81+
// Get imports from JS to Wasm
82+
Imports() {
83+
return this._imports;
84+
}
85+
86+
// Run train Wasm function
87+
Train() {
88+
if (this.Exports() != null) {
89+
this.Exports().train();
90+
}
91+
}
92+
93+
// Run test Wasm function
94+
Test() {
95+
if (this.Exports() != null) {
96+
this.Exports().test();
97+
}
98+
}
99+
100+
// Run unit test Wasm function
101+
UnitTest() {
102+
if (this.Exports() != null) {
103+
Object.keys(this.Exports()).forEach((func) => {
104+
if (func.startsWith("test_")) {
105+
console.log(">> Testing function:", func);
106+
console.time(" exectuion time");
107+
this.Exports()[func]();
108+
console.timeEnd(" exectuion time");
109+
}
110+
});
111+
}
112+
}
113+
114+
// Run predict Wasm function
115+
Predict(data) {
116+
if (this.Exports() != null) {
117+
let offset = this._PredictionInputOffset();
118+
let batch_size = this._PredictionBatchSize();
119+
120+
if (data === undefined || data.length === 0 || batch_size !== data.length) {
121+
console.error("Data size should match the batch size %d != %d", data.length, batch_size);
122+
return false;
123+
}
124+
125+
let index = 0;
126+
let memory = new Float32Array(this.Exports().memory.buffer, offset, data[0].length * batch_size);
127+
for (let c = 0; c < data[0].length; c++) {
128+
for (let r = 0; r < data.length; r++) {
129+
memory[index++] = data[r][c];
130+
}
131+
}
132+
this.Exports().predict();
133+
return true;
134+
}
135+
return false;
136+
}
137+
138+
ExtractWeights() {
139+
let weights = [];
140+
if (this.Exports() != null) {
141+
for (let l = 0; l < this._TotalLayers(); l++) {
142+
let weight_info = this._WeightInfo(l);
143+
let bias_info = this._BiasInfo(l);
144+
if (weight_info != null && bias_info != null) {
145+
let layer_weight = new LayerWeights(l);
146+
layer_weight.ImportWeightsFromBuffer(this.Exports().memory.buffer, weight_info.offset, weight_info.byte_size);
147+
layer_weight.ImportBiasFromBuffer(this.Exports().memory.buffer, bias_info.offset, bias_info.byte_size);
148+
weights.push(layer_weight.ToJson());
149+
}
150+
}
151+
return weights;
152+
}
153+
}
154+
155+
ImportWeights(weights_array) {
156+
for(var i=0; i < weights_array.length; i++) {
157+
// Wrap JSON in a LayerWeight object
158+
let imported_layer_weights = new LayerWeights(weights_array[i].layer);
159+
imported_layer_weights.ImportWeightsFromArray(weights_array[i].weights);
160+
imported_layer_weights.ImportBiasFromArray(weights_array[i].bias);
161+
// Load model weights info
162+
let weights_info = this._WeightInfo(weights_array[i].layer);
163+
let bias_info = this._BiasInfo(weights_array[i].layer);
164+
if (weights_info != null && bias_info != null) {
165+
// Wrap Wasm model weight in a LayerWeight object
166+
let model_layer_weights = new LayerWeights(weights_array[i].layer);
167+
model_layer_weights.ImportWeightsFromBuffer(this.Exports().memory.buffer,
168+
weights_info.offset, weights_info.byte_size);
169+
model_layer_weights.ImportBiasFromBuffer(this.Exports().memory.buffer,
170+
bias_info.offset, bias_info.byte_size);
171+
// Set weights
172+
if (!model_layer_weights.CopyWeights(imported_layer_weights)) {
173+
console.log("Import failed!");
174+
return false;
175+
}
176+
} else {
177+
console.error("Import failed: Layer %d does not exists!", weights_array[i].layer);
178+
return false;
179+
}
180+
};
181+
return true;
182+
}
183+
184+
_PredictionInputOffset() {
185+
if(this.Exports() != null) {
186+
return this.Exports().prediction_input_offset();
187+
}
188+
return false;
189+
}
190+
191+
_PredictionBatchSize() {
192+
if(this.Exports() != null) {
193+
return this.Exports().prediction_batch_size();
194+
}
195+
return false;
196+
}
197+
198+
_TotalLayers() {
199+
if(this.Exports() != null) {
200+
return this.Exports().total_layers();
201+
}
202+
return 0;
203+
}
204+
205+
_WeightInfo(layer_index) {
206+
let offset_func = 'weight_offset_' + layer_index;
207+
let length_func = 'weight_byte_size_' + layer_index;
208+
if(this.Exports() != null
209+
&& this.Exports()[offset_func] !== undefined
210+
&& this.Exports()[length_func] !== undefined) {
211+
return {
212+
offset: this.Exports()[offset_func](),
213+
byte_size: this.Exports()[length_func]()
214+
}
215+
}
216+
return null;
217+
}
218+
219+
_BiasInfo(layer_index) {
220+
let offset_func = 'bias_offset_' + layer_index;
221+
let length_func = 'bias_byte_size_' + layer_index;
222+
if(this.Exports() != null
223+
&& this.Exports()[offset_func] !== undefined
224+
&& this.Exports()[length_func] !== undefined) {
225+
return {
226+
offset: this.Exports()[offset_func](),
227+
byte_size: this.Exports()[length_func]()
228+
}
229+
}
230+
return null;
231+
}
232+
233+
// Initialize imports
234+
_InitImports() {
235+
let math_imports = {
236+
exp: Math.exp,
237+
log: Math.log,
238+
random: Math.random
239+
};
240+
241+
let message_imports = {
242+
log_training_time: (epoch, time_epoch, time_total) => {
243+
console.log("Training time at epoch", epoch + 1, "is", time_epoch, "ms",
244+
"and total time so far is", time_total, "ms");
245+
},
246+
log_training_error: (epoch, error) => {
247+
console.log("Training Error in epoch", epoch + 1, ":", error);
248+
},
249+
log_training_accuracy: (epoch, acc) => {
250+
console.log("Training Accuracy in epoch", epoch + 1, ":",
251+
Math.round(acc * 10000) / 10000);
252+
},
253+
log_testing_time: (time) => {
254+
console.log("Testing time:", time, "ms");
255+
},
256+
log_testing_error: (error) => {
257+
console.log("Testing Error:", error);
258+
},
259+
log_testing_accuracy: (acc) => {
260+
console.log("Testing Accuracy:", Math.round(acc * 10000) / 10000);
261+
},
262+
log_prediction_time: (time) => {
263+
console.log("Prediction time:", time, "ms");
264+
},
265+
// Forward timing
266+
log_forward_Time: () => {
267+
console.log("\n>> Forward algorithm steps time:");
268+
},
269+
log_forward_A_1: (time) => {
270+
console.log("A) Z[l] = W[l] . A[l-1] + b[l]");
271+
console.log(" 1) Z[l] = W[l] . A[l-1]:", time);
272+
},
273+
log_forward_A_2: (time) => {
274+
console.log(" 2) Z[l] = Z[l] + b[l]:", time);
275+
},
276+
log_forward_B: (time) => {
277+
console.log("B) A[l] = g[l](Z[l]):", time);
278+
},
279+
// Backward timing
280+
log_backward_Time: () => {
281+
console.log("\n>> Backward algorithm steps time:");
282+
},
283+
log_backward_A: (time) => {
284+
console.log("A) dA[L] = L(T, A[L]):", time);
285+
},
286+
log_backward_B_1: (time) => {
287+
console.log("B) dZ[l] = dA[l] * g'[l](Z[l])");
288+
console.log(" 1) dZ[l] = g'[l](Z[l]):", time);
289+
},
290+
log_backward_B_2: (time) => {
291+
console.log(" 2) dZ[l] = dA[l] * dZ[l]:", time);
292+
},
293+
log_backward_C_1: (time) => {
294+
console.log("C) dW[l] = (1/m) dZ[l] . A[l-1]^T");
295+
console.log(" 1) dW[l] = dZ[l] . A[l-1]^T:", time);
296+
},
297+
log_backward_C_2: (time) => {
298+
console.log(" 2) dW[l] = (1/m) dW[l]:", time);
299+
},
300+
log_backward_D_1: (time) => {
301+
console.log("D) db[l] = (1/m) dZ[l]");
302+
console.log(" 1) db[l] = SUM(dZ[l], row wise):", time);
303+
},
304+
log_backward_D_2: (time) => {
305+
console.log(" 2) db[l] = (1/m) db[l]:", time);
306+
},
307+
log_backward_E: (time) => {
308+
console.log("E) dA[l-1] = W[l]^T . dZ[l]:", time);
309+
},
310+
log_backward_F_1: (time) => {
311+
console.log("F) W[l] = W[l] - alpha * dW[l]");
312+
console.log(" 1) dW[l] = alpha * dW[l]:", time);
313+
},
314+
log_backward_F_2: (time) => {
315+
console.log(" 2) W[l] = W[l] - dW[l]:", time);
316+
},
317+
log_backward_G_1: (time) => {
318+
console.log("G) b[l] = b[l] - alpha * db[l]");
319+
console.log(" 1) db[l] = alpha * db[l]:", time);
320+
},
321+
log_backward_G_2: (time) => {
322+
console.log(" 2) b[l] = b[l] - db[l]:", time);
323+
},
324+
};
325+
326+
let system_imports = {
327+
print: console.log,
328+
time: () => {
329+
return new Date().getTime();
330+
},
331+
print_table_f32: (index, rows, cols) => {
332+
if (this.Exports() != null) {
333+
let view = new Float32Array(this.Exports().memory.buffer, index);
334+
let table = [];
335+
for (let r = 0; r < rows; ++r) {
336+
table.push([]);
337+
for (let c = 0; c < cols; ++c) {
338+
table[r].push(view[r * cols + c]);
339+
}
340+
}
341+
console.table(table);
342+
}
343+
}
344+
};
345+
346+
let test_imports = {
347+
assert_matrix_eq: (mat1_index, mat2_index, rows, cols) => {
348+
if (this.Exports() != null) {
349+
let mat1 = new Float32Array(this.Exports().memory.buffer, mat1_index, rows * cols);
350+
let mat2 = new Float32Array(this.Exports().memory.buffer, mat2_index, rows * cols);
351+
for (let i = 0; i < rows * cols; i++) {
352+
if (mat1[i] !== mat2[i]) {
353+
console.error("Matrix equality failed!");
354+
system_imports.print_table_f32(mat1_index, rows, cols);
355+
system_imports.print_table_f32(mat2_index, rows, cols);
356+
return;
357+
}
358+
}
359+
}
360+
}
361+
};
362+
return {
363+
"Math": math_imports,
364+
"Message": message_imports,
365+
"System": system_imports,
366+
"Test": test_imports,
367+
};
368+
}
369+
}
370+
371+
var module = module || { exports: {} };
372+
module.exports.CompiledModel = CompiledModel;

0 commit comments

Comments
 (0)