Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/early stopping #7

Merged
merged 7 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export("%>%")
export(accelerator)
export(fit)
export(luz_callback)
export(luz_callback_early_stopping)
export(luz_callback_metrics)
export(luz_callback_progress)
export(luz_callback_train_valid)
Expand Down
110 changes: 109 additions & 1 deletion R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ default_callbacks <- function() {
#' A `luz_callback` that can be passed to [fit.luz_module_generator()].
#' @family luz_callbacks
#' @export
luz_callback <- function(name, ..., private = NULL, active = NULL, parent_env = parent.frame(),
luz_callback <- function(name = NULL, ..., private = NULL, active = NULL, parent_env = parent.frame(),
inherit = NULL) {
make_class(
name = name,
Expand Down Expand Up @@ -263,4 +263,112 @@ luz_callback_train_valid <- luz_callback(
}
)

#' Early stopping callback
#'
#' Stops training when a monitored metric stops improving
#'
#' @param monitor A string in the format `<set>_<metric>` where `<set>` can be
#' 'train' or 'valid' and `<metric>` can be the abbreviation of any metric
#' that you are tracking during training.
#' @param min_delta Minimum improvement to reset the patience counter.
#' @param patience Number of epochs without improving until stoping training.
#' @param mode Specifies the direction that is considered an improvement. By default
#' 'min' is used. Can also be 'max' (higher is better) and 'zero'
#' (closer to zero is better).
#' @param baseline An initial value that will be used as the best seen value
#' in the begining. Model will stopm training if no better than baseline value
#' is found in the first `patience` epochs.
#'
#' @note
#' This callback adds a `on_early_stopping` callback that can be used to
#' call callbacks after as soon as the model stopped training.
#'
#' @note
#' If `verbose=TRUE` in [fit.luz_module_generator()] a message is printed when
#' early stopping.
#'
#' @returns
#' A `luz_callback` that does early stopping.
#'
#' @examples
#' cb <- luz_callback_early_stopping()
#'
#' @family luz_callbacks
#' @export
luz_callback_early_stopping <- luz_callback(
name = "early_stopping_callback",
initialize = function(monitor = "valid_loss", min_delta = 0, patience = 0,
mode="min", baseline=NULL) {
self$monitor <- monitor
self$min_delta <- min_delta
self$patience <- patience
self$mode <- mode
self$baseline <- baseline

if (!is.null(self$baseline))
self$current_best <- baseline

self$patience_counter <- 0L
},
on_fit_begin = function() {
ctx$handlers <- append(ctx$handlers, list(
early_stopping = function(err) {
ctx$call_callbacks("on_early_stopping")
invisible(NULL)
}
))
},
on_epoch_end = function() {

qty <- self$find_quantity()
if (is.null(self$current_best))
self$current_best <- qty

if (self$compare(qty, self$current_best)) {
# means that new qty is better then previous
self$current_best <- qty
self$patience_counter <- 0L
} else {
# mean that qty did not improve
self$patience_counter <- self$patience_counter + 1L
}

if (self$patience_counter >= self$patience) {
rlang::signal("Early stopping", class = "early_stopping")
}

},
on_early_stopping = function() {
inform(glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$epochs}"))
},
find_quantity = function() {
o <- strsplit(self$monitor, "_")[[1]]
set <- o[[1]]
qty <- o[[2]]
opt <- if (length(o) >= 3) o[[3]] else "opt"

out <- if (qty == "loss") {
as.numeric(utils::tail(ctx$losses[[set]], 1)[[1]][[opt]])
} else {
as.numeric(ctx$records$metrics[[set]][[qty]][[opt]])
}

if (length(out) != 1)
rlang::abort(glue::glue("Expected monitored metric to be length 1, got {length(out)}"))

out
},
# returns TRUE when the new is better then previous acording to mode
compare = function(new, old) {
out <- if (self$mode == "min")
(old - self$min_delta) > new
else if (self$mode == "max")
(new - self$min_delta) > old
else if (self$mode == "zero")
(abs(old) - self$min_delta) > abs(self$min_delta)

as.array(out)
}
)


87 changes: 49 additions & 38 deletions R/module.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,48 +193,53 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL
call_all_callbacks(ctx$callbacks, name)
}

ctx$call_callbacks("on_fit_begin")
ctx$handlers <- list()

for (epoch in seq_len(ctx$epochs)) {
ctx$epoch <- epoch
ctx$iter <- 0L
ctx$call_callbacks("on_epoch_begin")
ctx$call_callbacks("on_fit_begin")
rlang::with_handlers(
!!! ctx$handlers,
.expr = {
for (epoch in seq_len(ctx$epochs)) {
ctx$epoch <- epoch
ctx$iter <- 0L
ctx$call_callbacks("on_epoch_begin")

ctx$call_callbacks("on_train_begin")
ctx$call_callbacks("on_train_begin")

coro::loop(for (batch in ctx$data) {
bind_batch_to_ctx(ctx, batch)
ctx$iter <- ctx$iter + 1L
coro::loop(for (batch in ctx$data) {
bind_batch_to_ctx(ctx, batch)
ctx$iter <- ctx$iter + 1L

ctx$call_callbacks("on_train_batch_begin")
step()
ctx$call_callbacks("on_train_batch_end")
})
ctx$call_callbacks("on_train_batch_begin")
step()
ctx$call_callbacks("on_train_batch_end")
})

ctx$call_callbacks("on_train_end")
ctx$call_callbacks("on_train_end")

if (!is.null(ctx$valid_data)) {
if (!is.null(ctx$valid_data)) {

ctx$call_callbacks("on_valid_begin")
ctx$call_callbacks("on_valid_begin")

ctx$iter <- 0L
torch::with_no_grad({
coro::loop(for (batch in ctx$valid_data) {
bind_batch_to_ctx(ctx, batch)
ctx$iter <- ctx$iter + 1L
ctx$iter <- 0L
torch::with_no_grad({
coro::loop(for (batch in ctx$valid_data) {
bind_batch_to_ctx(ctx, batch)
ctx$iter <- ctx$iter + 1L

ctx$call_callbacks("on_valid_batch_begin")
step()
ctx$call_callbacks("on_valid_batch_end")
})
})
ctx$call_callbacks("on_valid_batch_begin")
step()
ctx$call_callbacks("on_valid_batch_end")
})
})

ctx$call_callbacks("on_valid_end")
ctx$call_callbacks("on_valid_end")

}
}

ctx$call_callbacks("on_epoch_end")
}
ctx$call_callbacks("on_epoch_end")
}
})

ctx$call_callbacks("on_fit_end")
structure(
Expand All @@ -251,7 +256,7 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL
#' @importFrom stats predict
#' @export
predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(),
accelerator = NULL) {
accelerator = NULL) {

ctx <- object$ctx

Expand All @@ -274,20 +279,26 @@ predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(),
else
stack <- pars$stack

ctx$handlers <- list()
ctx$output <- list()
ctx$callbacks <- initialize_callbacks(callbacks, ctx)

predict_fn <- if (is.null(ctx$model$predict)) ctx$model else ctx$model$predict

torch::with_no_grad({
ctx$call_callbacks("on_predict_begin")
coro::loop(for(batch in data) {
ctx$batch <- batch
ctx$input <- batch[[1]]
ctx$call_callbacks("on_predict_batch_begin")
ctx$output[[length(ctx$output) + 1]] <- do.call(predict_fn, list(ctx$input))
ctx$call_callbacks("on_predict_batch_end")
})
rlang::with_handlers(
!!! ctx$handlers,
.expr = {
coro::loop(for(batch in data) {
ctx$batch <- batch
ctx$input <- batch[[1]]
ctx$call_callbacks("on_predict_batch_begin")
ctx$output[[length(ctx$output) + 1]] <- do.call(predict_fn, list(ctx$input))
ctx$call_callbacks("on_predict_batch_end")
})
}
)
ctx$call_callbacks("on_predict_end")
})

Expand Down
2 changes: 2 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,5 @@ make_class <- function(name, ..., private, active, inherit, parent_env, .init_fu
attr(f, "r6_class") <- r6_class
f
}


1 change: 1 addition & 0 deletions man/ctx.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 56 additions & 0 deletions man/luz_callback_early_stopping.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_metrics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_progress.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_train_valid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/rmd/ctx.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,7 @@ The `ctx` object is used in luz to share information between the training loop a
+------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| `losses` | `list()` tracking losses over time. See also `help("luz_callback_metrics")` |
+------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| `handlers` | A named `list()` of handlers that is passed to `rlang::with_handlers()` during the training loop and can be used to handle errors or conditions that might be raised by other callbacks. |
+------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

: Context attributes
Loading