-
Notifications
You must be signed in to change notification settings - Fork 1
/
maskDetector.js
61 lines (53 loc) · 2.34 KB
/
maskDetector.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
window.onload = async() => {
const maskImageCount = 13;
const noMaskImageCount = 10;
const trainImagesContainer = document.querySelector('.train-images');
// Add mask images to the DOM and give them a class of `mask-img`
for (let i = 1; i <= maskImageCount; i++) {
const newImage = document.createElement('IMG');
newImage.setAttribute('src', `images/mask/${i}.jpg`);
newImage.classList.add('mask-img');
trainImagesContainer.appendChild(newImage);
}
// Add no mask images to the DOM and give them a class of `no-mask-img`
for (let i = 1; i <= noMaskImageCount; i++) {
const newImage = document.createElement('IMG');
newImage.setAttribute('src', `images/no_mask/${i}.jpg`);
newImage.classList.add('no-mask-img');
trainImagesContainer.appendChild(newImage);
}
// Load mobilenet module
const mobilenetModule = await mobilenet.load({ version: 2, alpha: 1 });
// Add examples to the KNN Classifier
const classifier = await trainClassifier(mobilenetModule);
// Predict class for the test image
const testImage = document.getElementById('test-img');
const tfTestImage = tf.browser.fromPixels(testImage);
const logits = mobilenetModule.infer(tfTestImage, 'conv_preds');
const prediction = await classifier.predictClass(logits);
// Add a border to the test image to display the prediction result
if (prediction.label == 1) { // no mask - red border
testImage.classList.add('no-mask');
} else { // has mask - green border
testImage.classList.add('mask');
}
};
async function trainClassifier(mobilenetModule) {
// Create a new KNN Classifier
const classifier = knnClassifier.create();
// Train using mask images
const maskImages = document.querySelectorAll('.mask-img');
maskImages.forEach(img => {
const tfImg = tf.browser.fromPixels(img);
const logits = mobilenetModule.infer(tfImg, 'conv_preds');
classifier.addExample(logits, 0); // has mask
});
// Train using no mask images
const noMaskImages = document.querySelectorAll('.no-mask-img');
noMaskImages.forEach(img => {
const tfImg = tf.browser.fromPixels(img);
const logits = mobilenetModule.infer(tfImg, 'conv_preds');
classifier.addExample(logits, 1); // no mask
});
return classifier;
}