Skip to content

Commit

Permalink
Add option to visualize only predictions with low char prob
Browse files Browse the repository at this point in the history
  • Loading branch information
ankandrew committed May 6, 2024
1 parent 2d194e2 commit 3207fd5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 31 deletions.
39 changes: 18 additions & 21 deletions fast_plate_ocr/cli/visualize_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

import logging
import pathlib
from contextlib import nullcontext

import click
import cv2
import keras
import numpy as np

import fast_plate_ocr.common.utils
from fast_plate_ocr.train.model.config import load_config_from_yaml
from fast_plate_ocr.train.utilities import utils
from fast_plate_ocr.train.utilities.utils import postprocess_model_output

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
Expand Down Expand Up @@ -42,27 +41,26 @@
type=click.Path(exists=True, dir_okay=True, file_okay=False, path_type=pathlib.Path),
help="Directory containing the images to make predictions from.",
)
@click.option(
"-t",
"--time",
default=True,
is_flag=True,
help="Log time taken to run predictions.",
)
@click.option(
"-l",
"--low-conf-thresh",
type=float,
default=0.2,
default=0.35,
show_default=True,
help="Threshold for displaying low confidence characters.",
)
@click.option(
"-l",
"--filter-conf",
type=float,
help="Display plates that any of the plate characters are below this number.",
)
def visualize_predictions(
model_path: pathlib.Path,
config_file: pathlib.Path,
img_dir: pathlib.Path,
low_conf_thresh: float,
time: bool,
filter_conf: float | None,
):
"""
Visualize OCR model predictions on unlabeled data.
Expand All @@ -75,20 +73,19 @@ def visualize_predictions(
img_dir, width=config.img_width, height=config.img_height
)
for image in images:
with (
fast_plate_ocr.common.utils.log_time_taken("Prediction time") if time else nullcontext()
):
x = np.expand_dims(image, 0)
prediction = model(x, training=False)
prediction = keras.ops.stop_gradient(prediction).numpy()
utils.display_predictions(
image=image,
x = np.expand_dims(image, 0)
prediction = model(x, training=False)
prediction = keras.ops.stop_gradient(prediction).numpy()
plate, probs = postprocess_model_output(
prediction=prediction,
alphabet=config.alphabet,
plate_slots=config.max_plate_slots,
max_plate_slots=config.max_plate_slots,
vocab_size=config.vocabulary_size,
low_conf_thresh=low_conf_thresh,
)
if not filter_conf or (filter_conf and np.any(probs < filter_conf)):
utils.display_predictions(
image=image, plate=plate, probs=probs, low_conf_thresh=low_conf_thresh
)
cv2.destroyAllWindows()


Expand Down
12 changes: 2 additions & 10 deletions fast_plate_ocr/train/utilities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,13 @@ def low_confidence_positions(probs, thresh=0.3) -> npt.NDArray:

def display_predictions(
image: npt.NDArray,
prediction: npt.NDArray,
alphabet: str,
plate_slots: int,
vocab_size: int,
plate: str,
probs: npt.NDArray,
low_conf_thresh: float,
) -> None:
"""
Display plate and corresponding prediction.
"""
plate, probs = postprocess_model_output(
prediction=prediction,
alphabet=alphabet,
max_plate_slots=plate_slots,
vocab_size=vocab_size,
)
plate_str = "".join(plate)
logging.info("Plate: %s", plate_str)
logging.info("Confidence: %s", probs)
Expand Down

0 comments on commit 3207fd5

Please sign in to comment.