Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Predicitions #856

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
language: system
types: [python]
pass_filenames: false
stages: [commit]
stages: [pre-commit]
- repo: local
hooks:
- id: docformatter
Expand All @@ -18,5 +18,5 @@ repos:
types: [python]
args: ['--in-place', '--recursive','src/deepforest/']
pass_filenames: false
stages: [commit]
stages: [pre-commit]

4 changes: 3 additions & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ build:
os: ubuntu-22.04
tools:
python: "3.12"

python:
install:
- requirements: dev_requirements.txt
- method: pip
path: .

submodules:
include: all
include: []

92 changes: 92 additions & 0 deletions docs/user_guide/batch_predicitons.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Batch Prediction in DeepForest

In this documentation, we highlight an efficient approach for batch prediction in DeepForest using the `predict_step` method and discuss a proposed enhancement to make this functionality more user-friendly.

---

## Current Challenges with Image Prediction

When working with dataloaders yielding batches of images, existing prediction methods (`predict_image`, `predict_file`, `predict_tile`) might require excessive preprocessing or manual intervention, such as:

- **Saving images to disk** to use `predict_file`.
- **Manipulating dataloaders** to ensure images are preprocessed as expected.
- Looping through each image in a batch and using `predict_image`, which is inefficient when modern GPUs can handle larger batches.

For example:
```python
for batch in test_loader:
for image_metadata, image, image_targets in batch:
# Preprocess image, e.g., DeepForest requires 0-255 data, channels first
pred = m.predict_image(image)
```
This is suboptimal when GPU memory allows larger batch processing.

---

## Optimized Batch Prediction

DeepForest provides a batch prediction mechanism through the `predict_step` method. This method is part of the PyTorch Lightning framework and is intended for `trainer.predict`. While not explicitly documented, it can be leveraged directly:

### Example:
```python
for idx, batch in enumerate(test_loader):
metadata, images, targets = batch
# Apply necessary preprocessing to the batch
predictions = m.predict_step(images, idx)
```
Here:
- `predict_step` processes the batch efficiently on the GPU.
- `predictions` is a list of results, formatted as dataframes, consistent with other `predict_*` methods.

---

## Limitations of Current Implementation

- **Undocumented Pathway:** The use of `predict_step` for batch predictions is not well-documented and may not be intuitive for users.
- **Reserved Method Name:** Since `predict_step` is reserved by PyTorch Lightning, it cannot be renamed. However, it can be wrapped in a user-friendly function.

---

## Proposed Solution: `predict_batch` Function

To enhance usability, we propose adding a `predict_batch` function to the API. This function would:
- Mirror the format of `predict_file`, `predict_image`, and `predict_tile`.
- Help guide users through batch prediction workflows.

### Example API for `predict_batch`
```python
def predict_batch(self, images, preprocess=True):
"""
Predict a batch of images with the DeepForest model.

Args:
images: A batch of images in a numpy array or PyTorch tensor format.
preprocess: Whether to apply preprocessing (e.g., scaling to 0-1, channels first).

Returns:
predictions: A list of pandas dataframes with bounding boxes, labels, and scores for each image in the batch.
"""
# Preprocess images if required
if preprocess:
images = preprocess_images(images)

# Use predict_step for efficient batch processing
predictions = self.predict_step(images)

return predictions
```

### Benefits:
- Streamlines batch prediction.
- Reduces the learning curve for new users.
- Ensures consistent API behavior.

---

## Next Steps

1. **Documentation:** Clearly document the current behavior of `predict_step` for advanced users.
2. **New Functionality:** Implement a `predict_batch` function to simplify batch processing workflows.
3. **Examples and Tutorials:** Add examples demonstrating batch predictions with and without the proposed `predict_batch` function.

By documenting and enhancing this functionality, DeepForest can provide a more intuitive and efficient experience for users handling large datasets.
36 changes: 36 additions & 0 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,43 @@ def predict_step(self, batch, batch_idx):
for result in batch_results:
boxes = visualize.format_boxes(result)
results.append(boxes)
return results

def predict_batch(self, images, preprocess_fn=None):
"""Predict a batch of images with the deepforest model.

Args:
images (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W) or (B, H, W, C).
preprocess_fn (callable, optional): A function to preprocess images before prediction.
If None, assumes images are preprocessed.

Returns:
List[pd.DataFrame]: A list of dataframes with predictions for each image.
"""

self.model.eval()

#conver to tensor if input is array
if isinstance(images, np.ndarray):
images = torch.tensor(images, device=self.device)

#check input format
if images.dim() == 4 and images.shape[-1] == 3:
#Convert channels_last (B, H, W, C) to channels_first (B, C, H, W)
images = images.permute(0, 3, 1, 2)

#appy preprocessing if available
if preprocess_fn:
images = preprocess_fn(images)

#using Pytorch Ligthning's predict_step
with torch.no_grad():
predictions = []
for idx, image in enumerate(images):
predictions = self.predict_step(image.unsqueeze(0), idx)
predictions.extend(predictions)
#convert predictions to dataframes
results = [pd.DataFrame(pred) for pred in predictions if pred is not None]
return results

def configure_optimizers(self):
Expand Down
Empty file added test.py
Empty file.
115 changes: 115 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader


from PIL import Image

Expand Down Expand Up @@ -674,3 +676,116 @@ def test_predict_tile_with_crop_model_empty():

# Assert the result
assert result is None


# @pytest.mark.parametrize("batch_size", [1, 4, 8])
# def test_batch_prediction(m, batch_size, raster_path):
# """
# Test batch prediction using a DataLoader.
# """
# # Prepare input data
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=batch_size)

# # Perform prediction
# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# # Check results
# assert len(predictions) == len(dl)
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }

# @pytest.mark.parametrize("batch_size", [1, 4])
# def test_batch_training(m, batch_size, tmpdir):
# """
# Test batch training with a DataLoader.
# """
# # Generate synthetic training data
# csv_file = get_data("example.csv")
# root_dir = os.path.dirname(csv_file)
# train_ds = m.load_dataset(csv_file, root_dir=root_dir)
# train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

# # Configure the model and trainer
# m.config["batch_size"] = batch_size
# m.create_trainer()
# trainer = m.trainer

# # Train the model
# trainer.fit(m, train_dl)

# # Assertions
# assert trainer.current_epoch == 1
# assert trainer.batch_size == batch_size

# @pytest.mark.parametrize("batch_size", [2, 4])
# def test_batch_data_augmentation(m, batch_size, raster_path):
# """
# Test batch prediction with data augmentation.
# """
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100, augment=True)
# dl = DataLoader(ds, batch_size=batch_size)

# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# assert len(predictions) == len(dl)
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }

# def test_batch_inference_consistency(m, raster_path):
# """
# Test that batch inference produces consistent results with single image predictions.
# """
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=4)

# batch_predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# batch_predictions.append(prediction)

# single_predictions = []
# for image in ds:
# prediction = m.predict_image(image=image)
# single_predictions.append(prediction)

# batch_df = pd.concat(batch_predictions, ignore_index=True)
# single_df = pd.concat(single_predictions, ignore_index=True)

# pd.testing.assert_frame_equal(batch_df, single_df)

# def test_large_batch_handling(m, raster_path):
# """
# Test model's ability to handle large batch sizes.
# """
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=16)

# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# assert len(predictions) > 0
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }
# assert not batch_pred.empty