SqueezeNet is a small CNN that can be used for image classification. It was trained on the ImageNet dataset and can classify images into 1000 different classes. The included ONNX model is copied from the ONNX model zoo, and the details of the model can be found in the paper.
The ONNX model is converted into a Burn model in Rust using the burn-import crate during build time. The weights are saved in a binary file during build time in Burn compatible format, and the model is loaded at runtime.
It is worth noting that the model can be fine-tuned to improve the accuracy, since the ONNX model is fully converted to a Burn model. The model is trained with the ImageNet dataset, which contains 1.2 million images. The model can be fine-tuned with a smaller dataset to improve the accuracy for a specific use case.
The labels for the classes are included in the crate and generated from the
labels.txt
during build time.
The data normalizer for the model is included in the crate. See Normalizer.
The model is no_std compatible.
See the classify example for how to use the model.
Add this to your Cargo.toml
:
[dependencies]
squeezenet-burn = { git = "https://github.com/tracel-ai/models", package = "squeezenet-burn", features = ["weights_embedded"], default-features = false }
- Use the
weights_embedded
feature to embed the weights in the binary.
cargo r --release --features weights_embedded --no-default-features --example classify samples/flamingo.jpg
- Use the
weights_file
feature to load the weights from a file.
cargo r --release --features weights_file --example classify samples/flamingo.jpg
- Use the
weights_f16
feature to use 16-bit floating point numbers for the weights.
cargo r --release --features "weights_embedded, weights_f16" --no-default-features --example classify samples/flamingo.jpg
Or
cargo r --release --features "weights_file, weights_f16" --example classify samples/flamingo.jpg
weights_file
: Load the weights from a file (enabled by default).weights_embedded
: Embed the weights in the binary.weights_f16
: Use 16-bit floating point numbers for the weights. (by default 32-bit is used)