diff --git a/README.md b/README.md index 44d46aa6..b1180f80 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,10 @@ which uses Useful files (sample data, biomonitoring stations, icons) can be found on our [Drive](https://drive.google.com/drive/folders/1eQWuf5WCT429xogQ2HiZqapehvweAtxP) (Appsilon internal). +To be able to run inference, +download the [ONNX models](https://drive.google.com/drive/folders/1jQppnRm_4kLDPspg0fAulK_MK_RvxIEB) +and place them in the `assets/models` directory. + ### Commands You can run the following commands from the project root. diff --git a/package-lock.json b/package-lock.json index e39ddc48..d0bf3dc7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -22,7 +22,6 @@ "i18next": "^21.8.16", "lodash": "^4.17.21", "mapbox-gl": "^1.13.2", - "ndarray": "^1.0.19", "node-machine-id": "^1.1.12", "react": "^18.2.0", "react-dom": "^18.2.0", @@ -41,7 +40,6 @@ "@types/jest": "28.1.8", "@types/lodash": "4.14.185", "@types/mapbox-gl": "1.13.4", - "@types/ndarray": "1.0.11", "@types/node": "18.6.3", "@types/react": "18.0.20", "@types/react-dom": "18.0.6", @@ -2548,12 +2546,6 @@ "integrity": "sha512-iiUgKzV9AuaEkZqkOLDIvlQiL6ltuZd9tGcW3gwpnX8JbuiuhFlEGmmFXEXkN50Cvq7Os88IY2v0dkDqXYWVgA==", "dev": true }, - "node_modules/@types/ndarray": { - "version": "1.0.11", - "resolved": "https://registry.npmjs.org/@types/ndarray/-/ndarray-1.0.11.tgz", - "integrity": "sha512-hOZVTN24zDHwCHaW7mF9n1vHJt83fZhNZ0YYRBwQGhA96yBWWDPTDDlqJatagHIOJB0a4xoNkNc+t/Cxd+6qUA==", - "dev": true - }, "node_modules/@types/node": { "version": "18.6.3", "resolved": "https://registry.npmjs.org/@types/node/-/node-18.6.3.tgz", @@ -9906,11 +9898,6 @@ "node": ">= 0.10" } }, - "node_modules/iota-array": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/iota-array/-/iota-array-1.0.0.tgz", - "integrity": "sha512-pZ2xT+LOHckCatGQ3DcG/a+QuEqvoxqkiL7tvE8nn3uuu+f6i1TtpB5/FtWFbxUuVr5PZCx8KskuGatbJDXOWA==" - }, "node_modules/ip": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ip/-/ip-2.0.0.tgz", @@ -9986,11 +9973,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/is-buffer": { - "version": "1.1.6", - "resolved": "https://registry.npmjs.org/is-buffer/-/is-buffer-1.1.6.tgz", - "integrity": "sha512-NcdALwpXkTm5Zvvbk7owOUSvVvBKDgKP5/ewfXEznmQFfs4ZRmanOeKBTjRVjka3QFoN6XJ+9F3USqfHqTaU5w==" - }, "node_modules/is-callable": { "version": "1.2.4", "resolved": "https://registry.npmjs.org/is-callable/-/is-callable-1.2.4.tgz", @@ -12744,15 +12726,6 @@ "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", "dev": true }, - "node_modules/ndarray": { - "version": "1.0.19", - "resolved": "https://registry.npmjs.org/ndarray/-/ndarray-1.0.19.tgz", - "integrity": "sha512-B4JHA4vdyZU30ELBw3g7/p9bZupyew5a7tX1Y/gGeF2hafrPaQZhgrGQfsvgfYbgdFZjYwuEcnaobeM/WMW+HQ==", - "dependencies": { - "iota-array": "^1.0.0", - "is-buffer": "^1.0.2" - } - }, "node_modules/negotiator": { "version": "0.6.3", "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.3.tgz", @@ -19626,12 +19599,6 @@ "integrity": "sha512-iiUgKzV9AuaEkZqkOLDIvlQiL6ltuZd9tGcW3gwpnX8JbuiuhFlEGmmFXEXkN50Cvq7Os88IY2v0dkDqXYWVgA==", "dev": true }, - "@types/ndarray": { - "version": "1.0.11", - "resolved": "https://registry.npmjs.org/@types/ndarray/-/ndarray-1.0.11.tgz", - "integrity": "sha512-hOZVTN24zDHwCHaW7mF9n1vHJt83fZhNZ0YYRBwQGhA96yBWWDPTDDlqJatagHIOJB0a4xoNkNc+t/Cxd+6qUA==", - "dev": true - }, "@types/node": { "version": "18.6.3", "resolved": "https://registry.npmjs.org/@types/node/-/node-18.6.3.tgz", @@ -25205,11 +25172,6 @@ "integrity": "sha512-agE4QfB2Lkp9uICn7BAqoscw4SZP9kTE2hxiFI3jBPmXJfdqiahTbUuKGsMoN2GtqL9AxhYioAcVvgsb1HvRbA==", "dev": true }, - "iota-array": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/iota-array/-/iota-array-1.0.0.tgz", - "integrity": "sha512-pZ2xT+LOHckCatGQ3DcG/a+QuEqvoxqkiL7tvE8nn3uuu+f6i1TtpB5/FtWFbxUuVr5PZCx8KskuGatbJDXOWA==" - }, "ip": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ip/-/ip-2.0.0.tgz", @@ -25264,11 +25226,6 @@ "has-tostringtag": "^1.0.0" } }, - "is-buffer": { - "version": "1.1.6", - "resolved": "https://registry.npmjs.org/is-buffer/-/is-buffer-1.1.6.tgz", - "integrity": "sha512-NcdALwpXkTm5Zvvbk7owOUSvVvBKDgKP5/ewfXEznmQFfs4ZRmanOeKBTjRVjka3QFoN6XJ+9F3USqfHqTaU5w==" - }, "is-callable": { "version": "1.2.4", "resolved": "https://registry.npmjs.org/is-callable/-/is-callable-1.2.4.tgz", @@ -27318,15 +27275,6 @@ "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", "dev": true }, - "ndarray": { - "version": "1.0.19", - "resolved": "https://registry.npmjs.org/ndarray/-/ndarray-1.0.19.tgz", - "integrity": "sha512-B4JHA4vdyZU30ELBw3g7/p9bZupyew5a7tX1Y/gGeF2hafrPaQZhgrGQfsvgfYbgdFZjYwuEcnaobeM/WMW+HQ==", - "requires": { - "iota-array": "^1.0.0", - "is-buffer": "^1.0.2" - } - }, "negotiator": { "version": "0.6.3", "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.3.tgz", diff --git a/package.json b/package.json index bfafa2d8..f8180890 100644 --- a/package.json +++ b/package.json @@ -112,7 +112,6 @@ "i18next": "^21.8.16", "lodash": "^4.17.21", "mapbox-gl": "^1.13.2", - "ndarray": "^1.0.19", "node-machine-id": "^1.1.12", "react": "^18.2.0", "react-dom": "^18.2.0", @@ -131,7 +130,6 @@ "@types/jest": "28.1.8", "@types/lodash": "4.14.185", "@types/mapbox-gl": "1.13.4", - "@types/ndarray": "1.0.11", "@types/node": "18.6.3", "@types/react": "18.0.20", "@types/react-dom": "18.0.6", diff --git a/src/common/models.ts b/src/common/models.ts index 0c759db0..153a7053 100644 --- a/src/common/models.ts +++ b/src/common/models.ts @@ -1,7 +1,8 @@ export const MODELS = { CENTRAL_AFRICAN_FORESTS: { - file: 'central_african_forests.onnx', name: 'Central African forests', + file: 'central_african_forests.onnx', + photoShape: { width: 768, height: 576 }, labels: [ 'Bird', 'Blank', @@ -34,8 +35,9 @@ export const MODELS = { ], }, EAST_AFRICAN_SAVANNAS: { - file: 'east_african_savannas.onnx', name: 'East African savannas', + file: 'east_african_savannas.onnx', + photoShape: { width: 512, height: 384 }, labels: [ 'aardvark', 'aardwolf', diff --git a/src/main/tools/runInference.ts b/src/main/tools/runInference.ts index e28f59c9..9c6f3adf 100644 --- a/src/main/tools/runInference.ts +++ b/src/main/tools/runInference.ts @@ -1,4 +1,3 @@ -import ndarray from 'ndarray'; import * as ort from 'onnxruntime-node'; import { join } from 'path'; import sharp from 'sharp'; @@ -6,32 +5,24 @@ import sharp from 'sharp'; import { Model, MODELS } from '../../common/models'; import { RESOURCES_PATH } from '../util'; -const WIDTH = 512; -const HEIGHT = 384; +const BATCH_SIZE = 1; const CHANNELS = 3; function modelPath(model: Model) { return join(RESOURCES_PATH, 'assets', 'models', MODELS[model].file); } -async function readPhoto(path: string) { - const raw = await sharp(path).resize(WIDTH, HEIGHT, { fit: 'fill' }).raw().toBuffer(); - - // Reshape and normalize to match model input shape. - const srcShape = [HEIGHT, WIDTH, CHANNELS]; - const dstShape = [CHANNELS, HEIGHT, WIDTH]; - const src = ndarray(raw, srcShape); - const dst = ndarray(new Float32Array(raw.length), dstShape); - for (let i = 0; i < HEIGHT; i += 1) { - for (let j = 0; j < WIDTH; j += 1) { - for (let k = 0; k < CHANNELS; k += 1) { - dst.set(k, i, j, src.get(i, j, k) / 255); - } - } - } - return dst; +async function readPhoto(path: string, shape: { width: number; height: number }) { + const raw = await sharp(path) + .resize({ ...shape, fit: 'fill' }) + .raw() + .toBuffer(); + const photo = new Float32Array(raw); + for (let i = 0; i < photo.length; i += 1) photo[i] /= 255; + return photo; } +// If output is not provided, build an empty result. function buildResult(model: Model, output?: ort.Tensor) { const entries = MODELS[model].labels.map((label, idx) => [ label, @@ -48,10 +39,11 @@ export default async function runInference(model: Model, photoPaths: string[]) { /* eslint-disable no-await-in-loop */ for (const path of photoPaths) { try { - const photo = await readPhoto(path); - // The first dimension of the tensor is the photo index (used for batch processing). - // TODO: Process photos in batches to improve performance. - const input = new ort.Tensor('float32', photo.data, [1, ...photo.shape]); + const { photoShape } = MODELS[model]; + const photo = await readPhoto(path, photoShape); + // TODO: Process multiple photos per batch to improve performance. + const inputShape = [BATCH_SIZE, photoShape.height, photoShape.width, CHANNELS]; + const input = new ort.Tensor('float32', photo, inputShape); const { output } = await session.run({ input }); results.push(buildResult(model, output)); } catch (e) {