Skip to content

Commit

Permalink
feat: ✨ ClearML training loss logging (#1844)
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee authored Jan 16, 2025
1 parent 3db0300 commit 38efc1b
Show file tree
Hide file tree
Showing 10 changed files with 378 additions and 108 deletions.
64 changes: 49 additions & 15 deletions references/classification/train_pytorch_character.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,16 @@ def record_lr(
return lr_recorder[: len(loss_recorder)], loss_recorder


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, clearml_log=False):
if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
if clearml_log:
from clearml import Logger

logger = Logger.current_logger()

# Iterate over the batches of the dataset
pbar = tqdm(train_loader, position=1)
for images, targets in pbar:
Expand All @@ -141,6 +146,12 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a
scheduler.step()

pbar.set_description(f"Training loss: {train_loss.item():.6}")
if clearml_log:
global iteration
logger.report_scalar(
title="Training Loss", series="train_loss", value=train_loss.item(), iteration=iteration
)
iteration += 1


@torch.no_grad()
Expand Down Expand Up @@ -318,35 +329,48 @@ def main(args):
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

config = {
"learning_rate": args.lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"architecture": args.arch,
"input_size": args.input_size,
"optimizer": args.optim,
"framework": "pytorch",
"vocab": args.vocab,
"scheduler": args.sched,
"pretrained": args.pretrained,
}

# W&B
if args.wb:
import wandb

run = wandb.init(
name=exp_name,
project="character-classification",
config={
"learning_rate": args.lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"architecture": args.arch,
"input_size": args.input_size,
"optimizer": args.optim,
"framework": "pytorch",
"vocab": args.vocab,
"scheduler": args.sched,
"pretrained": args.pretrained,
},
config=config,
)

# ClearML
if args.clearml:
from clearml import Task

task = Task.init(project_name="docTR/character-classification", task_name=exp_name, reuse_last_task_id=False)
task.upload_artifact("config", config)
global iteration
iteration = 0

# Create loss queue
min_loss = np.inf
# Training loop
if args.early_stop:
early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta)
for epoch in range(args.epochs):
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler)
fit_one_epoch(
model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, clearml_log=args.clearml
)

# Validation loop at the end of each epoch
val_loss, acc = evaluate(model, val_loader, batch_transforms)
Expand All @@ -361,6 +385,15 @@ def main(args):
"val_loss": val_loss,
"acc": acc,
})

# ClearML
if args.clearml:
from clearml import Logger

logger = Logger.current_logger()
logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch)
logger.report_scalar(title="Accuracy", series="acc", value=acc, iteration=epoch)

if args.early_stop and early_stopper.early_stop(val_loss):
print("Training halted early due to reaching patience limit.")
break
Expand Down Expand Up @@ -420,6 +453,7 @@ def parse_args():
"--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples"
)
parser.add_argument("--wb", dest="wb", action="store_true", help="Log to Weights & Biases")
parser.add_argument("--clearml", dest="clearml", action="store_true", help="Log to ClearML")
parser.add_argument("--push-to-hub", dest="push_to_hub", action="store_true", help="Push to Huggingface Hub")
parser.add_argument(
"--pretrained",
Expand Down
64 changes: 49 additions & 15 deletions references/classification/train_pytorch_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,16 @@ def record_lr(
return lr_recorder[: len(loss_recorder)], loss_recorder


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, clearml_log=False):
if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
if clearml_log:
from clearml import Logger

logger = Logger.current_logger()

# Iterate over the batches of the dataset
pbar = tqdm(train_loader, position=1)
for images, targets in pbar:
Expand All @@ -152,6 +157,12 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a
scheduler.step()

pbar.set_description(f"Training loss: {train_loss.item():.6}")
if clearml_log:
global iteration
logger.report_scalar(
title="Training Loss", series="train_loss", value=train_loss.item(), iteration=iteration
)
iteration += 1


@torch.no_grad()
Expand Down Expand Up @@ -324,35 +335,48 @@ def main(args):
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

config = {
"learning_rate": args.lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"architecture": args.arch,
"input_size": input_size,
"optimizer": args.optim,
"framework": "pytorch",
"classes": CLASSES,
"scheduler": args.sched,
"pretrained": args.pretrained,
}

# W&B
if args.wb:
import wandb

run = wandb.init(
name=exp_name,
project="orientation-classification",
config={
"learning_rate": args.lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"architecture": args.arch,
"input_size": input_size,
"optimizer": args.optim,
"framework": "pytorch",
"classes": CLASSES,
"scheduler": args.sched,
"pretrained": args.pretrained,
},
config=config,
)

# ClearML
if args.clearml:
from clearml import Task

task = Task.init(project_name="docTR/orientation-classification", task_name=exp_name, reuse_last_task_id=False)
task.upload_artifact("config", config)
global iteration
iteration = 0

# Create loss queue
min_loss = np.inf
# Training loop
if args.early_stop:
early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta)
for epoch in range(args.epochs):
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler)
fit_one_epoch(
model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, clearml_log=args.clearml
)

# Validation loop at the end of each epoch
val_loss, acc = evaluate(model, val_loader, batch_transforms)
Expand All @@ -367,6 +391,15 @@ def main(args):
"val_loss": val_loss,
"acc": acc,
})

# ClearML
if args.clearml:
from clearml import Logger

logger = Logger.current_logger()
logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch)
logger.report_scalar(title="Accuracy", series="acc", value=acc, iteration=epoch)

if args.early_stop and early_stopper.early_stop(val_loss):
print("Training halted early due to reaching patience limit.")
break
Expand Down Expand Up @@ -410,6 +443,7 @@ def parse_args():
"--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples"
)
parser.add_argument("--wb", dest="wb", action="store_true", help="Log to Weights & Biases")
parser.add_argument("--clearml", dest="clearml", action="store_true", help="Log to ClearML")
parser.add_argument("--push-to-hub", dest="push_to_hub", action="store_true", help="Push to Huggingface Hub")
parser.add_argument(
"--pretrained",
Expand Down
18 changes: 16 additions & 2 deletions references/classification/train_tensorflow_character.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ def apply_grads(optimizer, grads, model):
optimizer.apply_gradients(zip(grads, model.trainable_weights))


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False, clearml_log=False):
if clearml_log:
from clearml import Logger

logger = Logger.current_logger()

# Iterate over the batches of the dataset
pbar = tqdm(train_loader, position=1)
for images, targets in pbar:
Expand All @@ -111,6 +116,12 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False):
apply_grads(optimizer, grads, model)

pbar.set_description(f"Training loss: {train_loss.numpy().mean():.6}")
if clearml_log:
global iteration
logger.report_scalar(
title="Training Loss", series="train_loss", value=train_loss.numpy().mean(), iteration=iteration
)
iteration += 1


def evaluate(model, val_loader, batch_transforms):
Expand Down Expand Up @@ -315,6 +326,8 @@ def main(args):

task = Task.init(project_name="docTR/character-classification", task_name=exp_name, reuse_last_task_id=False)
task.upload_artifact("config", config)
global iteration
iteration = 0

# Create loss queue
min_loss = np.inf
Expand All @@ -323,7 +336,7 @@ def main(args):
if args.early_stop:
early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta)
for epoch in range(args.epochs):
fit_one_epoch(model, train_loader, batch_transforms, optimizer, args.amp)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, args.amp, args.clearml)

# Validation loop at the end of each epoch
val_loss, acc = evaluate(model, val_loader, batch_transforms)
Expand All @@ -346,6 +359,7 @@ def main(args):
logger = Logger.current_logger()
logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch)
logger.report_scalar(title="Accuracy", series="acc", value=acc, iteration=epoch)

if args.early_stop and early_stopper.early_stop(val_loss):
print("Training halted early due to reaching patience limit.")
break
Expand Down
18 changes: 16 additions & 2 deletions references/classification/train_tensorflow_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,12 @@ def apply_grads(optimizer, grads, model):
optimizer.apply_gradients(zip(grads, model.trainable_weights))


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False, clearml_log=False):
if clearml_log:
from clearml import Logger

logger = Logger.current_logger()

# Iterate over the batches of the dataset
pbar = tqdm(train_loader, position=1)
for images, targets in pbar:
Expand All @@ -125,6 +130,12 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False):
apply_grads(optimizer, grads, model)

pbar.set_description(f"Training loss: {train_loss.numpy().mean():.6}")
if clearml_log:
global iteration
logger.report_scalar(
title="Training Loss", series="train_loss", value=train_loss.numpy().mean(), iteration=iteration
)
iteration += 1


def evaluate(model, val_loader, batch_transforms):
Expand Down Expand Up @@ -324,6 +335,8 @@ def main(args):

task = Task.init(project_name="docTR/orientation-classification", task_name=exp_name, reuse_last_task_id=False)
task.upload_artifact("config", config)
global iteration
iteration = 0

# Create loss queue
min_loss = np.inf
Expand All @@ -332,7 +345,7 @@ def main(args):
if args.early_stop:
early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta)
for epoch in range(args.epochs):
fit_one_epoch(model, train_loader, batch_transforms, optimizer, args.amp)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, args.amp, args.clearml)

# Validation loop at the end of each epoch
val_loss, acc = evaluate(model, val_loader, batch_transforms)
Expand All @@ -355,6 +368,7 @@ def main(args):
logger = Logger.current_logger()
logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch)
logger.report_scalar(title="Accuracy", series="acc", value=acc, iteration=epoch)

if args.early_stop and early_stopper.early_stop(val_loss):
print("Training halted early due to reaching patience limit.")
break
Expand Down
Loading

0 comments on commit 38efc1b

Please sign in to comment.