Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TEST]: Pr813 #815

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,13 +410,14 @@ def predict_image(self,
result["label"] = result.label.apply(
lambda x: self.numeric_to_label_dict[x])

result = utilities.read_file(result)
if path is None:
result = utilities.read_file(result)
warnings.warn(
"An image was passed directly to predict_image, the root_dir will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
"An image was passed directly to predict_image, the result.root_dir attribute will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
)
else:
result.root_dir = os.path.dirname(path)
root_dir = getattr(result, 'root_dir', None) or os.path.dirname(path)
results = utilities.read_file(result, root_dir=root_dir)

return result

Expand Down Expand Up @@ -556,6 +557,7 @@ def predict_tile(self,
lambda x: self.numeric_to_label_dict[x])
if raster_path:
results["image_path"] = os.path.basename(raster_path)

if return_plot:
# Add deprecated warning
warnings.warn("return_plot is deprecated and will be removed in 2.0. "
Expand Down Expand Up @@ -591,18 +593,19 @@ def predict_tile(self,
trainer=self.trainer,
transform=crop_transform,
augment=crop_augment)

if results.empty:
warnings.warn("No predictions made, returning None")
return None

results = utilities.read_file(results)

if raster_path is None:
warnings.warn(
"An image was passed directly to predict_tile, the root_dir will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
"An image was passed directly to predict_tile, the results.root_dir attribute will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
)
results = utilities.read_file(results)
else:
results.root_dir = os.path.dirname(raster_path)
root_dir = getattr(results, 'root_dir', None) or os.path.dirname(raster_path)
results = utilities.read_file(results, root_dir=root_dir)

return results

Expand Down
9 changes: 3 additions & 6 deletions deepforest/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def _predict_image_(model,

df = visualize.format_boxes(prediction[0])
df = across_class_nms(df, iou_threshold=nms_thresh)

if path:
df["image_path"] = os.path.basename(path)
# df["root_dir"] = os.path.basename(path)
if return_plot:
# Bring to gpu
image = image.cpu()
Expand All @@ -53,12 +55,7 @@ def _predict_image_(model,
image = image[:, :, ::-1] * 255
image = image.astype("uint8")
image = visualize.plot_predictions(image, df, color=color, thickness=thickness)

return image
else:
if path:
df["image_path"] = os.path.basename(path)

return df


Expand Down
11 changes: 9 additions & 2 deletions deepforest/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,10 @@ def read_file(input, root_dir=None):
This is the main entry point for reading annotations into deepforest.
Args:
input: a path to a file or a pandas dataframe
root_dir: Optional directory to prepend to the image_path column
root_dir (str): location of the image files, if not in the same directory as the annotations file
Returns:
df: a geopandas dataframe with the properly formatted geometry column
df.root_dir: the root directory of the image files
"""
# read file
if isinstance(input, str):
Expand Down Expand Up @@ -352,7 +353,13 @@ def read_file(input, root_dir=None):
df = gpd.GeoDataFrame(df, geometry='geometry')

# If root_dir is specified, add as attribute
df.root_dir = root_dir
if root_dir is not None:
df.root_dir = root_dir
else:
try:
df.root_dir = os.path.dirname(input)
except TypeError:
warnings.warn("root_dir argument for the location of images should be specified if input is not a path, returning without results.root_dir attribute", UserWarning)

return df

Expand Down
132 changes: 97 additions & 35 deletions deepforest/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,98 @@ def convert_to_sv_format(df, width=None, height=None):

return detections

def __check_color__(color, num_labels):
if isinstance(color, list) and len(color) == 3:
if num_labels > 1:
warnings.warn(
"""Multiple labels detected in the results and results_color argument provides a single color.
Each label will be plotted with a different color using a built-in color ramp.
If you want to customize colors with multiple labels pass a supervision.ColorPalette object to results_color with the appropriate number of labels"""
)
return sv.ColorPalette.from_matplotlib('viridis', num_labels)
else:
return sv.Color(color[0], color[1], color[2])
elif isinstance(color, sv.draw.color.ColorPalette):
if num_labels > len(color.colors):
warnings.warn(
"""The number of colors provided in results_color does not match the number number of labels.
Replacing the provided color palette with a built-in built-in color palette.
To use a custom color palette make sure the number of values matches the number of labels"""
)
return sv.ColorPalette.from_matplotlib('viridis', num_labels)
else:
return color
elif isinstance(color, list):
raise ValueError(
"results_color must be either a 3 item list containing RGB values or an sv.ColorPalette instance"
)
else:
raise TypeError(
"results_color must be either a list of RGB values or an sv.ColorPalette instance"
)

def plot_annotations(annotations,
savedir=None,
height=None,
width=None,
color=[245, 135, 66],
thickness=2,
basename=None,
root_dir=None,
radius=3):
"""Plot the prediction results.

Args:
annotations: a pandas dataframe with prediction results
savedir: optional path to save the figure. If None (default), the figure will be interactively plotted.
height: height of the image in pixels. Required if the geometry type is 'polygon'.
width: width of the image in pixels. Required if the geometry type is 'polygon'.
results_color (list or sv.ColorPalette): color of the results annotations as a tuple of RGB color (if a single color), e.g. orange annotations is [245, 135, 66], or an supervision.ColorPalette if multiple labels and specifying colors for each label
thickness: thickness of the rectangle border line in px
basename: optional basename for the saved figure. If None (default), the basename will be extracted from the image path.
root_dir: optional path to the root directory of the images. If None (default), the root directory will be extracted from the annotations dataframe.root_dir attribute.
radius: radius of the points in px
Returns:
None
"""
# Convert colors, check for multi-class labels
num_labels = len(annotations.label.unique())
annotation_color = __check_color__(color, num_labels)

# Read images
if not hasattr(annotations, 'root_dir'):
if root_dir is None:
raise ValueError("The 'annotations.root_dir' attribute does not exist. Please specify the 'root_dir' argument.")
else:
root_dir = root_dir
else:
root_dir = annotations.root_dir

image_path = os.path.join(root_dir, annotations.image_path.unique()[0])
image = np.array(Image.open(image_path))

# Plot the results following https://supervision.roboflow.com/annotators/
fig, ax = plt.subplots()
annotated_scene = _plot_image_with_geometry(df=annotations,
image=image,
sv_color=annotation_color,
height=height,
width=width,
thickness=thickness,
radius=radius)

if savedir:
if basename is None:
basename = os.path.splitext(os.path.basename(
annotations.image_path.unique()[0]))[0]
image_name = "{}.png".format(basename)
image_path = os.path.join(savedir, image_name)
cv2.imwrite(image_path, annotated_scene)
else:
# Display the image using Matplotlib
plt.imshow(annotated_scene)
plt.axis('off') # Hide axes for a cleaner look
plt.show()

def plot_results(results,
ground_truth=None,
Expand Down Expand Up @@ -386,38 +478,8 @@ def plot_results(results,
"""
# Convert colors, check for multi-class labels
num_labels = len(results.label.unique())
if isinstance(results_color, list) and len(results_color) == 3:
if num_labels > 1:
warnings.warn(
"""Multiple labels detected in the results and results_color argument provides a single color.
Each label will be plotted with a different color using a built-in color ramp.
If you want to customize colors with multiple labels pass a supervision.ColorPalette object to results_color with the appropriate number of labels"""
)
results_color_sv = sv.ColorPalette.from_matplotlib('viridis', num_labels)
else:
results_color_sv = sv.Color(results_color[0], results_color[1],
results_color[2])
elif isinstance(results_color, sv.draw.color.ColorPalette):
if num_labels > len(results_color.colors):
warnings.warn(
"""The number of colors provided in results_color does not match the number number of labels.
Replacing the provided color palette with a built-in built-in color palette.
To use a custom color palette make sure the number of values matches the number of labels"""
)
results_color_sv = sv.ColorPalette.from_matplotlib('viridis', num_labels)
else:
results_color_sv = results_color
elif isinstance(results_color, list):
raise ValueError(
"results_color must be either a 3 item list containing RGB values or an sv.ColorPalette instance"
)
else:
raise TypeError(
"results_color must be either a list of RGB values or an sv.ColorPalette instance"
)

ground_truth_color_sv = sv.Color(ground_truth_color[0], ground_truth_color[1],
ground_truth_color[2])
results_color_sv = __check_color__(results_color, num_labels)
ground_truth_color_sv = __check_color__(ground_truth_color, num_labels)

# Read images
root_dir = results.root_dir
Expand All @@ -426,7 +488,7 @@ def plot_results(results,

# Plot the results following https://supervision.roboflow.com/annotators/
fig, ax = plt.subplots()
annotated_scene = _plot_image_with_results(df=results,
annotated_scene = _plot_image_with_geometry(df=results,
image=image,
sv_color=results_color_sv,
height=height,
Expand All @@ -436,7 +498,7 @@ def plot_results(results,

if ground_truth is not None:
# Plot the ground truth annotations
annotated_scene = _plot_image_with_results(df=ground_truth,
annotated_scene = _plot_image_with_geometry(df=ground_truth,
image=annotated_scene,
sv_color=ground_truth_color_sv,
height=height,
Expand All @@ -458,7 +520,7 @@ def plot_results(results,
plt.show()


def _plot_image_with_results(df,
def _plot_image_with_geometry(df,
image,
sv_color,
thickness=1,
Expand Down
17 changes: 15 additions & 2 deletions docs/user_guide/01_Reading_Writing_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ The most time-consuming part of many open-source projects is getting the data in

## Annotation Geometries and Coordinate Systems

DeepForest was originally designed for bounding box annotations. As of DeepForest 1.4.0, point and polygon annotations are also supported. There are two ways to format annotations, depending on the annotation platform you are using. `read_file` can read points, polygons, and boxes, in both image coordinate systems (relative to image origin at top-left 0,0) as well as projected coordinates on the Earth's surface.
DeepForest was originally designed for bounding box annotations. As of DeepForest 1.4.0, point and polygon annotations are also supported. There are two ways to format annotations, depending on the annotation platform you are using. `read_file` can read points, polygons, and boxes, in both image coordinate systems (relative to image origin at top-left 0,0) as well as projected coordinates on the Earth's surface. The `read_file` method also appends the location of the current image directory as an attribute. To access this attribute use

```
filename = get_data("OSBS_029.csv")

```

**Note:** For CSV files, coordinates are expected to be in the image coordinate system, not projected coordinates (such as latitude/longitude or UTM).

Expand All @@ -28,7 +33,7 @@ OSBS_029.tif,364,204,400,246,Tree

```python
filename = get_data("OSBS_029.csv")
utilities.read_file(filename)
df = utilities.read_file(filename)
```

Example output:
Expand All @@ -44,6 +49,14 @@ Example output:

**Note:** To maintain continuity with versions < 1.4.0, the function for boxes continues to output `xmin`, `ymin`, `xmax`, and `ymax` columns as individual columns as well.

The location of these image files is saved in the root_dir attribute

```
df.root_dir
'/Users/benweinstein/Documents/DeepForest/deepforest/data'
```


#### Shapefiles

Geographic data can also be saved as shapefiles with projected coordinates.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_predict_image_fromarray(m):

assert isinstance(prediction, pd.DataFrame)
assert set(prediction.columns) == {"xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"}

assert not hasattr(prediction, 'root_dir')

def test_predict_big_file(m, tmpdir):
m.config["train"]["fast_dev_run"] = False
Expand Down
Loading
Loading