Skip to content

Commit

Permalink
silence predict tile stdout if needed, see Lightning-AI/pytorch-light…
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Dec 13, 2024
1 parent b578637 commit f4376a7
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 9 deletions.
14 changes: 10 additions & 4 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,8 @@ def predict_tile(self,
thickness=1,
crop_model=None,
crop_transform=None,
crop_augment=False):
crop_augment=False,
verbose=True):
"""For images too large to input into the model, predict_tile cuts the
image into overlapping windows, predicts trees on each window and
reassambles into a single array.
Expand All @@ -498,6 +499,7 @@ def predict_tile(self,
cropModel: a deepforest.model.CropModel object to predict on crops
crop_transform: a torchvision.transforms object to apply to crops
crop_augment: a boolean to apply augmentations to crops
verbose: a boolean to print the number of predictions in overlapping windows
(deprecated) return_plot: return a plot of the image with predictions overlaid
(deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
(deprecated) thickness: thickness of the rectangle border line in px
Expand Down Expand Up @@ -529,7 +531,7 @@ def predict_tile(self,
warnings.warn(
"More than one GPU detected. Using only the first GPU for predict_tile.")
self.config["devices"] = 1
self.create_trainer()
self.create_trainer(enable_progress_bar=verbose)

if (raster_path is None) and (image is None):
raise ValueError(
Expand Down Expand Up @@ -559,7 +561,10 @@ def predict_tile(self,
ds = dataset.RasterDataset(raster_path=raster_path,
patch_overlap=patch_overlap,
patch_size=patch_size)


if not verbose:
self.create_trainer(enable_progress_bar=False)

batched_results = self.trainer.predict(self, self.predict_dataloader(ds))

# Flatten list from batched prediction
Expand All @@ -573,7 +578,8 @@ def predict_tile(self,
ds.windows,
sigma=sigma,
thresh=thresh,
iou_threshold=iou_threshold)
iou_threshold=iou_threshold,
verbose=verbose)
results["label"] = results.label.apply(
lambda x: self.numeric_to_label_dict[x])
if raster_path:
Expand Down
12 changes: 7 additions & 5 deletions src/deepforest/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _predict_image_(model,
return df


def mosiac(boxes, windows, sigma=0.5, thresh=0.001, iou_threshold=0.1):
def mosiac(boxes, windows, sigma=0.5, thresh=0.001, iou_threshold=0.1, verbose=True):
# transform the coordinates to original system
for index, _ in enumerate(boxes):
xmin, ymin, xmax, ymax = windows[index].getRect()
Expand All @@ -71,9 +71,10 @@ def mosiac(boxes, windows, sigma=0.5, thresh=0.001, iou_threshold=0.1):
boxes[index].ymax += ymin

predicted_boxes = pd.concat(boxes)
print(
f"{predicted_boxes.shape[0]} predictions in overlapping windows, applying non-max supression"
)
if verbose:
print(
f"{predicted_boxes.shape[0]} predictions in overlapping windows, applying non-max supression"
)
# move prediciton to tensor
boxes = torch.tensor(predicted_boxes[["xmin", "ymin", "xmax", "ymax"]].values,
dtype=torch.float32)
Expand All @@ -98,7 +99,8 @@ def mosiac(boxes, windows, sigma=0.5, thresh=0.001, iou_threshold=0.1):
mosaic_df = pd.DataFrame(image_detections,
columns=["xmin", "ymin", "xmax", "ymax", "label", "score"])

print(f"{mosaic_df.shape[0]} predictions kept after non-max suppression")
if verbose:
print(f"{mosaic_df.shape[0]} predictions kept after non-max suppression")

return mosaic_df

Expand Down
19 changes: 19 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,25 @@ def test_predict_tile_no_mosaic(m, raster_path):
assert len(prediction[0]) == 2
assert prediction[0][1].shape == (300, 300, 3)

@pytest.mark.parametrize("verbose", [True, False])
def test_predict_tile_verbose(m, raster_path, capsys, verbose):
m.config["train"]["fast_dev_run"] = False
m.create_trainer()
prediction = m.predict_tile(raster_path=raster_path,
patch_size=300,
patch_overlap=0,
mosaic=True,
verbose=verbose)

# Check no output was printed
if not verbose:
captured = capsys.readouterr()
print(captured.out)
assert not captured.out
else:
captured = capsys.readouterr()
print(captured.out)
assert captured.out

def test_evaluate(m, tmpdir):
csv_file = get_data("OSBS_029.csv")
Expand Down

0 comments on commit f4376a7

Please sign in to comment.