From a6e4fb726c7c34fdfa59c3741ef98ddc6f3b9e32 Mon Sep 17 00:00:00 2001 From: Kirill Simonov Date: Sat, 3 Jun 2023 00:08:04 -0400 Subject: [PATCH] Wrap catalog functions (#374) * Wrap SQLTables and SQLColumns catalog functions * Add tests for catalog functions --- docs/src/index.md | 6 ++++++ src/API.jl | 34 +++++++++++++++++++++++++++++ src/ODBC.jl | 1 + src/catalog.jl | 55 +++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++++ 5 files changed, 100 insertions(+) create mode 100644 src/catalog.jl diff --git a/docs/src/index.md b/docs/src/index.md index 4dad14f..9cc268e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -171,6 +171,12 @@ DBInterface.executemultiple ODBC.load ``` +### Catalog functions +```@docs +ODBC.tables +ODBC.columns +``` + ### ODBC administrative functions ```@docs ODBC.drivers diff --git a/src/API.jl b/src/API.jl index 8f5be79..f92fc40 100644 --- a/src/API.jl +++ b/src/API.jl @@ -564,6 +564,40 @@ function diagnostics(h::Handle) return String(take!(io)) end +function catalogstr(str::AbstractString) + buf = cwstring(str) + return (buf, length(buf)) +end + +catalogstr(::Union{Nothing, Missing}) = + return (C_NULL, 0) + +function SQLTables(stmt::Ptr{Cvoid}, catalogname, schemaname, tablename, tabletype) + c, clen = catalogstr(catalogname) + s, slen = catalogstr(schemaname) + t, tlen = catalogstr(tablename) + tt, ttlen = catalogstr(tabletype) + @odbc(:SQLTablesW, + (Ptr{Cvoid}, Ptr{SQLWCHAR}, SQLSMALLINT, Ptr{SQLWCHAR}, SQLSMALLINT, Ptr{SQLWCHAR}, SQLSMALLINT, Ptr{SQLWCHAR}, SQLSMALLINT), + stmt, c, clen, s, slen, t, tlen, tt, ttlen) +end + +tables(stmt::Handle, catalogname, schemaname, tablename, tabletype) = + @checksuccess stmt SQLTables(getptr(stmt), catalogname, schemaname, tablename, tabletype) + +function SQLColumns(stmt::Ptr{Cvoid}, catalogname, schemaname, tablename, columnname) + c, clen = catalogstr(catalogname) + s, slen = catalogstr(schemaname) + t, tlen = catalogstr(tablename) + col, collen = catalogstr(columnname) + @odbc(:SQLColumnsW, + (Ptr{Cvoid}, Ptr{SQLWCHAR}, SQLSMALLINT, Ptr{SQLWCHAR}, SQLSMALLINT, Ptr{SQLWCHAR}, SQLSMALLINT, Ptr{SQLWCHAR}, SQLSMALLINT), + stmt, c, clen, s, slen, t, tlen, col, collen) +end + +columns(stmt::Handle, catalogname, schemaname, tablename, columnname) = + @checksuccess stmt SQLColumns(getptr(stmt), catalogname, schemaname, tablename, columnname) + macro checkinst(expr) esc(quote ret = $expr diff --git a/src/ODBC.jl b/src/ODBC.jl index 41d86ac..fbf612f 100644 --- a/src/ODBC.jl +++ b/src/ODBC.jl @@ -8,6 +8,7 @@ include("API.jl") include("utils.jl") include("dbinterface.jl") include("load.jl") +include("catalog.jl") """ ODBC.setdebug(debug::Bool=true, tracefile::String=joinpath(tempdir(), "odbc.log")) diff --git a/src/catalog.jl b/src/catalog.jl new file mode 100644 index 0000000..b35a3e4 --- /dev/null +++ b/src/catalog.jl @@ -0,0 +1,55 @@ +# Catalog functions. + +""" + tables(conn; catalogname=nothing, schemaname=nothing, tablename=nothing, tabletype=nothing) -> ODBC.Cursor + +Find tables by the given criteria. This function returns a `Cursor` object that +produces one row per matching table. + +Search criteria include: + * `catalogname`: search pattern for catalog names + * `schemaname`: search pattern for schema names + * `tablename`: search pattern for table names + * `tabletypes`: comma-separated list of table types + +A search pattern may contain an underscore (`_`) to represent any single character +and a percent sign (`%`) to represent any sequence of zero or more characters. +Use an escape character (driver-specific, but usually `\\`) to include underscores, +percent signs, and escape characters as literals. +""" +function tables(conn; catalogname=nothing, schemaname=nothing, tablename=nothing, tabletype=nothing) + clear!(conn) + stmt = API.Handle(API.SQL_HANDLE_STMT, API.getptr(conn.dbc)) + conn.stmts[stmt] = 0 + conn.cursorstmt = stmt + API.enableasync(stmt) + API.tables(stmt, catalogname, schemaname, tablename, tabletype) + return Cursor(stmt) +end + +""" + columns(conn; catalogname=nothing, schemaname=nothing, tablename=nothing, columnname=nothing) -> ODBC.Cursor + +Find columns by the given criteria. This function returns a `Cursor` object that +produces one row per matching column. + +Search criteria include: + * `catalogname`: name of the catalog + * `schemaname`: search pattern for schema names + * `tablename`: search pattern for table names + * `columnname`: search pattern for column names + +A search pattern may contain an underscore (`_`) to represent any single character +and a percent sign (`%`) to represent any sequence of zero or more characters. +Use an escape character (driver-specific, but usually `\\`) to include underscores, +percent signs, and escape characters as literals. +""" +function columns(conn; catalogname=nothing, schemaname=nothing, tablename=nothing, columnname=nothing) + clear!(conn) + stmt = API.Handle(API.SQL_HANDLE_STMT, API.getptr(conn.dbc)) + conn.stmts[stmt] = 0 + conn.cursorstmt = stmt + API.enableasync(stmt) + API.columns(stmt, catalogname, schemaname, tablename, columnname) + return Cursor(stmt) +end diff --git a/test/runtests.jl b/test/runtests.jl index 9647fcc..9b4aaad 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -283,6 +283,10 @@ DBInterface.execute(conn, "INSERT INTO big_decimal (`dec`) VALUES (1234567890123 ret = DBInterface.execute(conn, "select * from big_decimal") |> columntable @test ret.dec[1] == d128"1.2345678901234567891e17" +ret = ODBC.tables(conn, tablename="emp%") |> columntable +@test ret.TABLE_NAME == ["Employee", "Employee2", "Employee_copy"] +ret = ODBC.columns(conn, tablename="emp%", columnname="望研") |> columntable +@test ret.COLUMN_NAME == ["望研"] DBInterface.execute(conn, """DROP USER IF EXISTS 'authtest'""") DBInterface.execute(conn, """CREATE USER 'authtest' IDENTIFIED BY 'authtestpw'""")