diff --git a/burn-book/src/building-blocks/dataset.md b/burn-book/src/building-blocks/dataset.md index 00a82a46f6..b33ee2f084 100644 --- a/burn-book/src/building-blocks/dataset.md +++ b/burn-book/src/building-blocks/dataset.md @@ -137,7 +137,7 @@ those are the only requirements. ### Images `ImageFolderDataset` is a generic vision dataset used to load images from disk. It is currently -available for multi-class and multi-label classification tasks as well as semantic segmentation tasks. +available for multi-class and multi-label classification tasks as well as semantic segmentation and object detection tasks. ```rust, ignore // Create an image classification dataset from the root folder, @@ -197,6 +197,20 @@ let dataset = ImageFolderDataset::new_segmentation_with_items( .unwrap(); ``` +```rust, ignore +// Create an object detection dataset from a COCO dataset. Currently only +// the import of object detection data (bounding boxes) is supported. +// +// COCO offers separate annotation and image archives for training and +// validation, paths to the unpacked files need to be passed as parameters: + +let dataset = ImageFolderDataset::new_coco_detection( + "/path/to/coco/instances_train2017.json", + "/path/to/coco/images/train2017" +) +.unwrap(); + +``` ### Comma-Separated Values (CSV) Loading records from a simple CSV file in-memory is simple with the `InMemDataset`: diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index c562dbb8e6..8b73e46367 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -4,11 +4,14 @@ use crate::{Dataset, InMemDataset}; use globwalk::{self, DirEntry}; use image::{self, ColorType}; use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::collections::{HashMap, HashSet}; +use std::fs; use std::path::{Path, PathBuf}; use thiserror::Error; const SUPPORTED_FILES: [&str; 4] = ["bmp", "jpg", "jpeg", "png"]; +const BBOX_MIN_NUM_VALUES: usize = 4; /// Image data type. #[derive(Debug, Clone, PartialEq)] @@ -82,7 +85,7 @@ pub struct SegmentationMask { /// Object detection bounding box annotation. #[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] pub struct BoundingBox { - /// Coordinates. + /// Coordinates in [x_min, y_min, width, height] format. pub coords: [f32; 4], /// Box class label. @@ -107,8 +110,8 @@ pub struct ImageDatasetItem { enum AnnotationRaw { Label(String), MultiLabel(Vec), + BoundingBoxes(Vec), SegmentationMask(PathBuf), - // TODO: bounding boxes } #[derive(Deserialize, Serialize, Debug, Clone)] @@ -184,9 +187,169 @@ fn parse_image_annotation( mask: segmentation_mask_to_vec_usize(mask_path), }) } + AnnotationRaw::BoundingBoxes(v) => Annotation::BoundingBoxes(v.clone()), } } +/// Retrieve all available classes from the COCO JSON +fn parse_coco_classes( + json: &serde_json::Value, +) -> Result, ImageLoaderError> { + let mut classes = HashMap::new(); + + if let Some(json_classes) = json["categories"].as_array() { + for class in json_classes { + let id = class["id"] + .as_u64() + .ok_or_else(|| ImageLoaderError::ParsingError("Invalid class ID".to_string())) + .and_then(|v| { + usize::try_from(v).map_err(|_| { + ImageLoaderError::ParsingError("Class ID out of usize range".to_string()) + }) + })?; + + let name = class["name"] + .as_str() + .filter(|&s| !s.is_empty()) + .ok_or_else(|| ImageLoaderError::ParsingError("Invalid class name".to_string()))? + .to_string(); + + classes.insert(name, id); + } + } + + if classes.is_empty() { + return Err(ImageLoaderError::ParsingError( + "No classes found in annotations".to_string(), + )); + } + + Ok(classes) +} + +/// Retrieve annotations from COCO JSON +fn parse_coco_bbox_annotations( + json: &serde_json::Value, +) -> Result, ImageLoaderError> { + let mut annotations = HashMap::new(); + + if let Some(json_annotations) = json["annotations"].as_array() { + for annotation in json_annotations { + let image_id = annotation["image_id"].as_u64().ok_or_else(|| { + ImageLoaderError::ParsingError("Invalid image ID in annotation".into()) + })?; + + let class_id = annotation["category_id"] + .as_u64() + .ok_or_else(|| { + ImageLoaderError::ParsingError("Invalid class ID in annotations".to_string()) + }) + .and_then(|v| { + usize::try_from(v).map_err(|_| { + ImageLoaderError::ParsingError( + "Class ID in annotations out of usize range".to_string(), + ) + }) + })?; + + let bbox_coords = annotation["bbox"] + .as_array() + .ok_or_else(|| ImageLoaderError::ParsingError("missing bbox array".to_string()))? + .iter() + .map(|v| { + v.as_f64() + .ok_or_else(|| { + ImageLoaderError::ParsingError("invalid bbox value".to_string()) + }) + .map(|val| val as f32) + }) + .collect::, _>>()?; + + if bbox_coords.len() < BBOX_MIN_NUM_VALUES { + return Err(ImageLoaderError::ParsingError(format!( + "not enough bounding box coordinates in annotation for image {}", + image_id + ))); + } + + let bbox = BoundingBox { + coords: [ + bbox_coords[0], + bbox_coords[1], + bbox_coords[2], + bbox_coords[3], + ], + label: class_id, + }; + + annotations + .entry(image_id) + .and_modify(|entry| { + if let AnnotationRaw::BoundingBoxes(ref mut bboxes) = entry { + bboxes.push(bbox.clone()); + } + }) + .or_insert_with(|| AnnotationRaw::BoundingBoxes(vec![bbox])); + } + } + + if annotations.is_empty() { + return Err(ImageLoaderError::ParsingError( + "no annotations found".to_string(), + )); + } + + Ok(annotations) +} + +/// Retrieve all available images from the COCO JSON +fn parse_coco_images>( + images_path: &P, + mut annotations: HashMap, + json: &serde_json::Value, +) -> Result, ImageLoaderError> { + let mut images = Vec::new(); + if let Some(json_images) = json["images"].as_array() { + for image in json_images { + let image_id = image["id"].as_u64().ok_or_else(|| { + ImageLoaderError::ParsingError("Invalid image ID in image list".to_string()) + })?; + + let file_name = image["file_name"] + .as_str() + .ok_or_else(|| ImageLoaderError::ParsingError("Invalid image ID".to_string()))? + .to_string(); + + let mut image_path = images_path.as_ref().to_path_buf(); + image_path.push(file_name); + + if !image_path.exists() { + return Err(ImageLoaderError::IOError(format!( + "Image {} not found", + image_path.display() + ))); + } + + let annotation = annotations + .remove(&image_id) + .unwrap_or_else(|| AnnotationRaw::BoundingBoxes(Vec::new())); + + images.push(ImageDatasetItemRaw { + annotation, + image_path, + }); + } + } + + if images.is_empty() { + return Err(ImageLoaderError::ParsingError( + "No images found in annotations".to_string(), + )); + } + + Ok(images) +} + impl Mapper for PathToImageDatasetItem { /// Convert a raw image dataset item (path-like) to a 3D image array with a target label. fn map(&self, item: &ImageDatasetItemRaw) -> ImageDatasetItem { @@ -272,6 +435,10 @@ pub enum ImageLoaderError { /// Invalid file error. #[error("Invalid file extension: `{0}`")] InvalidFileExtensionError(String), + + /// Parsing error. + #[error("Parsing error: `{0}`")] + ParsingError(String), } type ImageDatasetMapper = @@ -468,6 +635,37 @@ impl ImageFolderDataset { Self::with_items(items, classes) } + /// Create a COCO detection dataset based on the annotations JSON and image directory. + /// + /// # Arguments + /// + /// * `annotations_json` - Path to the JSON file containing annotations in COCO format (for + /// example instances_train2017.json). + /// + /// * `images_path` - Path containing the images matching the annotations JSON. + /// + /// # Returns + /// A new dataset instance. + pub fn new_coco_detection, I: AsRef>( + annotations_json: A, + images_path: I, + ) -> Result { + let file = fs::File::open(annotations_json) + .map_err(|e| ImageLoaderError::IOError(format!("Failed to open annotations: {}", e)))?; + let json: Value = serde_json::from_reader(file).map_err(|e| { + ImageLoaderError::ParsingError(format!("Failed to parse annotations: {}", e)) + })?; + + let classes = parse_coco_classes(&json)?; + let annotations = parse_coco_bbox_annotations(&json)?; + let items = parse_coco_images(&images_path, annotations, &json)?; + let dataset = InMemDataset::new(items); + let mapper = PathToImageDatasetItem { classes }; + let dataset = MapperDataset::new(dataset, mapper); + + Ok(Self { dataset }) + } + /// Create an image dataset with the specified items. /// /// # Arguments @@ -519,6 +717,8 @@ mod tests { use super::*; const DATASET_ROOT: &str = "tests/data/image_folder"; const SEGMASK_ROOT: &str = "tests/data/segmask_folder"; + const COCO_JSON: &str = "tests/data/dataset_coco.json"; + const COCO_IMAGES: &str = "tests/data/image_folder_coco"; #[test] pub fn image_folder_dataset() { @@ -809,4 +1009,96 @@ mod tests { }) ); } + + #[test] + pub fn coco_detection_dataset() { + let dataset = ImageFolderDataset::new_coco_detection(COCO_JSON, COCO_IMAGES).unwrap(); + assert_eq!(dataset.len(), 3); // we have only three images defined + assert_eq!(dataset.get(3), None); + + const TWO_DOTS_AND_TRIANGLE_B1: BoundingBox = BoundingBox { + coords: [ + 3.1251719394773056, + 18.0907840440165, + 10.96011004126548, + 10.740027510316379, + ], + label: 0, + }; + + const TWO_DOTS_AND_TRIANGLE_B2: BoundingBox = BoundingBox { + coords: [ + 3.2572214580467658, + 3.0371389270976605, + 10.563961485557085, + 10.828060522696012, + ], + label: 0, + }; + + const TWO_DOTS_AND_TRIANGLE_B3: BoundingBox = BoundingBox { + coords: [ + 15.097661623108666, + 3.3892709766162312, + 12.632737276478679, + 11.18019257221458, + ], + label: 1, + }; + + const DOTS_TRIANGLE_B1: BoundingBox = BoundingBox { + coords: [ + 3.125171939477304, + 17.914718019257222, + 10.82806052269601, + 11.004126547455297, + ], + label: 0, + }; + + const DOTS_TRIANGLE_B2: BoundingBox = BoundingBox { + coords: [ + 15.27372764786794, + 3.301237964236589, + 12.192572214580478, + 11.708390646492433, + ], + label: 1, + }; + + const ONE_DOT_B1: BoundingBox = BoundingBox { + coords: [ + 10.07977991746905, + 9.59559834938102, + 10.960110041265464, + 11.356258596973863, + ], + label: 0, + }; + + for item in dataset.iter() { + let file_name = Path::new(&item.image_path).file_name().unwrap(); + match item.annotation { + // check if the number of bounding boxes is correct + Annotation::BoundingBoxes(v) => { + if file_name == "two_dots_and_triangle.jpg" { + assert_eq!(v.len(), 3); + assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B1)); + assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B2)); + assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B3)); + } else if file_name == "dot_triangle.jpg" { + assert_eq!(v.len(), 2); + assert!(v.contains(&DOTS_TRIANGLE_B1)); + assert!(v.contains(&DOTS_TRIANGLE_B2)); + } else if file_name == "one_dot.jpg" { + assert_eq!(v.len(), 1); + assert!(v.contains(&ONE_DOT_B1)); + } else { + panic!("{}", format!("unexpected image name: {}", item.image_path)); + } + } + _ => panic!("unexpected annotation"), + } + } + } } diff --git a/crates/burn-dataset/tests/data/dataset_coco.json b/crates/burn-dataset/tests/data/dataset_coco.json new file mode 100644 index 0000000000..6a75bf9e0e --- /dev/null +++ b/crates/burn-dataset/tests/data/dataset_coco.json @@ -0,0 +1,132 @@ +{ + "images": [ + { + "width": 32, + "height": 32, + "id": 0, + "file_name": "two_dots_and_triangle.jpg" + }, + { + "width": 32, + "height": 32, + "id": 1, + "file_name": "dot_triangle.jpg" + }, + { + "width": 32, + "height": 32, + "id": 2, + "file_name": "one_dot.jpg" + } + ], + "categories": [ + { + "id": 0, + "name": "dot" + }, + { + "id": 1, + "name": "triangle" + } + ], + "annotations": [ + { + "id": 0, + "image_id": 0, + "category_id": 0, + "segmentation": [], + "bbox": [ + 3.1251719394773056, + 18.0907840440165, + 10.96011004126548, + 10.740027510316379 + ], + "ignore": 0, + "iscrowd": 0, + "area": 117.71188335928603 + }, + { + "id": 1, + "image_id": 0, + "category_id": 0, + "segmentation": [], + "bbox": [ + 3.2572214580467658, + 3.0371389270976605, + 10.563961485557085, + 10.828060522696012 + ], + "ignore": 0, + "iscrowd": 0, + "area": 114.38721432504178 + }, + { + "id": 2, + "image_id": 0, + "category_id": 1, + "segmentation": [], + "bbox": [ + 15.097661623108666, + 3.3892709766162312, + 12.632737276478679, + 11.18019257221458 + ], + "ignore": 0, + "iscrowd": 0, + "area": 141.23643546522516 + }, + { + "id": 3, + "image_id": 1, + "category_id": 0, + "segmentation": [], + "bbox": [ + 3.125171939477304, + 17.914718019257222, + 10.82806052269601, + 11.004126547455297 + ], + "ignore": 0, + "iscrowd": 0, + "area": 119.15334825525184 + }, + { + "id": 4, + "image_id": 1, + "category_id": 1, + "segmentation": [], + "bbox": [ + 15.27372764786794, + 3.301237964236589, + 12.192572214580478, + 11.708390646492433 + ], + "ignore": 0, + "iscrowd": 0, + "area": 142.7553984738776 + }, + { + "id": 5, + "image_id": 2, + "category_id": 0, + "segmentation": [], + "bbox": [ + 10.07977991746905, + 9.59559834938102, + 10.960110041265464, + 11.356258596973863 + ], + "ignore": 0, + "iscrowd": 0, + "area": 124.46584387990049 + } + ], + "info": { + "year": 2024, + "version": "1.0", + "description": "", + "contributor": "", + "url": "", + "date_created": "2024-12-11 22:16:31.823494" + } +} diff --git a/crates/burn-dataset/tests/data/image_folder_coco/dot_triangle.jpg b/crates/burn-dataset/tests/data/image_folder_coco/dot_triangle.jpg new file mode 100644 index 0000000000..572ab2e6d1 Binary files /dev/null and b/crates/burn-dataset/tests/data/image_folder_coco/dot_triangle.jpg differ diff --git a/crates/burn-dataset/tests/data/image_folder_coco/one_dot.jpg b/crates/burn-dataset/tests/data/image_folder_coco/one_dot.jpg new file mode 100644 index 0000000000..b719f53061 Binary files /dev/null and b/crates/burn-dataset/tests/data/image_folder_coco/one_dot.jpg differ diff --git a/crates/burn-dataset/tests/data/image_folder_coco/two_dots_and_triangle.jpg b/crates/burn-dataset/tests/data/image_folder_coco/two_dots_and_triangle.jpg new file mode 100644 index 0000000000..fe52aa6170 Binary files /dev/null and b/crates/burn-dataset/tests/data/image_folder_coco/two_dots_and_triangle.jpg differ