-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pkg/cmsis-nn: add support to RIOT #13062
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
PKG_NAME=cmsis-nn | ||
PKG_URL=https://github.com/ARM-software/CMSIS_5 | ||
PKG_VERSION=5.6.0 | ||
PKG_LICENSE=Apache-2.0 | ||
CFLAGS += -Wno-strict-aliasing -Wno-unused-parameter | ||
|
||
include $(RIOTBASE)/pkg/pkg.mk | ||
|
||
all: | ||
"$(MAKE)" -C $(PKG_BUILDDIR) -f $(CURDIR)/Makefile.$(PKG_NAME) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
PKG_NAME=cmsis-nn | ||
|
||
# A list of all the directories to build for CMSIS-NN | ||
CMSIS_DIRS += \ | ||
CMSIS/NN/Source/ActivationFunctions \ | ||
CMSIS/NN/Source/BasicMathFunctions \ | ||
CMSIS/NN/Source/ConcatenationFunctions \ | ||
CMSIS/NN/Source/ConvolutionFunctions \ | ||
CMSIS/NN/Source/FullyConnectedFunctions \ | ||
CMSIS/NN/Source/NNSupportFunctions \ | ||
CMSIS/NN/Source/PoolingFunctions \ | ||
CMSIS/NN/Source/ReshapeFunctions \ | ||
CMSIS/NN/Source/SoftmaxFunctions \ | ||
# | ||
|
||
INCLUDES += -I$(CURDIR)/CMSIS/NN/Include | ||
CMSIS_BINDIRS = $(addprefix $(BINDIR)/$(PKG_NAME)/,$(CMSIS_DIRS)) | ||
|
||
# Override default RIOT search path for sources to include all of the CMSIS-NN | ||
# sources in one library instead of one library per subdirectory. | ||
SRC := $(foreach DIR,$(CMSIS_DIRS),$(wildcard $(DIR)/*.c)) | ||
SRCXX := $(foreach DIR,$(CMSIS_DIRS),$(wildcard $(DIR)/*.cpp)) | ||
ASMSRC := $(foreach DIR,$(CMSIS_DIRS),$(wildcard $(DIR)/*.s)) | ||
ASSMSRC := $(foreach DIR,$(CMSIS_DIRS),$(wildcard $(DIR)/*.S)) | ||
|
||
OBJC := $(SRC:%.c=$(BINDIR)/$(PKG_NAME)/%.o) | ||
OBJCXX := $(SRCXX:%.cpp=$(BINDIR)/$(PKG_NAME)/%.o) | ||
ASMOBJ := $(ASMSRC:%.s=$(BINDIR)/$(PKG_NAME)/%.o) | ||
ASSMOBJ := $(ASSMSRC:%.S=$(BINDIR)/$(PKG_NAME)/%.o) | ||
OBJ = $(OBJC) $(OBJCXX) $(ASMOBJ) $(ASSMOBJ) | ||
|
||
# Create subdirectories if they do not already exist | ||
$(OBJ): | $(CMSIS_BINDIRS) | ||
|
||
$(CMSIS_BINDIRS): | ||
@mkdir -p $@ | ||
|
||
# Reset the default goal. | ||
.DEFAULT_GOAL := | ||
|
||
# Include RIOT settings and recipes | ||
include $(RIOTBASE)/Makefile.base |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
FEATURES_REQUIRED += arch_cortexm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
INCLUDES += -I$(PKGDIRBASE)/cmsis-nn/CMSIS/NN/Include | ||
|
||
# Required for some basic math functions | ||
INCLUDES += -I$(PKGDIRBASE)/cmsis-nn/CMSIS/DSP/Include | ||
|
||
CFLAGS += -Wno-sign-compare | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
BOARD ?= nucleo-l476rg | ||
include ../Makefile.tests_common | ||
|
||
USEPKG += cmsis-nn | ||
|
||
BLOBS += input | ||
|
||
# Boards that were tested and are known to work | ||
# This package only works with Cortex M3, M4 and M7 CPUs but there's no easy | ||
# way provided by the build system to filter them at that level (arch_cortexm is | ||
# the only feature available) for the moment. | ||
BOARD_WHITELIST := \ | ||
b-l475e-iot01a \ | ||
iotlab-m3 \ | ||
nrf52832-mdk \ | ||
nrf52dk \ | ||
nucleo-l476rg \ | ||
same54-xpro \ | ||
stm32f723e-disco | ||
# | ||
|
||
include $(RIOTBASE)/Makefile.include |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
## ARM CMSIS-NN package | ||
|
||
This application shows how to use the neural network API provided by the ARM CMSIS | ||
package in order to determine the type of "object" present in an RGB image. | ||
The image are part of the [SIFAR10 dataset](http://www.cs.toronto.edu/~kriz/cifar.html) | ||
which contains 10 classes of objects: plane, car, cat, bird, deer, dog, frog, | ||
horse, ship and truck. | ||
|
||
Expected output | ||
--------------- | ||
|
||
``` | ||
Predicted class: cat | ||
``` | ||
|
||
Change the input image | ||
---------------------- | ||
|
||
Use the `generate_image.py` script and the `-i` option to generate a new | ||
input image. | ||
For example, the following command | ||
``` | ||
./generate_image.py -i 1 | ||
``` | ||
will generate an input containing an image with a boat. | ||
|
||
The generated image is displayed at the end of the script execution, for visual | ||
validation of the prediction made by the neural network running on the device. | ||
|
||
Note that each time a new image is generated, the firmware must be rebuilt so | ||
that it embeds the new image. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#!/usr/bin/env python3 | ||
|
||
"""Generate a binary file from a sample image of the CIFAR-10 dataset. | ||
Pixel of the sample are stored as uint8, images have size 32x32x3. | ||
""" | ||
|
||
import os | ||
import argparse | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from tensorflow.keras.datasets import cifar10 | ||
|
||
|
||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) | ||
|
||
|
||
def main(args): | ||
_, (cifar10_test, _) = cifar10.load_data() | ||
data = cifar10_test[args.index] | ||
data = data.astype('uint8') | ||
|
||
output_path = os.path.join(SCRIPT_DIR, args.output) | ||
np.ndarray.tofile(data, output_path) | ||
|
||
if args.no_plot is False: | ||
plt.imshow(data) | ||
plt.show() | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-i", "--index", type=int, default=0, | ||
help="Image index in CIFAR test dataset") | ||
parser.add_argument("-o", "--output", type=str, default='input', | ||
help="Output filename") | ||
parser.add_argument("--no-plot", default=False, action='store_true', | ||
help="Disable image display in matplotlib") | ||
main(parser.parse_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/* | ||
* Copyright (C) 2019 Inria | ||
* | ||
* This file is subject to the terms and conditions of the GNU Lesser | ||
* General Public License v2.1. See the file LICENSE in the top level | ||
* directory for more details. | ||
*/ | ||
|
||
/** | ||
* @ingroup tests | ||
* @{ | ||
* | ||
* @file | ||
* @brief Sample application for ARM CMSIS-NN package | ||
* | ||
* This example is adapted from ARM CMSIS CIFAR10 example to RIOT by Alexandre Abadie | ||
* https://github.com/ARM-software/CMSIS_5/tree/develop/CMSIS/NN/Examples/ARM/arm_nn_examples/cifar10 | ||
* | ||
* @author Alexandre Abadie <[email protected]> | ||
*/ | ||
|
||
#include <stdint.h> | ||
#include <stdio.h> | ||
#include "arm_math.h" | ||
#include "parameter.h" | ||
#include "weights.h" | ||
|
||
#include "arm_nnfunctions.h" | ||
|
||
#include "blob/input.h" | ||
|
||
/* There are 10 different classes of objects in the CIFAR10 dataset */ | ||
#define CLASSES_NUMOF 10 | ||
|
||
/* include the input and weights */ | ||
static const q7_t conv1_wt[CONV1_IM_CH * CONV1_KER_DIM * CONV1_KER_DIM * CONV1_OUT_CH] = CONV1_WT; | ||
static const q7_t conv1_bias[CONV1_OUT_CH] = CONV1_BIAS; | ||
|
||
static const q7_t conv2_wt[CONV2_IM_CH * CONV2_KER_DIM * CONV2_KER_DIM * CONV2_OUT_CH] = CONV2_WT; | ||
static const q7_t conv2_bias[CONV2_OUT_CH] = CONV2_BIAS; | ||
|
||
static const q7_t conv3_wt[CONV3_IM_CH * CONV3_KER_DIM * CONV3_KER_DIM * CONV3_OUT_CH] = CONV3_WT; | ||
static const q7_t conv3_bias[CONV3_OUT_CH] = CONV3_BIAS; | ||
|
||
static const q7_t ip1_wt[IP1_DIM * IP1_OUT] = IP1_WT; | ||
static const q7_t ip1_bias[IP1_OUT] = IP1_BIAS; | ||
|
||
/* Here the image_data should be the raw uint8 type RGB image in [RGB, RGB, RGB ... RGB] format */ | ||
// static const uint8_t image_data[CONV1_IM_CH * CONV1_IM_DIM * CONV1_IM_DIM] = IMG_DATA; | ||
static q7_t output_data[IP1_OUT]; | ||
|
||
/* vector buffer: max(im2col buffer,average pool buffer, fully connected buffer) */ | ||
static q7_t col_buffer[2 * 5 * 5 * 32 * 2]; | ||
static q7_t img_buffer1[32 * 32 * 10 * 4]; | ||
static q7_t *img_buffer2 = (q7_t *)(img_buffer1 + (32 * 32 * 32)); | ||
|
||
static const char classes[CLASSES_NUMOF][6] = { | ||
"plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" }; | ||
|
||
int main(void) | ||
{ | ||
printf("start execution\n"); | ||
|
||
uint8_t *image_data = (uint8_t *)input; | ||
|
||
/* input pre-processing */ | ||
int mean_data[3] = INPUT_MEAN_SHIFT; | ||
unsigned int scale_data[3] = INPUT_RIGHT_SHIFT; | ||
for (unsigned i = 0; i < input_len; i += 3) { | ||
img_buffer2[i] = (q7_t)__SSAT( ((((int)image_data[i] - mean_data[0]) << 7) + (0x1 << (scale_data[0] - 1))) | ||
>> scale_data[0], 8); | ||
img_buffer2[i + 1] = (q7_t)__SSAT( ((((int)image_data[i + 1] - mean_data[1]) << 7) + (0x1 << (scale_data[1] - 1))) | ||
>> scale_data[1], 8); | ||
img_buffer2[i + 2] = (q7_t)__SSAT( ((((int)image_data[i + 2] - mean_data[2]) << 7) + (0x1 << (scale_data[2] - 1))) | ||
>> scale_data[2], 8); | ||
} | ||
|
||
/* conv1 img_buffer2 -> img_buffer1 */ | ||
arm_convolve_HWC_q7_RGB(img_buffer2, CONV1_IM_DIM, CONV1_IM_CH, conv1_wt, CONV1_OUT_CH, CONV1_KER_DIM, CONV1_PADDING, | ||
CONV1_STRIDE, conv1_bias, CONV1_BIAS_LSHIFT, CONV1_OUT_RSHIFT, img_buffer1, CONV1_OUT_DIM, | ||
(q15_t *)col_buffer, NULL); | ||
|
||
arm_relu_q7(img_buffer1, CONV1_OUT_DIM * CONV1_OUT_DIM * CONV1_OUT_CH); | ||
|
||
/* pool1 img_buffer1 -> img_buffer2 */ | ||
arm_maxpool_q7_HWC(img_buffer1, CONV1_OUT_DIM, CONV1_OUT_CH, POOL1_KER_DIM, | ||
POOL1_PADDING, POOL1_STRIDE, POOL1_OUT_DIM, NULL, img_buffer2); | ||
|
||
/* conv2 img_buffer2 -> img_buffer1 */ | ||
arm_convolve_HWC_q7_fast(img_buffer2, CONV2_IM_DIM, CONV2_IM_CH, conv2_wt, CONV2_OUT_CH, CONV2_KER_DIM, | ||
CONV2_PADDING, CONV2_STRIDE, conv2_bias, CONV2_BIAS_LSHIFT, CONV2_OUT_RSHIFT, img_buffer1, | ||
CONV2_OUT_DIM, (q15_t *)col_buffer, NULL); | ||
|
||
arm_relu_q7(img_buffer1, CONV2_OUT_DIM * CONV2_OUT_DIM * CONV2_OUT_CH); | ||
|
||
/* pool2 img_buffer1 -> img_buffer2 */ | ||
arm_maxpool_q7_HWC(img_buffer1, CONV2_OUT_DIM, CONV2_OUT_CH, POOL2_KER_DIM, | ||
POOL2_PADDING, POOL2_STRIDE, POOL2_OUT_DIM, col_buffer, img_buffer2); | ||
|
||
/* conv3 img_buffer2 -> img_buffer1 */ | ||
arm_convolve_HWC_q7_fast(img_buffer2, CONV3_IM_DIM, CONV3_IM_CH, conv3_wt, CONV3_OUT_CH, CONV3_KER_DIM, | ||
CONV3_PADDING, CONV3_STRIDE, conv3_bias, CONV3_BIAS_LSHIFT, CONV3_OUT_RSHIFT, img_buffer1, | ||
CONV3_OUT_DIM, (q15_t *)col_buffer, NULL); | ||
|
||
arm_relu_q7(img_buffer1, CONV3_OUT_DIM * CONV3_OUT_DIM * CONV3_OUT_CH); | ||
|
||
/* pool3 img_buffer-> img_buffer2 */ | ||
arm_maxpool_q7_HWC(img_buffer1, CONV3_OUT_DIM, CONV3_OUT_CH, POOL3_KER_DIM, | ||
POOL3_PADDING, POOL3_STRIDE, POOL3_OUT_DIM, col_buffer, img_buffer2); | ||
|
||
arm_fully_connected_q7_opt(img_buffer2, ip1_wt, IP1_DIM, IP1_OUT, IP1_BIAS_LSHIFT, IP1_OUT_RSHIFT, ip1_bias, | ||
output_data, (q15_t *)img_buffer1); | ||
|
||
arm_softmax_q7(output_data, CLASSES_NUMOF, output_data); | ||
|
||
int val = -1; | ||
uint8_t class_idx = 0; | ||
for (unsigned i = 0; i < CLASSES_NUMOF; i++) { | ||
if (output_data[i] > val) { | ||
val = output_data[i]; | ||
class_idx = i; | ||
} | ||
} | ||
|
||
if (val > 0) { | ||
printf("Predicted class: %s\n", classes[class_idx]); | ||
} | ||
else { | ||
puts("No match found"); | ||
} | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Copyright (C) 2019 Inria | ||
* | ||
* This file is subject to the terms and conditions of the GNU Lesser | ||
* General Public License v2.1. See the file LICENSE in the top level | ||
* directory for more details. | ||
*/ | ||
|
||
/** | ||
* @ingroup tests | ||
* @{ | ||
* | ||
* @file | ||
* @brief CNN parameters | ||
*/ | ||
|
||
#ifndef PARAMETER_H | ||
#define PARAMETER_H | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
#define CONV1_IM_DIM 32 | ||
#define CONV1_IM_CH 3 | ||
#define CONV1_KER_DIM 5 | ||
#define CONV1_PADDING 2 | ||
#define CONV1_STRIDE 1 | ||
#define CONV1_OUT_CH 32 | ||
#define CONV1_OUT_DIM 32 | ||
|
||
#define POOL1_KER_DIM 3 | ||
#define POOL1_STRIDE 2 | ||
#define POOL1_PADDING 0 | ||
#define POOL1_OUT_DIM 16 | ||
|
||
#define CONV2_IM_DIM 16 | ||
#define CONV2_IM_CH 32 | ||
#define CONV2_KER_DIM 5 | ||
#define CONV2_PADDING 2 | ||
#define CONV2_STRIDE 1 | ||
#define CONV2_OUT_CH 16 | ||
#define CONV2_OUT_DIM 16 | ||
|
||
#define POOL2_KER_DIM 3 | ||
#define POOL2_STRIDE 2 | ||
#define POOL2_PADDING 0 | ||
#define POOL2_OUT_DIM 8 | ||
|
||
#define CONV3_IM_DIM 8 | ||
#define CONV3_IM_CH 16 | ||
#define CONV3_KER_DIM 5 | ||
#define CONV3_PADDING 2 | ||
#define CONV3_STRIDE 1 | ||
#define CONV3_OUT_CH 32 | ||
#define CONV3_OUT_DIM 8 | ||
|
||
#define POOL3_KER_DIM 3 | ||
#define POOL3_STRIDE 2 | ||
#define POOL3_PADDING 0 | ||
#define POOL3_OUT_DIM 4 | ||
|
||
#define IP1_DIM 4*4*32 | ||
#define IP1_IM_DIM 4 | ||
#define IP1_IM_CH 32 | ||
#define IP1_OUT 10 | ||
|
||
#ifdef __cplusplus | ||
} /* end extern "C" */ | ||
#endif | ||
|
||
#endif /* PARAMETER_H */ | ||
/** @} */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import sys | ||
from testrunner import run | ||
|
||
|
||
def testfunc(child): | ||
child.expect_exact("Predicted class: cat") | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(run(testfunc)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already in
Makefile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept this one and removed the other: the warning occurs in a header file and is not silenced from the package
Makefile
. So I have to silent it globally unfortunately.