Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify workflow #399

Draft
wants to merge 17 commits into
base: v0.2.0
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
remove extraneous code in favour of NextMethod
dajmcdon committed Sep 21, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 233e12e6440ebec93125e948d50e2f53a9e6274c
147 changes: 24 additions & 123 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
@@ -95,23 +95,10 @@ add_epi_recipe <- function(
#' @rdname add_epi_recipe
#' @export
remove_epi_recipe <- function(x) {
workflows:::validate_is_workflow(x)

if (!workflows:::has_preprocessor_recipe(x)) {
rlang::warn("The workflow has no recipe preprocessor to remove.")
}

actions <- x$pre$actions
actions[["recipe"]] <- NULL

new_epi_workflow(
pre = workflows:::new_stage_pre(actions = actions),
fit = x$fit,
post = x$post,
trained = FALSE
)
workflows::remove_recipe(x)
}


#' @rdname add_epi_recipe
#' @export
update_epi_recipe <- function(x, recipe, ..., blueprint = default_epi_recipe_blueprint()) {
@@ -180,15 +167,21 @@ adjust_epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_workflow <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
recipe <- adjust_epi_recipe(workflows::extract_preprocessor(x), which_step, ...)
adjust_epi_recipe.epi_workflow <- function(
x, which_step, ..., blueprint = default_epi_recipe_blueprint()
) {

update_epi_recipe(x, recipe, blueprint = blueprint)
rec <- adjust_epi_recipe(
workflows::extract_preprocessor(x), which_step, ...
)
update_epi_recipe(x, rec, blueprint = blueprint)
}

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
adjust_epi_recipe.epi_recipe <- function(
x, which_step, ..., blueprint = default_epi_recipe_blueprint()
) {
if (!(is.numeric(which_step) || is.character(which_step))) {
cli::cli_abort(
c("`which_step` must be a number or a character.",
@@ -294,109 +287,17 @@ kill_levels <- function(x, keys) {

#' @export
print.epi_recipe <- function(x, form_width = 30, ...) {
cli::cli_div(theme = list(.pkg = list("vec-trunc" = Inf, "vec-last" = ", ")))

cli::cli_h1("Epi Recipe")
cli::cli_h3("Inputs")

tab <- table(x$var_info$role, useNA = "ifany")
tab <- stats::setNames(tab, names(tab))
names(tab)[is.na(names(tab))] <- "undeclared role"

roles <- c("outcome", "predictor", "case_weights", "undeclared role")

tab <- c(
tab[names(tab) == roles[1]],
tab[names(tab) == roles[2]],
tab[names(tab) == roles[3]],
sort(tab[!names(tab) %in% roles], TRUE),
tab[names(tab) == roles[4]]
)

cli::cli_text("Number of variables by role")

spaces_needed <- max(nchar(names(tab))) - nchar(names(tab)) +
max(nchar(tab)) - nchar(tab)

cli::cli_verbatim(
glue::glue("{names(tab)}: {strrep('\ua0', spaces_needed)}{tab}")
)

if ("tr_info" %in% names(x)) {
cli::cli_h3("Training information")
nmiss <- x$tr_info$nrows - x$tr_info$ncomplete
nrows <- x$tr_info$nrows

cli::cli_text(
"Training data contained {nrows} data points and {cli::no(nmiss)} \\
incomplete row{?s}."
)
}

if (!is.null(x$steps)) {
cli::cli_h3("Operations")
}

fmt <- cli::cli_fmt({
for (step in x$steps) {
print(step, form_width = form_width)
}
})
cli::cli_ol(fmt)
cli::cli_end()

invisible(x)
}

# Currently only used in the workflow printing
print_preprocessor_recipe <- function(x, ...) {
recipe <- workflows::extract_preprocessor(x)
steps <- recipe$steps
n_steps <- length(steps)
cli::cli_text("{n_steps} Recipe step{?s}.")

if (n_steps == 0L) {
return(invisible(x))
}

step_names <- map_chr(steps, workflows:::pull_step_name)

if (n_steps <= 10L) {
cli::cli_ol(step_names)
return(invisible(x))
}

extra_steps <- n_steps - 10L
step_names <- step_names[1:10]

cli::cli_ol(step_names)
cli::cli_bullets("... and {extra_steps} more step{?s}.")
invisible(x)
}

print_preprocessor <- function(x) {
has_preprocessor_formula <- workflows:::has_preprocessor_formula(x)
has_preprocessor_recipe <- workflows:::has_preprocessor_recipe(x)
has_preprocessor_variables <- workflows:::has_preprocessor_variables(x)

no_preprocessor <- !has_preprocessor_formula && !has_preprocessor_recipe &&
!has_preprocessor_variables

if (no_preprocessor) {
return(invisible(x))
}

cli::cli_rule("Preprocessor")
cli::cli_text("")

if (has_preprocessor_formula) {
workflows:::print_preprocessor_formula(x)
}
if (has_preprocessor_recipe) {
print_preprocessor_recipe(x)
}
if (has_preprocessor_variables) {
workflows:::print_preprocessor_variables(x)
}
o <- cli::cli_fmt(NextMethod())
# Fix up the recipe name
rr <- unlist(strsplit(o[2], "Recipe"))
len <- nchar(rr[2])
h1_tail <- paste0(substr(rr[2], 1, len / 2 - 10), substr(rr[2], len / 2, len))
o[2] <- paste0(rr[1], "Epi Recipe", h1_tail)

# Number the operations
ops <- seq(grep(" Operations ", o, fixed = TRUE) + 1, length(o))
rep_ops <- sub("\033[36m•\033[39m ", "", o[ops], fixed = TRUE) # kills the •
o[ops] <- paste0(ops - ops[1] + 1, ". ", rep_ops)
cli::cli_bullets(o)
invisible(x)
}
21 changes: 8 additions & 13 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
@@ -32,18 +32,13 @@
#'
#' wf
epi_workflow <- function(preprocessor = NULL, spec = NULL, postprocessor = NULL) {
out <- workflows::workflow(spec = spec)
class(out) <- c("epi_workflow", class(out))
out <- workflows::workflow(preprocessor = preprocessor, spec = spec)

if (is_epi_recipe(preprocessor)) {
out <- add_epi_recipe(out, preprocessor)
} else if (!is_null(preprocessor)) {
out <- workflows:::add_preprocessor(out, preprocessor)
}
if (!is_null(postprocessor)) {
out <- add_postprocessor(out, postprocessor)
}

class(out) <- c("epi_workflow", class(out))
out
}

@@ -162,11 +157,14 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), .
}
components <- list()
components$mold <- workflows::extract_mold(object)
components$forged <- hardhat::forge(new_data,
components$forged <- hardhat::forge(
new_data,
blueprint = components$mold$blueprint
)
components$keys <- grab_forged_keys(components$forged, object, new_data)
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
components <- apply_frosting(
object, components, new_data, type = type, opts = opts, ...
)
components$predictions
}

@@ -216,10 +214,7 @@ new_epi_workflow <- function(

#' @export
print.epi_workflow <- function(x, ...) {
print_header(x)
print_preprocessor(x)
# workflows:::print_case_weights(x)
print_model(x)
NextMethod()
print_postprocessor(x)
invisible(x)
}