Skip to content

Commit

Permalink
Merge pull request #278 from tidymodels/consolidate-contr_one_hot
Browse files Browse the repository at this point in the history
Consolidate `contr_one_hot()`
  • Loading branch information
hfrick authored Jan 28, 2025
2 parents 7aaa32f + d9c7474 commit 986abc6
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 16 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hardhat
Title: Construct Modeling Packages
Version: 1.4.0.9002
Version: 1.4.0.9003
Authors@R: c(
person("Hannah", "Frick", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0002-6049-5258")),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export(check_outcomes_are_univariate)
export(check_prediction_size)
export(check_predictors_are_numeric)
export(check_quantile_levels)
export(contr_one_hot)
export(create_modeling_package)
export(default_formula_blueprint)
export(default_recipe_blueprint)
Expand Down
18 changes: 14 additions & 4 deletions R/model-matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,28 +171,38 @@ model_matrix_one_hot <- function(terms, data, ..., call = caller_env()) {
#' This contrast function produces a model matrix that has indicator columns for
#' each level of each factor.
#'
#' @param n A vector of character factor levels or the number of unique levels.
#' @param n A vector of character factor levels (of length >=1) or the number
#' of unique levels (>= 1).
#' @param contrasts This argument is for backwards compatibility and only the
#' default of `TRUE` is supported.
#' @param sparse This argument is for backwards compatibility and only the
#' default of `FALSE` is supported.
#'
#' @includeRmd man/rmd/one-hot.Rmd details
#'
#' @return A diagonal matrix that is `n`-by-`n`.
#'
#' @keywords internal
#' @export
contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {
if (sparse) {
cli::cli_warn("{.code sparse = TRUE} not implemented for {.fn contr_one_hot}.")
cli::cli_warn("{.code sparse = TRUE} not implemented for {.fun contr_one_hot}.")
}

if (!contrasts) {
cli::cli_warn("{.code contrasts = FALSE} not implemented for {.fn contr_one_hot}.")
cli::cli_warn(
"{.code contrasts = FALSE} not implemented for {.fun contr_one_hot}."
)
}

if (is.character(n)) {
if (length(n) < 1) {
cli::cli_abort("{.arg n} cannot be empty.")
}
names <- n
n <- length(names)
} else if (is.numeric(n)) {
check_number_whole(n, min = 1)
n <- as.integer(n)

if (length(n) != 1L) {
Expand All @@ -201,7 +211,7 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {

names <- as.character(seq_len(n))
} else {
cli::cli_abort("{.arg n} must be a character vector or an integer of size 1.")
check_number_whole(n, min = 1)
}

out <- diag(n)
Expand Down
71 changes: 70 additions & 1 deletion man/contr_one_hot.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 47 additions & 0 deletions man/rmd/one-hot.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
```{r load, include = FALSE}
library(dplyr)
```

By default, `model.matrix()` generates binary indicator variables for factor predictors. When the formula does not remove an intercept, an incomplete set of indicators are created; no indicator is made for the first level of the factor.

For example, `species` and `island` both have three levels but `model.matrix()` creates two indicator variables for each:

```{r ref-cell}
library(dplyr)
library(modeldata)
data(penguins)
levels(penguins$species)
levels(penguins$island)
model.matrix(~ species + island, data = penguins) %>%
colnames()
```

For a formula with no intercept, the first factor is expanded to indicators for _all_ factor levels but all other factors are expanded to all but one (as above):

```{r hybrid}
model.matrix(~ 0 + species + island, data = penguins) %>%
colnames()
```

For inference, this hybrid encoding can be problematic.

To generate all indicators, use this contrast:

```{r one-hot}
# Switch out the contrast method
old_contr <- options("contrasts")$contrasts
new_contr <- old_contr
new_contr["unordered"] <- "contr_one_hot"
options(contrasts = new_contr)
model.matrix(~ species + island, data = penguins) %>%
colnames()
options(contrasts = old_contr)
```

Removing the intercept here does not affect the factor encodings.


42 changes: 34 additions & 8 deletions tests/testthat/_snaps/model-matrix.md
Original file line number Diff line number Diff line change
@@ -1,38 +1,64 @@
# `contr_one_hot()` input checks

Code
contr_one_hot(n = 1, sparse = TRUE)
contr_one_hot(n = 2, sparse = TRUE)
Condition
Warning:
`sparse = TRUE` not implemented for `contr_one_hot()`.
Output
1
1 1
1 2
1 1 0
2 0 1

---

Code
contr_one_hot(n = 1, contrasts = FALSE)
contr_one_hot(n = 2, contrasts = FALSE)
Condition
Warning:
`contrasts = FALSE` not implemented for `contr_one_hot()`.
Output
1
1 1
1 2
1 1 0
2 0 1

---

Code
contr_one_hot(n = 1:2)
Condition
Error in `contr_one_hot()`:
! `n` must have length 1 when an integer is provided.
! `n` must be a whole number, not an integer vector.

---

Code
contr_one_hot(n = list(1:2))
Condition
Error in `contr_one_hot()`:
! `n` must be a character vector or an integer of size 1.
! `n` must be a whole number, not a list.

---

Code
contr_one_hot(character(0))
Condition
Error in `contr_one_hot()`:
! `n` cannot be empty.

---

Code
contr_one_hot(-1)
Condition
Error in `contr_one_hot()`:
! `n` must be a whole number larger than or equal to 1, not the number -1.

---

Code
contr_one_hot(list())
Condition
Error in `contr_one_hot()`:
! `n` must be a whole number, not an empty list.

28 changes: 26 additions & 2 deletions tests/testthat/test-model-matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,36 @@ test_that("`model_matrix()` strips all attributes from the `model.matrix()` resu
})

test_that("`contr_one_hot()` input checks", {
expect_snapshot(contr_one_hot(n = 1, sparse = TRUE))
expect_snapshot(contr_one_hot(n = 1, contrasts = FALSE))
expect_snapshot(contr_one_hot(n = 2, sparse = TRUE))
expect_snapshot(contr_one_hot(n = 2, contrasts = FALSE))

expect_snapshot(error = TRUE, {
contr_one_hot(n = 1:2)
})
expect_snapshot(error = TRUE, {
contr_one_hot(n = list(1:2))
})
expect_snapshot(error = TRUE, {
contr_one_hot(character(0))
})
expect_snapshot(error = TRUE, {
contr_one_hot(-1)
})
expect_snapshot(error = TRUE, {
contr_one_hot(list())
})
})

test_that("one-hot encoding contrasts", {
contr_mat <- contr_one_hot(12)
expect_equal(colnames(contr_mat), paste(1:12))
expect_equal(rownames(contr_mat), paste(1:12))
expect_true(all(apply(contr_mat, 1, sum) == 1))
expect_true(all(apply(contr_mat, 2, sum) == 1))

chr_contr_mat <- contr_one_hot(letters[1:12])
expect_equal(colnames(chr_contr_mat), letters[1:12])
expect_equal(rownames(chr_contr_mat), letters[1:12])
expect_true(all(apply(chr_contr_mat, 1, sum) == 1))
expect_true(all(apply(chr_contr_mat, 2, sum) == 1))
})

0 comments on commit 986abc6

Please sign in to comment.