diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 63ed3d0f..8bd5417f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: language: system types: [python] pass_filenames: false - stages: [commit] + stages: [pre-commit] - repo: local hooks: - id: docformatter @@ -18,5 +18,5 @@ repos: types: [python] args: ['--in-place', '--recursive','src/deepforest/'] pass_filenames: false - stages: [commit] + stages: [pre-commit] diff --git a/.readthedocs.yml b/.readthedocs.yml index 4db54370..c9ead8e3 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -4,6 +4,7 @@ build: os: ubuntu-22.04 tools: python: "3.12" + python: install: - requirements: dev_requirements.txt @@ -11,4 +12,5 @@ python: path: . submodules: - include: all + include: [] + diff --git a/docs/user_guide/batch_predicitons.md b/docs/user_guide/batch_predicitons.md new file mode 100644 index 00000000..78273e7a --- /dev/null +++ b/docs/user_guide/batch_predicitons.md @@ -0,0 +1,92 @@ +# Batch Prediction in DeepForest + +In this documentation, we highlight an efficient approach for batch prediction in DeepForest using the `predict_step` method and discuss a proposed enhancement to make this functionality more user-friendly. + +--- + +## Current Challenges with Image Prediction + +When working with dataloaders yielding batches of images, existing prediction methods (`predict_image`, `predict_file`, `predict_tile`) might require excessive preprocessing or manual intervention, such as: + +- **Saving images to disk** to use `predict_file`. +- **Manipulating dataloaders** to ensure images are preprocessed as expected. +- Looping through each image in a batch and using `predict_image`, which is inefficient when modern GPUs can handle larger batches. + +For example: +```python +for batch in test_loader: + for image_metadata, image, image_targets in batch: + # Preprocess image, e.g., DeepForest requires 0-255 data, channels first + pred = m.predict_image(image) +``` +This is suboptimal when GPU memory allows larger batch processing. + +--- + +## Optimized Batch Prediction + +DeepForest provides a batch prediction mechanism through the `predict_step` method. This method is part of the PyTorch Lightning framework and is intended for `trainer.predict`. While not explicitly documented, it can be leveraged directly: + +### Example: +```python +for idx, batch in enumerate(test_loader): + metadata, images, targets = batch + # Apply necessary preprocessing to the batch + predictions = m.predict_step(images, idx) +``` +Here: +- `predict_step` processes the batch efficiently on the GPU. +- `predictions` is a list of results, formatted as dataframes, consistent with other `predict_*` methods. + +--- + +## Limitations of Current Implementation + +- **Undocumented Pathway:** The use of `predict_step` for batch predictions is not well-documented and may not be intuitive for users. +- **Reserved Method Name:** Since `predict_step` is reserved by PyTorch Lightning, it cannot be renamed. However, it can be wrapped in a user-friendly function. + +--- + +## Proposed Solution: `predict_batch` Function + +To enhance usability, we propose adding a `predict_batch` function to the API. This function would: +- Mirror the format of `predict_file`, `predict_image`, and `predict_tile`. +- Help guide users through batch prediction workflows. + +### Example API for `predict_batch` +```python +def predict_batch(self, images, preprocess=True): + """ + Predict a batch of images with the DeepForest model. + + Args: + images: A batch of images in a numpy array or PyTorch tensor format. + preprocess: Whether to apply preprocessing (e.g., scaling to 0-1, channels first). + + Returns: + predictions: A list of pandas dataframes with bounding boxes, labels, and scores for each image in the batch. + """ + # Preprocess images if required + if preprocess: + images = preprocess_images(images) + + # Use predict_step for efficient batch processing + predictions = self.predict_step(images) + + return predictions +``` + +### Benefits: +- Streamlines batch prediction. +- Reduces the learning curve for new users. +- Ensures consistent API behavior. + +--- + +## Next Steps + +1. **Documentation:** Clearly document the current behavior of `predict_step` for advanced users. +2. **New Functionality:** Implement a `predict_batch` function to simplify batch processing workflows. +3. **Examples and Tutorials:** Add examples demonstrating batch predictions with and without the proposed `predict_batch` function. + +By documenting and enhancing this functionality, DeepForest can provide a more intuitive and efficient experience for users handling large datasets. diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 320beca0..f2f1695d 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -763,7 +763,43 @@ def predict_step(self, batch, batch_idx): for result in batch_results: boxes = visualize.format_boxes(result) results.append(boxes) + return results + + def predict_batch(self, images, preprocess_fn=None): + """Predict a batch of images with the deepforest model. + + Args: + images (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W) or (B, H, W, C). + preprocess_fn (callable, optional): A function to preprocess images before prediction. + If None, assumes images are preprocessed. + + Returns: + List[pd.DataFrame]: A list of dataframes with predictions for each image. + """ + self.model.eval() + + #conver to tensor if input is array + if isinstance(images, np.ndarray): + images = torch.tensor(images, device=self.device) + + #check input format + if images.dim() == 4 and images.shape[-1] == 3: + #Convert channels_last (B, H, W, C) to channels_first (B, C, H, W) + images = images.permute(0, 3, 1, 2) + + #appy preprocessing if available + if preprocess_fn: + images = preprocess_fn(images) + + #using Pytorch Ligthning's predict_step + with torch.no_grad(): + predictions = [] + for idx, image in enumerate(images): + predictions = self.predict_step(image.unsqueeze(0), idx) + predictions.extend(predictions) + #convert predictions to dataframes + results = [pd.DataFrame(pred) for pred in predictions if pred is not None] return results def configure_optimizers(self): diff --git a/test.py b/test.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_main.py b/tests/test_main.py index a8c3543e..61371ccf 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -18,6 +18,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import TensorBoardLogger +from torch.utils.data import DataLoader + from PIL import Image @@ -674,3 +676,116 @@ def test_predict_tile_with_crop_model_empty(): # Assert the result assert result is None + + +# @pytest.mark.parametrize("batch_size", [1, 4, 8]) +# def test_batch_prediction(m, batch_size, raster_path): +# """ +# Test batch prediction using a DataLoader. +# """ +# # Prepare input data +# tile = np.array(Image.open(raster_path)) +# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) +# dl = DataLoader(ds, batch_size=batch_size) + +# # Perform prediction +# predictions = [] +# for batch in dl: +# prediction = m.predict_batch(batch) +# predictions.append(prediction) + +# # Check results +# assert len(predictions) == len(dl) +# for batch_pred in predictions: +# assert isinstance(batch_pred, pd.DataFrame) +# assert set(batch_pred.columns) == { +# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry" +# } + +# @pytest.mark.parametrize("batch_size", [1, 4]) +# def test_batch_training(m, batch_size, tmpdir): +# """ +# Test batch training with a DataLoader. +# """ +# # Generate synthetic training data +# csv_file = get_data("example.csv") +# root_dir = os.path.dirname(csv_file) +# train_ds = m.load_dataset(csv_file, root_dir=root_dir) +# train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) + +# # Configure the model and trainer +# m.config["batch_size"] = batch_size +# m.create_trainer() +# trainer = m.trainer + +# # Train the model +# trainer.fit(m, train_dl) + +# # Assertions +# assert trainer.current_epoch == 1 +# assert trainer.batch_size == batch_size + +# @pytest.mark.parametrize("batch_size", [2, 4]) +# def test_batch_data_augmentation(m, batch_size, raster_path): +# """ +# Test batch prediction with data augmentation. +# """ +# tile = np.array(Image.open(raster_path)) +# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100, augment=True) +# dl = DataLoader(ds, batch_size=batch_size) + +# predictions = [] +# for batch in dl: +# prediction = m.predict_batch(batch) +# predictions.append(prediction) + +# assert len(predictions) == len(dl) +# for batch_pred in predictions: +# assert isinstance(batch_pred, pd.DataFrame) +# assert set(batch_pred.columns) == { +# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry" +# } + +# def test_batch_inference_consistency(m, raster_path): +# """ +# Test that batch inference produces consistent results with single image predictions. +# """ +# tile = np.array(Image.open(raster_path)) +# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) +# dl = DataLoader(ds, batch_size=4) + +# batch_predictions = [] +# for batch in dl: +# prediction = m.predict_batch(batch) +# batch_predictions.append(prediction) + +# single_predictions = [] +# for image in ds: +# prediction = m.predict_image(image=image) +# single_predictions.append(prediction) + +# batch_df = pd.concat(batch_predictions, ignore_index=True) +# single_df = pd.concat(single_predictions, ignore_index=True) + +# pd.testing.assert_frame_equal(batch_df, single_df) + +# def test_large_batch_handling(m, raster_path): +# """ +# Test model's ability to handle large batch sizes. +# """ +# tile = np.array(Image.open(raster_path)) +# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) +# dl = DataLoader(ds, batch_size=16) + +# predictions = [] +# for batch in dl: +# prediction = m.predict_batch(batch) +# predictions.append(prediction) + +# assert len(predictions) > 0 +# for batch_pred in predictions: +# assert isinstance(batch_pred, pd.DataFrame) +# assert set(batch_pred.columns) == { +# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry" +# } +# assert not batch_pred.empty