Skip to content

fit calibrators at fit.container() #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#' Re-calibrate numeric predictions
#'
#' @param x A [container()].
#' @param calibrator A pre-trained calibration method from the \pkg{probably}
#' package, such as [probably::cal_estimate_linear()].
#' @param type Character. One of `"linear"`, `"isotonic"`, or
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When putting together #13, ultimately felt that this should have a different argument than the same one supplied to container(type). I'd propose method, but will wait to Replace All until there's more feedback here.

#' `"isotonic_boot"`, corresponding to the function from the \pkg{probably}
#' package [probably::cal_estimate_linear()],
#' [probably::cal_estimate_isotonic()], or
#' [probably::cal_estimate_isotonic_boot()], respectively.
#' @examples
#' library(modeldata)
#' library(probably)
Expand All @@ -14,27 +17,24 @@
#'
#' dat
#'
#' # calibrate numeric predictions
#' reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred)
#'
#' # specify calibration
#' reg_ctr <-
#' container(mode = "regression") %>%
#' adjust_numeric_calibration(reg_cal)
#' adjust_numeric_calibration(type = "linear")
#'
#' # "train" container
#' # train container
#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred)
#'
#' predict(reg_ctr, dat)
#' predict(reg_ctr_trained, dat)
#' @export
adjust_numeric_calibration <- function(x, calibrator) {
check_container(x)
check_required(calibrator)
if (!inherits(calibrator, "cal_regression")) {
cli_abort(
"{.arg calibrator} should be a \\
{.help [<cal_regression> object](probably::cal_estimate_linear)}, \\
not {.obj_type_friendly {calibrator}}."
adjust_numeric_calibration <- function(x, type = NULL) {
# to-do: add argument specifying `prop` in initial_split
check_container(x, calibration_type = "numeric")
# wait to `check_type()` until `fit()` time
if (!is.null(type)) {
arg_match0(
type,
c("linear", "isotonic", "isotonic_boot")
)
}

Expand All @@ -43,7 +43,7 @@ adjust_numeric_calibration <- function(x, calibrator) {
"numeric_calibration",
inputs = "numeric",
outputs = "numeric",
arguments = list(calibrator = calibrator),
arguments = list(type = type),
results = list(),
trained = FALSE
)
Expand All @@ -67,19 +67,33 @@ print.numeric_calibration <- function(x, ...) {

#' @export
fit.numeric_calibration <- function(object, data, container = NULL, ...) {
type <- check_type(object$type, container$type)
# todo: adjust_numeric_calibration() should take arguments to pass to
# cal_estimate_* via dots
fit <-
eval_bare(
call2(
paste0("cal_estimate_", type),
.data = data,
truth = container$columns$outcome,
estimate = container$columns$estimate,
.ns = "probably"
)
)

new_operation(
class(object),
inputs = object$inputs,
outputs = object$outputs,
arguments = object$arguments,
results = list(),
results = list(fit = fit),
trained = TRUE
)
}

#' @export
predict.numeric_calibration <- function(object, new_data, container, ...) {
probably::cal_apply(new_data, object$argument$calibrator)
probably::cal_apply(new_data, object$results$fit)
}

# todo probably needs required_pkgs methods for cal objects
Expand Down
45 changes: 31 additions & 14 deletions R/adjust-probability-calibration.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
#' Re-calibrate classification probability predictions
#'
#' @param x A [container()].
#' @param calibrator A pre-trained calibration method from the \pkg{probably}
#' package, such as [probably::cal_estimate_logistic()].
#' @param type Character. One of `"logistic"`, `"multinomial"`,
#' `"beta"`, `"isotonic"`, or `"isotonic_boot"`, corresponding to the
#' function from the \pkg{probably} package [probably::cal_estimate_logistic()],
#' [probably::cal_estimate_multinomial()], etc., respectively.
#' @export
adjust_probability_calibration <- function(x, calibrator) {
check_container(x)
cls <- c("cal_binary", "cal_multinomial")
check_required(calibrator)
if (!inherits_any(calibrator, cls)) {
cli_abort(
"{.arg calibrator} should be a \\
{.help [<cal_binary> or <cal_multinomial> object](probably::cal_estimate_logistic)}, \\
not {.obj_type_friendly {calibrator}}."
adjust_probability_calibration <- function(x, type = NULL) {
# to-do: add argument specifying `prop` in initial_split
check_container(x, calibration_type = "probability")
# wait to `check_type()` until `fit()` time
if (!is.null(type)) {
arg_match(
type,
c("logistic", "multinomial", "beta", "isotonic", "isotonic_boot")
)
}

Expand All @@ -21,7 +22,7 @@ adjust_probability_calibration <- function(x, calibrator) {
"probability_calibration",
inputs = "probability",
outputs = "probability_class",
arguments = list(calibrator = calibrator),
arguments = list(type = type),
results = list(),
trained = FALSE
)
Expand All @@ -45,19 +46,35 @@ print.probability_calibration <- function(x, ...) {

#' @export
fit.probability_calibration <- function(object, data, container = NULL, ...) {
type <- check_type(object$type, container$type)
# todo: adjust_probability_calibration() should take arguments to pass to
# cal_estimate_* via dots
# to-do: add argument specifying `prop` in initial_split
fit <-
eval_bare(
call2(
paste0("cal_estimate_", type),
.data = data,
# todo: make getters for the entries in `columns`
truth = container$columns$outcome,
estimate = container$columns$estimate,
.ns = "probably"
)
)

new_operation(
class(object),
inputs = object$inputs,
outputs = object$outputs,
arguments = object$arguments,
results = list(),
results = list(fit = fit),
trained = TRUE
)
}

#' @export
predict.probability_calibration <- function(object, new_data, container, ...) {
probably::cal_apply(new_data, object$argument$calibrator)
probably::cal_apply(new_data, object$results$fit)
}

# todo probably needs required_pkgs methods for cal objects
Expand Down
2 changes: 1 addition & 1 deletion R/container.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ fit.container <- function(object, .data, outcome, estimate, probabilities = c(),

num_oper <- length(object$operations)
for (op in seq_len(num_oper)) {
object$operations[[op]] <- fit(object$operations[[op]], data, object)
object$operations[[op]] <- fit(object$operations[[op]], .data, object)
.data <- predict(object$operations[[op]], .data, object)
}

Expand Down
85 changes: 83 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,95 @@ is_container <- function(x) {
}

# ad-hoc checking --------------------------------------------------------------
check_container <- function(x, call = caller_env(), arg = caller_arg(x)) {
check_container <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) {
if (!is_container(x)) {
cli::cli_abort(
cli_abort(
"{.arg {arg}} should be a {.help [{.cls container}](container::container)}, \\
not {.obj_type_friendly {x}}.",
call = call
)
}

# check that the type of calibration ("numeric" or "probability") is
# compatible with the container type
if (!is.null(calibration_type)) {
container_type <- x$type
switch(
container_type,
regression =
check_calibration_type(calibration_type, "numeric", container_type, call = call),
binary = , multinomial =
check_calibration_type(calibration_type, "probability", container_type, call = call)
)
}

invisible()
}

check_calibration_type <- function(calibration_type, calibration_type_expected,
container_type, call) {
if (!identical(calibration_type, calibration_type_expected)) {
cli_abort(
"A {.field {container_type}} container is incompatible with the operation \\
{.fun {paste0('adjust_', calibration_type, '_calibration')}}.",
call = call
)
}
}

types_regression <- c("linear", "isotonic", "isotonic_boot")
types_binary <- c("logistic", "beta", "isotonic", "isotonic_boot")
types_multiclass <- c("multinomial", "beta", "isotonic", "isotonic_boot")
# a check function to be called when a container is being `fit()`ted.
# by the time a container is fitted, we have:
# * `adjust_type`, the `type` argument passed to an `adjust_*` function
# * this argument has already been checked to agree with the kind of
# `adjust_*()` function via `arg_match0()`.
# * `container_type`, the `type` argument either specified in `container()`
# or inferred in `fit.container()`.
check_type <- function(adjust_type,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple gaps in this logic depending on when/how different variants of the type are supplied. Will sit tight on closing those up until we address #13.

container_type,
arg = caller_arg(adjust_type),
call = caller_env()) {
# if no `adjust_type` was supplied, infer a reasonable one based on the
# `container_type`
if (is.null(adjust_type)) {
switch(
container_type,
regression = return("linear"),
binary = return("logistic"),
multiclass = return("multinomial")
)
}

switch(
container_type,
regression = arg_match0(
adjust_type,
types_regression,
arg_nm = arg,
error_call = call
Comment on lines +118 to +119
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one

),
binary = arg_match0(
adjust_type,
types_binary,
arg_nm = arg,
error_call = call
),
multiclass = arg_match0(
adjust_type,
types_multiclass,
arg_nm = arg,
error_call = call
),
arg_match0(
adjust_type,
unique(c(types_regression, types_binary, types_multiclass)),
arg_nm = arg,
error_call = call
)
)

adjust_type
}

18 changes: 9 additions & 9 deletions man/adjust_numeric_calibration.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions man/adjust_probability_calibration.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 9 additions & 8 deletions tests/testthat/_snaps/adjust-numeric-calibration.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
# adjustment printing

Code
ctr_reg %>% adjust_numeric_calibration(dummy_reg_cal)
ctr_reg %>% adjust_numeric_calibration()
Message

-- Container -------------------------------------------------------------------
A postprocessor with 1 operation:
A regression postprocessor with 1 operation:

* Re-calibrate numeric predictions.

# errors informatively with bad input

Code
adjust_numeric_calibration(ctr_reg)
adjust_numeric_calibration(ctr_reg, "boop")
Condition
Error in `adjust_numeric_calibration()`:
! `calibrator` is absent but must be supplied.
! `type` must be one of "linear", "isotonic", or "isotonic_boot", not "boop".

---

Code
adjust_numeric_calibration(ctr_reg, "boop")
container("classification", "binary") %>% adjust_numeric_calibration("linear")
Condition
Error in `adjust_numeric_calibration()`:
! `calibrator` should be a <cal_regression> object (`?probably::cal_estimate_linear()`), not a string.
! A binary container is incompatible with the operation `adjust_numeric_calibration()`.

---

Code
adjust_numeric_calibration(ctr_cls, dummy_cls_cal)
container("regression", "regression") %>% adjust_numeric_calibration("binary")
Condition
Error in `adjust_numeric_calibration()`:
! `calibrator` should be a <cal_regression> object (`?probably::cal_estimate_linear()`), not a <cal_binary> object.
! `type` must be one of "linear", "isotonic", or "isotonic_boot", not "binary".
i Did you mean "linear"?

Loading