Skip to content

Commit

Permalink
Merge pull request #354 from Appsilon/update-models
Browse files Browse the repository at this point in the history
Update models
  • Loading branch information
kamilzyla authored Sep 22, 2022
2 parents 508daf3 + 478cb30 commit 425d929
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 79 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ which uses
Useful files (sample data, biomonitoring stations, icons)
can be found on our [Drive](https://drive.google.com/drive/folders/1eQWuf5WCT429xogQ2HiZqapehvweAtxP) (Appsilon internal).

To be able to run inference,
download the [ONNX models](https://drive.google.com/drive/folders/1jQppnRm_4kLDPspg0fAulK_MK_RvxIEB)
and place them in the `assets/models` directory.

### Commands

You can run the following commands from the project root.
Expand Down
52 changes: 0 additions & 52 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@
"i18next": "^21.8.16",
"lodash": "^4.17.21",
"mapbox-gl": "^1.13.2",
"ndarray": "^1.0.19",
"node-machine-id": "^1.1.12",
"react": "^18.2.0",
"react-dom": "^18.2.0",
Expand All @@ -131,7 +130,6 @@
"@types/jest": "28.1.8",
"@types/lodash": "4.14.185",
"@types/mapbox-gl": "1.13.4",
"@types/ndarray": "1.0.11",
"@types/node": "18.6.3",
"@types/react": "18.0.20",
"@types/react-dom": "18.0.6",
Expand Down
6 changes: 4 additions & 2 deletions src/common/models.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
export const MODELS = {
CENTRAL_AFRICAN_FORESTS: {
file: 'central_african_forests.onnx',
name: 'Central African forests',
file: 'central_african_forests.onnx',
photoShape: { width: 768, height: 576 },
labels: [
'Bird',
'Blank',
Expand Down Expand Up @@ -34,8 +35,9 @@ export const MODELS = {
],
},
EAST_AFRICAN_SAVANNAS: {
file: 'east_african_savannas.onnx',
name: 'East African savannas',
file: 'east_african_savannas.onnx',
photoShape: { width: 512, height: 384 },
labels: [
'aardvark',
'aardwolf',
Expand Down
38 changes: 15 additions & 23 deletions src/main/tools/runInference.ts
Original file line number Diff line number Diff line change
@@ -1,37 +1,28 @@
import ndarray from 'ndarray';
import * as ort from 'onnxruntime-node';
import { join } from 'path';
import sharp from 'sharp';

import { Model, MODELS } from '../../common/models';
import { RESOURCES_PATH } from '../util';

const WIDTH = 512;
const HEIGHT = 384;
const BATCH_SIZE = 1;
const CHANNELS = 3;

function modelPath(model: Model) {
return join(RESOURCES_PATH, 'assets', 'models', MODELS[model].file);
}

async function readPhoto(path: string) {
const raw = await sharp(path).resize(WIDTH, HEIGHT, { fit: 'fill' }).raw().toBuffer();

// Reshape and normalize to match model input shape.
const srcShape = [HEIGHT, WIDTH, CHANNELS];
const dstShape = [CHANNELS, HEIGHT, WIDTH];
const src = ndarray(raw, srcShape);
const dst = ndarray(new Float32Array(raw.length), dstShape);
for (let i = 0; i < HEIGHT; i += 1) {
for (let j = 0; j < WIDTH; j += 1) {
for (let k = 0; k < CHANNELS; k += 1) {
dst.set(k, i, j, src.get(i, j, k) / 255);
}
}
}
return dst;
async function readPhoto(path: string, shape: { width: number; height: number }) {
const raw = await sharp(path)
.resize({ ...shape, fit: 'fill' })
.raw()
.toBuffer();
const photo = new Float32Array(raw);
for (let i = 0; i < photo.length; i += 1) photo[i] /= 255;
return photo;
}

// If output is not provided, build an empty result.
function buildResult(model: Model, output?: ort.Tensor) {
const entries = MODELS[model].labels.map((label, idx) => [
label,
Expand All @@ -48,10 +39,11 @@ export default async function runInference(model: Model, photoPaths: string[]) {
/* eslint-disable no-await-in-loop */
for (const path of photoPaths) {
try {
const photo = await readPhoto(path);
// The first dimension of the tensor is the photo index (used for batch processing).
// TODO: Process photos in batches to improve performance.
const input = new ort.Tensor('float32', photo.data, [1, ...photo.shape]);
const { photoShape } = MODELS[model];
const photo = await readPhoto(path, photoShape);
// TODO: Process multiple photos per batch to improve performance.
const inputShape = [BATCH_SIZE, photoShape.height, photoShape.width, CHANNELS];
const input = new ort.Tensor('float32', photo, inputShape);
const { output } = await session.run({ input });
results.push(buildResult(model, output));
} catch (e) {
Expand Down

0 comments on commit 425d929

Please sign in to comment.