diff --git a/deepforest/visualize.py b/deepforest/visualize.py index 8cac36f7..6dabee1e 100644 --- a/deepforest/visualize.py +++ b/deepforest/visualize.py @@ -2,8 +2,6 @@ import os import pandas as pd import matplotlib -import matplotlib.pyplot as plt -import matplotlib.patches as patches from PIL import Image import numpy as np import pandas.api.types as ptypes @@ -71,8 +69,9 @@ def plot_prediction_and_targets(image, predictions, targets, image_name, savedir return figure_path -def plot_prediction_dataframe(df, root_dir, ground_truth=None, savedir=None): - """For each row in dataframe, call plot predictions. For multi-class labels, boxes will be colored by labels. Ground truth boxes will all be same color, regardless of class. +def plot_prediction_dataframe(df, root_dir, savedir, ground_truth=None): + """For each row in dataframe, call plot predictions and save plot files to disk. + For multi-class labels, boxes will be colored by labels. Ground truth boxes will all be same color, regardless of class. Args: df: a pandas dataframe with image_path, xmin, xmax, ymin, ymax and label columns. The image_path column should be the relative path from root_dir, not the full path. root_dir: relative dir to look for image names from df.image_path @@ -90,10 +89,9 @@ def plot_prediction_dataframe(df, root_dir, ground_truth=None, savedir=None): annotations = ground_truth[ground_truth.image_path == name] image = plot_predictions(image, annotations) - if savedir: - figure_name = "{}/{}.png".format(savedir, os.path.splitext(name)[0]) - written_figures.append(figure_name) - cv2.imwrite(figure_name, image) + figure_name = "{}/{}.png".format(savedir, os.path.splitext(name)[0]) + written_figures.append(figure_name) + cv2.imwrite(figure_name, image) return written_figures diff --git a/tests/test_IoU.py b/tests/test_IoU.py index bbc3d75f..92b38355 100644 --- a/tests/test_IoU.py +++ b/tests/test_IoU.py @@ -6,12 +6,11 @@ from deepforest import visualize import os -import pytest import shapely import geopandas as gpd import pandas as pd -def test_compute_IoU(download_release): +def test_compute_IoU(download_release, tmpdir): m = main.deepforest() m.use_release(check_release=False) csv_file = get_data("OSBS_029.csv") @@ -26,7 +25,7 @@ def test_compute_IoU(download_release): ground_truth.label = 0 predictions.label = 0 - visualize.plot_prediction_dataframe(df=predictions, ground_truth=ground_truth, root_dir=os.path.dirname(csv_file)) + visualize.plot_prediction_dataframe(df=predictions, ground_truth=ground_truth, root_dir=os.path.dirname(csv_file), savedir=tmpdir) result = IoU.compute_IoU(ground_truth, predictions) assert result.shape[0] == ground_truth.shape[0]