Skip to content

Commit

Permalink
issue #1209: tidymodels indexing (#1227)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock authored Oct 5, 2024
1 parent 6061942 commit 936b792
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 3 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: marginaleffects
Title: Predictions, Comparisons, Slopes, Marginal Means, and Hypothesis Tests
Version: 0.22.0.5
Version: 0.22.0.6
Authors@R:
c(person(given = "Vincent",
family = "Arel-Bundock",
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Bugs:
* `hypotheses(joint = TRUE)` would throw an error if sample sizes could not be computed, even if they were not needed. Thanks to Noah Greifer.
* `hypotheses(joint = TRUE)` respects the `vcov` argument. Thanks to @kennchua for report #1214.
* `ordbetareg` models in `glmmTMB` are now supported. Thanks to @jgeller112 for code contribution #1221.
* `tidymodels()`: Indexing overrode the value of predictors in the output data frame. The numerical estimates were unaffected. Thanks to @agmath for report #1209.

## 0.22.0

Expand Down
4 changes: 2 additions & 2 deletions R/methods_tidymodels.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ get_predict.model_fit <- function(model, newdata, type = NULL, ...) {

if (type == "numeric") {
v <- intersect(c(".pred", ".pred_res"), colnames(out))[1]
out <- data.frame(rowid = seq_along(out), estimate = out[[v]])
out <- data.frame(rowid = seq_len(nrow(out)), estimate = out[[v]])

} else if (type == "class") {
out <- data.frame(rowid = seq_along(out), estimate = out[[".pred_class"]])
out <- data.frame(rowid = seq_len(nrow(out)), estimate = out[[".pred_class"]])

} else if (type == "prob") {
colnames(out) <- substr(colnames(out), 7, nchar(colnames(out)))
Expand Down
18 changes: 18 additions & 0 deletions inst/tinytest/test-pkg-tidymodels.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ p <- plot_comparisons(fit,
expect_inherits(p, "data.frame")


# Issue 1209
nobs <- 50
my_data <- tibble(
x = runif(nobs, 0, 10),
y = -(x - 11)^2 + 100 + rnorm(nobs, 0, 25)
)
lr_spec <- linear_reg()
lr_rec <- recipe(y ~ x, data = my_data) |>
step_poly(x, degree = 2)
lr_wf <- workflow() |>
add_model(lr_spec) |>
add_recipe(lr_rec)
lr_fit <- lr_wf |>
fit(my_data)
mfx <- slopes(lr_fit, newdata = my_data, variable = "x")
expect_equivalent(mfx$x, my_data$x)
expect_equivalent(mfx$y, my_data$y)



rm(list = ls())
Expand Down

0 comments on commit 936b792

Please sign in to comment.