Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cli errors for validation_set and vfold. #532

Merged
merged 7 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ export(validation_time_split)
export(vfold_cv)
import(vctrs)
importFrom(cli,cli_abort)
importFrom(cli,cli_warn)
importFrom(dplyr,"%>%")
importFrom(dplyr,arrange)
importFrom(dplyr,arrange_)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

* The new `inner_split()` function and its methods for various resamples is for usage in tune to create a inner resample of the analysis set to fit the preprocessor and model on one part and the post-processor on the other part (#483, #488, #489).

* Started moving error messages to cli (#499, #502). With contributions from @PriKalra (#523, #526, #528, #530, #531) and @JamesHWade (#518).
* Started moving error messages to cli (#499, #502). With contributions from @PriKalra (#523, #526, #528, #530, #531, #532) and @JamesHWade (#518).

* Fixed example for `nested_cv()` (@seb09, #520).

Expand Down
2 changes: 1 addition & 1 deletion R/rsample-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

## usethis namespace: start
#' @importFrom lifecycle deprecated
#' @importFrom cli cli_abort
#' @importFrom cli cli_abort cli_warn
## usethis namespace: end
NULL

Expand Down
2 changes: 1 addition & 1 deletion R/tidy.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ tidy.nested_cv <- function(x, unique_ind = TRUE, ...) {

inner_id <- grep("^id", names(inner_tidy))
if (length(inner_id) != length(id_cols)) {
rlang::abort("Cannot merge tidy data sets")
cli_abort("Cannot merge tidy data sets.")
}
names(inner_tidy)[inner_id] <- id_cols
full_join(outer_tidy, inner_tidy, by = id_cols)
Expand Down
8 changes: 4 additions & 4 deletions R/validation_set.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ validation.val_split <- function(x, ...) {
#' @rdname validation_set
#' @export
testing.val_split <- function(x, ...) {
rlang::abort(
"The testing data is not part of the validation set object.",
i = "It is part of the result of the initial 3-way split, e.g., with `initial_validation_split()`."
)
cli_abort(c(
"The testing data is not part of the validation set object.",
"i" = "It is part of the result of the initial 3-way split, e.g., with {.fun initial_validation_split}."
))
}
39 changes: 19 additions & 20 deletions R/vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ vfold_cv <- function(data, v = 10, repeats = 1,
)
} else {
if (v == nrow(data)) {
rlang::abort(
glue::glue("Repeated resampling when `v` is {v} would create identical resamples")
cli_abort(
"Repeated resampling when {.arg v} is {v} would create identical resamples."
)
}
for (i in 1:repeats) {
Expand Down Expand Up @@ -225,14 +225,13 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance =
split_objs <- group_vfold_splits(data = data, group = group, v = v, balance = balance, strata = strata, pool = pool)
} else {
if (is.null(v)) {
rlang::abort(
"Repeated resampling when `v` is `NULL` would create identical resamples"
cli_abort(
"Repeated resampling when {.arg v} is {.val NULL} would create identical resamples."
)
}
if (v == length(unique(getElement(data, group)))) {
rlang::abort(
glue::glue("Repeated resampling when `v` is {v} would create identical resamples")
)
cli_abort("Repeated resampling when {.arg v} is {.val {v}} would create identical resamples.")

}
for (i in 1:repeats) {
tmp <- group_vfold_splits(data = data, group = group, v = v, balance = balance, strata = strata, pool = pool)
Expand Down Expand Up @@ -288,21 +287,20 @@ group_vfold_splits <- function(data, group, v = NULL, balance, strata = NULL, po
)$count
)
message <- c(
"Leaving `v = NULL` while using stratification will set `v` to the number of groups present in the least common stratum."
"Leaving {.code v = NULL} while using stratification will set {.arg v} to the number of groups present in the least common stratum."
)

if (max_v < 5) {
rlang::abort(c(
cli_abort(c(
message,
x = glue::glue("The least common stratum only had {max_v} groups, which may not be enough for cross-validation."),
i = "Set `v` explicitly to override this error."
),
call = rlang::caller_env())
"*" = "The least common stratum only had {.val {max_v}} groups, which may not be enough for cross-validation.",
"i" = "Set {.arg v} explicitly to override this error."
), call = rlang::caller_env())
}

rlang::warn(c(
cli_warn(c(
message,
i = "Set `v` explicitly to override this warning."
i = "Set {.arg v} explicitly to override this warning."
),
call = rlang::caller_env())
}
Expand Down Expand Up @@ -334,10 +332,11 @@ add_vfolds <- function(x, v) {

check_v <- function(v, max_v, rows = "rows", prevent_loo = TRUE, call = rlang::caller_env()) {
if (!is.numeric(v) || length(v) != 1 || v < 2) {
rlang::abort("`v` must be a single positive integer greater than 1", call = call)
cli_abort("{.arg v} must be a single positive integer greater than 1.", call = call)
} else if (v > max_v) {
rlang::abort(
glue::glue("The number of {rows} is less than `v = {v}`"), call = call
cli_abort(
"The number of {rows} is less than {.arg v} = {.val {v}}.",
call = call
)
} else if (prevent_loo && isTRUE(v == max_v)) {
cli_abort(c(
Expand All @@ -364,14 +363,14 @@ check_grouped_strata <- function(group, strata, pool, data) {

if (nrow(vctrs::vec_unique(grouped_table)) !=
nrow(vctrs::vec_unique(grouped_table["group"]))) {
rlang::abort("`strata` must be constant across all members of each `group`.")
cli_abort("{.arg strata} must be constant across all members of each {.arg group}.")
}

strata
}

check_repeats <- function(repeats, call = rlang::caller_env()) {
if (!is.numeric(repeats) || length(repeats) != 1 || repeats < 1) {
rlang::abort("`repeats` must be a single positive integer", call = call)
cli_abort("{.arg repeats} must be a single positive integer.", call = call)
}
}
10 changes: 5 additions & 5 deletions tests/testthat/_snaps/clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

---

`v` must be a single positive integer greater than 1
`v` must be a single positive integer greater than 1.

---

The number of rows is less than `v = 500`
The number of rows is less than `v` = 500.

---

Expand All @@ -20,15 +20,15 @@
clustering_cv(Orange, v = 1, vars = "Tree")
Condition
Error in `clustering_cv()`:
! `v` must be a single positive integer greater than 1
! `v` must be a single positive integer greater than 1.

---

`repeats` must be a single positive integer
`repeats` must be a single positive integer.

---

`repeats` must be a single positive integer
`repeats` must be a single positive integer.

---

Expand Down
1 change: 1 addition & 0 deletions tests/testthat/_snaps/validation_set.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
Condition
Error in `testing()`:
! The testing data is not part of the validation set object.
i It is part of the result of the initial 3-way split, e.g., with `initial_validation_split()`.

20 changes: 10 additions & 10 deletions tests/testthat/_snaps/vfold.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,31 @@

# bad args

`v` must be a single positive integer greater than 1
`v` must be a single positive integer greater than 1.

---

`v` must be a single positive integer greater than 1
`v` must be a single positive integer greater than 1.

---

`v` must be a single positive integer greater than 1
`v` must be a single positive integer greater than 1.

---

The number of rows is less than `v = 500`
The number of rows is less than `v` = 500.

---

Repeated resampling when `v` is 150 would create identical resamples
Repeated resampling when `v` is 150 would create identical resamples.

---

`repeats` must be a single positive integer
`repeats` must be a single positive integer.

---

`repeats` must be a single positive integer
`repeats` must be a single positive integer.

---

Expand Down Expand Up @@ -67,19 +67,19 @@

# grouping -- bad args

Repeated resampling when `v` is 4 would create identical resamples
Repeated resampling when `v` is 4 would create identical resamples.

---

Repeated resampling when `v` is `NULL` would create identical resamples
Repeated resampling when `v` is "NULL" would create identical resamples.

---

Code
group_vfold_cv(Orange, v = 1, group = "Tree")
Condition
Error in `group_vfold_cv()`:
! `v` must be a single positive integer greater than 1
! `v` must be a single positive integer greater than 1.

# grouping -- other balance methods

Expand Down
Loading