Skip to content

Commit

Permalink
simple predictor for using gj as library
Browse files Browse the repository at this point in the history
  • Loading branch information
TankredO committed Jun 20, 2022
1 parent 408e845 commit b3137de
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 11 deletions.
24 changes: 15 additions & 9 deletions ginjinn/ginjinn_config/ginjinn_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
'instance-segmentation',
]

class GinjinnConfiguration: #pylint: disable=too-many-arguments,too-many-instance-attributes

class GinjinnConfiguration: # pylint: disable=too-many-arguments,too-many-instance-attributes
'''GinJinn configuration class.
A class representing the configuration of a GinJinn project.
Expand Down Expand Up @@ -51,6 +52,7 @@ class GinjinnConfiguration: #pylint: disable=too-many-arguments,too-many-instanc
InvalidGinjinnConfigurationError
If any of the general configuration is contradictionary or malformed.
'''

def __init__(
self,
project_dir: str,
Expand All @@ -60,8 +62,9 @@ def __init__(
training_configuration: GinjinnTrainingConfiguration,
augmentation_configuration: GinjinnAugmentationConfiguration,
detectron_configuration: GinjinnDetectronConfiguration = GinjinnDetectronConfiguration(),
options_configuration: GinjinnOptionsConfiguration =
GinjinnOptionsConfiguration.from_dictionary({}),
options_configuration: GinjinnOptionsConfiguration = GinjinnOptionsConfiguration.from_dictionary(
{}
),
):
self.project_dir = project_dir
self.task = task
Expand Down Expand Up @@ -134,9 +137,7 @@ def from_dictionary(cls, config: dict):
config['input'],
project_dir=project_dir,
)
model_configuration = GinjinnModelConfiguration.from_dictionary(
config['model']
)
model_configuration = GinjinnModelConfiguration.from_dictionary(config['model'])
training_configuration = GinjinnTrainingConfiguration.from_dictionary(
config.get('training', {})
)
Expand All @@ -162,7 +163,10 @@ def from_dictionary(cls, config: dict):
)

@classmethod
def from_config_file(cls, file_path: str):
def from_config_file(
cls,
file_path: str,
):
'''Build GinjinnConfiguration from YAML configuration file.
Parameters
Expand Down Expand Up @@ -201,6 +205,8 @@ def _check(self):

model_tasks = MODELS[self.model.name]['tasks']
if not self.task in model_tasks:
err_msg = f'Task "{self.task}" is incompatible with model ' +\
f'"{self.model.name}" (available tasks: {", ".join(model_tasks)}).'
err_msg = (
f'Task "{self.task}" is incompatible with model '
+ f'"{self.model.name}" (available tasks: {", ".join(model_tasks)}).'
)
raise InvalidGinjinnConfigurationError(err_msg)
2 changes: 1 addition & 1 deletion ginjinn/predictor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
''' Predictor module
'''

from .predictors import GinjinnPredictor
from .predictors import GinjinnPredictor, SimpleGinjinnPredictor
155 changes: 154 additions & 1 deletion ginjinn/predictor/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import datetime
import yaml
import glob
import json
import os
Expand All @@ -20,7 +21,7 @@
import ginjinn.segmentation_refinement as refine
import torch
from ginjinn.data_reader.data_reader import get_class_names
from ginjinn.ginjinn_config import GinjinnConfiguration
from ginjinn.ginjinn_config import GinjinnConfiguration, GinjinnModelConfiguration
from ginjinn.utils.utils import bbox_from_polygons
import warnings

Expand Down Expand Up @@ -509,3 +510,155 @@ def draw_instance_predictions_gj(
alpha=alpha,
)
return self.output


class SimpleGinjinnPredictor:
'''A class for predicting from a trained Detectron2 model.
Parameters
----------
cfg : CfgNode
Detectron2 configuration object
class_names : list of str
Ordered list of object class names
task : str
"bbox-detection" or "instance-segmentation"
'''

def __init__(self, cfg: CfgNode, class_names: List[str], task: str):
self.class_names = class_names
self.d2_cfg = cfg
self.task = task
self.d2_predictor = DefaultPredictor(self.d2_cfg)

@classmethod
def from_ginjinn_config(
cls,
gj_cfg: GinjinnConfiguration,
checkpoint_name: str = "model_final.pth",
) -> "SimpleGinjinnPredictor":
"""
Build GinjinnPredictor object from GinjinnConfiguration instead of
Detectron2 configuration.
Parameters
----------
gj_cfg : GinjinnConfiguration
checkpoint_name : str
Name of the checkpoint to use.
Returns
-------
GinjinnPredictor
"""

d2_cfg = gj_cfg.to_detectron2_config()
d2_cfg.MODEL.WEIGHTS = os.path.join(d2_cfg.OUTPUT_DIR, checkpoint_name)

return cls(d2_cfg, get_class_names(gj_cfg.project_dir), gj_cfg.task)

@classmethod
def from_ginjinn_project(
cls,
ginjinn_project_dir: str,
checkpoint_name: str = "model_final.pth",
) -> "SimpleGinjinnPredictor":
"""
Build GinjinnPredictor object from GinjinnConfiguration instead of
Detectron2 configuration.
Parameters
----------
ginjinn_project_dir : str
GinJinn2 project directory
checkpoint_name : str
Name of the checkpoint to use.
Returns
-------
GinjinnPredictor
"""
config_file = os.path.join(ginjinn_project_dir, 'ginjinn_config.yaml')
with open(config_file) as f:
d = yaml.safe_load(f)

task = d['task']
model_cfg = GinjinnModelConfiguration.from_dictionary(d['model'])
d2_cfg = model_cfg.to_detectron2_config()
d2_cfg.MODEL.WEIGHTS = os.path.join(
ginjinn_project_dir, 'outputs', checkpoint_name
)

return cls(d2_cfg, get_class_names(ginjinn_project_dir), task)

def predict(
self,
image: np.ndarray,
threshold: Union[float, int] = 0.8,
seg_refinement: bool = False,
refinement_device: str = "cuda:0",
refinement_method: str = "full",
):
"""
img_names : list of str, default=[]
File names of images to be used as input. By default, all images within self.img_dir
will be used.
threshold : float or int, default=0.8
Minimum score of predicted instances
seg_refinement : bool, default=False
If true, predictions are postprocessed with CascadePSP.
This option only works for instance segmentation.
refinement_device : str, default="cuda:0"
CPU or CUDA device for refinement with CascadePSP
refinement_method : str, default="full"
If set to "fast", the local refinement step will be skipped.
"""
self.d2_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
self.d2_cfg.MODEL.RETINANET.SCORE_THRESH_TEST = threshold

predictions = self.d2_predictor(image)

# convert to numpy arrays
boxes = (
predictions["instances"].get_fields()["pred_boxes"].to("cpu").tensor.numpy()
)
classes = (
predictions["instances"].get_fields()["pred_classes"].to("cpu").numpy()
)
scores = predictions["instances"].get_fields()["scores"].to("cpu").numpy()

if self.task == "instance-segmentation":
masks = (
predictions["instances"].get_fields()["pred_masks"].to("cpu").numpy()
)
else:
masks = None # np.array([])

torch.cuda.empty_cache()

if self.task == "instance-segmentation" and seg_refinement:
refiner = refine.Refiner(device=refinement_device)

if seg_refinement:
for i_mask, mask in enumerate(masks):
masks[i_mask] = refiner.refine(
image,
mask.astype("int") * 255,
fast=True if refinement_method == "fast" else False,
)

if self.task == "instance-segmentation":
for i_mask, mask in enumerate(masks):
# recalculate bounding boxes
x_any = masks[i_mask].any(axis=0)
y_any = masks[i_mask].any(axis=1)
x = np.where(x_any == True)[0]
y = np.where(y_any == True)[0]
if len(x) > 0 and len(y) > 0:
boxes[i_mask] = [x[0], y[0], x[-1] + 1, y[-1] + 1]
else:
boxes[i_mask] = [0, 0, 0, 0]

torch.cuda.empty_cache()

return classes, boxes, masks, scores

0 comments on commit b3137de

Please sign in to comment.