Skip to content

Commit

Permalink
Merge pull request #496 from weecology/bug_fix_plot_predictions2
Browse files Browse the repository at this point in the history
Bug fix plot predictions2
  • Loading branch information
ethanwhite authored Oct 3, 2023
2 parents fd237cf + 56d204e commit 1d0bda6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
14 changes: 6 additions & 8 deletions deepforest/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import os
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import numpy as np
import pandas.api.types as ptypes
Expand Down Expand Up @@ -71,8 +69,9 @@ def plot_prediction_and_targets(image, predictions, targets, image_name, savedir
return figure_path


def plot_prediction_dataframe(df, root_dir, ground_truth=None, savedir=None):
"""For each row in dataframe, call plot predictions. For multi-class labels, boxes will be colored by labels. Ground truth boxes will all be same color, regardless of class.
def plot_prediction_dataframe(df, root_dir, savedir, ground_truth=None):
"""For each row in dataframe, call plot predictions and save plot files to disk.
For multi-class labels, boxes will be colored by labels. Ground truth boxes will all be same color, regardless of class.
Args:
df: a pandas dataframe with image_path, xmin, xmax, ymin, ymax and label columns. The image_path column should be the relative path from root_dir, not the full path.
root_dir: relative dir to look for image names from df.image_path
Expand All @@ -90,10 +89,9 @@ def plot_prediction_dataframe(df, root_dir, ground_truth=None, savedir=None):
annotations = ground_truth[ground_truth.image_path == name]
image = plot_predictions(image, annotations)

if savedir:
figure_name = "{}/{}.png".format(savedir, os.path.splitext(name)[0])
written_figures.append(figure_name)
cv2.imwrite(figure_name, image)
figure_name = "{}/{}.png".format(savedir, os.path.splitext(name)[0])
written_figures.append(figure_name)
cv2.imwrite(figure_name, image)

return written_figures

Expand Down
5 changes: 2 additions & 3 deletions tests/test_IoU.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from deepforest import visualize

import os
import pytest
import shapely
import geopandas as gpd
import pandas as pd

def test_compute_IoU(download_release):
def test_compute_IoU(download_release, tmpdir):
m = main.deepforest()
m.use_release(check_release=False)
csv_file = get_data("OSBS_029.csv")
Expand All @@ -26,7 +25,7 @@ def test_compute_IoU(download_release):

ground_truth.label = 0
predictions.label = 0
visualize.plot_prediction_dataframe(df=predictions, ground_truth=ground_truth, root_dir=os.path.dirname(csv_file))
visualize.plot_prediction_dataframe(df=predictions, ground_truth=ground_truth, root_dir=os.path.dirname(csv_file), savedir=tmpdir)

result = IoU.compute_IoU(ground_truth, predictions)
assert result.shape[0] == ground_truth.shape[0]
Expand Down

0 comments on commit 1d0bda6

Please sign in to comment.