From d4d8f2a07560195bff0b290e30ec1b4c68026d78 Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Thu, 10 Jan 2019 09:27:23 -0600 Subject: [PATCH] Ensure all window translation get matching aggregate translation with clear errro Fixes #129 --- NEWS.md | 3 +++ R/translate-sql-base.r | 5 +++++ R/translate-sql-helpers.r | 16 ++++++++++++++++ tests/testthat/test-translate-sql-helpers.r | 9 +++++++++ 4 files changed, 33 insertions(+) diff --git a/NEWS.md b/NEWS.md index 39d3ba05e..875214e3e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # dbplyr (development version) +* Functions that are only available in a windowed (`mutate()`) query now + throw an error when called in a aggregate (`summarise()`) query (#129) + * `na_if()` is translated to `NULLIF()` for all databases (#211). * SQL translation (via `partial_eval()`) now correctly interprets the diff --git a/R/translate-sql-base.r b/R/translate-sql-base.r index fb1c967a6..623840924 100644 --- a/R/translate-sql-base.r +++ b/R/translate-sql-base.r @@ -217,6 +217,11 @@ base_agg <- sql_translator( sum = sql_aggregate("sum"), min = sql_aggregate("min"), max = sql_aggregate("max"), + + # first = sql_prefix("FIRST_VALUE", 1), + # last = sql_prefix("LAST_VALUE", 1), + # nth = sql_prefix("NTH_VALUE", 2), + n_distinct = function(...) { vars <- sql_vector(list(...), parens = FALSE, collapse = ", ") build_sql("COUNT(DISTINCT ", vars, ")") diff --git a/R/translate-sql-helpers.r b/R/translate-sql-helpers.r index e5eedca8e..409bcb371 100644 --- a/R/translate-sql-helpers.r +++ b/R/translate-sql-helpers.r @@ -74,6 +74,11 @@ sql_variant <- function(scalar = sql_translator(), )) } + # An ensure that every window function is flagged in aggregate context + missing <- setdiff(ls(window), ls(aggregate)) + missing_funs <- map(missing, sql_aggregate_win) + env_bind(aggregate, !!!set_names(missing_funs, missing)) + structure( list(scalar = scalar, aggregate = aggregate, window = window), class = "sql_variant" @@ -187,6 +192,17 @@ sql_aggregate_2 <- function(f) { } } +sql_aggregate_win <- function(f) { + force(f) + + function(...) { + stop( + "`", f, "()` is only available in a windowed (`mutate()`) context", + call. = FALSE + ) + } +} + check_na_rm <- function(f, na.rm) { if (identical(na.rm, TRUE)) { diff --git a/tests/testthat/test-translate-sql-helpers.r b/tests/testthat/test-translate-sql-helpers.r index 3b23dbdcd..c36ea5ffa 100644 --- a/tests/testthat/test-translate-sql-helpers.r +++ b/tests/testthat/test-translate-sql-helpers.r @@ -18,6 +18,15 @@ test_that("missing window functions create a warning", { ) }) +test_that("missing aggregate functions filled in", { + sim_scalar <- sql_translator() + sim_agg <- sql_translator() + sim_win <- sql_translator(mean = function() {}) + + trans <- sql_variant(sim_scalar, sim_agg, sim_win) + expect_error(trans$aggregate$mean(), "only available in a window") +}) + test_that("output of print method for sql_variant is correct", { sim_trans <- sql_translator(`+` = sql_infix("+")) expect_known_output(