Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix plot predictions2 #496

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions deepforest/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import os
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import numpy as np
import pandas.api.types as ptypes
Expand Down Expand Up @@ -71,8 +69,9 @@ def plot_prediction_and_targets(image, predictions, targets, image_name, savedir
return figure_path


def plot_prediction_dataframe(df, root_dir, ground_truth=None, savedir=None):
"""For each row in dataframe, call plot predictions. For multi-class labels, boxes will be colored by labels. Ground truth boxes will all be same color, regardless of class.
def plot_prediction_dataframe(df, root_dir, savedir, ground_truth=None):
"""For each row in dataframe, call plot predictions and save plot files to disk.
For multi-class labels, boxes will be colored by labels. Ground truth boxes will all be same color, regardless of class.
Args:
df: a pandas dataframe with image_path, xmin, xmax, ymin, ymax and label columns. The image_path column should be the relative path from root_dir, not the full path.
root_dir: relative dir to look for image names from df.image_path
Expand All @@ -90,10 +89,9 @@ def plot_prediction_dataframe(df, root_dir, ground_truth=None, savedir=None):
annotations = ground_truth[ground_truth.image_path == name]
image = plot_predictions(image, annotations)

if savedir:
figure_name = "{}/{}.png".format(savedir, os.path.splitext(name)[0])
written_figures.append(figure_name)
cv2.imwrite(figure_name, image)
figure_name = "{}/{}.png".format(savedir, os.path.splitext(name)[0])
written_figures.append(figure_name)
cv2.imwrite(figure_name, image)

return written_figures

Expand Down
5 changes: 2 additions & 3 deletions tests/test_IoU.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from deepforest import visualize

import os
import pytest
import shapely
import geopandas as gpd
import pandas as pd

def test_compute_IoU(download_release):
def test_compute_IoU(download_release, tmpdir):
m = main.deepforest()
m.use_release(check_release=False)
csv_file = get_data("OSBS_029.csv")
Expand All @@ -26,7 +25,7 @@ def test_compute_IoU(download_release):

ground_truth.label = 0
predictions.label = 0
visualize.plot_prediction_dataframe(df=predictions, ground_truth=ground_truth, root_dir=os.path.dirname(csv_file))
visualize.plot_prediction_dataframe(df=predictions, ground_truth=ground_truth, root_dir=os.path.dirname(csv_file), savedir=tmpdir)

result = IoU.compute_IoU(ground_truth, predictions)
assert result.shape[0] == ground_truth.shape[0]
Expand Down
Loading