-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement uncount()
#352
Comments
Looks like we would need to make |
There is already a tidyr issue tidyverse/tidyr#1071 for this. In dbplyr we simply added |
Should we add |
It might take some time for tidyverse/tidyr#1101 to be merged (the interface of most of these functions is not completely stable yet). Therefore, I created a separate PR tidyverse/tidyr#1358 for |
Here's an initial implementation that I think covers everything. Once tidyverse/tidyr#1358 is merged we can create this as an Edit: I'm not sure if there's a better way to deal with creating the library(data.table)
library(dtplyr)
library(dplyr, warn.conflicts = FALSE)
library(tidyr, warn.conflicts = FALSE)
dtplyr_uncount <- function(data, weights, ..., .remove = TRUE, .id = NULL) {
weights <- enquo(weights)
needs_id <- !is.null(.id)
if (needs_id) {
.reps <- pull(data, !!weights)
}
out <- slice(data, rep(1:.N, !!weights))
if (needs_id) {
out <- mutate(out, !!.id := sequence(!!.reps))
}
if (.remove) {
out <- select(out, -!!weights)
}
out
}
df <- data.table(x = c("a", "b"), n = c(1, 2))
df %>%
dtplyr_uncount(n)
#> Source: local data table [3 x 1]
#> Call: `_DT1`[rep(1:.N, n)[between(rep(1:.N, n), -.N, .N)], .(x)]
#>
#> x
#> <chr>
#> 1 a
#> 2 b
#> 3 b
#>
#> # Use as.data.table()/as.data.frame()/as_tibble() to access results
df %>%
dtplyr_uncount(n, .id = "id", .remove = FALSE)
#> Source: local data table [3 x 3]
#> Call: `_DT2`[rep(1:.N, n)[between(rep(1:.N, n), -.N, .N)]][, `:=`(id = sequence(c(1,
#> 2)))]
#>
#> x n id
#> <chr> <dbl> <int>
#> 1 a 1 1
#> 2 b 2 1
#> 3 b 2 2
#>
#> # Use as.data.table()/as.data.frame()/as_tibble() to access results |
I think this might be a better way - it avoids the library(data.table)
library(dtplyr)
library(dplyr, warn.conflicts = FALSE)
library(tidyr, warn.conflicts = FALSE)
library(purrr)
step_subset_j <- dtplyr:::step_subset_j
step_subset <- dtplyr:::step_subset
dtplyr_uncount <- function(data, weights, ..., .remove = TRUE, .id = NULL) {
weights <- quo_squash(enquo(weights))
groups <- group_vars(data)
has_groups <- length(groups) > 0
if (has_groups) {
data <- ungroup(data)
}
if (is.null(.id)) {
if (.remove) {
vars <- setdiff(data$vars, as_name(weights))
j <- call2(".", !!!syms(vars))
} else {
vars <- data$vars
j <- NULL
}
out <- step_subset(
data,
vars = vars,
i = expr(rep(1:.N, !!weights)),
j = j
)
} else {
vars_names <- data$vars
if (.remove) {
vars_names <- setdiff(vars_names, as_name(weights))
}
vars <- map(syms(vars_names), ~ expr(rep(!!.x, !!weights)))
names(vars) <- vars_names
vars <- append(vars, exprs(!!.id := sequence(!!weights)))
out <- step_subset_j(
data,
vars = vars_names,
j = call2(".", !!!vars)
)
}
if (has_groups) {
out <- group_by(out, !!!syms(groups))
}
out
}
df <- lazy_dt(data.table(x = c("a", "b"), n = c(1, 2)))
df %>%
dtplyr_uncount(n)
#> Source: local data table [3 x 1]
#> Call: `_DT1`[rep(1:.N, n), .(x)]
#>
#> x
#> <chr>
#> 1 a
#> 2 b
#> 3 b
#>
#> # Use as.data.table()/as.data.frame()/as_tibble() to access results
df %>%
dtplyr_uncount(n, .id = "id", .remove = FALSE)
#> Source: local data table [3 x 3]
#> Call: `_DT1`[, .(x = rep(x, n), n = rep(n, n), id = sequence(n))]
#>
#> x n id
#> <chr> <dbl> <int>
#> 1 a 1 1
#> 2 b 2 1
#> 3 b 2 2
#>
#> # Use as.data.table()/as.data.frame()/as_tibble() to access results |
Would it be faster to just always use the second branch of your function, where you devtools::load_all('/Users/mbp/Documents/GitHub/dtp')
#> ℹ Loading dtplyr
#> Warning: package 'dplyr' was built under R version 4.1.2
# dtplyr_uncount <- [as above]
dtplyr_uncount2 <- function(data, weights, ..., .remove = TRUE, .id = NULL) {
weights <- quo_squash(enquo(weights))
groups <- group_vars(data)
has_groups <- length(groups) > 0
if (has_groups) {
data <- ungroup(data)
}
vars_names <- data$vars
if (.remove) {
vars_names <- setdiff(vars_names, as_name(weights))
}
vars <- map(syms(vars_names), ~ expr(rep(!!.x, !!weights)))
names(vars) <- vars_names
if (!is.null(.id)) {
vars <- append(vars, exprs(!!.id := sequence(!!weights)))
}
out <- step_subset_j(
data,
vars = vars_names,
j = call2(".", !!!vars)
)
if (has_groups) {
out <- group_by(out, !!!syms(groups))
}
out
}
df <- data.table(x = c("a", "b"), n = c(1, 2)) %>%
mutate(n = c(1, 2)*1e6) %>%
lazy_dt()
library(bench)
mark(
a = df %>% dtplyr_uncount(n) %>% collect,
b = df %>% dtplyr_uncount2(n) %>% collect)
#> Warning: Some expressions had a GC in every iteration; so filtering is disabled.
#> # A tibble: 2 × 6
#> expression min median `itr/sec` mem_alloc `gc/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
#> 1 a 55.1ms 60.4ms 13.9 69.4MB 21.8
#> 2 b 19.2ms 31.7ms 25.7 46.1MB 19.8 Created on 2022-05-12 by the reprex package (v2.0.1) |
Looks like it. I hadn’t gotten around to testing it yet. I’m a bit surprised to be honest - I would have assumed the simple slice/select would be much more efficient. |
The text was updated successfully, but these errors were encountered: