Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pathlib.Path objects as input #469

Open
wants to merge 7 commits into
base: 0.5.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions solaris/data/coco.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
from pathlib import Path

import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -47,7 +48,7 @@ def geojson2coco(

Arguments
---------
image_src : :class:`str` or :class:`list` or :class:`dict`
image_src : :class:`str` or :class:`pathlib.Path` or :class:`list` or :class:`dict`
Source image(s) to use in the dataset. This can be::

1. a string path to an image,
Expand Down Expand Up @@ -149,8 +150,8 @@ def geojson2coco(
logger.setLevel(_get_logging_level(int(verbose)))
logger.debug("Preparing image filename: image ID dict.")
# pdb.set_trace()
if isinstance(image_src, str):
if image_src.endswith("json"):
if isinstance(image_src, (str, Path)):
if str(image_src).endswith("json"):
logger.debug("COCO json provided. Extracting fname:id dict.")
with open(image_src, "r") as f:
image_ref = json.load(f)
Expand Down Expand Up @@ -599,13 +600,13 @@ def _get_fname_list(p, recursive=False, extension=".tif"):
"""Get a list of filenames from p, which can be a dir, fname, or list."""
if isinstance(p, list):
return p
elif isinstance(p, str):
if os.path.isdir(p):
elif isinstance(p, (str, Path)):
if Path(p).is_dir():
return get_files_recursively(
p, traverse_subdirs=recursive, extension=extension
)
elif os.path.isfile(p):
return [p]
elif Path(p).is_file():
return [str(p)]
else:
raise ValueError("If a string is provided, it must be a valid" " path.")
else:
Expand Down
34 changes: 16 additions & 18 deletions solaris/eval/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from pathlib import Path

import geopandas as gpd
import pandas as pd
import shapely.wkt
from fiona._err import CPLE_OpenFailedError
from fiona.errors import DriverError
from solaris.utils.core import _check_gdf_load
from tqdm.auto import tqdm

from . import iou
Expand All @@ -29,28 +31,29 @@ class Evaluator:

Arguments
---------
ground_truth_vector_file : str
ground_truth_vector_file : `str` or :class:`pathlib.Path`
Path to .geojson file for ground truth.

"""

def __init__(self, ground_truth_vector_file):
# Load Ground Truth : Ground Truth should be in geojson or shape file
try:
if ground_truth_vector_file.lower().endswith("json"):
self.load_truth(ground_truth_vector_file)
elif ground_truth_vector_file.lower().endswith("csv"):
self.load_truth(ground_truth_vector_file, truthCSV=True)
self.ground_truth_fname = ground_truth_vector_file
except AttributeError: # handles passing gdf instead of path to file
self.ground_truth_GDF = ground_truth_vector_file
if isinstance(ground_truth_vector_file, (str, Path)):
self.ground_truth_fname = str(ground_truth_vector_file)
else:
self.ground_truth_fname = "GeoDataFrame variable"

if isinstance(ground_truth_vector_file, (str, Path)) and ground_truth_vector_file.lower().endswith("csv"):
self.load_truth(ground_truth_vector_file, truthCSV=True)
else:
self.load_truth(ground_truth_vector_file)
self.ground_truth_sindex = self.ground_truth_GDF.sindex # get sindex
# create deep copy of ground truth file for calculations
self.ground_truth_GDF_Edit = self.ground_truth_GDF.copy(deep=True)
self.proposal_GDF = gpd.GeoDataFrame([]) # initialize proposal GDF

def __repr__(self):

return "Evaluator {}".format(os.path.split(self.ground_truth_fname)[-1])

def get_iou_by_building(self):
Expand Down Expand Up @@ -509,7 +512,7 @@ def load_proposal(

Arguments
---------
proposal_vector_file : str
proposal_vector_file : `str` or :class:`pathlib.Path`
Path to the file containing proposal vector objects. This can be
a .geojson or a .csv.
conf_field_list : list, optional
Expand Down Expand Up @@ -540,7 +543,7 @@ def load_proposal(
"""

# Load Proposal if proposal_vector_file is a path to a file
if os.path.isfile(proposal_vector_file):
if Path(proposal_vector_file).is_file():
# if it's a CSV format, first read into a pd df and then convert
# to gpd gdf by loading in geometries using shapely
if proposalCSV:
Expand Down Expand Up @@ -588,7 +591,7 @@ def load_truth(

Arguments
---------
ground_truth_vector_file : str
ground_truth_vector_file : `str` or :class:`pathlib.Path`
Path to the ground truth vector file. Must be either .geojson or
.csv format.
truthCSV : bool, optional
Expand Down Expand Up @@ -617,12 +620,7 @@ def load_truth(
],
)
else:
try:
self.ground_truth_GDF = gpd.read_file(ground_truth_vector_file)
except (CPLE_OpenFailedError, DriverError): # empty geojson
self.ground_truth_GDF = gpd.GeoDataFrame(
{"sindex": [], "condition": [], "geometry": []}
)
self.ground_truth_GDF = _check_gdf_load(ground_truth_vector_file)
# force calculation of spatialindex
self.ground_truth_sindex = self.ground_truth_GDF.sindex
# create deep copy of ground truth file for calculations
Expand Down
8 changes: 4 additions & 4 deletions solaris/eval/pixel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def f1(
``1``, values < `prop_threshold` will be set to ``0``.
show_plot : bool, optional
Switch to plot the outputs. Defaults to ``False``.
im_file : str, optional
im_file : `str` or :class:`pathlib.Path`, optional
Image file corresponding to the masks. Ignored if
``show_plot == False``. Defaults to ``''``.
show_colorbar : bool, optional
Switch to show colorbar. Ignored if ``show_plot == False``.
Defaults to ``False``.
plot_file : str, optional
plot_file : `str` or :class:`pathlib.Path`, optional
Output file if plotting. Ignored if ``show_plot == False``.
Defaults to ``''``.
dpi : int, optional
Expand Down Expand Up @@ -167,7 +167,7 @@ def f1(
plt.suptitle(title, fontsize=fontsize)

# ground truth
if len(im_file) > 0:
if len(str(im_file)) > 0:
# raw image
ax1.imshow(cv2.imread(im_file, 1))
# ground truth
Expand Down Expand Up @@ -211,7 +211,7 @@ def f1(
# fig.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.subplots_adjust(top=0.8)

if len(plot_file) > 0:
if len(str(plot_file)) > 0:
plt.savefig(plot_file, dpi=dpi)
print("Time to create and save F1 plots:", time.time() - t0, "seconds")

Expand Down
37 changes: 19 additions & 18 deletions solaris/eval/vector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import os
from pathlib import Path

import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -49,9 +50,9 @@ def get_all_objects(
unique classes present in each
Arguments
---------
proposal_polygons_dir : str
proposal_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains any model proposal polygons
gt_polygons_dir : str
gt_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains the ground truth polygons
prediction_cat_attrib : str
The column or attribute within the predictions that specifies
Expand All @@ -71,23 +72,23 @@ def get_all_objects(
A union of the prop_objs and gt_objs lists
"""
objs = []
os.chdir(proposal_polygons_dir)
os.chdir(str(proposal_polygons_dir))
search = "*" + file_format
proposal_geojsons = glob.glob(search)
for geojson in tqdm(proposal_geojsons):
ground_truth_poly = os.path.join(gt_polygons_dir, geojson)
ground_truth_poly = Path(gt_polygons_dir) / geojson
if os.path.exists(ground_truth_poly):
ground_truth_gdf = gpd.read_file(ground_truth_poly)
proposal_gdf = gpd.read_file(geojson)
for index, row in proposal_gdf.iterrows():
objs.append(row[prediction_cat_attrib])
prop_objs = list(set(objs))
os.chdir(gt_polygons_dir)
os.chdir(str(gt_polygons_dir))
search = "*" + file_format
objs = []
gt_geojsons = glob.glob(search)
for geojson in tqdm(gt_geojsons):
proposal_poly = os.path.join(proposal_polygons_dir, geojson)
proposal_poly = Path(proposal_polygons_dir) / geojson
if os.path.exists(proposal_poly):
proposal_gdf = gpd.read_file(proposal_poly)
ground_truth_gdf = gpd.read_file(geojson)
Expand All @@ -114,9 +115,9 @@ def precision_calc(
calculate metric for classes that exist in the ground truth.
Arguments
---------
proposal_polygons_dir : str
proposal_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains any model proposal polygons
gt_polygons_dir : str
gt_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains the ground truth polygons
prediction_cat_attrib : str
The column or attribute within the predictions that specifies
Expand Down Expand Up @@ -148,7 +149,7 @@ def precision_calc(
All confidences for each object for each class
"""
ious = []
os.chdir(proposal_polygons_dir)
os.chdir(str(proposal_polygons_dir))
search = "*" + file_format
proposal_geojsons = glob.glob(search)
iou_holder = []
Expand All @@ -166,7 +167,7 @@ def precision_calc(
confidences.append([])

for geojson in tqdm(proposal_geojsons):
ground_truth_poly = os.path.join(gt_polygons_dir, geojson)
ground_truth_poly = Path(gt_polygons_dir) / geojson
if os.path.exists(ground_truth_poly):
ground_truth_gdf = gpd.read_file(ground_truth_poly)
proposal_gdf = gpd.read_file(geojson)
Expand Down Expand Up @@ -241,9 +242,9 @@ def recall_calc(
calculate metric for classes that exist in the ground truth.
Arguments
---------
proposal_polygons_dir : str
proposal_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains any model proposal polygons
gt_polygons_dir : str
gt_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains the ground truth polygons
prediction_cat_attrib : str
The column or attribute within the predictions that specifies
Expand All @@ -270,7 +271,7 @@ def recall_calc(
The mean recall score of recall_by_class
"""
ious = []
os.chdir(gt_polygons_dir)
os.chdir(str(gt_polygons_dir))
search = "*" + file_format
gt_geojsons = glob.glob(search)
iou_holder = []
Expand All @@ -285,7 +286,7 @@ def recall_calc(
for i in range(len(object_subset)):
iou_holder.append([])
for geojson in tqdm(gt_geojsons):
proposal_poly = os.path.join(proposal_polygons_dir, geojson)
proposal_poly = Path(proposal_polygons_dir) / geojson
if os.path.exists(proposal_poly):
proposal_gdf = gpd.read_file(proposal_poly)
ground_truth_gdf = gpd.read_file(geojson)
Expand Down Expand Up @@ -353,9 +354,9 @@ def mF1(
only calculate metric for classes that exist in the ground truth.
Arguments
---------
proposal_polygons_dir : str
proposal_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains any model proposal polygons
gt_polygons_dir : str
gt_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains the ground truth polygons
prediction_cat_attrib : str
The column or attribute within the predictions that specifies
Expand Down Expand Up @@ -480,9 +481,9 @@ def mAP_score(

Arguments
---------
proposal_polygons_dir : str
proposal_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains any model proposal polygons
gt_polygons_dir : str
gt_polygons_dir : `str` or :class:`pathlib.Path`
The path that contains the ground truth polygons
prediction_cat_attrib : str
The column or attribute within the predictions that specifies
Expand Down
10 changes: 6 additions & 4 deletions solaris/raster/image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import numpy as np
import rasterio

Expand All @@ -9,7 +11,7 @@ def get_geo_transform(raster_src):

Arguments
---------
raster_src : str, :class:`rasterio.DatasetReader`, or `osgeo.gdal.Dataset`
raster_src : str, :class:`pathlib.Path`, :class:`rasterio.DatasetReader`, or `osgeo.gdal.Dataset`
Path to a raster image with georeferencing data to apply to `geom`.
Alternatively, an opened :class:`rasterio.Band` object or
:class:`osgeo.gdal.Dataset` object can be provided. Required if not
Expand All @@ -21,7 +23,7 @@ def get_geo_transform(raster_src):
An affine transformation object to the image's location in its CRS.
"""

if isinstance(raster_src, str):
if isinstance(raster_src, (str, Path)):
affine_obj = rasterio.open(raster_src).transform
elif isinstance(raster_src, rasterio.DatasetReader):
affine_obj = raster_src.transform
Expand Down Expand Up @@ -175,7 +177,7 @@ def stitch_images(
# ---------
# array : :class:`numpy.ndarray`
# A numpy array with a the shape: [Channels, X, Y] or [X, Y]
# out_name : str
# out_name : str or :class:`pathlib.Path`
# The output name and path for your image
# proj : :class:`gdal.projection`
# A projection, can be extracted from an image opened with gdal with
Expand All @@ -200,7 +202,7 @@ def stitch_images(
# driver = gdal.GetDriverByName("GTiff")
# if len(array.shape) == 2:
# array = array[np.newaxis, ...]
# os.makedirs(os.path.dirname(os.path.abspath(out_name)), exist_ok=True)
# Path(out_name).resolve().parent.mkdir(exist_ok=True)
# dataset = driver.Create(out_name, array.shape[2], array.shape[1], array.shape[0], out_format)
# if verbose is True:
# print("Array Shape, should be [Channels, X, Y] or [X,Y]:", array.shape)
Expand Down
14 changes: 7 additions & 7 deletions solaris/tile/raster_tile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path

import numpy as np
import rasterio
Expand Down Expand Up @@ -71,8 +72,8 @@ class RasterTiler(object):
src_path : `str`
The path or URL to the source dataset. Used for calling
``rio_cogeo.cogeo.cog_validate()``.
dest_dir : `str`
The directory to save the output tiles to. If not
dest_dir : `str` or :class:`pathlib.Path`
The directory to save the output tiles to.
dest_crs : int
The EPSG code for the output images. If not provided, outputs will
keep the same CRS as the source image when ``Tiler.make_tile_images()``
Expand Down Expand Up @@ -129,8 +130,7 @@ def __init__(
if verbose:
print("Initializing Tiler...")
self.dest_dir = dest_dir
if not os.path.exists(self.dest_dir):
os.makedirs(self.dest_dir)
Path(self.dest_dir).mkdir(exist_ok=True)
if dest_crs is not None:
self.dest_crs = _check_crs(dest_crs)
else:
Expand Down Expand Up @@ -180,7 +180,7 @@ def tile(

Arguments
---------
src : :class:`rasterio.io.DatasetReader` or str
src : :class:`rasterio.io.DatasetReader`, str or :class:`pathlib.Path`
The source dataset to tile.
nodata_threshold : float, optional
Nodata percentages greater than this threshold will not be saved as tiles.
Expand Down Expand Up @@ -297,13 +297,13 @@ def tile_generator(

Arguments
---------
src : `str` or :class:`Rasterio.DatasetReader`
src : `str`, :class:`pathlib.Path` or :class:`Rasterio.DatasetReader`
The source data to tile from. If this is a "classic"
(non-cloud-optimized) GeoTIFF, the whole image will be loaded in;
if it's cloud-optimized, only the required portions will be loaded
during tiling unless ``force_load_cog=True`` was specified upon
initialization.
dest_dir : str, optional
dest_dir : str or :class:`pathlib.Path`, optional
The path to the destination directory to output images to. If the
path doesn't exist, it will be created. This argument is required
if it wasn't provided during initialization.
Expand Down
Loading