Skip to content

Commit

Permalink
Extend ImageFolderDataset to support import of COCO detection
Browse files Browse the repository at this point in the history
COCO is a popular dataset which defines an own dataset format:
https://cocodataset.org/#format-data

This commit introduces a ImageFolderDataset::new_coco_detection()
function, which expects a path to the JSON annotations file in COCO
format and a path where the actual images are stored.

Note, that while COCO also offers segmentation and pose estimation data,
for now only the import of detection (bounding boxes) data is supported.
  • Loading branch information
jin-eld committed Dec 19, 2024
1 parent ebd7649 commit 1687e55
Show file tree
Hide file tree
Showing 6 changed files with 441 additions and 3 deletions.
16 changes: 15 additions & 1 deletion burn-book/src/building-blocks/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ those are the only requirements.
### Images

`ImageFolderDataset` is a generic vision dataset used to load images from disk. It is currently
available for multi-class and multi-label classification tasks as well as semantic segmentation tasks.
available for multi-class and multi-label classification tasks as well as semantic segmentation and object detection tasks.

```rust, ignore
// Create an image classification dataset from the root folder,
Expand Down Expand Up @@ -197,6 +197,20 @@ let dataset = ImageFolderDataset::new_segmentation_with_items(
.unwrap();
```

```rust, ignore
// Create an object detection dataset from a COCO dataset. Currently only
// the import of object detection data (bounding boxes) is supported.
//
// COCO offers separate annotation and image archives for training and
// validation, paths to the unpacked files need to be passed as parameters:
let dataset = ImageFolderDataset::new_coco_detection(
"/path/to/coco/instances_train2017.json",
"/path/to/coco/images/train2017"
)
.unwrap();
```
### Comma-Separated Values (CSV)

Loading records from a simple CSV file in-memory is simple with the `InMemDataset`:
Expand Down
296 changes: 294 additions & 2 deletions crates/burn-dataset/src/vision/image_folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ use crate::{Dataset, InMemDataset};
use globwalk::{self, DirEntry};
use image::{self, ColorType};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::{Path, PathBuf};
use thiserror::Error;

const SUPPORTED_FILES: [&str; 4] = ["bmp", "jpg", "jpeg", "png"];
const BBOX_MIN_NUM_VALUES: usize = 4;

/// Image data type.
#[derive(Debug, Clone, PartialEq)]
Expand Down Expand Up @@ -82,7 +85,7 @@ pub struct SegmentationMask {
/// Object detection bounding box annotation.
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
pub struct BoundingBox {
/// Coordinates.
/// Coordinates in [x_min, y_min, width, height] format.
pub coords: [f32; 4],

/// Box class label.
Expand All @@ -107,8 +110,8 @@ pub struct ImageDatasetItem {
enum AnnotationRaw {
Label(String),
MultiLabel(Vec<String>),
BoundingBoxes(Vec<BoundingBox>),
SegmentationMask(PathBuf),
// TODO: bounding boxes
}

#[derive(Deserialize, Serialize, Debug, Clone)]
Expand Down Expand Up @@ -184,9 +187,169 @@ fn parse_image_annotation(
mask: segmentation_mask_to_vec_usize(mask_path),
})
}
AnnotationRaw::BoundingBoxes(v) => Annotation::BoundingBoxes(v.clone()),
}
}

/// Retrieve all available classes from the COCO JSON
fn parse_coco_classes(
json: &serde_json::Value,
) -> Result<HashMap<String, usize>, ImageLoaderError> {
let mut classes = HashMap::new();

if let Some(json_classes) = json["categories"].as_array() {
for class in json_classes {
let id = class["id"]
.as_u64()
.ok_or_else(|| ImageLoaderError::ParsingError("Invalid class ID".to_string()))
.and_then(|v| {
usize::try_from(v).map_err(|_| {
ImageLoaderError::ParsingError("Class ID out of usize range".to_string())
})
})?;

let name = class["name"]
.as_str()
.filter(|&s| !s.is_empty())
.ok_or_else(|| ImageLoaderError::ParsingError("Invalid class name".to_string()))?
.to_string();

classes.insert(name, id);
}
}

if classes.is_empty() {
return Err(ImageLoaderError::ParsingError(
"No classes found in annotations".to_string(),
));
}

Ok(classes)
}

/// Retrieve annotations from COCO JSON
fn parse_coco_bbox_annotations(
json: &serde_json::Value,
) -> Result<HashMap<u64, AnnotationRaw>, ImageLoaderError> {
let mut annotations = HashMap::new();

if let Some(json_annotations) = json["annotations"].as_array() {
for annotation in json_annotations {
let image_id = annotation["image_id"].as_u64().ok_or_else(|| {
ImageLoaderError::ParsingError("Invalid image ID in annotation".into())
})?;

let class_id = annotation["category_id"]
.as_u64()
.ok_or_else(|| {
ImageLoaderError::ParsingError("Invalid class ID in annotations".to_string())
})
.and_then(|v| {
usize::try_from(v).map_err(|_| {
ImageLoaderError::ParsingError(
"Class ID in annotations out of usize range".to_string(),
)
})
})?;

let bbox_coords = annotation["bbox"]
.as_array()
.ok_or_else(|| ImageLoaderError::ParsingError("missing bbox array".to_string()))?
.iter()
.map(|v| {
v.as_f64()
.ok_or_else(|| {
ImageLoaderError::ParsingError("invalid bbox value".to_string())
})
.map(|val| val as f32)
})
.collect::<Result<Vec<f32>, _>>()?;

if bbox_coords.len() < BBOX_MIN_NUM_VALUES {
return Err(ImageLoaderError::ParsingError(format!(
"not enough bounding box coordinates in annotation for image {}",
image_id
)));
}

let bbox = BoundingBox {
coords: [
bbox_coords[0],
bbox_coords[1],
bbox_coords[2],
bbox_coords[3],
],
label: class_id,
};

annotations
.entry(image_id)
.and_modify(|entry| {
if let AnnotationRaw::BoundingBoxes(ref mut bboxes) = entry {
bboxes.push(bbox.clone());
}
})
.or_insert_with(|| AnnotationRaw::BoundingBoxes(vec![bbox]));
}
}

if annotations.is_empty() {
return Err(ImageLoaderError::ParsingError(
"no annotations found".to_string(),
));
}

Ok(annotations)
}

/// Retrieve all available images from the COCO JSON
fn parse_coco_images<P: AsRef<Path>>(
images_path: &P,
mut annotations: HashMap<u64, AnnotationRaw>,
json: &serde_json::Value,
) -> Result<Vec<ImageDatasetItemRaw>, ImageLoaderError> {
let mut images = Vec::new();
if let Some(json_images) = json["images"].as_array() {
for image in json_images {
let image_id = image["id"].as_u64().ok_or_else(|| {
ImageLoaderError::ParsingError("Invalid image ID in image list".to_string())
})?;

let file_name = image["file_name"]
.as_str()
.ok_or_else(|| ImageLoaderError::ParsingError("Invalid image ID".to_string()))?
.to_string();

let mut image_path = images_path.as_ref().to_path_buf();
image_path.push(file_name);

if !image_path.exists() {
return Err(ImageLoaderError::IOError(format!(
"Image {} not found",
image_path.display()
)));
}

let annotation = annotations
.remove(&image_id)
.unwrap_or_else(|| AnnotationRaw::BoundingBoxes(Vec::new()));

images.push(ImageDatasetItemRaw {
annotation,
image_path,
});
}
}

if images.is_empty() {
return Err(ImageLoaderError::ParsingError(
"No images found in annotations".to_string(),
));
}

Ok(images)
}

impl Mapper<ImageDatasetItemRaw, ImageDatasetItem> for PathToImageDatasetItem {
/// Convert a raw image dataset item (path-like) to a 3D image array with a target label.
fn map(&self, item: &ImageDatasetItemRaw) -> ImageDatasetItem {
Expand Down Expand Up @@ -272,6 +435,10 @@ pub enum ImageLoaderError {
/// Invalid file error.
#[error("Invalid file extension: `{0}`")]
InvalidFileExtensionError(String),

/// Parsing error.
#[error("Parsing error: `{0}`")]
ParsingError(String),
}

type ImageDatasetMapper =
Expand Down Expand Up @@ -468,6 +635,37 @@ impl ImageFolderDataset {
Self::with_items(items, classes)
}

/// Create a COCO detection dataset based on the annotations JSON and image directory.
///
/// # Arguments
///
/// * `annotations_json` - Path to the JSON file containing annotations in COCO format (for
/// example instances_train2017.json).
///
/// * `images_path` - Path containing the images matching the annotations JSON.
///
/// # Returns
/// A new dataset instance.
pub fn new_coco_detection<A: AsRef<Path>, I: AsRef<Path>>(
annotations_json: A,
images_path: I,
) -> Result<Self, ImageLoaderError> {
let file = fs::File::open(annotations_json)
.map_err(|e| ImageLoaderError::IOError(format!("Failed to open annotations: {}", e)))?;
let json: Value = serde_json::from_reader(file).map_err(|e| {
ImageLoaderError::ParsingError(format!("Failed to parse annotations: {}", e))
})?;

let classes = parse_coco_classes(&json)?;
let annotations = parse_coco_bbox_annotations(&json)?;
let items = parse_coco_images(&images_path, annotations, &json)?;
let dataset = InMemDataset::new(items);
let mapper = PathToImageDatasetItem { classes };
let dataset = MapperDataset::new(dataset, mapper);

Ok(Self { dataset })
}

/// Create an image dataset with the specified items.
///
/// # Arguments
Expand Down Expand Up @@ -519,6 +717,8 @@ mod tests {
use super::*;
const DATASET_ROOT: &str = "tests/data/image_folder";
const SEGMASK_ROOT: &str = "tests/data/segmask_folder";
const COCO_JSON: &str = "tests/data/dataset_coco.json";
const COCO_IMAGES: &str = "tests/data/image_folder_coco";

#[test]
pub fn image_folder_dataset() {
Expand Down Expand Up @@ -809,4 +1009,96 @@ mod tests {
})
);
}

#[test]
pub fn coco_detection_dataset() {
let dataset = ImageFolderDataset::new_coco_detection(COCO_JSON, COCO_IMAGES).unwrap();
assert_eq!(dataset.len(), 3); // we have only three images defined
assert_eq!(dataset.get(3), None);

const TWO_DOTS_AND_TRIANGLE_B1: BoundingBox = BoundingBox {
coords: [
3.1251719394773056,
18.0907840440165,
10.96011004126548,
10.740027510316379,
],
label: 0,
};

const TWO_DOTS_AND_TRIANGLE_B2: BoundingBox = BoundingBox {
coords: [
3.2572214580467658,
3.0371389270976605,
10.563961485557085,
10.828060522696012,
],
label: 0,
};

const TWO_DOTS_AND_TRIANGLE_B3: BoundingBox = BoundingBox {
coords: [
15.097661623108666,
3.3892709766162312,
12.632737276478679,
11.18019257221458,
],
label: 1,
};

const DOTS_TRIANGLE_B1: BoundingBox = BoundingBox {
coords: [
3.125171939477304,
17.914718019257222,
10.82806052269601,
11.004126547455297,
],
label: 0,
};

const DOTS_TRIANGLE_B2: BoundingBox = BoundingBox {
coords: [
15.27372764786794,
3.301237964236589,
12.192572214580478,
11.708390646492433,
],
label: 1,
};

const ONE_DOT_B1: BoundingBox = BoundingBox {
coords: [
10.07977991746905,
9.59559834938102,
10.960110041265464,
11.356258596973863,
],
label: 0,
};

for item in dataset.iter() {
let file_name = Path::new(&item.image_path).file_name().unwrap();
match item.annotation {
// check if the number of bounding boxes is correct
Annotation::BoundingBoxes(v) => {
if file_name == "two_dots_and_triangle.jpg" {
assert_eq!(v.len(), 3);
assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B1));
assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B2));
assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B3));
} else if file_name == "dot_triangle.jpg" {
assert_eq!(v.len(), 2);
assert!(v.contains(&DOTS_TRIANGLE_B1));
assert!(v.contains(&DOTS_TRIANGLE_B2));
} else if file_name == "one_dot.jpg" {
assert_eq!(v.len(), 1);
assert!(v.contains(&ONE_DOT_B1));
} else {
panic!("{}", format!("unexpected image name: {}", item.image_path));
}
}
_ => panic!("unexpected annotation"),
}
}
}
}
Loading

0 comments on commit 1687e55

Please sign in to comment.