Skip to content

Commit

Permalink
Batch Predicitions (#856)
Browse files Browse the repository at this point in the history
Add a batch predict function to make dataloaders more easy to integrate, closes #849
  • Loading branch information
RohitP2005 authored Jan 7, 2025
1 parent d6c8545 commit 2cf4198
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 3 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
language: system
types: [python]
pass_filenames: false
stages: [commit]
stages: [pre-commit]
- repo: local
hooks:
- id: docformatter
Expand All @@ -18,5 +18,5 @@ repos:
types: [python]
args: ['--in-place', '--recursive','src/deepforest/']
pass_filenames: false
stages: [commit]
stages: [pre-commit]

4 changes: 3 additions & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ build:
os: ubuntu-22.04
tools:
python: "3.12"

python:
install:
- requirements: dev_requirements.txt
- method: pip
path: .

submodules:
include: all
include: []

36 changes: 36 additions & 0 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,43 @@ def predict_step(self, batch, batch_idx):
for result in batch_results:
boxes = visualize.format_boxes(result)
results.append(boxes)
return results

def predict_batch(self, images, preprocess_fn=None):
"""Predict a batch of images with the deepforest model.
Args:
images (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W) or (B, H, W, C).
preprocess_fn (callable, optional): A function to preprocess images before prediction.
If None, assumes images are preprocessed.
Returns:
List[pd.DataFrame]: A list of dataframes with predictions for each image.
"""

self.model.eval()

#conver to tensor if input is array
if isinstance(images, np.ndarray):
images = torch.tensor(images, device=self.device)

#check input format
if images.dim() == 4 and images.shape[-1] == 3:
#Convert channels_last (B, H, W, C) to channels_first (B, C, H, W)
images = images.permute(0, 3, 1, 2)

#appy preprocessing if available
if preprocess_fn:
images = preprocess_fn(images)

#using Pytorch Ligthning's predict_step
with torch.no_grad():
predictions = []
for idx, image in enumerate(images):
predictions = self.predict_step(image.unsqueeze(0), idx)
predictions.extend(predictions)
#convert predictions to dataframes
results = [pd.DataFrame(pred) for pred in predictions if pred is not None]
return results

def configure_optimizers(self):
Expand Down
Empty file added test.py
Empty file.
105 changes: 105 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader


from PIL import Image

Expand Down Expand Up @@ -674,3 +676,106 @@ def test_predict_tile_with_crop_model_empty():

# Assert the result
assert result is None


# @pytest.mark.parametrize("batch_size", [1, 4, 8])
# def test_batch_prediction(m, batch_size, raster_path):
#
# # Prepare input data
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=batch_size)

# # Perform prediction
# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# # Check results
# assert len(predictions) == len(dl)
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }

# @pytest.mark.parametrize("batch_size", [1, 4])
# def test_batch_training(m, batch_size, tmpdir):
#
# # Generate synthetic training data
# csv_file = get_data("example.csv")
# root_dir = os.path.dirname(csv_file)
# train_ds = m.load_dataset(csv_file, root_dir=root_dir)
# train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

# # Configure the model and trainer
# m.config["batch_size"] = batch_size
# m.create_trainer()
# trainer = m.trainer

# # Train the model
# trainer.fit(m, train_dl)

# # Assertions
# assert trainer.current_epoch == 1
# assert trainer.batch_size == batch_size

# @pytest.mark.parametrize("batch_size", [2, 4])
# def test_batch_data_augmentation(m, batch_size, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100, augment=True)
# dl = DataLoader(ds, batch_size=batch_size)

# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# assert len(predictions) == len(dl)
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }

# def test_batch_inference_consistency(m, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=4)

# batch_predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# batch_predictions.append(prediction)

# single_predictions = []
# for image in ds:
# prediction = m.predict_image(image=image)
# single_predictions.append(prediction)

# batch_df = pd.concat(batch_predictions, ignore_index=True)
# single_df = pd.concat(single_predictions, ignore_index=True)

# pd.testing.assert_frame_equal(batch_df, single_df)

# def test_large_batch_handling(m, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=16)

# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# assert len(predictions) > 0
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }
# assert not batch_pred.empty

0 comments on commit 2cf4198

Please sign in to comment.