Skip to content

Commit

Permalink
rows_*() fixes (#1347)
Browse files Browse the repository at this point in the history
* Rename `get_col_types()` to `db_col_types()`

* Pass table instead of string to `db_col_types()`

* Rename to `target_table()`

* Document `db_col_types()`
  • Loading branch information
mgirlich authored Aug 8, 2023
1 parent 5fa4410 commit 6a43c63
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 72 deletions.
12 changes: 8 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ S3method(copy_to,src_sql)
S3method(count,tbl_lazy)
S3method(cross_join,tbl_lazy)
S3method(db_analyze,DBIConnection)
S3method(db_col_types,DBIConnection)
S3method(db_col_types,MariaDBConnection)
S3method(db_col_types,MySQL)
S3method(db_col_types,MySQLConnection)
S3method(db_col_types,PostgreSQL)
S3method(db_col_types,PqConnection)
S3method(db_col_types,TestConnection)
S3method(db_collect,DBIConnection)
S3method(db_compute,DBIConnection)
S3method(db_connection_describe,DBIConnection)
Expand Down Expand Up @@ -135,10 +142,6 @@ S3method(format,ident)
S3method(format,sql)
S3method(format,src_sql)
S3method(full_join,tbl_lazy)
S3method(get_col_types,DBIConnection)
S3method(get_col_types,MariaDBConnection)
S3method(get_col_types,PqConnection)
S3method(get_col_types,TestConnection)
S3method(group_by,tbl_lazy)
S3method(group_size,tbl_sql)
S3method(group_vars,tbl_lazy)
Expand Down Expand Up @@ -429,6 +432,7 @@ export(build_sql)
export(copy_inline)
export(copy_lahman)
export(copy_nycflights13)
export(db_col_types)
export(db_collect)
export(db_compute)
export(db_connection_describe)
Expand Down
11 changes: 11 additions & 0 deletions R/backend-mysql.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ db_connection_describe.MySQL <- db_connection_describe.MariaDBConnection
#' @export
db_connection_describe.MySQLConnection <- db_connection_describe.MariaDBConnection

#' @export
db_col_types.MariaDBConnection <- function(con, table, call) {
table <- as_table_ident(table, error_call = call)
col_info_df <- DBI::dbGetQuery(con, glue_sql2(con, "SHOW COLUMNS FROM {.tbl table};"))
set_names(col_info_df[["Type"]], col_info_df[["Field"]])
}
#' @export
db_col_types.MySQL <- db_col_types.MariaDBConnection
#' @export
db_col_types.MySQLConnection <- db_col_types.MariaDBConnection

#' @export
sql_translation.MariaDBConnection <- function(con) {
sql_variant(
Expand Down
13 changes: 13 additions & 0 deletions R/backend-postgres.R
Original file line number Diff line number Diff line change
Expand Up @@ -404,4 +404,17 @@ db_supports_table_alias_with_as.PostgreSQL <- function(con) {
TRUE
}

#' @export
db_col_types.PqConnection <- function(con, table, call) {
table <- as_table_ident(table, error_call = call)
res <- DBI::dbSendQuery(con, glue_sql2(con, "SELECT * FROM {.tbl table} LIMIT 0"))
on.exit(DBI::dbClearResult(res))
DBI::dbFetch(res, n = 0)
col_info_df <- DBI::dbColumnInfo(res)
set_names(col_info_df[[".typname"]], col_info_df[["name"]])
}

#' @export
db_col_types.PostgreSQL <- db_col_types.PqConnection

utils::globalVariables(c("strpos", "%::%", "%FROM%", "%ILIKE%", "DATE", "EXTRACT", "TO_CHAR", "string_agg", "%~*%", "%~%", "MONTH", "DOY", "DATE_TRUNC", "INTERVAL", "FLOOR", "WEEK"))
22 changes: 22 additions & 0 deletions R/db.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#' * `dbplyr_edition()` declares which version of the dbplyr API you want.
#' See below for more details.
#'
#' * `db_col_types()` returns the column types of a table.
#'
#' @section dbplyr 2.0.0:
#' dbplyr 2.0.0 renamed a number of generics so that they could be cleanly moved
#' from dplyr to dbplyr. If you have an existing backend, you'll need to rename
Expand Down Expand Up @@ -80,6 +82,26 @@ db_sql_render.DBIConnection <- function(con, sql, ..., cte = FALSE, sql_options
sql_render(sql, con = con, ..., sql_options = sql_options)
}

#' @rdname db-misc
#' @export
db_col_types <- function(con, table, call) {
if (is_null(table)) {
return(NULL)
}

UseMethod("db_col_types")
}

#' @export
db_col_types.TestConnection <- function(con, table, call) {
NULL
}

#' @export
db_col_types.DBIConnection <- function(con, table, call) {
NULL
}

#' Options for generating SQL
#'
#' @param cte If `FALSE`, the default, subqueries are used. If `TRUE` common
Expand Down
85 changes: 25 additions & 60 deletions R/rows.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ rows_insert.tbl_lazy <- function(x,
method = NULL) {
check_dots_empty()
rows_check_in_place(x, in_place)
name <- target_table_name(x, in_place)
table <- target_table(x, in_place)

conflict <- rows_check_conflict(conflict)

Expand All @@ -111,10 +111,10 @@ rows_insert.tbl_lazy <- function(x,

returning_cols <- rows_check_returning(x, returning, enexpr(returning))

if (!is_null(name)) {
if (!is_null(table)) {
sql <- sql_query_insert(
con = remote_con(x),
table = name,
table = table,
from = sql_render(y, remote_con(x), lvl = 1),
insert_cols = colnames(y),
by = by,
Expand Down Expand Up @@ -154,17 +154,17 @@ rows_append.tbl_lazy <- function(x,
returning = NULL) {
check_dots_empty()
rows_check_in_place(x, in_place)
name <- target_table_name(x, in_place)
table <- target_table(x, in_place)

rows_check_containment(x, y)
y <- rows_auto_copy(x, y, copy = copy)

returning_cols <- rows_check_returning(x, returning, enexpr(returning))

if (!is_null(name)) {
if (!is_null(table)) {
sql <- sql_query_append(
con = remote_con(x),
table = name,
table = table,
from = sql_render(y, remote_con(x), lvl = 1),
insert_cols = colnames(y),
...,
Expand Down Expand Up @@ -202,7 +202,7 @@ rows_update.tbl_lazy <- function(x,
returning = NULL) {
check_dots_empty()
rows_check_in_place(x, in_place)
name <- target_table_name(x, in_place)
table <- target_table(x, in_place)

rows_check_containment(x, y)
y <- rows_auto_copy(x, y, copy = copy)
Expand All @@ -219,7 +219,7 @@ rows_update.tbl_lazy <- function(x,
returning_cols <- rows_check_returning(x, returning, enexpr(returning))


if (!is_null(name)) {
if (!is_null(table)) {
# TODO handle `returning_cols` here
if (is_empty(new_columns)) {
return(invisible(x))
Expand All @@ -234,7 +234,7 @@ rows_update.tbl_lazy <- function(x,

sql <- sql_query_update_from(
con = con,
table = name,
table = table,
from = sql_render(y, remote_con(y), lvl = 1),
by = by,
update_values = update_values,
Expand Down Expand Up @@ -282,7 +282,7 @@ rows_patch.tbl_lazy <- function(x,
returning = NULL) {
check_dots_empty()
rows_check_in_place(x, in_place)
name <- target_table_name(x, in_place)
table <- target_table(x, in_place)

rows_check_containment(x, y)
y <- rows_auto_copy(x, y, copy = copy)
Expand All @@ -298,7 +298,7 @@ rows_patch.tbl_lazy <- function(x,

returning_cols <- rows_check_returning(x, returning, enexpr(returning))

if (!is_null(name)) {
if (!is_null(table)) {
# TODO handle `returning_cols` here
if (is_empty(new_columns)) {
return(invisible(x))
Expand All @@ -308,14 +308,14 @@ rows_patch.tbl_lazy <- function(x,

update_cols <- setdiff(colnames(y), by)
update_values <- sql_coalesce(
sql_table_prefix(con, update_cols, name),
sql_table_prefix(con, update_cols, table),
sql_table_prefix(con, update_cols, "...y")
)
update_values <- set_names(update_values, update_cols)

sql <- sql_query_update_from(
con = con,
table = name,
table = table,
from = sql_render(y, remote_con(y), lvl = 1),
by = by,
update_values = update_values,
Expand Down Expand Up @@ -371,7 +371,7 @@ rows_upsert.tbl_lazy <- function(x,
method = NULL) {
check_dots_empty()
rows_check_in_place(x, in_place)
name <- target_table_name(x, in_place)
table <- target_table(x, in_place)

rows_check_containment(x, y)
y <- rows_auto_copy(x, y, copy = copy)
Expand All @@ -385,15 +385,15 @@ rows_upsert.tbl_lazy <- function(x,

new_columns <- setdiff(colnames(y), by)

if (!is_null(name)) {
if (!is_null(table)) {
# TODO use `rows_insert()` here
if (is_empty(new_columns)) {
return(invisible(x))
}

sql <- sql_query_upsert(
con = remote_con(x),
table = name,
table = table,
from = sql_render(y, remote_con(x), lvl = 1),
by = by,
update_cols = setdiff(colnames(y), by),
Expand Down Expand Up @@ -446,7 +446,7 @@ rows_delete.tbl_lazy <- function(x,
returning = NULL) {
check_dots_empty()
rows_check_in_place(x, in_place)
name <- target_table_name(x, in_place)
table <- target_table(x, in_place)

rows_check_containment(x, y)
y <- rows_auto_copy(x, y, copy = copy)
Expand All @@ -466,10 +466,10 @@ rows_delete.tbl_lazy <- function(x,
inform(message, class = c("dplyr_message_delete_extra_cols", "dplyr_message"))
}

if (!is_null(name)) {
if (!is_null(table)) {
sql <- sql_query_delete(
con = remote_con(x),
table = name,
table = table,
from = sql_render(y, remote_con(x), lvl = 2),
by = by,
...,
Expand Down Expand Up @@ -704,18 +704,18 @@ tick <- function(x) {

# other helpers -----------------------------------------------------------

target_table_name <- function(x, in_place) {
target_table <- function(x, in_place) {
# Never touch target table with `in_place = FALSE`
if (!is_true(in_place)) {
return(NULL)
}

name <- remote_name(x)
if (is_null(name)) {
table <- remote_table(x)
if (is_null(table)) {
cli_abort("Can't determine name for target table. Set {.code in_place = FALSE} to return a lazy table.")
}

ident(name)
table
}

rows_prep <- function(con, table, from, by, lvl = 0) {
Expand Down Expand Up @@ -747,8 +747,8 @@ rows_auto_copy <- function(x, y, copy, call = caller_env()) {
return(y)
}

name <- remote_name(x)
x_types <- get_col_types(remote_con(x), name, call)
table <- remote_table(x)
x_types <- db_col_types(remote_con(x), table, call)

if (!is_null(x_types)) {
rows_check_containment(x, y, error_call = call)
Expand All @@ -758,41 +758,6 @@ rows_auto_copy <- function(x, y, copy, call = caller_env()) {
auto_copy(x, y, copy = copy, types = x_types)
}

get_col_types <- function(con, name, call) {
if (is_null(name)) {
return(NULL)
}

UseMethod("get_col_types")
}

#' @export
get_col_types.TestConnection <- function(con, name, call) {
NULL
}

#' @export
get_col_types.DBIConnection <- function(con, name, call) {
NULL
}

#' @export
get_col_types.PqConnection <- function(con, name, call) {
name <- as_table_ident(name, error_call = call)
res <- DBI::dbSendQuery(con, glue_sql2(con, "SELECT * FROM {.tbl name} LIMIT 0"))
on.exit(DBI::dbClearResult(res))
DBI::dbFetch(res, n = 0)
col_info_df <- DBI::dbColumnInfo(res)
set_names(col_info_df[[".typname"]], col_info_df[["name"]])
}

#' @export
get_col_types.MariaDBConnection <- function(con, name, call) {
name <- as_table_ident(name, error_call = call)
col_info_df <- DBI::dbGetQuery(con, glue_sql2(con, "SHOW COLUMNS FROM {.tbl name};"))
set_names(col_info_df[["Type"]], col_info_df[["Field"]])
}

rows_get_or_execute <- function(x, sql, returning_cols, call = caller_env()) {
con <- remote_con(x)
msg <- "Can't modify database table {.val {remote_name(x)}}."
Expand Down
4 changes: 4 additions & 0 deletions man/db-misc.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/_snaps/rows.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
(df %>% mutate(x = x + 1) %>% rows_insert(df, by = "x", conflict = "ignore",
in_place = TRUE))
Condition
Error in `target_table_name()`:
Error in `target_table()`:
! Can't determine name for target table. Set `in_place = FALSE` to return a lazy table.

---
Expand Down Expand Up @@ -275,7 +275,7 @@
(df %>% mutate(x = x + 1) %>% rows_update(df, by = "x", unmatched = "ignore",
in_place = TRUE))
Condition
Error in `target_table_name()`:
Error in `target_table()`:
! Can't determine name for target table. Set `in_place = FALSE` to return a lazy table.

# `rows_update()` works with `in_place = FALSE`
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-backend-mysql.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,6 @@ test_that("casts `y` column for local df", {
expect_equal(tbl(con, "df_x") %>% collect(), out)

types_expected <- c(id = "bigint(20)", val = "bigint(20)", ltext = "longtext")
expect_equal(get_col_types(con, table2), types_expected)
expect_equal(get_col_types(con, in_schema("test", "df_x")), types_expected)
expect_equal(db_col_types(con, table2), types_expected)
expect_equal(db_col_types(con, in_schema("test", "df_x")), types_expected)
})
Loading

0 comments on commit 6a43c63

Please sign in to comment.