Skip to content

Commit 3874410

Browse files
topepo‘topepo’simonpcouchEmilHvitfeldt
authored
first pass at the post-processing container (#1)
Co-authored-by: ‘topepo’ <‘[email protected]’> Co-authored-by: simonpcouch <[email protected]> Co-authored-by: Emil Hvitfeldt <[email protected]>
1 parent aa5ac35 commit 3874410

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1862
-10
lines changed

DESCRIPTION

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,28 @@ Authors@R: c(
66
person("Hannah", "Frick", , "[email protected]", role = "aut"),
77
person("Emil", "HvitFeldt", , "[email protected]", role = "aut"),
88
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
9-
person(given = "Posit Software, PBC", role = c("cph", "fnd"))
9+
person("Posit Software, PBC", role = c("cph", "fnd"))
1010
)
1111
Description: Sandbox for a postprocessor object.
1212
License: MIT + file LICENSE
13+
URL: https://github.com/tidymodels/container
14+
BugReports: https://github.com/tidymodels/container/issues
15+
Imports:
16+
cli,
17+
dplyr,
18+
generics,
19+
hardhat,
20+
probably (>= 1.0.3.9000),
21+
purrr,
22+
rlang (>= 1.1.0),
23+
tibble,
24+
tidyselect
1325
Suggests:
26+
modeldata,
1427
testthat (>= 3.0.0)
28+
Remotes:
29+
tidymodels/probably
1530
Config/testthat/edition: 3
1631
Encoding: UTF-8
1732
Roxygen: list(markdown = TRUE)
1833
RoxygenNote: 7.3.1
19-
URL: https://github.com/tidymodels/container
20-
BugReports: https://github.com/tidymodels/container/issues
21-
Imports:
22-
cli,
23-
rlang (>= 1.1.0)

NAMESPACE

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,63 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(fit,container)
4+
S3method(fit,equivocal_zone)
5+
S3method(fit,numeric_calibration)
6+
S3method(fit,numeric_range)
7+
S3method(fit,predictions_custom)
8+
S3method(fit,probability_calibration)
9+
S3method(fit,probability_threshold)
10+
S3method(predict,container)
11+
S3method(predict,equivocal_zone)
12+
S3method(predict,numeric_calibration)
13+
S3method(predict,numeric_range)
14+
S3method(predict,predictions_custom)
15+
S3method(predict,probability_calibration)
16+
S3method(predict,probability_threshold)
17+
S3method(print,container)
18+
S3method(print,equivocal_zone)
19+
S3method(print,numeric_calibration)
20+
S3method(print,numeric_range)
21+
S3method(print,predictions_custom)
22+
S3method(print,probability_calibration)
23+
S3method(print,probability_threshold)
24+
S3method(required_pkgs,equivocal_zone)
25+
S3method(required_pkgs,numeric_calibration)
26+
S3method(required_pkgs,numeric_range)
27+
S3method(required_pkgs,predictions_custom)
28+
S3method(required_pkgs,probability_calibration)
29+
S3method(required_pkgs,probability_threshold)
30+
S3method(tunable,equivocal_zone)
31+
S3method(tunable,numeric_calibration)
32+
S3method(tunable,numeric_range)
33+
S3method(tunable,predictions_custom)
34+
S3method(tunable,probability_calibration)
35+
S3method(tunable,probability_threshold)
36+
export("%>%")
37+
export(adjust_equivocal_zone)
38+
export(adjust_numeric_calibration)
39+
export(adjust_numeric_range)
40+
export(adjust_predictions_custom)
41+
export(adjust_probability_calibration)
42+
export(adjust_probability_threshold)
43+
export(container)
44+
export(extract_parameter_dials)
45+
export(extract_parameter_set_dials)
46+
export(fit)
47+
export(required_pkgs)
48+
export(tidy)
49+
export(tunable)
50+
export(tune_args)
351
import(rlang)
452
importFrom(cli,cli_abort)
553
importFrom(cli,cli_inform)
654
importFrom(cli,cli_warn)
55+
importFrom(dplyr,"%>%")
56+
importFrom(generics,fit)
57+
importFrom(generics,required_pkgs)
58+
importFrom(generics,tidy)
59+
importFrom(generics,tunable)
60+
importFrom(generics,tune_args)
61+
importFrom(hardhat,extract_parameter_dials)
62+
importFrom(hardhat,extract_parameter_set_dials)
63+
importFrom(stats,predict)

R/adjust-equivocal-zone.R

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#' Apply an equivocal zone to a binary classification model.
2+
#'
3+
#' @param x A [container()].
4+
#' @param value A numeric value (between zero and 1/2) or [hardhat::tune()]. The
5+
#' value is the size of the buffer around the threshold.
6+
#' @param threshold A numeric value (between zero and one) or [hardhat::tune()].
7+
#' @examples
8+
#' library(dplyr)
9+
#' library(modeldata)
10+
#'
11+
#' post_obj <-
12+
#' container(mode = "classification") %>%
13+
#' adjust_equivocal_zone(value = 1 / 4)
14+
#'
15+
#'
16+
#' post_res <- fit(
17+
#' post_obj,
18+
#' two_class_example,
19+
#' outcome = c(truth),
20+
#' estimate = c(predicted),
21+
#' probabilities = c(Class1, Class2)
22+
#' )
23+
#'
24+
#' predict(post_res, two_class_example)
25+
#' @export
26+
adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) {
27+
check_container(x)
28+
if (!is_tune(value)) {
29+
check_number_decimal(value, min = 0, max = 1 / 2)
30+
}
31+
if (!is_tune(threshold)) {
32+
check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10)
33+
}
34+
35+
op <-
36+
new_operation(
37+
"equivocal_zone",
38+
inputs = "probability",
39+
outputs = "class",
40+
arguments = list(value = value, threshold = threshold),
41+
results = list(),
42+
trained = FALSE
43+
)
44+
45+
new_container(
46+
mode = x$mode,
47+
type = x$type,
48+
operations = c(x$operations, list(op)),
49+
columns = x$dat,
50+
ptype = x$ptype,
51+
call = current_env()
52+
)
53+
}
54+
55+
#' @export
56+
print.equivocal_zone <- function(x, ...) {
57+
# check for tune() first
58+
59+
if (is_tune(x$arguments$value)) {
60+
cli::cli_bullets(c("*" = "Add equivocal zone of optimized size."))
61+
} else {
62+
trn <- ifelse(x$trained, " [trained]", "")
63+
cli::cli_bullets(c(
64+
"*" = "Add equivocal zone of size
65+
{signif(x$arguments$value, digits = 3)}.{trn}"
66+
))
67+
}
68+
invisible(x)
69+
}
70+
71+
#' @export
72+
fit.equivocal_zone <- function(object, data, container = NULL, ...) {
73+
new_operation(
74+
class(object),
75+
inputs = object$inputs,
76+
outputs = object$outputs,
77+
arguments = object$arguments,
78+
results = list(),
79+
trained = TRUE
80+
)
81+
}
82+
83+
#' @export
84+
predict.equivocal_zone <- function(object, new_data, container, ...) {
85+
est_nm <- container$columns$estimate
86+
prob_nm <- container$columns$probabilities[1]
87+
lvls <- levels(new_data[[est_nm]])
88+
col_syms <- syms(prob_nm[1])
89+
cls_pred <- probably::make_two_class_pred(
90+
new_data[[prob_nm]],
91+
levels = lvls,
92+
buffer = object$arguments$value,
93+
threshold = object$arguments$threshold
94+
)
95+
new_data[[est_nm]] <- cls_pred # todo convert to factor?
96+
new_data
97+
}
98+
99+
#' @export
100+
required_pkgs.equivocal_zone <- function(x, ...) {
101+
c("container", "probably")
102+
}
103+
104+
#' @export
105+
tunable.equivocal_zone <- function(x, ...) {
106+
tibble::new_tibble(list(
107+
name = "buffer",
108+
call_info = list(list(pkg = "dials", fun = "buffer")),
109+
source = "container",
110+
component = "equivocal_zone",
111+
component_id = "equivocal_zone"
112+
))
113+
}
114+
115+
# todo missing methods:
116+
# todo tune_args
117+
# todo tidy
118+
# todo extract_parameter_set_dials

R/adjust-numeric-calibration.R

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#' Re-calibrate numeric predictions
2+
#'
3+
#' @param x A [container()].
4+
#' @param calibrator A pre-trained calibration method from the \pkg{probably}
5+
#' package, such as [probably::cal_estimate_linear()].
6+
#' @examples
7+
#' library(modeldata)
8+
#' library(probably)
9+
#' library(tibble)
10+
#'
11+
#' # create example data
12+
#' set.seed(1)
13+
#' dat <- tibble(y = rnorm(100), y_pred = y/2 + rnorm(100))
14+
#'
15+
#' dat
16+
#'
17+
#' # calibrate numeric predictions
18+
#' reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred)
19+
#'
20+
#' # specify calibration
21+
#' reg_ctr <-
22+
#' container(mode = "regression") %>%
23+
#' adjust_numeric_calibration(reg_cal)
24+
#'
25+
#' # "train" container
26+
#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred)
27+
#'
28+
#' predict(reg_ctr, dat)
29+
#' @export
30+
adjust_numeric_calibration <- function(x, calibrator) {
31+
check_container(x)
32+
check_required(calibrator)
33+
if (!inherits(calibrator, "cal_regression")) {
34+
cli_abort(
35+
"{.arg calibrator} should be a \\
36+
{.help [<cal_regression> object](probably::cal_estimate_linear)}, \\
37+
not {.obj_type_friendly {calibrator}}."
38+
)
39+
}
40+
41+
op <-
42+
new_operation(
43+
"numeric_calibration",
44+
inputs = "numeric",
45+
outputs = "numeric",
46+
arguments = list(calibrator = calibrator),
47+
results = list(),
48+
trained = FALSE
49+
)
50+
51+
new_container(
52+
mode = x$mode,
53+
type = x$type,
54+
operations = c(x$operations, list(op)),
55+
columns = x$dat,
56+
ptype = x$ptype,
57+
call = current_env()
58+
)
59+
}
60+
61+
#' @export
62+
print.numeric_calibration <- function(x, ...) {
63+
trn <- ifelse(x$trained, " [trained]", "")
64+
cli::cli_bullets(c("*" = "Re-calibrate numeric predictions.{trn}"))
65+
invisible(x)
66+
}
67+
68+
#' @export
69+
fit.numeric_calibration <- function(object, data, container = NULL, ...) {
70+
new_operation(
71+
class(object),
72+
inputs = object$inputs,
73+
outputs = object$outputs,
74+
arguments = object$arguments,
75+
results = list(),
76+
trained = TRUE
77+
)
78+
}
79+
80+
#' @export
81+
predict.numeric_calibration <- function(object, new_data, container, ...) {
82+
probably::cal_apply(new_data, object$argument$calibrator)
83+
}
84+
85+
# todo probably needs required_pkgs methods for cal objects
86+
#' @export
87+
required_pkgs.numeric_calibration <- function(x, ...) {
88+
c("container", "probably")
89+
}
90+
91+
#' @export
92+
tunable.numeric_calibration <- function(x, ...) {
93+
no_param
94+
}
95+
96+
# todo missing methods:
97+
# todo tune_args
98+
# todo tidy
99+
# todo extract_parameter_set_dials

0 commit comments

Comments
 (0)