Skip to content

Commit

Permalink
Merge pull request #19 from ankandrew/dev-test
Browse files Browse the repository at this point in the history
New models and Improved logging
  • Loading branch information
ankandrew authored Jul 24, 2024
2 parents 4411012 + 99f5c3e commit 4200fbd
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 63 deletions.
12 changes: 12 additions & 0 deletions config/latin_plate_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Config example for Latin plates from 70 countries

# Max number of plate slots supported. This represents the number of model classification heads.
max_plate_slots: 9
# All the possible character set for the model output.
alphabet: '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_'
# Padding character for plates which length is smaller than MAX_PLATE_SLOTS. It should still be present in the alphabet.
pad_char: '_'
# Image height which is fed to the model.
img_height: 70
# Image width which is fed to the model.
img_width: 140
53 changes: 45 additions & 8 deletions fast_plate_ocr/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""

import pathlib
import shutil
from datetime import datetime
from typing import Literal

import albumentations as A
import click
Expand All @@ -12,6 +14,7 @@
from keras.src.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader

from fast_plate_ocr.cli.utils import print_params, print_train_details
from fast_plate_ocr.train.data.augmentation import TRAIN_AUGMENTATION
from fast_plate_ocr.train.data.dataset import LicensePlateDataset
from fast_plate_ocr.train.model.config import load_config_from_yaml
Expand Down Expand Up @@ -80,7 +83,7 @@
)
@click.option(
"--output-dir",
default="./trained-models",
default="./trained_models",
type=click.Path(dir_okay=True, path_type=pathlib.Path),
help="Output directory where model will be saved.",
)
Expand All @@ -100,9 +103,9 @@
@click.option(
"--tensorboard-dir",
"-l",
default="logs",
default="tensorboard_logs",
show_default=True,
type=str,
type=click.Path(path_type=pathlib.Path),
help="The path of the directory where to save the TensorBoard log files.",
)
@click.option(
Expand All @@ -117,8 +120,30 @@
default=60,
show_default=True,
type=int,
help="Reduce the learning rate by 0.5x if 'val_plate_acc' doesn't improve within X epochs.",
help="Patience to reduce the learning rate if 'val_plate_acc' doesn't improve within X epochs.",
)
@click.option(
"--reduce-lr-factor",
default=0.85,
show_default=True,
type=float,
help="Reduce the learning rate by this factor when 'val_plate_acc' doesn't improve.",
)
@click.option(
"--activation",
default="relu",
show_default=True,
type=str,
help="Activation function to use.",
)
@click.option(
"--pool-layer",
default="max",
show_default=True,
type=click.Choice(["max", "avg"]),
help="Choose the pooling layer to use.",
)
@print_params(table_title="CLI Training Parameters", c1_title="Parameter", c2_title="Details")
def train(
dense: bool,
config_file: pathlib.Path,
Expand All @@ -131,9 +156,12 @@ def train(
output_dir: pathlib.Path,
epochs: int,
tensorboard: bool,
tensorboard_dir: str,
tensorboard_dir: pathlib.Path,
early_stopping_patience: int,
reduce_lr_patience: int,
reduce_lr_factor: float,
activation: str,
pool_layer: Literal["max", "avg"],
) -> None:
"""
Train the License Plate OCR model.
Expand All @@ -142,6 +170,7 @@ def train(
A.load(augmentation_path, data_format="yaml") if augmentation_path else TRAIN_AUGMENTATION
)
config = load_config_from_yaml(config_file)
print_train_details(train_augmentation, config.model_dump())
train_torch_dataset = LicensePlateDataset(
annotations_file=annotations,
transform=train_augmentation,
Expand Down Expand Up @@ -169,6 +198,8 @@ def train(
dense=dense,
max_plate_slots=config.max_plate_slots,
vocabulary_size=config.vocabulary_size,
activation=activation,
pool_layer=pool_layer,
)
model.compile(
loss=cce_loss(vocabulary_size=config.vocabulary_size),
Expand All @@ -188,13 +219,17 @@ def train(
output_dir.mkdir(parents=True, exist_ok=True)
model_file_path = output_dir / "cnn_ocr-epoch_{epoch:02d}-acc_{val_plate_acc:.3f}.keras"

# Save params and config used for training
shutil.copy(config_file, output_dir / "config.yaml")
A.save(train_augmentation, output_dir / "train_augmentation.yaml", "yaml")

callbacks = [
# Reduce the learning rate by 0.5x if 'val_plate_acc' doesn't improve within X epochs
ReduceLROnPlateau(
"val_plate_acc",
patience=reduce_lr_patience,
factor=0.5,
min_lr=1e-5,
factor=reduce_lr_factor,
min_lr=1e-6,
verbose=1,
),
# Stop training when 'val_plate_acc' doesn't improve for X epochs
Expand All @@ -217,7 +252,9 @@ def train(
]

if tensorboard:
callbacks.append(TensorBoard(log_dir=tensorboard_dir))
run_dir = tensorboard_dir / datetime.now().strftime("run_%Y%m%d_%H%M%S")
run_dir.mkdir(parents=True, exist_ok=True)
callbacks.append(TensorBoard(log_dir=run_dir))

model.fit(train_dataloader, epochs=epochs, validation_data=val_dataloader, callbacks=callbacks)

Expand Down
83 changes: 83 additions & 0 deletions fast_plate_ocr/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Utils used for the CLI scripts.
"""

import inspect
import pathlib
from collections.abc import Callable
from functools import wraps
from typing import Any

import albumentations as A
from rich import box
from rich.console import Console
from rich.pretty import Pretty
from rich.table import Table


def print_variables_as_table(
c1_title: str, c2_title: str, title: str = "Variables Table", **kwargs: Any
) -> None:
"""
Prints variables in a formatted table using the rich library.
Args:
c1_title (str): Title of the first column.
c2_title (str): Title of the second column.
title (str): Title of the table.
**kwargs (Any): Variable names and values to be printed.
"""
console = Console()
console.print("\n")
table = Table(title=title, show_header=True, header_style="bold blue", box=box.ROUNDED)
table.add_column(c1_title, min_width=20, justify="left", style="bold")
table.add_column(c2_title, min_width=60, justify="left", style="bold")

for key, value in kwargs.items():
if isinstance(value, pathlib.Path):
value = str(value) # noqa: PLW2901
table.add_row(f"[bold]{key}[/bold]", Pretty(value))

console.print(table)


def print_params(
table_title: str = "Parameters Table", c1_title: str = "Variable", c2_title: str = "Value"
) -> Callable:
"""
A decorator that prints the parameters of a function in a formatted table
using the rich library.
Args:
c1_title (str, optional): Title of the first column. Defaults to "Variable".
c2_title (str, optional): Title of the second column. Defaults to "Value".
table_title (str, optional): Title of the table. Defaults to "Parameters Table".
Returns:
Callable: The wrapped function with parameter printing functionality.
"""

def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
func_signature = inspect.signature(func)
bound_arguments = func_signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
params = dict(bound_arguments.arguments.items())
print_variables_as_table(c1_title, c2_title, table_title, **params)
return func(*args, **kwargs)

return wrapper

return decorator


def print_train_details(augmentation: A.Compose, config: dict[str, Any]) -> None:
console = Console()
console.print("\n")
console.print("[bold blue]Augmentation Pipeline:[/bold blue]")
console.print(Pretty(augmentation))
console.print("\n")
console.print("[bold blue]Configuration:[/bold blue]")
console.print(Pretty(config))
console.print("\n")
2 changes: 1 addition & 1 deletion fast_plate_ocr/train/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def vocabulary_size(self) -> int:
return len(self.alphabet)

@model_validator(mode="after")
def check_passwords_match(self) -> "PlateOCRConfig":
def check_pad_in_alphabet(self) -> "PlateOCRConfig":
if self.pad_char not in self.alphabet:
raise ValueError("Pad character must be present in model alphabet.")
return self
Expand Down
58 changes: 45 additions & 13 deletions fast_plate_ocr/train/model/layer_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
"""

from keras import regularizers
from keras.activations import relu, relu6
from keras.layers import Activation, BatchNormalization, Conv2D, SeparableConv2D
from keras.src.layers import AveragePooling2D, MaxPooling2D


def block_no_bn(i, k=3, n_c=64, s=1, padding="same"):
def block_no_bn(i, k=3, n_c=64, s=1, padding="same", activation: str = "relu"):
x1 = Conv2D(
kernel_size=k,
filters=n_c,
Expand All @@ -16,7 +16,7 @@ def block_no_bn(i, k=3, n_c=64, s=1, padding="same"):
kernel_regularizer=regularizers.l2(0.01),
use_bias=False,
)(i)
x2 = Activation(relu)(x1)
x2 = Activation(activation)(x1)
return x2, x1


Expand All @@ -33,7 +33,7 @@ def block_no_activation(i, k=3, n_c=64, s=1, padding="same"):
return x


def block_bn(i, k=3, n_c=64, s=1, padding="same"):
def block_bn(i, k=3, n_c=64, s=1, padding="same", activation: str = "relu"):
x1 = Conv2D(
kernel_size=k,
filters=n_c,
Expand All @@ -43,18 +43,20 @@ def block_bn(i, k=3, n_c=64, s=1, padding="same"):
use_bias=False,
)(i)
x2 = BatchNormalization()(x1)
x2 = Activation(relu)(x2)
x2 = Activation(activation)(x2)
return x2, x1


def block_bn_no_l2(i, k=3, n_c=64, s=1, padding="same"):
def block_bn_no_l2(i, k=3, n_c=64, s=1, padding="same", activation: str = "relu"):
x1 = Conv2D(kernel_size=k, filters=n_c, strides=s, padding=padding, use_bias=False)(i)
x2 = BatchNormalization()(x1)
x2 = Activation(relu)(x2)
x2 = Activation(activation)(x2)
return x2, x1


def block_bn_sep_conv_l2(i, k=3, n_c=64, s=1, padding="same", depth_multiplier=1):
def block_bn_sep_conv_l2(
i, k=3, n_c=64, s=1, padding="same", depth_multiplier=1, activation: str = "relu"
):
l2_kernel_reg = regularizers.l2(0.01)
x1 = SeparableConv2D(
kernel_size=k,
Expand All @@ -67,11 +69,11 @@ def block_bn_sep_conv_l2(i, k=3, n_c=64, s=1, padding="same", depth_multiplier=1
pointwise_regularizer=l2_kernel_reg,
)(i)
x2 = BatchNormalization()(x1)
x2 = Activation(relu)(x2)
x2 = Activation(activation)(x2)
return x2, x1


def block_bn_relu6(i, k=3, n_c=64, s=1, padding="same"):
def block_bn_relu6(i, k=3, n_c=64, s=1, padding="same", activation: str = "relu6"):
x1 = Conv2D(
kernel_size=k,
filters=n_c,
Expand All @@ -81,12 +83,42 @@ def block_bn_relu6(i, k=3, n_c=64, s=1, padding="same"):
use_bias=False,
)(i)
x2 = BatchNormalization()(x1)
x2 = Activation(relu6)(x2)
x2 = Activation(activation)(x2)
return x2, x1


def block_bn_relu6_no_l2(i, k=3, n_c=64, s=1, padding="same"):
def block_bn_relu6_no_l2(i, k=3, n_c=64, s=1, padding="same", activation: str = "relu6"):
x1 = Conv2D(kernel_size=k, filters=n_c, strides=s, padding=padding, use_bias=False)(i)
x2 = BatchNormalization()(x1)
x2 = Activation(relu6)(x2)
x2 = Activation(activation)(x2)
return x2, x1


def block_average_conv_down(x, n_c, padding="same", activation: str = "relu"):
x = AveragePooling2D(pool_size=2, strides=1, padding="valid")(x)
x = Conv2D(
filters=n_c,
kernel_size=3,
strides=2,
padding=padding,
kernel_regularizer=regularizers.l2(0.01),
use_bias=False,
)(x)
x = BatchNormalization()(x)
x = Activation(activation)(x)
return x


def block_max_conv_down(x, n_c, padding="same", activation: str = "relu"):
x = MaxPooling2D(pool_size=2, strides=1, padding="valid")(x)
x = Conv2D(
filters=n_c,
kernel_size=3,
strides=2,
padding=padding,
kernel_regularizer=regularizers.l2(0.01),
use_bias=False,
)(x)
x = BatchNormalization()(x)
x = Activation(activation)(x)
return x
Loading

0 comments on commit 4200fbd

Please sign in to comment.