Skip to content

Commit

Permalink
Merge pull request #510 from microbit-foundation/refactor-confidence
Browse files Browse the repository at this point in the history
Refactor confidence
  • Loading branch information
r59q authored Jul 30, 2024
2 parents 25b21e7 + 01d8ef0 commit 0aa8618
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 60 deletions.
5 changes: 3 additions & 2 deletions src/pages/training/KnnModelTrainingPageView.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import { onMount } from 'svelte';
import { knnConfig } from '../../script/stores/knnConfig';
const classifier = stores.getClassifier();
const confidences = stores.getConfidences();
const gestures = stores.getGestures();
const confidences = gestures.getConfidences();
//const confidences = gestures.getConfidences();
const filters = classifier.getFilters();
onMount(() => {
Expand Down Expand Up @@ -85,7 +86,7 @@
</div>
{#if $state.isInputReady}
<p>
{(($confidences.get(gesture.ID)?.currentConfidence ?? 0) * 100).toFixed(2)}%
{(($confidences.get(gesture.ID) ?? 0) * 100).toFixed(2)}%
</p>
{/if}
</div>
Expand Down
3 changes: 3 additions & 0 deletions src/script/domain/ClassifierRepository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
* SPDX-License-Identifier: MIT
*/
import Classifier from './stores/Classifier';
import Confidences from './stores/Confidences';
import GestureConfidence from './stores/gesture/GestureConfidence';

interface ClassifierRepository {
getClassifier(): Classifier;

getGestureConfidence(gestureId: number): GestureConfidence;

getConfidences(): Confidences;
}

export default ClassifierRepository;
1 change: 1 addition & 0 deletions src/script/domain/Repositories.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
import ClassifierRepository from './ClassifierRepository';
import GestureRepository from './GestureRepository';
import Confidences from './stores/Confidences';

interface Repositories {
getGestureRepository(): GestureRepository;
Expand Down
48 changes: 48 additions & 0 deletions src/script/domain/stores/Confidences.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
* (c) 2023, Center for Computational Thinking and Design at Aarhus University and contributors
*
* SPDX-License-Identifier: MIT
*/
import {
Readable,
Subscriber,
Unsubscriber,
Writable,
get,
writable,
} from 'svelte/store';
import { GestureID } from './gesture/Gesture';

type GestureConfidenceMap = Map<GestureID, number>;

class Confidences implements Readable<GestureConfidenceMap> {
private confidenceStore: Writable<GestureConfidenceMap>;

constructor() {
this.confidenceStore = writable(new Map<GestureID, number>());
}

public subscribe(
run: Subscriber<GestureConfidenceMap>,
invalidate?: ((value?: GestureConfidenceMap | undefined) => void) | undefined,
): Unsubscriber {
return this.confidenceStore.subscribe(run, invalidate);
}

public setConfidence(gestureId: GestureID, confidence: number): void {
this.confidenceStore.update((map: GestureConfidenceMap) => {
map.set(gestureId, confidence);
return map;
});
}

public getConfidence(gestureId: GestureID): number {
const confidence = get(this.confidenceStore).get(gestureId);
if (confidence === undefined) {
throw new Error(`No confidence value found for gesture with ID ${gestureId}`);
}
return confidence;
}
}

export default Confidences;
34 changes: 1 addition & 33 deletions src/script/domain/stores/gesture/Gestures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,41 +37,13 @@ export type RecordingData = {
class Gestures implements Readable<GestureData[]> {
private static subscribableGestures: Writable<Gesture[]>;
private repository: GestureRepository;
private confidenceStore: Readable<Map<number, Confidence>>;

constructor(classifierRepository: ClassifierRepository, repository: GestureRepository) {
constructor(repository: GestureRepository) {
this.repository = repository;
Gestures.subscribableGestures = writable();
this.repository.subscribe(storeArray => {
Gestures.subscribableGestures.set(storeArray);
});

this.confidenceStore = derived([this, ...this.getGestures()], stores => {
const confidenceMap: Map<number, Confidence> = new Map();

const [_, ...gestureStores] = stores;
const thiz = stores[0] as GestureData[];
thiz.forEach(gesture => {
// TODO: The following ought to be fixed. See https://github.com/microbit-foundation/cctd-ml-machine/issues/508
const store = gestureStores.find(store => store.ID === gesture.ID)
?.confidence || {
currentConfidence: 0,
requiredConfidence: 0,
isConfident: false,
};
confidenceMap.set(gesture.ID, {
...store,
currentConfidence: classifierRepository
.getGestureConfidence(gesture.ID)
.getCurrentConfidence(),
});
});

/*gestureStores.forEach(store => {
confidenceMap.set(store.ID, store.confidence);
});*/
return confidenceMap;
});
}

public subscribe(
Expand Down Expand Up @@ -159,10 +131,6 @@ class Gestures implements Readable<GestureData[]> {
);
}

public getConfidences(): Readable<Map<number, Confidence>> {
return this.confidenceStore;
}

private addGestureFromPersistedData(gestureData: PersistantGestureData): Gesture {
return this.repository.addGesture(gestureData);
}
Expand Down
33 changes: 13 additions & 20 deletions src/script/repository/LocalStorageClassifierRepository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,20 @@ import ClassifierRepository from '../domain/ClassifierRepository';
import Gesture, { GestureID } from '../domain/stores/gesture/Gesture';
import Classifier from '../domain/stores/Classifier';
import GestureConfidence from '../domain/stores/gesture/GestureConfidence';
import Confidences from '../domain/stores/Confidences';

export type TrainerConsumer = <T extends MLModel>(
trainer: ModelTrainer<T>,
) => Promise<void>;

class LocalStorageClassifierRepository implements ClassifierRepository {
private static readonly PERSISTANT_FILTERS_KEY = 'filters';
private static confidences: Writable<Map<GestureID, number>>;
private static mlModel: Writable<MLModel | undefined>;
private static filters: Filters;
private static persistedFilters: PersistantWritable<FilterType[]>;
private classifierFactory: ClassifierFactory;

constructor() {
const initialConfidence = new Map<GestureID, number>();
LocalStorageClassifierRepository.confidences = writable(initialConfidence);
constructor(private confidences: Confidences) {
LocalStorageClassifierRepository.mlModel = writable(undefined);
LocalStorageClassifierRepository.persistedFilters = new PersistantWritable(
FilterTypes.toIterable(),
Expand Down Expand Up @@ -82,9 +80,7 @@ class LocalStorageClassifierRepository implements ClassifierRepository {
if (confidence < 0 || confidence > 1) {
throw new Error('Cannot set gesture confidence. Must be in the range 0.0-1.0');
}
const newConfidences = get(LocalStorageClassifierRepository.confidences);
newConfidences.set(gestureId, confidence);
LocalStorageClassifierRepository.confidences.set(newConfidences);
this.confidences.setConfidence(gestureId, confidence);
}

private getFilters(): Writable<Filter[]> {
Expand Down Expand Up @@ -113,28 +109,25 @@ class LocalStorageClassifierRepository implements ClassifierRepository {
}

public getGestureConfidence(gestureId: number): GestureConfidence {
const derivedConfidence = derived(
[LocalStorageClassifierRepository.confidences],
stores => {
const confidenceStore = stores[0];
if (confidenceStore.has(gestureId)) {
return confidenceStore.get(gestureId) as number;
}
throw new Error("No confidence found for gesture with id '" + gestureId + "'");
},
);
const derivedConfidence = derived([this.confidences], stores => {
const confidenceStore = stores[0];
if (confidenceStore.has(gestureId)) {
return confidenceStore.get(gestureId) as number;
}
throw new Error("No confidence found for gesture with id '" + gestureId + "'");
});
return new GestureConfidence(
StaticConfiguration.defaultRequiredConfidence,
derivedConfidence,
);
}

public hasGestureConfidence(gestureId: number): boolean {
return get(LocalStorageClassifierRepository.confidences).has(gestureId);
return get(this.confidences).has(gestureId);
}

public getConfidences(): Writable<Map<GestureID, number>> {
return LocalStorageClassifierRepository.confidences;
public getConfidences(): Confidences {
return this.confidences;
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/script/repository/LocalStorageRepositories.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import LocalStorageGestureRepository from './LocalStorageGestureRepository';
import LocalStorageClassifierRepository from './LocalStorageClassifierRepository';
import Repositories from '../domain/Repositories';
import Confidences from '../domain/stores/Confidences';

class LocalStorageRepositories implements Repositories {
private gestureRepository: LocalStorageGestureRepository;
Expand All @@ -20,7 +21,8 @@ class LocalStorageRepositories implements Repositories {
throw new Error('Could not instantiate repository. It is already instantiated!');
}
LocalStorageRepositories.instance = this;
this.classifierRepository = new LocalStorageClassifierRepository();
const confidences = new Confidences();
this.classifierRepository = new LocalStorageClassifierRepository(confidences);
this.gestureRepository = new LocalStorageGestureRepository(this.classifierRepository);
}

Expand Down
12 changes: 8 additions & 4 deletions src/script/stores/Stores.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import Gestures from '../domain/stores/gesture/Gestures';
import PollingPredictorEngine from '../engine/PollingPredictorEngine';
import LocalStorageRepositories from '../repository/LocalStorageRepositories';
import Logger from '../utils/Logger';
import Confidences from '../domain/stores/Confidences';

type StoresType = {
liveData: LiveData<LiveDataVector>;
Expand All @@ -35,16 +36,15 @@ class Stores implements Readable<StoresType> {
private engine: Engine | undefined;
private classifier: Classifier;
private gestures: Gestures;
private confidences: Confidences;

public constructor() {
this.liveData = writable(undefined);
this.engine = undefined;
const repositories: Repositories = new LocalStorageRepositories();
this.classifier = repositories.getClassifierRepository().getClassifier();
this.gestures = new Gestures(
repositories.getClassifierRepository(),
repositories.getGestureRepository(),
);
this.confidences = repositories.getClassifierRepository().getConfidences();
this.gestures = new Gestures(repositories.getGestureRepository());
}

public subscribe(
Expand Down Expand Up @@ -94,6 +94,10 @@ class Stores implements Readable<StoresType> {
}
return this.engine;
}

public getConfidences(): Confidences {
return this.confidences;
}
}

export const stores = new Stores();

0 comments on commit 0aa8618

Please sign in to comment.