diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index 94b3d74eb3..c562dbb8e6 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -97,6 +97,9 @@ pub struct ImageDatasetItem { /// Annotation for the image. pub annotation: Annotation, + + /// Original image source. + pub image_path: String, } /// Raw annotation types. @@ -250,6 +253,7 @@ impl Mapper for PathToImageDatasetItem { ImageDatasetItem { image: img_vec, annotation, + image_path: item.image_path.display().to_string(), } } } diff --git a/examples/custom-image-dataset/src/data.rs b/examples/custom-image-dataset/src/data.rs index f997b96a2a..a360ea9dd4 100644 --- a/examples/custom-image-dataset/src/data.rs +++ b/examples/custom-image-dataset/src/data.rs @@ -47,6 +47,7 @@ pub struct ClassificationBatcher { pub struct ClassificationBatch { pub images: Tensor, pub targets: Tensor, + pub images_path: Vec, } impl ClassificationBatcher { @@ -83,6 +84,9 @@ impl Batcher> for Classific }) .collect(); + // Original sample path + let images_path: Vec = items.iter().map(|item| item.image_path.clone()).collect(); + let images = items .into_iter() .map(|item| TensorData::new(image_as_vec_u8(item), Shape::new([32, 32, 3]))) @@ -100,6 +104,10 @@ impl Batcher> for Classific let images = self.normalizer.normalize(images); - ClassificationBatch { images, targets } + ClassificationBatch { + images, + targets, + images_path, + } } }