From a882843294ebbd725a634fef4e260c55856aef2b Mon Sep 17 00:00:00 2001 From: Sergey 'Jin' Bostandzhyan Date: Fri, 20 Dec 2024 14:04:54 +0100 Subject: [PATCH] Extend ImageFolderDataset to support import of COCO detection (#2612) COCO is a popular dataset which defines an own dataset format: https://cocodataset.org/#format-data This commit introduces a ImageFolderDataset::new_coco_detection() function, which expects a path to the JSON annotations file in COCO format and a path where the actual images are stored. Note, that while COCO also offers segmentation and pose estimation data, for now only the import of detection (bounding boxes) data is supported. --- burn-book/src/building-blocks/dataset.md | 16 +- .../burn-dataset/src/vision/image_folder.rs | 296 +++++++++++++++++- .../burn-dataset/tests/data/dataset_coco.json | 132 ++++++++ .../data/image_folder_coco/dot_triangle.jpg | Bin 0 -> 1566 bytes .../tests/data/image_folder_coco/one_dot.jpg | Bin 0 -> 1434 bytes .../two_dots_and_triangle.jpg | Bin 0 -> 1706 bytes 6 files changed, 441 insertions(+), 3 deletions(-) create mode 100644 crates/burn-dataset/tests/data/dataset_coco.json create mode 100644 crates/burn-dataset/tests/data/image_folder_coco/dot_triangle.jpg create mode 100644 crates/burn-dataset/tests/data/image_folder_coco/one_dot.jpg create mode 100644 crates/burn-dataset/tests/data/image_folder_coco/two_dots_and_triangle.jpg diff --git a/burn-book/src/building-blocks/dataset.md b/burn-book/src/building-blocks/dataset.md index 00a82a46f6..b33ee2f084 100644 --- a/burn-book/src/building-blocks/dataset.md +++ b/burn-book/src/building-blocks/dataset.md @@ -137,7 +137,7 @@ those are the only requirements. ### Images `ImageFolderDataset` is a generic vision dataset used to load images from disk. It is currently -available for multi-class and multi-label classification tasks as well as semantic segmentation tasks. +available for multi-class and multi-label classification tasks as well as semantic segmentation and object detection tasks. ```rust, ignore // Create an image classification dataset from the root folder, @@ -197,6 +197,20 @@ let dataset = ImageFolderDataset::new_segmentation_with_items( .unwrap(); ``` +```rust, ignore +// Create an object detection dataset from a COCO dataset. Currently only +// the import of object detection data (bounding boxes) is supported. +// +// COCO offers separate annotation and image archives for training and +// validation, paths to the unpacked files need to be passed as parameters: + +let dataset = ImageFolderDataset::new_coco_detection( + "/path/to/coco/instances_train2017.json", + "/path/to/coco/images/train2017" +) +.unwrap(); + +``` ### Comma-Separated Values (CSV) Loading records from a simple CSV file in-memory is simple with the `InMemDataset`: diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index c562dbb8e6..8b73e46367 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -4,11 +4,14 @@ use crate::{Dataset, InMemDataset}; use globwalk::{self, DirEntry}; use image::{self, ColorType}; use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::collections::{HashMap, HashSet}; +use std::fs; use std::path::{Path, PathBuf}; use thiserror::Error; const SUPPORTED_FILES: [&str; 4] = ["bmp", "jpg", "jpeg", "png"]; +const BBOX_MIN_NUM_VALUES: usize = 4; /// Image data type. #[derive(Debug, Clone, PartialEq)] @@ -82,7 +85,7 @@ pub struct SegmentationMask { /// Object detection bounding box annotation. #[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] pub struct BoundingBox { - /// Coordinates. + /// Coordinates in [x_min, y_min, width, height] format. pub coords: [f32; 4], /// Box class label. @@ -107,8 +110,8 @@ pub struct ImageDatasetItem { enum AnnotationRaw { Label(String), MultiLabel(Vec), + BoundingBoxes(Vec), SegmentationMask(PathBuf), - // TODO: bounding boxes } #[derive(Deserialize, Serialize, Debug, Clone)] @@ -184,9 +187,169 @@ fn parse_image_annotation( mask: segmentation_mask_to_vec_usize(mask_path), }) } + AnnotationRaw::BoundingBoxes(v) => Annotation::BoundingBoxes(v.clone()), } } +/// Retrieve all available classes from the COCO JSON +fn parse_coco_classes( + json: &serde_json::Value, +) -> Result, ImageLoaderError> { + let mut classes = HashMap::new(); + + if let Some(json_classes) = json["categories"].as_array() { + for class in json_classes { + let id = class["id"] + .as_u64() + .ok_or_else(|| ImageLoaderError::ParsingError("Invalid class ID".to_string())) + .and_then(|v| { + usize::try_from(v).map_err(|_| { + ImageLoaderError::ParsingError("Class ID out of usize range".to_string()) + }) + })?; + + let name = class["name"] + .as_str() + .filter(|&s| !s.is_empty()) + .ok_or_else(|| ImageLoaderError::ParsingError("Invalid class name".to_string()))? + .to_string(); + + classes.insert(name, id); + } + } + + if classes.is_empty() { + return Err(ImageLoaderError::ParsingError( + "No classes found in annotations".to_string(), + )); + } + + Ok(classes) +} + +/// Retrieve annotations from COCO JSON +fn parse_coco_bbox_annotations( + json: &serde_json::Value, +) -> Result, ImageLoaderError> { + let mut annotations = HashMap::new(); + + if let Some(json_annotations) = json["annotations"].as_array() { + for annotation in json_annotations { + let image_id = annotation["image_id"].as_u64().ok_or_else(|| { + ImageLoaderError::ParsingError("Invalid image ID in annotation".into()) + })?; + + let class_id = annotation["category_id"] + .as_u64() + .ok_or_else(|| { + ImageLoaderError::ParsingError("Invalid class ID in annotations".to_string()) + }) + .and_then(|v| { + usize::try_from(v).map_err(|_| { + ImageLoaderError::ParsingError( + "Class ID in annotations out of usize range".to_string(), + ) + }) + })?; + + let bbox_coords = annotation["bbox"] + .as_array() + .ok_or_else(|| ImageLoaderError::ParsingError("missing bbox array".to_string()))? + .iter() + .map(|v| { + v.as_f64() + .ok_or_else(|| { + ImageLoaderError::ParsingError("invalid bbox value".to_string()) + }) + .map(|val| val as f32) + }) + .collect::, _>>()?; + + if bbox_coords.len() < BBOX_MIN_NUM_VALUES { + return Err(ImageLoaderError::ParsingError(format!( + "not enough bounding box coordinates in annotation for image {}", + image_id + ))); + } + + let bbox = BoundingBox { + coords: [ + bbox_coords[0], + bbox_coords[1], + bbox_coords[2], + bbox_coords[3], + ], + label: class_id, + }; + + annotations + .entry(image_id) + .and_modify(|entry| { + if let AnnotationRaw::BoundingBoxes(ref mut bboxes) = entry { + bboxes.push(bbox.clone()); + } + }) + .or_insert_with(|| AnnotationRaw::BoundingBoxes(vec![bbox])); + } + } + + if annotations.is_empty() { + return Err(ImageLoaderError::ParsingError( + "no annotations found".to_string(), + )); + } + + Ok(annotations) +} + +/// Retrieve all available images from the COCO JSON +fn parse_coco_images>( + images_path: &P, + mut annotations: HashMap, + json: &serde_json::Value, +) -> Result, ImageLoaderError> { + let mut images = Vec::new(); + if let Some(json_images) = json["images"].as_array() { + for image in json_images { + let image_id = image["id"].as_u64().ok_or_else(|| { + ImageLoaderError::ParsingError("Invalid image ID in image list".to_string()) + })?; + + let file_name = image["file_name"] + .as_str() + .ok_or_else(|| ImageLoaderError::ParsingError("Invalid image ID".to_string()))? + .to_string(); + + let mut image_path = images_path.as_ref().to_path_buf(); + image_path.push(file_name); + + if !image_path.exists() { + return Err(ImageLoaderError::IOError(format!( + "Image {} not found", + image_path.display() + ))); + } + + let annotation = annotations + .remove(&image_id) + .unwrap_or_else(|| AnnotationRaw::BoundingBoxes(Vec::new())); + + images.push(ImageDatasetItemRaw { + annotation, + image_path, + }); + } + } + + if images.is_empty() { + return Err(ImageLoaderError::ParsingError( + "No images found in annotations".to_string(), + )); + } + + Ok(images) +} + impl Mapper for PathToImageDatasetItem { /// Convert a raw image dataset item (path-like) to a 3D image array with a target label. fn map(&self, item: &ImageDatasetItemRaw) -> ImageDatasetItem { @@ -272,6 +435,10 @@ pub enum ImageLoaderError { /// Invalid file error. #[error("Invalid file extension: `{0}`")] InvalidFileExtensionError(String), + + /// Parsing error. + #[error("Parsing error: `{0}`")] + ParsingError(String), } type ImageDatasetMapper = @@ -468,6 +635,37 @@ impl ImageFolderDataset { Self::with_items(items, classes) } + /// Create a COCO detection dataset based on the annotations JSON and image directory. + /// + /// # Arguments + /// + /// * `annotations_json` - Path to the JSON file containing annotations in COCO format (for + /// example instances_train2017.json). + /// + /// * `images_path` - Path containing the images matching the annotations JSON. + /// + /// # Returns + /// A new dataset instance. + pub fn new_coco_detection, I: AsRef>( + annotations_json: A, + images_path: I, + ) -> Result { + let file = fs::File::open(annotations_json) + .map_err(|e| ImageLoaderError::IOError(format!("Failed to open annotations: {}", e)))?; + let json: Value = serde_json::from_reader(file).map_err(|e| { + ImageLoaderError::ParsingError(format!("Failed to parse annotations: {}", e)) + })?; + + let classes = parse_coco_classes(&json)?; + let annotations = parse_coco_bbox_annotations(&json)?; + let items = parse_coco_images(&images_path, annotations, &json)?; + let dataset = InMemDataset::new(items); + let mapper = PathToImageDatasetItem { classes }; + let dataset = MapperDataset::new(dataset, mapper); + + Ok(Self { dataset }) + } + /// Create an image dataset with the specified items. /// /// # Arguments @@ -519,6 +717,8 @@ mod tests { use super::*; const DATASET_ROOT: &str = "tests/data/image_folder"; const SEGMASK_ROOT: &str = "tests/data/segmask_folder"; + const COCO_JSON: &str = "tests/data/dataset_coco.json"; + const COCO_IMAGES: &str = "tests/data/image_folder_coco"; #[test] pub fn image_folder_dataset() { @@ -809,4 +1009,96 @@ mod tests { }) ); } + + #[test] + pub fn coco_detection_dataset() { + let dataset = ImageFolderDataset::new_coco_detection(COCO_JSON, COCO_IMAGES).unwrap(); + assert_eq!(dataset.len(), 3); // we have only three images defined + assert_eq!(dataset.get(3), None); + + const TWO_DOTS_AND_TRIANGLE_B1: BoundingBox = BoundingBox { + coords: [ + 3.1251719394773056, + 18.0907840440165, + 10.96011004126548, + 10.740027510316379, + ], + label: 0, + }; + + const TWO_DOTS_AND_TRIANGLE_B2: BoundingBox = BoundingBox { + coords: [ + 3.2572214580467658, + 3.0371389270976605, + 10.563961485557085, + 10.828060522696012, + ], + label: 0, + }; + + const TWO_DOTS_AND_TRIANGLE_B3: BoundingBox = BoundingBox { + coords: [ + 15.097661623108666, + 3.3892709766162312, + 12.632737276478679, + 11.18019257221458, + ], + label: 1, + }; + + const DOTS_TRIANGLE_B1: BoundingBox = BoundingBox { + coords: [ + 3.125171939477304, + 17.914718019257222, + 10.82806052269601, + 11.004126547455297, + ], + label: 0, + }; + + const DOTS_TRIANGLE_B2: BoundingBox = BoundingBox { + coords: [ + 15.27372764786794, + 3.301237964236589, + 12.192572214580478, + 11.708390646492433, + ], + label: 1, + }; + + const ONE_DOT_B1: BoundingBox = BoundingBox { + coords: [ + 10.07977991746905, + 9.59559834938102, + 10.960110041265464, + 11.356258596973863, + ], + label: 0, + }; + + for item in dataset.iter() { + let file_name = Path::new(&item.image_path).file_name().unwrap(); + match item.annotation { + // check if the number of bounding boxes is correct + Annotation::BoundingBoxes(v) => { + if file_name == "two_dots_and_triangle.jpg" { + assert_eq!(v.len(), 3); + assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B1)); + assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B2)); + assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B3)); + } else if file_name == "dot_triangle.jpg" { + assert_eq!(v.len(), 2); + assert!(v.contains(&DOTS_TRIANGLE_B1)); + assert!(v.contains(&DOTS_TRIANGLE_B2)); + } else if file_name == "one_dot.jpg" { + assert_eq!(v.len(), 1); + assert!(v.contains(&ONE_DOT_B1)); + } else { + panic!("{}", format!("unexpected image name: {}", item.image_path)); + } + } + _ => panic!("unexpected annotation"), + } + } + } } diff --git a/crates/burn-dataset/tests/data/dataset_coco.json b/crates/burn-dataset/tests/data/dataset_coco.json new file mode 100644 index 0000000000..6a75bf9e0e --- /dev/null +++ b/crates/burn-dataset/tests/data/dataset_coco.json @@ -0,0 +1,132 @@ +{ + "images": [ + { + "width": 32, + "height": 32, + "id": 0, + "file_name": "two_dots_and_triangle.jpg" + }, + { + "width": 32, + "height": 32, + "id": 1, + "file_name": "dot_triangle.jpg" + }, + { + "width": 32, + "height": 32, + "id": 2, + "file_name": "one_dot.jpg" + } + ], + "categories": [ + { + "id": 0, + "name": "dot" + }, + { + "id": 1, + "name": "triangle" + } + ], + "annotations": [ + { + "id": 0, + "image_id": 0, + "category_id": 0, + "segmentation": [], + "bbox": [ + 3.1251719394773056, + 18.0907840440165, + 10.96011004126548, + 10.740027510316379 + ], + "ignore": 0, + "iscrowd": 0, + "area": 117.71188335928603 + }, + { + "id": 1, + "image_id": 0, + "category_id": 0, + "segmentation": [], + "bbox": [ + 3.2572214580467658, + 3.0371389270976605, + 10.563961485557085, + 10.828060522696012 + ], + "ignore": 0, + "iscrowd": 0, + "area": 114.38721432504178 + }, + { + "id": 2, + "image_id": 0, + "category_id": 1, + "segmentation": [], + "bbox": [ + 15.097661623108666, + 3.3892709766162312, + 12.632737276478679, + 11.18019257221458 + ], + "ignore": 0, + "iscrowd": 0, + "area": 141.23643546522516 + }, + { + "id": 3, + "image_id": 1, + "category_id": 0, + "segmentation": [], + "bbox": [ + 3.125171939477304, + 17.914718019257222, + 10.82806052269601, + 11.004126547455297 + ], + "ignore": 0, + "iscrowd": 0, + "area": 119.15334825525184 + }, + { + "id": 4, + "image_id": 1, + "category_id": 1, + "segmentation": [], + "bbox": [ + 15.27372764786794, + 3.301237964236589, + 12.192572214580478, + 11.708390646492433 + ], + "ignore": 0, + "iscrowd": 0, + "area": 142.7553984738776 + }, + { + "id": 5, + "image_id": 2, + "category_id": 0, + "segmentation": [], + "bbox": [ + 10.07977991746905, + 9.59559834938102, + 10.960110041265464, + 11.356258596973863 + ], + "ignore": 0, + "iscrowd": 0, + "area": 124.46584387990049 + } + ], + "info": { + "year": 2024, + "version": "1.0", + "description": "", + "contributor": "", + "url": "", + "date_created": "2024-12-11 22:16:31.823494" + } +} diff --git a/crates/burn-dataset/tests/data/image_folder_coco/dot_triangle.jpg b/crates/burn-dataset/tests/data/image_folder_coco/dot_triangle.jpg new file mode 100644 index 0000000000000000000000000000000000000000..572ab2e6d136739b944fe9aa0a39e4c1e1bd084c GIT binary patch literal 1566 zcmb7EeM}p57=G@q#}(SLyK;O=^w^jmSmHo}#u)kN1PaU*wndPd+Lq5iyV4d|oc}Bm z7l}@oGB=%Wby?Io=jL?8X&jlGEYT@Z<~CT&LNtQeAKMIzi-GgKV*_Ty#W&ZxC-3__ zZ{LsKoqA5a2(i4xQ38g{#@GOL1_}01gD28ZuXMFUnw3(=DyKTm9(364ZfEuC5=Z6Z z#0J>iZ9ae4Pzd+~kx+GMu~Jj}lp;*P3=@*TBj4i-2aBA}N-S#Eegb6ze$3W#ElT_M zfTg}6>;o8u!iBzIC_?xd!ueg1poX^-mVM2hdcyk&+d@LCB%C z)pqJlCag3qz}^Me*BtT_CyzLp_5J|mXKQA0y@sFDu)nQ?YST^Y(A*GMw}zgX^h-x6 z99RV>6nN2r7PKLPY3S!ryrgP@j}$Ah9)5UqJ+vJVhirbf;l>AZ zbr)kU2Kb*<^^fI%@G3CftEzuRRdxCufV%(;2IlqacLNVkk-vA|m$?gA(hVFrG4J!f z4h(DtQqBiGA&>q)9NpQ*M&RnZfTb2l7zX0%X6HsREsi~BfweTN%G1D(H-NjFD7R~o zzKtZ9o7;bKZp=>|gB={pY8RP-F5bxV1_K`_2r));oY`!OGnp*$R?!kK#hXlGk|-s} ziHV8kCCMpCa*9<>l(i%bM;Zek%k#0a#blBHx2eY=8EH)l$4Fo$hLafe1X5@U+FP{1 zVvM{%5&0i(VyL4qa|~FK6Y2c+<5{WDWZNci&ErZz+R*(RsnSJIrwgTgj6q~LAu);6 zMhY-I8!b<$*mUseSeMxOK*921>&(|(#lIaLiTtp!`?c%ey>ivn?$!#BZmo;nEKIEds5=_!fBV8%UPJ5U$q%=bk5A>*_VgVI)LrhEOSblxu5fN^{bt7_DyT6RzH4v zX^fOn1Tm+|)2Fm=x;dwUjM2K^?>aoOt9xehz-M2c={WU%=Uay+He5*)D^G3hy!X{)HI38{KNT_(odgoc;aR0QQzgAavtm5R@FE6KBY|+hM nweK7-Kb9R`_UVvgyIXpXJ!&p9UdaF0R<-U6`h>bV)boD>$RbUE literal 0 HcmV?d00001 diff --git a/crates/burn-dataset/tests/data/image_folder_coco/one_dot.jpg b/crates/burn-dataset/tests/data/image_folder_coco/one_dot.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b719f53061cd3b2cfc5b71a7a2244a91df02aefd GIT binary patch literal 1434 zcmb7DeQXnT7=G?|y;qoL6NO)ZLN9OQd9k&^n5rmR# z3i)X_g|OPT2uBv-XnQhFIh!b_G!{=#ebzjOV>*6W$MMc?n$0kMLt9H?=MH+8(p`dj z1ki{eR7B8?4s;@o3VJ8VlR}8%zGe8u20uuw2(da4CC1a(k2t~xANm&3n1UCOp{(-Eo^aPOK4~&jweUbNo zvu^-J6J6nC*yzL4leM-2zYPQSU4Zi&!1DZ@ZWPnw82S!qBCQ%Sz_H`Nnl`GvXPLdt zL|D*m#hiKlv>Wh)XH37yOmqomK`@yFizw!pZ5Er&YOz}Fc@D{*C+AtMQh_8p6_?9p z%P%Y{P>LLiOVNvPJaJ4yt{~(pcB@_a-=MdWIX~37%v&^|tl7 zdd^&F9=rN@z*OO#dhxG6XOlEci9{mMstc{a$+d7x>W1wPR;_X0Q6p)26e<%iRhZ0KoxhtiRN>9s)V3cKoJNdF~OU$2bDZOVU6Q330|@fmL8x&}D-{v2b5JW?+G3iL{(Z zErSw;RMcV=)Paa*=mAdFq@s;A6AcInBsEUdSVYiZDQS2+cM-vvpy|wgd~?oszW4Vx3SsT+|E zK%Oi#7nOAUKwr3^wIJ=^ z0}6|^%nU%NFdSztu{#iNM;u+@C=u{c#GEv#KAAOxEdqW~z(s}SSew&?4&_!`=9Bmi#(OCw zffmw04=ON0ITS!4I3OI~HuRK%0pke^oLoOSp{Imj15o>MPPL1gRlq4L7SUx zS?-!hr;Cs$0MOs_{H+iG@tXkC2YLQpBhODa0Z_vL7i@F+#U}w`{zU)5IbZMz0QbEB zotNi)rh0(0`v81~N{n`+a~%qwWNt3NWHW$b3xG!tfNS>5+%P7@;nV;?2F@zA8{pV+ zfMt28y={TMC7AGVZi{jr%#Xhg8lXs0xX?s+(Gr>#iD*}`*hM0DmCI$WGMU0n$tc`d zH<^s_VptE()6-M#?(O5n`6xM0P7om|%n{L2nwD}3nS%S@#&-cL!EI7e1PdffP%Ob; z1|OV)@D`myP(_DDq}j7P@o^y^CgtK0ue^OYQ-D)y0Lv8XtwD6Xd_#ETHfS;BG#NP;576g}HBL6IU> z2{hvousSk5zO*(lD8XQSP^*iYG#S2VOz}2bRvoC>s?&`;oznMOeob`!?k93Ol(FBe z8nKm4zc;paeMd-t2e?QoWR9R8f8joAIbPr$8?bWx(7By2jGoAQrzRpJY_n)OseiO7 zd!^wp_qASG6T?H)?*rV`zRE1)aB+|Op3|9z&unKbQxhoQf zp1*mzr77xA>pEI$-LWI|C2ZB-IqCVMCfERoXI7^hYfE3AowYh3g02ba88S}&bYahW z4`a{7CfzR0J+<^k`kFxPVDrU^6~3*3fpri*5qc-B{JXpFHdNOaT=%C2S}uydstPq7 z2yJWMbtWQ=47~O+N(Bt|hTIjT*tn~|7OQa?=b3PROiX2*N|_wI$TkT3#QfJjISES{ zJSp>=1MCi)GeYCT)|k|;sr|6A=dle2Ye((&>b6mlrpZdamCLk#;;rB!H;IeuqQ;oI zx`wJ_)=Kt{Rr<`bB;8u2-+=hT(fggd$Fq~q>n4AjOnl?jBl=6HO5Z*!j}31O>bN33 an!~n`=j6$f;ph*-(lbBB-}$yue&`==a*EXe literal 0 HcmV?d00001