diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 320beca0..1ee7511a 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -477,7 +477,8 @@ def predict_tile(self, thickness=1, crop_model=None, crop_transform=None, - crop_augment=False): + crop_augment=False, + verbose=True): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and reassambles into a single array. @@ -498,6 +499,7 @@ def predict_tile(self, cropModel: a deepforest.model.CropModel object to predict on crops crop_transform: a torchvision.transforms object to apply to crops crop_augment: a boolean to apply augmentations to crops + verbose: a boolean to print the number of predictions in overlapping windows (deprecated) return_plot: return a plot of the image with predictions overlaid (deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255) (deprecated) thickness: thickness of the rectangle border line in px @@ -529,7 +531,7 @@ def predict_tile(self, warnings.warn( "More than one GPU detected. Using only the first GPU for predict_tile.") self.config["devices"] = 1 - self.create_trainer() + self.create_trainer(enable_progress_bar=verbose) if (raster_path is None) and (image is None): raise ValueError( @@ -559,7 +561,10 @@ def predict_tile(self, ds = dataset.RasterDataset(raster_path=raster_path, patch_overlap=patch_overlap, patch_size=patch_size) - + + if not verbose: + self.create_trainer(enable_progress_bar=False) + batched_results = self.trainer.predict(self, self.predict_dataloader(ds)) # Flatten list from batched prediction @@ -573,7 +578,8 @@ def predict_tile(self, ds.windows, sigma=sigma, thresh=thresh, - iou_threshold=iou_threshold) + iou_threshold=iou_threshold, + verbose=verbose) results["label"] = results.label.apply( lambda x: self.numeric_to_label_dict[x]) if raster_path: diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index 03fcb11f..c9059ea4 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -61,7 +61,7 @@ def _predict_image_(model, return df -def mosiac(boxes, windows, sigma=0.5, thresh=0.001, iou_threshold=0.1): +def mosiac(boxes, windows, sigma=0.5, thresh=0.001, iou_threshold=0.1, verbose=True): # transform the coordinates to original system for index, _ in enumerate(boxes): xmin, ymin, xmax, ymax = windows[index].getRect() @@ -71,9 +71,10 @@ def mosiac(boxes, windows, sigma=0.5, thresh=0.001, iou_threshold=0.1): boxes[index].ymax += ymin predicted_boxes = pd.concat(boxes) - print( - f"{predicted_boxes.shape[0]} predictions in overlapping windows, applying non-max supression" - ) + if verbose: + print( + f"{predicted_boxes.shape[0]} predictions in overlapping windows, applying non-max supression" + ) # move prediciton to tensor boxes = torch.tensor(predicted_boxes[["xmin", "ymin", "xmax", "ymax"]].values, dtype=torch.float32) @@ -98,7 +99,8 @@ def mosiac(boxes, windows, sigma=0.5, thresh=0.001, iou_threshold=0.1): mosaic_df = pd.DataFrame(image_detections, columns=["xmin", "ymin", "xmax", "ymax", "label", "score"]) - print(f"{mosaic_df.shape[0]} predictions kept after non-max suppression") + if verbose: + print(f"{mosaic_df.shape[0]} predictions kept after non-max suppression") return mosaic_df diff --git a/tests/test_main.py b/tests/test_main.py index a8c3543e..86d9b125 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -322,6 +322,25 @@ def test_predict_tile_no_mosaic(m, raster_path): assert len(prediction[0]) == 2 assert prediction[0][1].shape == (300, 300, 3) +@pytest.mark.parametrize("verbose", [True, False]) +def test_predict_tile_verbose(m, raster_path, capsys, verbose): + m.config["train"]["fast_dev_run"] = False + m.create_trainer() + prediction = m.predict_tile(raster_path=raster_path, + patch_size=300, + patch_overlap=0, + mosaic=True, + verbose=verbose) + + # Check no output was printed + if not verbose: + captured = capsys.readouterr() + print(captured.out) + assert not captured.out + else: + captured = capsys.readouterr() + print(captured.out) + assert captured.out def test_evaluate(m, tmpdir): csv_file = get_data("OSBS_029.csv")