From 9d073e724ede7f7d098e39ae659663a5ad44f0de Mon Sep 17 00:00:00 2001 From: jiawen wang Date: Thu, 28 Nov 2024 22:13:10 +0800 Subject: [PATCH] images source (#2558) * Keep the image source as a field of ImageDatasetItem * add image source to batch * Update data.rs * Fix fmt --------- Co-authored-by: Guillaume Lagrange --- crates/burn-dataset/src/vision/image_folder.rs | 4 ++++ examples/custom-image-dataset/src/data.rs | 10 +++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index 94b3d74eb3..c562dbb8e6 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -97,6 +97,9 @@ pub struct ImageDatasetItem { /// Annotation for the image. pub annotation: Annotation, + + /// Original image source. + pub image_path: String, } /// Raw annotation types. @@ -250,6 +253,7 @@ impl Mapper for PathToImageDatasetItem { ImageDatasetItem { image: img_vec, annotation, + image_path: item.image_path.display().to_string(), } } } diff --git a/examples/custom-image-dataset/src/data.rs b/examples/custom-image-dataset/src/data.rs index f997b96a2a..a360ea9dd4 100644 --- a/examples/custom-image-dataset/src/data.rs +++ b/examples/custom-image-dataset/src/data.rs @@ -47,6 +47,7 @@ pub struct ClassificationBatcher { pub struct ClassificationBatch { pub images: Tensor, pub targets: Tensor, + pub images_path: Vec, } impl ClassificationBatcher { @@ -83,6 +84,9 @@ impl Batcher> for Classific }) .collect(); + // Original sample path + let images_path: Vec = items.iter().map(|item| item.image_path.clone()).collect(); + let images = items .into_iter() .map(|item| TensorData::new(image_as_vec_u8(item), Shape::new([32, 32, 3]))) @@ -100,6 +104,10 @@ impl Batcher> for Classific let images = self.normalizer.normalize(images); - ClassificationBatch { images, targets } + ClassificationBatch { + images, + targets, + images_path, + } } }