Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add YOLOX object detection #24

Merged
merged 22 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ examples constructed using the [Burn](https://github.com/burn-rs/burn) deep lear

## Collection of Official Models

| Model | Description | Repository Link |
|------------------------------------------------|-------------------------------------------------------|----------------------------------------------|
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |
| Model | Description | Repository Link |
|------------------------------------------------|----------------------------------------------------------|----------------------------------------------|
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | A single-stage object detector based on the YOLO series. | [yolox-burn](yolox-burn/README.md) |

## Community Contributions

Expand Down
2 changes: 2 additions & 0 deletions yolox-burn/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Output image
*.output.png
30 changes: 30 additions & 0 deletions yolox-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[package]
authors = ["guillaumelagrange <[email protected]>"]
license = "MIT OR Apache-2.0"
name = "yolox-burn"
version = "0.1.0"
edition = "2021"

[features]
default = []
std = []
pretrained = ["burn/network", "std", "dep:dirs"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { git = "https://github.com/tracel-ai/burn.git", default-features = false, rev = "0138e16af6a85c746035632b2a2df0df8098ebc2" }
burn-import = { git = "https://github.com/tracel-ai/burn.git", rev = "0138e16af6a85c746035632b2a2df0df8098ebc2" }
itertools = { version = "0.12.1", default-features = false, features = [
"use_alloc",
] }
dirs = { version = "5.0.1", optional = true }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

[dev-dependencies]
burn = { git = "https://github.com/tracel-ai/burn.git", features = [
"ndarray",
], rev = "0138e16af6a85c746035632b2a2df0df8098ebc2" }
image = { version = "0.24.7", features = ["png", "jpeg"] }
1 change: 1 addition & 0 deletions yolox-burn/LICENSE-APACHE
1 change: 1 addition & 0 deletions yolox-burn/LICENSE-MIT
16 changes: 16 additions & 0 deletions yolox-burn/NOTICES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# NOTICES AND INFORMATION

This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided.

## Sample Image

Image Title: Man with Bike and Pet Dog circa 1900 (archive ref DDX1319-2-3)
Author: East Riding Archives
Source: https://commons.wikimedia.org/wiki/File:Man_with_Bike_and_Pet_Dog_circa_1900_%28archive_ref_DDX1319-2-3%29_%2826507570321%29.jpg
License: [Creative Commons](https://www.flickr.com/commons/usage/)

## Pre-trained Model

The COCO pre-trained model was ported from the original [YOLOX implementation](https://github.com/Megvii-BaseDetection/YOLOX).

As opposed to other YOLO variants (YOLOv8, YOLO-NAS, etc.), both the code and pre-trained weights are distributed under the [Apache 2.0](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/LICENSE) open source license.
44 changes: 44 additions & 0 deletions yolox-burn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# YOLOX Burn

There have been many different object detection models with the YOLO prefix released in the recent
years, though most of them carry a GPL or AGPL license which restricts their usage. For this reason,
we selected [YOLOX](https://arxiv.org/abs/2107.08430) as the first object detection architecture
since both the original code and pre-trained weights are released under the
[Apache 2.0](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/LICENSE) open source license.

You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the YOLOX variants in
[src/model/yolox.rs](src/model/yolox.rs).

The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html).

## Usage

### `Cargo.toml`

Add this to your `Cargo.toml`:

```toml
[dependencies]
yolox-burn = { git = "https://github.com/burn-rs/models", package = "yolox-burn", default-features = false }
```

If you want to get the COCO pre-trained weights, enable the `pretrained` feature flag.

```toml
[dependencies]
yolox-burn = { git = "https://github.com/burn-rs/models", package = "yolox-burn", features = ["pretrained"] }
```

**Important:** this feature requires `std`.

### Example Usage

The [inference example](examples/inference.rs) initializes a YOLOX-Tiny from the COCO
[pre-trained weights](https://github.com/Megvii-BaseDetection/YOLOX?tab=readme-ov-file#standard-models)
with the `NdArray` backend and performs inference on the provided input image.

You can run the example with the following command:

```sh
cargo run --release --features pretrained --example inference samples/dog_bike_man.jpg
```
145 changes: 145 additions & 0 deletions yolox-burn/examples/inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use std::path::Path;

use image::{DynamicImage, ImageBuffer};
use yolox_burn::model::{boxes::nms, weights, yolox::Yolox, BoundingBox};

use burn::{
backend::NdArray,
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor},
};

const HEIGHT: usize = 640;
const WIDTH: usize = 640;

fn to_tensor<B: Backend, T: Element>(
data: Vec<T>,
shape: [usize; 3],
device: &Device<B>,
) -> Tensor<B, 3> {
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device)
// [H, W, C] -> [C, H, W]
.permute([2, 0, 1])
}

/// Draws bounding boxes on the given image.
///
/// # Arguments
///
/// * `image`: Original input image.
/// * `boxes` - Bounding boxes, grouped per class.
/// * `color` - [R, G, B] color values to draw the boxes.
/// * `ratio` - [x, y] aspect ratio to scale the predicted boxes.
///
/// # Returns
///
/// The image annotated with bounding boxes.
fn draw_boxes(
image: DynamicImage,
boxes: &[Vec<BoundingBox>],
color: &[u8; 3],
ratio: &[f32; 2], // (x, y) ratio
) -> DynamicImage {
// Assumes x1 <= x2 and y1 <= y2
fn draw_rect(
image: &mut ImageBuffer<image::Rgb<u8>, Vec<u8>>,
x1: u32,
x2: u32,
y1: u32,
y2: u32,
color: &[u8; 3],
) {
for x in x1..=x2 {
let pixel = image.get_pixel_mut(x, y1);
*pixel = image::Rgb(*color);
let pixel = image.get_pixel_mut(x, y2);
*pixel = image::Rgb(*color);
}
for y in y1..=y2 {
let pixel = image.get_pixel_mut(x1, y);
*pixel = image::Rgb(*color);
let pixel = image.get_pixel_mut(x2, y);
*pixel = image::Rgb(*color);
}
}

// Annotate the original image and print boxes information.
let (image_h, image_w) = (image.height(), image.width());
let mut image = image.to_rgb8();
for (class_index, bboxes_for_class) in boxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
let xmin = (b.xmin * ratio[0]).clamp(0., image_w as f32 - 1.);
let ymin = (b.ymin * ratio[1]).clamp(0., image_h as f32 - 1.);
let xmax = (b.xmax * ratio[0]).clamp(0., image_w as f32 - 1.);
let ymax = (b.ymax * ratio[1]).clamp(0., image_h as f32 - 1.);

println!(
"Predicted {} ({:.2}) at [{:.2}, {:.2}, {:.2}, {:.2}]",
class_index, b.confidence, xmin, ymin, xmax, ymax,
);

draw_rect(
&mut image,
xmin as u32,
xmax as u32,
ymin as u32,
ymax as u32,
color,
);
}
}
DynamicImage::ImageRgb8(image)
}

pub fn main() {
// Parse arguments
let img_path = std::env::args().nth(1).expect("No image path provided");

// Create YOLOX-Tiny
let device = Default::default();
let model: Yolox<NdArray> = Yolox::yolox_tiny_pretrained(weights::YoloxTiny::Coco, &device)
.map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}"))
.unwrap();

// Load image
let img = image::open(&img_path)
.map_err(|err| format!("Failed to load image {img_path}.\nError: {err}"))
.unwrap();

// Resize to 640x640
let resized_img = img.resize_exact(
WIDTH as u32,
HEIGHT as u32,
image::imageops::FilterType::Triangle, // also known as bilinear in 2D
);

// Create tensor from image data
let x = to_tensor(
resized_img.into_rgb8().into_raw(),
[HEIGHT, WIDTH, 3],
&device,
)
.unsqueeze::<4>(); // [B, C, H, W]

// Forward pass
let out = model.forward(x);

// Post-processing
let [_, num_boxes, num_outputs] = out.dims();
let boxes = out.clone().slice([0..1, 0..num_boxes, 0..4]);
let obj_scores = out.clone().slice([0..1, 0..num_boxes, 4..5]);
let cls_scores = out.slice([0..1, 0..num_boxes, 5..num_outputs]);
let scores = cls_scores * obj_scores;
let boxes = nms(boxes, scores, 0.65, 0.5);

// Draw outputs and save results
let (h, w) = (img.height(), img.width());
let img_out = draw_boxes(
img,
&boxes[0],
&[239u8, 62u8, 5u8],
&[w as f32 / WIDTH as f32, h as f32 / HEIGHT as f32],
);

let img_path = Path::new(&img_path);
let _ = img_out.save(img_path.with_extension("output.png"));
}
Binary file added yolox-burn/samples/dog_bike_man.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions yolox-burn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#![cfg_attr(not(feature = "std"), no_std)]
pub mod model;
extern crate alloc;
Loading