Skip to content

Commit

Permalink
Merge pull request #437 from microbit-foundation/430-make-a-knnmodel-…
Browse files Browse the repository at this point in the history
…and-a-knnmodeltrainer

Added a KNN model and KNN model trainer
  • Loading branch information
Karlo-Emilo authored Jan 6, 2024
2 parents 4f22ef4 + 81d22c6 commit 16ed3ae
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 16 deletions.
10 changes: 10 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
},
"dependencies": {
"@microsoft/applicationinsights-web": "^3.0.0",
"@tensorflow-models/knn-classifier": "^1.2.6",
"@tensorflow/tfjs": "^4.4.0",
"bowser": "^2.11.0",
"browser-lang": "^0.2.1",
Expand Down
18 changes: 4 additions & 14 deletions src/components/playground/EngineInteractionButtons.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,19 @@
SPDX-License-Identifier: MIT
-->
<script lang="ts">
import StaticConfiguration from '../../StaticConfiguration';
import Gesture from '../../script/domain/Gesture';
import Model from '../../script/domain/Model';
import AccelerometerClassifierInput from '../../script/mlmodels/AccelerometerClassifierInput';
import LayersModelTrainer from '../../script/mlmodels/LayersModelTrainer';
import { classifier, engine, gestures } from '../../script/stores/Stores';
import playgroundContext from './PlaygroundContext';
import TrainKnnModelButton from './TrainKNNModelButton.svelte';
import TrainLayersModelButton from './TrainLayersModelButton.svelte';
const getRandomGesture = (): Gesture => {
return gestures.getGestures()[
Math.floor(Math.random() * gestures.getNumberOfGestures())
];
};
const model: Model = classifier.getModel();
const trainModelButtonClicked = () => {
playgroundContext.addMessage('training model...');
model
.train(new LayersModelTrainer(StaticConfiguration.layersModelTrainingSettings))
.then(() => {
playgroundContext.addMessage('Finished training!');
});
};
const predictButtonClicked = () => {
const randGesture = getRandomGesture();
playgroundContext.addMessage(
Expand All @@ -43,7 +32,8 @@
};
</script>

<button class="border-1 p-2 m-1" on:click={trainModelButtonClicked}>train model!</button>
<TrainLayersModelButton />
<TrainKnnModelButton />
<button class="border-1 p-2 m-1" on:click={predictButtonClicked}
>Predict random gesture!</button>
<button
Expand Down
24 changes: 24 additions & 0 deletions src/components/playground/TrainKNNModelButton.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<!--
(c) 2023, Center for Computational Thinking and Design at Aarhus University and contributors
SPDX-License-Identifier: MIT
-->
<script lang="ts">
import Model from '../../script/domain/Model';
import KNNModelTrainer from '../../script/mlmodels/KNNModelTrainer';
import { classifier } from '../../script/stores/Stores';
import playgroundContext from './PlaygroundContext';
const model: Model = classifier.getModel();
const trainModelButtonClicked = () => {
playgroundContext.addMessage('training model...');
const k = 10;
model.train(new KNNModelTrainer(k)).then(() => {
playgroundContext.addMessage('Finished training a KNN model (k=10)!');
});
};
</script>

<button class="border-1 p-2 m-1" on:click={trainModelButtonClicked}>
train KNN model!
</button>
25 changes: 25 additions & 0 deletions src/components/playground/TrainLayersModelButton.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<!--
(c) 2023, Center for Computational Thinking and Design at Aarhus University and contributors
SPDX-License-Identifier: MIT
-->
<script lang="ts">
import StaticConfiguration from '../../StaticConfiguration';
import Model from '../../script/domain/Model';
import LayersModelTrainer from '../../script/mlmodels/LayersModelTrainer';
import { classifier } from '../../script/stores/Stores';
import playgroundContext from './PlaygroundContext';
const model: Model = classifier.getModel();
const trainModelButtonClicked = () => {
playgroundContext.addMessage('training model...');
model
.train(new LayersModelTrainer(StaticConfiguration.layersModelTrainingSettings))
.then(() => {
playgroundContext.addMessage('Finished training!');
});
};
</script>

<button class="border-1 p-2 m-1" on:click={trainModelButtonClicked}
>train layers model!</button>
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ class AccelerometerSynthesizer implements Readable<AccelerometerSynthesizerData>
xSpeed: this.getInitialSineSpeed(),
ySpeed: this.getInitialSineSpeed() + 1 / 1000,
zSpeed: this.getInitialSineSpeed() + 2 / 1000,
isActive: true,
isActive: false,
});
this.updateInterval();
}

public subscribe(
Expand Down
40 changes: 40 additions & 0 deletions src/script/mlmodels/KNNMLModel.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/**
* (c) 2023, Center for Computational Thinking and Design at Aarhus University and contributors
*
* SPDX-License-Identifier: MIT
*/
import { tensor } from '@tensorflow/tfjs';
import MLModel from '../domain/MLModel';
import * as knnClassifier from '@tensorflow-models/knn-classifier';

class KNNMLModel implements MLModel {
constructor(
private model: knnClassifier.KNNClassifier,
private k: number,
) {}
public async predict(filteredData: number[]): Promise<number[]> {
const inputTensor = tensor([filteredData]);

try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-call
const prediction = await this.model.predictClass(inputTensor, this.k);
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
const classes = Object.getOwnPropertyNames(prediction.confidences);
const confidences: number[] = [];

for (let i = 0; i < classes.length; i++) {
const clazz = classes[i];
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-member-access
const confidence = prediction.confidences[clazz];
confidences.push(confidence as number);
}

return Promise.resolve(confidences);
} catch (err) {
console.error('Prediction error: ', err);
return Promise.reject(err);
}
}
}

export default KNNMLModel;
32 changes: 32 additions & 0 deletions src/script/mlmodels/KNNModelTrainer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/**
* (c) 2023, Center for Computational Thinking and Design at Aarhus University and contributors
*
* SPDX-License-Identifier: MIT
*/
import ModelTrainer, { TrainingData } from '../domain/ModelTrainer';
import KNNMLModel from './KNNMLModel';
import * as tf from '@tensorflow/tfjs';
import * as knnClassifier from '@tensorflow-models/knn-classifier';

/**
* Trains a K-Nearest Neighbour model
*/
class KNNModelTrainer implements ModelTrainer<KNNMLModel> {
constructor(private k: number) {}
public trainModel(trainingData: TrainingData): Promise<KNNMLModel> {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-call
const knn: knnClassifier.KNNClassifier = knnClassifier.create();

trainingData.classes.forEach((gestureClass, index) => {
gestureClass.samples.forEach(sample => {
const example = tf.tensor([sample.value]);
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-call
knn.addExample(example, index);
});
});

return Promise.resolve(new KNNMLModel(knn, this.k));
}
}

export default KNNModelTrainer;

0 comments on commit 16ed3ae

Please sign in to comment.