-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
9,403 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,372 @@ | ||
// This class stores information about model weights | ||
// this is useful for extracting trained model weights | ||
// or importing pre-trained model weights | ||
class LayerWeights { | ||
_layer; | ||
_weights; | ||
_bias; | ||
constructor(layer) { | ||
this._layer = layer; | ||
} | ||
_Float32ArrayFromBuffer(buffer, offset, byte_size) { | ||
return new Float32Array(buffer, offset, byte_size / Float32Array.BYTES_PER_ELEMENT); | ||
} | ||
_Float32ArrayFromArray(array) { | ||
return new Float32Array(array); | ||
} | ||
ImportWeightsFromBuffer(buffer, offset, byte_size) { | ||
this._weights = this._Float32ArrayFromBuffer(buffer, offset, byte_size); | ||
} | ||
ImportWeightsFromArray(array) { | ||
this._weights = this._Float32ArrayFromArray(array); | ||
} | ||
ImportBiasFromBuffer(buffer, offset, byte_size) { | ||
this._bias = this._Float32ArrayFromBuffer(buffer, offset, byte_size); | ||
} | ||
ImportBiasFromArray(array) { | ||
this._bias = this._Float32ArrayFromArray(array); | ||
} | ||
CopyWeights(layer_weights) { | ||
if(this._layer === layer_weights._layer) { | ||
if(this._weights.length === layer_weights._weights.length && this._bias.length === layer_weights._bias.length) { | ||
this._weights.set(layer_weights._weights); | ||
this._bias.set(layer_weights._bias); | ||
} else { | ||
console.error("Failed to copy weights/bias: arrays size are different (Weights: %d and %d), (Bias: %d and %d)", | ||
this._weights.length, layer_weights._weights.length, this._bias.length, layer_weights._bias.length); | ||
return false; | ||
} | ||
} else { | ||
console.error("Failed to copy weights: layer id are different (%d != %d)", this._layer, layer_weights._layer); | ||
return false; | ||
} | ||
return true; | ||
} | ||
ToJson() { | ||
return { | ||
layer: this._layer, | ||
weights: Array.from(this._weights), | ||
bias: Array.from(this._bias) | ||
} | ||
} | ||
} | ||
|
||
// This class is a wrapper for the generated Wasm model | ||
// It also contains the import functions from JS to Wasm | ||
// Most functions simply calls the Wasm exported functions | ||
// but some might process the arguments in order to pass | ||
// them correctly to Wasm functions | ||
class CompiledModel { | ||
_wasm = null; | ||
_imports = {}; | ||
|
||
constructor() { | ||
this._imports = this._InitImports(); | ||
} | ||
|
||
// Set the wasm instance | ||
SetWasm(wasm) { | ||
this._wasm = wasm; | ||
} | ||
|
||
// Get exports from Wasm to JS | ||
Exports() { | ||
if (this._wasm == null) { | ||
console.error("Wasm instance was not set"); | ||
return null; | ||
} | ||
return this._wasm.instance.exports; | ||
} | ||
|
||
// Get imports from JS to Wasm | ||
Imports() { | ||
return this._imports; | ||
} | ||
|
||
// Run train Wasm function | ||
Train() { | ||
if (this.Exports() != null) { | ||
this.Exports().train(); | ||
} | ||
} | ||
|
||
// Run test Wasm function | ||
Test() { | ||
if (this.Exports() != null) { | ||
this.Exports().test(); | ||
} | ||
} | ||
|
||
// Run unit test Wasm function | ||
UnitTest() { | ||
if (this.Exports() != null) { | ||
Object.keys(this.Exports()).forEach((func) => { | ||
if (func.startsWith("test_")) { | ||
console.log(">> Testing function:", func); | ||
console.time(" exectuion time"); | ||
this.Exports()[func](); | ||
console.timeEnd(" exectuion time"); | ||
} | ||
}); | ||
} | ||
} | ||
|
||
// Run predict Wasm function | ||
Predict(data) { | ||
if (this.Exports() != null) { | ||
let offset = this._PredictionInputOffset(); | ||
let batch_size = this._PredictionBatchSize(); | ||
|
||
if (data === undefined || data.length === 0 || batch_size !== data.length) { | ||
console.error("Data size should match the batch size %d != %d", data.length, batch_size); | ||
return false; | ||
} | ||
|
||
let index = 0; | ||
let memory = new Float32Array(this.Exports().memory.buffer, offset, data[0].length * batch_size); | ||
for (let c = 0; c < data[0].length; c++) { | ||
for (let r = 0; r < data.length; r++) { | ||
memory[index++] = data[r][c]; | ||
} | ||
} | ||
this.Exports().predict(); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
ExtractWeights() { | ||
let weights = []; | ||
if (this.Exports() != null) { | ||
for (let l = 0; l < this._TotalLayers(); l++) { | ||
let weight_info = this._WeightInfo(l); | ||
let bias_info = this._BiasInfo(l); | ||
if (weight_info != null && bias_info != null) { | ||
let layer_weight = new LayerWeights(l); | ||
layer_weight.ImportWeightsFromBuffer(this.Exports().memory.buffer, weight_info.offset, weight_info.byte_size); | ||
layer_weight.ImportBiasFromBuffer(this.Exports().memory.buffer, bias_info.offset, bias_info.byte_size); | ||
weights.push(layer_weight.ToJson()); | ||
} | ||
} | ||
return weights; | ||
} | ||
} | ||
|
||
ImportWeights(weights_array) { | ||
for(var i=0; i < weights_array.length; i++) { | ||
// Wrap JSON in a LayerWeight object | ||
let imported_layer_weights = new LayerWeights(weights_array[i].layer); | ||
imported_layer_weights.ImportWeightsFromArray(weights_array[i].weights); | ||
imported_layer_weights.ImportBiasFromArray(weights_array[i].bias); | ||
// Load model weights info | ||
let weights_info = this._WeightInfo(weights_array[i].layer); | ||
let bias_info = this._BiasInfo(weights_array[i].layer); | ||
if (weights_info != null && bias_info != null) { | ||
// Wrap Wasm model weight in a LayerWeight object | ||
let model_layer_weights = new LayerWeights(weights_array[i].layer); | ||
model_layer_weights.ImportWeightsFromBuffer(this.Exports().memory.buffer, | ||
weights_info.offset, weights_info.byte_size); | ||
model_layer_weights.ImportBiasFromBuffer(this.Exports().memory.buffer, | ||
bias_info.offset, bias_info.byte_size); | ||
// Set weights | ||
if (!model_layer_weights.CopyWeights(imported_layer_weights)) { | ||
console.log("Import failed!"); | ||
return false; | ||
} | ||
} else { | ||
console.error("Import failed: Layer %d does not exists!", weights_array[i].layer); | ||
return false; | ||
} | ||
}; | ||
return true; | ||
} | ||
|
||
_PredictionInputOffset() { | ||
if(this.Exports() != null) { | ||
return this.Exports().prediction_input_offset(); | ||
} | ||
return false; | ||
} | ||
|
||
_PredictionBatchSize() { | ||
if(this.Exports() != null) { | ||
return this.Exports().prediction_batch_size(); | ||
} | ||
return false; | ||
} | ||
|
||
_TotalLayers() { | ||
if(this.Exports() != null) { | ||
return this.Exports().total_layers(); | ||
} | ||
return 0; | ||
} | ||
|
||
_WeightInfo(layer_index) { | ||
let offset_func = 'weight_offset_' + layer_index; | ||
let length_func = 'weight_byte_size_' + layer_index; | ||
if(this.Exports() != null | ||
&& this.Exports()[offset_func] !== undefined | ||
&& this.Exports()[length_func] !== undefined) { | ||
return { | ||
offset: this.Exports()[offset_func](), | ||
byte_size: this.Exports()[length_func]() | ||
} | ||
} | ||
return null; | ||
} | ||
|
||
_BiasInfo(layer_index) { | ||
let offset_func = 'bias_offset_' + layer_index; | ||
let length_func = 'bias_byte_size_' + layer_index; | ||
if(this.Exports() != null | ||
&& this.Exports()[offset_func] !== undefined | ||
&& this.Exports()[length_func] !== undefined) { | ||
return { | ||
offset: this.Exports()[offset_func](), | ||
byte_size: this.Exports()[length_func]() | ||
} | ||
} | ||
return null; | ||
} | ||
|
||
// Initialize imports | ||
_InitImports() { | ||
let math_imports = { | ||
exp: Math.exp, | ||
log: Math.log, | ||
random: Math.random | ||
}; | ||
|
||
let message_imports = { | ||
log_training_time: (epoch, time_epoch, time_total) => { | ||
console.log("Training time at epoch", epoch + 1, "is", time_epoch, "ms", | ||
"and total time so far is", time_total, "ms"); | ||
}, | ||
log_training_error: (epoch, error) => { | ||
console.log("Training Error in epoch", epoch + 1, ":", error); | ||
}, | ||
log_training_accuracy: (epoch, acc) => { | ||
console.log("Training Accuracy in epoch", epoch + 1, ":", | ||
Math.round(acc * 10000) / 10000); | ||
}, | ||
log_testing_time: (time) => { | ||
console.log("Testing time:", time, "ms"); | ||
}, | ||
log_testing_error: (error) => { | ||
console.log("Testing Error:", error); | ||
}, | ||
log_testing_accuracy: (acc) => { | ||
console.log("Testing Accuracy:", Math.round(acc * 10000) / 10000); | ||
}, | ||
log_prediction_time: (time) => { | ||
console.log("Prediction time:", time, "ms"); | ||
}, | ||
// Forward timing | ||
log_forward_Time: () => { | ||
console.log("\n>> Forward algorithm steps time:"); | ||
}, | ||
log_forward_A_1: (time) => { | ||
console.log("A) Z[l] = W[l] . A[l-1] + b[l]"); | ||
console.log(" 1) Z[l] = W[l] . A[l-1]:", time); | ||
}, | ||
log_forward_A_2: (time) => { | ||
console.log(" 2) Z[l] = Z[l] + b[l]:", time); | ||
}, | ||
log_forward_B: (time) => { | ||
console.log("B) A[l] = g[l](Z[l]):", time); | ||
}, | ||
// Backward timing | ||
log_backward_Time: () => { | ||
console.log("\n>> Backward algorithm steps time:"); | ||
}, | ||
log_backward_A: (time) => { | ||
console.log("A) dA[L] = L(T, A[L]):", time); | ||
}, | ||
log_backward_B_1: (time) => { | ||
console.log("B) dZ[l] = dA[l] * g'[l](Z[l])"); | ||
console.log(" 1) dZ[l] = g'[l](Z[l]):", time); | ||
}, | ||
log_backward_B_2: (time) => { | ||
console.log(" 2) dZ[l] = dA[l] * dZ[l]:", time); | ||
}, | ||
log_backward_C_1: (time) => { | ||
console.log("C) dW[l] = (1/m) dZ[l] . A[l-1]^T"); | ||
console.log(" 1) dW[l] = dZ[l] . A[l-1]^T:", time); | ||
}, | ||
log_backward_C_2: (time) => { | ||
console.log(" 2) dW[l] = (1/m) dW[l]:", time); | ||
}, | ||
log_backward_D_1: (time) => { | ||
console.log("D) db[l] = (1/m) dZ[l]"); | ||
console.log(" 1) db[l] = SUM(dZ[l], row wise):", time); | ||
}, | ||
log_backward_D_2: (time) => { | ||
console.log(" 2) db[l] = (1/m) db[l]:", time); | ||
}, | ||
log_backward_E: (time) => { | ||
console.log("E) dA[l-1] = W[l]^T . dZ[l]:", time); | ||
}, | ||
log_backward_F_1: (time) => { | ||
console.log("F) W[l] = W[l] - alpha * dW[l]"); | ||
console.log(" 1) dW[l] = alpha * dW[l]:", time); | ||
}, | ||
log_backward_F_2: (time) => { | ||
console.log(" 2) W[l] = W[l] - dW[l]:", time); | ||
}, | ||
log_backward_G_1: (time) => { | ||
console.log("G) b[l] = b[l] - alpha * db[l]"); | ||
console.log(" 1) db[l] = alpha * db[l]:", time); | ||
}, | ||
log_backward_G_2: (time) => { | ||
console.log(" 2) b[l] = b[l] - db[l]:", time); | ||
}, | ||
}; | ||
|
||
let system_imports = { | ||
print: console.log, | ||
time: () => { | ||
return new Date().getTime(); | ||
}, | ||
print_table_f32: (index, rows, cols) => { | ||
if (this.Exports() != null) { | ||
let view = new Float32Array(this.Exports().memory.buffer, index); | ||
let table = []; | ||
for (let r = 0; r < rows; ++r) { | ||
table.push([]); | ||
for (let c = 0; c < cols; ++c) { | ||
table[r].push(view[r * cols + c]); | ||
} | ||
} | ||
console.table(table); | ||
} | ||
} | ||
}; | ||
|
||
let test_imports = { | ||
assert_matrix_eq: (mat1_index, mat2_index, rows, cols) => { | ||
if (this.Exports() != null) { | ||
let mat1 = new Float32Array(this.Exports().memory.buffer, mat1_index, rows * cols); | ||
let mat2 = new Float32Array(this.Exports().memory.buffer, mat2_index, rows * cols); | ||
for (let i = 0; i < rows * cols; i++) { | ||
if (mat1[i] !== mat2[i]) { | ||
console.error("Matrix equality failed!"); | ||
system_imports.print_table_f32(mat1_index, rows, cols); | ||
system_imports.print_table_f32(mat2_index, rows, cols); | ||
return; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
return { | ||
"Math": math_imports, | ||
"Message": message_imports, | ||
"System": system_imports, | ||
"Test": test_imports, | ||
}; | ||
} | ||
} | ||
|
||
var module = module || { exports: {} }; | ||
module.exports.CompiledModel = CompiledModel; |
Oops, something went wrong.