Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

530 add axis highlighting visually2 #547

Merged
merged 9 commits into from
Nov 12, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Move highlighted axis into stores!
r59q committed Oct 30, 2024
commit f92d53ce0339f82328c35e3026d387d71d65cf76
10 changes: 3 additions & 7 deletions src/components/graphs/MicrobitLiveGraph.svelte
Original file line number Diff line number Diff line change
@@ -7,26 +7,22 @@
<script lang="ts">
import StaticConfiguration from '../../StaticConfiguration';
import { Feature, hasFeature } from '../../script/FeatureToggles';
import Axes from '../../script/domain/Axes';
import { stores } from '../../script/stores/Stores';
import { highlightedAxis } from '../../script/stores/uiStore';
import LiveGraph from './LiveGraph.svelte';
//axis={Axes.X}
const highlightedAxis = stores.getHighlightedAxis();
export let width: number;
$: showhighlit = hasFeature(Feature.KNN_MODEL) && $highlightedAxis !== undefined;
console.log(hasFeature(Feature.KNN_MODEL) && $highlightedAxis !== undefined);
$: highlightedVectorIndex =
$highlightedAxis === Axes.X ? 0 : $highlightedAxis === Axes.Y ? 1 : 2;
</script>

{#if showhighlit}
{#key highlightedVectorIndex}
{#key $highlightedAxis}
<LiveGraph
minValue={StaticConfiguration.liveGraphValueBounds.min}
maxValue={StaticConfiguration.liveGraphValueBounds.max}
liveData={$stores.liveData}
highlightVectorIndex={highlightedVectorIndex}
highlightVectorIndex={$highlightedAxis}
{width} />
{/key}
{:else}
36 changes: 22 additions & 14 deletions src/components/graphs/knngraph/AxesFilterVectorView.svelte
Original file line number Diff line number Diff line change
@@ -7,10 +7,8 @@
<script lang="ts">
import { Unsubscriber, derived, get } from 'svelte/store';
import StaticConfiguration from '../../../StaticConfiguration';
import Axes from '../../../script/domain/Axes';
import { extractAxisFromAccelerometerData } from '../../../script/utils/graphUtils';
import StandardButton from '../../buttons/StandardButton.svelte';
import { highlightedAxis } from '../../../script/stores/uiStore';
import arrowCreate from 'arrows-svg';
import { onMount } from 'svelte';
import { vectorArrows } from './AxesFilterVector';
@@ -20,6 +18,7 @@
const classifier = stores.getClassifier();
$: liveData = $stores.liveData;
const highlightedAxis = stores.getHighlightedAxis();
const drawArrows = (fromId: string) => {
get(vectorArrows).forEach(arr => arr.clear());
@@ -48,19 +47,19 @@
});
};
const updateArrows = (axis: Axes | undefined) => {
const updateArrows = (axis: number | undefined) => {
if (axis) {
const getId = (): string => {
if ($highlightedAxis === Axes.X) {
if ($highlightedAxis === 0) {
return 'fromX';
}
if ($highlightedAxis === Axes.Y) {
if ($highlightedAxis === 1) {
return 'fromY';
}
if ($highlightedAxis === Axes.Z) {
if ($highlightedAxis === 2) {
return 'fromZ';
}
throw Error('This shouldnt happen');
throw Error('Cannot update arrows for axis ' + axis);
};
drawArrows(getId());
}
@@ -83,7 +82,7 @@
const filteredSeries = stores
.getClassifier()
.getFilters()
.compute(extractAxisFromAccelerometerData(series, get(highlightedAxis)!));
.compute(extractAxisFromAccelerometerData(series, $highlightedAxis!));
return filteredSeries;
} catch (e) {
return Array(classifier.getFilters().count()).fill(0);
@@ -142,22 +141,31 @@
<StandardButton
color={StaticConfiguration.liveGraphColors[0]}
small
outlined={$highlightedAxis !== Axes.X}
onClick={() => ($highlightedAxis = Axes.X)}>X</StandardButton>
outlined={$highlightedAxis !== 0}
onClick={() => {
$highlightedAxis = 0
stores.getHighlightedAxis().set(0);
}}>X</StandardButton>
</div>
<div class="flex flex-row space-x-2" id="fromY">
<StandardButton
color={StaticConfiguration.liveGraphColors[1]}
small
outlined={$highlightedAxis !== Axes.Y}
onClick={() => ($highlightedAxis = Axes.Y)}>Y</StandardButton>
outlined={$highlightedAxis !== 1}
onClick={() => {
$highlightedAxis = 1
stores.getHighlightedAxis().set(1);
}}>Y</StandardButton>
</div>
<div class="flex flex-row space-x-2" id="fromZ">
<StandardButton
color={StaticConfiguration.liveGraphColors[2]}
small
outlined={$highlightedAxis !== Axes.Z}
onClick={() => ($highlightedAxis = Axes.Z)}>Z</StandardButton>
outlined={$highlightedAxis !== 2}
onClick={() => {
$highlightedAxis = 2
stores.getHighlightedAxis().set(2);
}}>Z</StandardButton>
</div>
</div>
<div class="pl-20 flex flex-col justify-around">
11 changes: 5 additions & 6 deletions src/components/graphs/knngraph/KNNModelGraphController.ts
Original file line number Diff line number Diff line change
@@ -11,7 +11,6 @@ import {
MicrobitAccelerometerDataVector,
} from '../../../script/livedata/MicrobitAccelerometerData';
import { TimestampedData } from '../../../script/domain/LiveDataBuffer';
import Axes from '../../../script/domain/Axes';
import Filters from '../../../script/domain/Filters';
import { Point3D } from '../../../script/utils/graphUtils';
import StaticConfiguration from '../../../StaticConfiguration';
@@ -54,7 +53,7 @@ class KNNModelGraphController {
origin: { x: number; y: number },
classId: string,
colors: string[],
axis?: Axes,
axis?: number,
) {
this.filters = stores.getClassifier().getFilters();
this.trainingData = this.trainingDataToPoints();
@@ -229,16 +228,16 @@ class KNNModelGraphController {
}

// Called whenever any subscribed store is altered
private onUpdate(draw: UpdateCall, axis?: Axes) {
private onUpdate(draw: UpdateCall, axis?: number) {
let data: TimestampedData<MicrobitAccelerometerDataVector>[] = draw.data;

const getLiveFilteredData = () => {
switch (axis) {
case Axes.X:
case 0:
return this.filters.compute(data.map(d => d.value.getAccelerometerData().x));
case Axes.Y:
case 1:
return this.filters.compute(data.map(d => d.value.getAccelerometerData().y));
case Axes.Z:
case 2:
return this.filters.compute(data.map(d => d.value.getAccelerometerData().z));
default:
throw new Error("Shouldn't happen");
18 changes: 8 additions & 10 deletions src/components/graphs/knngraph/KnnModelGraph.svelte
Original file line number Diff line number Diff line change
@@ -10,21 +10,19 @@
import ClassifierFactory from '../../../script/domain/ClassifierFactory';
import KnnModelGraphSvgWithControls from './KnnModelGraphSvgWithControls.svelte';
import { extractAxisFromTrainingData } from '../../../script/utils/graphUtils';
import Axes from '../../../script/domain/Axes';
import { TrainingData } from '../../../script/domain/ModelTrainer';
import { highlightedAxis } from '../../../script/stores/uiStore';
import KnnPointToolTipView from './KnnPointToolTipView.svelte';
import { stores } from '../../../script/stores/Stores';
import { get } from 'svelte/store';
import StaticConfiguration from '../../../StaticConfiguration';
import Filters from '../../../script/domain/Filters';
import { FilterType } from '../../../script/domain/FilterTypes';
const classifierFactory = new ClassifierFactory();
const classifier = stores.getClassifier();
const gestures = stores.getGestures();
const filters = classifier.getFilters();
const highlightedAxis = stores.getHighlightedAxis();
const canvasWidth = 450;
const canvasHeight = 300;
@@ -40,20 +38,20 @@
const accelZData = extractAxisFromTrainingData(allData, 2, 3);
const dataGetter = (): TrainingData => {
const axis = get(highlightedAxis);
if (axis === Axes.X) {
const axis = $highlightedAxis;
if (axis === 0) {
return accelXData;
}
if (axis === Axes.Y) {
if (axis === 1) {
return accelYData;
}
if (axis === Axes.Z) {
if (axis === 2) {
return accelZData;
}
throw new Error('Should not happen');
throw new Error('Cannot get data for axis ' + axis);
};
const initSingle = (axis: Axes) => {
const initSingle = (axis: number) => {
const svgSingle = d3.select('.d3-3d-single');
const graphColors = [
...$gestures.map(data => data.color),
@@ -88,7 +86,7 @@
});
onMount(() => {
controller.set(initSingle(Axes.X));
controller.set(initSingle(0));
return () => {
get(controller)?.destroy();
};
9 changes: 3 additions & 6 deletions src/pages/training/KnnModelTrainingPageView.svelte
Original file line number Diff line number Diff line change
@@ -5,30 +5,27 @@
-->
<script lang="ts">
import { stores } from '../../script/stores/Stores';
import { highlightedAxis, state } from '../../script/stores/uiStore';
import { state } from '../../script/stores/uiStore';
import AxesFilterVectorView from '../../components/graphs/knngraph/AxesFilterVectorView.svelte';
import { trainModel } from './TrainingPage';
import ModelRegistry from '../../script/domain/ModelRegistry';
import KnnModelGraph from '../../components/graphs/knngraph/KnnModelGraph.svelte';
import StaticConfiguration from '../../StaticConfiguration';
import Axes from '../../script/domain/Axes';
import { t } from '../../i18n';
import { onMount } from 'svelte';
import { knnConfig } from '../../script/stores/knnConfig';
import Logger from '../../script/utils/Logger';
const classifier = stores.getClassifier();
const confidences = stores.getConfidences();
const gestures = stores.getGestures();
//const confidences = gestures.getConfidences();
const filters = classifier.getFilters();
const highlightedAxis = stores.getHighlightedAxis();
onMount(() => {
trainModel(ModelRegistry.KNN);
return () => unsubscribe();
});
$: {
if ($highlightedAxis === undefined) {
$highlightedAxis = Axes.X;
$highlightedAxis = 0;
}
if (!$classifier.model.isTrained) {
trainModel(ModelRegistry.KNN);
2 changes: 1 addition & 1 deletion src/pages/training/NeuralNetworkTrainingPageView.svelte
Original file line number Diff line number Diff line change
@@ -14,10 +14,10 @@
import Logger from '../../script/utils/Logger';
import { Feature, hasFeature } from '../../script/FeatureToggles';
import { onMount } from 'svelte';
import { highlightedAxis } from '../../script/stores/uiStore';
const classifier = stores.getClassifier();
const model = classifier.getModel();
const highlightedAxis = stores.getHighlightedAxis();
const trainModelClickHandler = () => {
trainModel(ModelRegistry.NeuralNetwork).then(() => {
6 changes: 2 additions & 4 deletions src/pages/training/TrainModelButton.svelte
Original file line number Diff line number Diff line change
@@ -11,11 +11,8 @@
import { Feature, hasFeature } from '../../script/FeatureToggles';
import StandardButton from '../../components/buttons/StandardButton.svelte';
import { Writable } from 'svelte/store';
import { highlightedAxis } from '../../script/stores/uiStore';
import { stores } from '../../script/stores/Stores';
import { options, trainModel } from './TrainModelButton';
import Axes from '../../script/domain/Axes';
import ModelRegistry, { ModelInfo } from '../../script/domain/ModelRegistry';
import { LossTrainingIteration } from '../../script/mlmodels/LayersModelTrainer';
@@ -25,6 +22,7 @@
export let selectedOption: Writable<DropdownOption>;
const classifier = stores.getClassifier();
const highlightedAxis = stores.getHighlightedAxis();
const model = classifier.getModel();
@@ -59,7 +57,7 @@
$: {
if ($selectedOption.id === 'KNN' && !$highlightedAxis) {
highlightedAxis.set(Axes.X);
highlightedAxis.set(0);
}
if ($selectedOption.id === 'NN' && $highlightedAxis) {
highlightedAxis.set(undefined);
5 changes: 2 additions & 3 deletions src/pages/training/TrainModelButton.ts
Original file line number Diff line number Diff line change
@@ -5,10 +5,8 @@
*/
import { Writable, get } from 'svelte/store';
import { DropdownOption } from '../../components/buttons/Buttons';
import { highlightedAxis } from '../../script/stores/uiStore';
import ModelTrainer from '../../script/domain/ModelTrainer';
import MLModel from '../../script/domain/MLModel';
import Axes from '../../script/domain/Axes';
import { stores } from '../../script/stores/Stores';
import StaticConfiguration from '../../StaticConfiguration';
import KNNNonNormalizedModelTrainer from '../../script/mlmodels/KNNNonNormalizedModelTrainer';
@@ -22,6 +20,7 @@ import ModelRegistry, { ModelInfo } from '../../script/domain/ModelRegistry';
import { knnConfig } from '../../script/stores/knnConfig';

const classifier = stores.getClassifier();
const highlightedAxis = stores.getHighlightedAxis();

export const options: DropdownOption[] = ModelRegistry.getModels().map(model => {
return {
@@ -36,7 +35,7 @@ export const getModelTrainer = (
): ModelTrainer<MLModel> => {
const currentAxis = get(highlightedAxis);
if (model.id === ModelRegistry.KNN.id) {
const offset = currentAxis === Axes.X ? 0 : currentAxis === Axes.Y ? 1 : 2;
const offset = currentAxis === 0 ? 0 : currentAxis === 1 ? 1 : 2; // TODO: Rewrite to use just use the axis as offset directly
return new KNNNonNormalizedModelTrainer(get(knnConfig).k, data =>
extractAxisFromTrainingData(data, offset, 3),
);
14 changes: 7 additions & 7 deletions src/pages/training/TrainingPage.ts
Original file line number Diff line number Diff line change
@@ -4,8 +4,7 @@
* SPDX-License-Identifier: MIT
*/
import { get, writable } from 'svelte/store';
import { highlightedAxis, selectedModel } from '../../script/stores/uiStore';
import Axes from '../../script/domain/Axes';
import { selectedModel } from '../../script/stores/uiStore';
import KNNNonNormalizedModelTrainer from '../../script/mlmodels/KNNNonNormalizedModelTrainer';
import StaticConfiguration from '../../StaticConfiguration';
import { extractAxisFromTrainingData } from '../../script/utils/graphUtils';
@@ -29,7 +28,7 @@ const trainingIterationHandler = (h: LossTrainingIteration) => {
};

const trainNNModel = async () => {
highlightedAxis.set(undefined);
stores.getHighlightedAxis().set(undefined);
loss.set([]);
const modelTrainer = new LayersModelTrainer(
StaticConfiguration.layersModelTrainingSettings,
@@ -39,11 +38,12 @@ const trainNNModel = async () => {
};

const trainKNNModel = async () => {
if (get(highlightedAxis) === undefined) {
highlightedAxis.set(Axes.X);
if (get(stores.getHighlightedAxis()) === undefined) {
stores.getHighlightedAxis().set(0);
}
const currentAxis = get(highlightedAxis);
const offset = currentAxis === Axes.X ? 0 : currentAxis === Axes.Y ? 1 : 2;
const currentAxis = get(stores.getHighlightedAxis());
// TODO: Rewrite offset to use the axis directly instead
const offset = currentAxis === 0 ? 0 : currentAxis === 1 ? 1 : 2;
const modelTrainer = new KNNNonNormalizedModelTrainer(
get(knnConfig).k,
data => {
12 changes: 0 additions & 12 deletions src/script/domain/Axes.ts

This file was deleted.

19 changes: 0 additions & 19 deletions src/script/domain/MLModelFactory.ts

This file was deleted.

11 changes: 5 additions & 6 deletions src/script/mlmodels/AccelerometerClassifierInput.ts
Original file line number Diff line number Diff line change
@@ -6,8 +6,7 @@
import { get } from 'svelte/store';
import ClassifierInput from '../domain/ClassifierInput';
import Filters from '../domain/Filters';
import { highlightedAxis } from '../stores/uiStore';
import Axes from '../domain/Axes';
import { stores } from '../stores/Stores';

class AccelerometerClassifierInput implements ClassifierInput {
constructor(
@@ -18,15 +17,15 @@ class AccelerometerClassifierInput implements ClassifierInput {

public getInput(filters: Filters): number[] {
// TODO: Bad! How should we go about deciding what axes are provided for prediction when axes are highlighted?
const axis = get(highlightedAxis);
const axis = get(stores.getHighlightedAxis());
if (axis) {
if (axis === Axes.X) {
if (axis === 0) {
return [...filters.compute(this.xs)];
}
if (axis === Axes.Y) {
if (axis === 1) {
return [...filters.compute(this.ys)];
}
if (axis === Axes.Z) {
if (axis === 2) {
return [...filters.compute(this.zs)];
}
}
6 changes: 6 additions & 0 deletions src/script/stores/Stores.ts
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@ class Stores implements Readable<StoresType> {
private classifier: Classifier;
private gestures: Gestures;
private confidences: Confidences;
private highlightedAxis: Writable<number | undefined>;

public constructor() {
this.liveData = writable(undefined);
@@ -45,6 +46,7 @@ class Stores implements Readable<StoresType> {
this.classifier = repositories.getClassifierRepository().getClassifier();
this.confidences = repositories.getClassifierRepository().getConfidences();
this.gestures = new Gestures(repositories.getGestureRepository());
this.highlightedAxis = writable(undefined);
}

public subscribe(
@@ -98,6 +100,10 @@ class Stores implements Readable<StoresType> {
public getConfidences(): Confidences {
return this.confidences;
}

public getHighlightedAxis(): Writable<number | undefined> {
return this.highlightedAxis;
}
}

export const stores = new Stores();
5 changes: 0 additions & 5 deletions src/script/stores/uiStore.ts
Original file line number Diff line number Diff line change
@@ -14,7 +14,6 @@ import { DeviceRequestStates } from './connectDialogStore';
import CookieManager from '../CookieManager';
import { isInputPatternValid } from './connectionStore';
import Gesture from '../domain/stores/gesture/Gesture';
import Axes from '../domain/Axes';
import PersistantWritable from '../repository/PersistantWritable';
import { DropdownOption } from '../../components/buttons/Buttons';
import { stores } from './Stores';
@@ -136,10 +135,6 @@ export const selectedModel = new PersistantWritable<ModelInfo>(
'selectedModel',
);

// TODO: Should probably be elsewhere
export const prevHighlightedAxis = writable<Axes | undefined>(undefined);
export const highlightedAxis = writable<Axes | undefined>(undefined);

const initialMicrobitInteraction: MicrobitInteractions = MicrobitInteractions.AB;

export const microbitInteraction = writable<MicrobitInteractions>(
10 changes: 5 additions & 5 deletions src/script/utils/graphUtils.ts
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
* SPDX-License-Identifier: MIT
*/

import Axes from '../domain/Axes';
import { TrainingData } from '../domain/ModelTrainer';
import { MicrobitAccelerometerData } from '../livedata/MicrobitAccelerometerData';

@@ -101,16 +100,17 @@ export const extractFilterFromTrainingData = (

export const extractAxisFromAccelerometerData = (
data: MicrobitAccelerometerData[],
axis: Axes,
axis: number,
) => {
switch (axis) {
case Axes.X:
case 0:
return data.map(val => val.x);
case Axes.Y:
case 1:
return data.map(val => val.y);
case Axes.Z:
case 2:
return data.map(val => val.z);
}
throw new Error(`Cannot extract from axis ${axis}`)
};

export const distanceBetween = (point1: Point3D, point2: Point3D): number => {