Skip to content

Commit

Permalink
Return result sets instead of map
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Jul 6, 2023
1 parent b9bc020 commit 4bf6dd6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 29 deletions.
47 changes: 24 additions & 23 deletions lib/adbc_connection.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ defmodule Adbc.Connection do
Documentation for `Adbc.Connection`.
"""

# TODO: Add prepared queries
# TODO: Result sets
# TODO: Documentation everywhere (including options)
# TODO: Review driver API
# TODO: Tests with postgresql

@type t :: GenServer.server()
@type result_set :: map
@type result_set :: Adbc.Result.t()

use GenServer
import Adbc.Helper, only: [error_to_exception: 1]
Expand Down Expand Up @@ -45,16 +46,16 @@ defmodule Adbc.Connection do
TODO.
"""
def query(conn, query, params \\ []) when is_binary(query) and is_list(params) do
stream_lock(conn, {:query, query, params}, &stream_results/1)
stream_lock(conn, {:query, query, params}, &stream_results/2)
end

@doc """
TODO.
"""
def query_pointer(conn, query, params \\ [], fun)
when is_binary(query) and is_list(params) and is_function(fun) do
stream_lock(conn, {:query, query, params}, fn stream_ref ->
fun.(Adbc.Nif.adbc_arrow_array_stream_get_pointer(stream_ref))
stream_lock(conn, {:query, query, params}, fn stream_ref, rows_affected ->
fun.(Adbc.Nif.adbc_arrow_array_stream_get_pointer(stream_ref), rows_affected)
end)
end

Expand Down Expand Up @@ -87,7 +88,7 @@ defmodule Adbc.Connection do
@spec get_info(t(), list(non_neg_integer())) ::
{:ok, result_set} | {:error, Exception.t()}
def get_info(conn, info_codes \\ []) when is_list(info_codes) do
stream_lock(conn, {:adbc_connection_get_info, [info_codes]}, &stream_results/1)
stream_lock(conn, {:adbc_connection_get_info, [info_codes]}, &stream_results/2)
end

@doc """
Expand Down Expand Up @@ -190,7 +191,7 @@ defmodule Adbc.Connection do
opts[:column_name]
]

stream_lock(conn, {:adbc_connection_get_objects, args}, &stream_results/1)
stream_lock(conn, {:adbc_connection_get_objects, args}, &stream_results/2)
end

@doc """
Expand All @@ -205,14 +206,14 @@ defmodule Adbc.Connection do
@spec get_table_types(t) ::
{:ok, result_set} | {:error, Exception.t()}
def get_table_types(conn) do
stream_lock(conn, {:adbc_connection_get_table_types, []}, &stream_results/1)
stream_lock(conn, {:adbc_connection_get_table_types, []}, &stream_results/2)
end

defp stream_lock(conn, command, fun) do
case GenServer.call(conn, {:stream_lock, command}, :infinity) do
{:ok, conn, unlock_ref, stream_ref} ->
{:ok, conn, unlock_ref, stream_ref, rows_affected} ->
try do
fun.(stream_ref)
fun.(stream_ref, rows_affected)
after
GenServer.cast(conn, {:unlock, unlock_ref})
end
Expand All @@ -222,18 +223,17 @@ defmodule Adbc.Connection do
end
end

defp stream_results(reference) do
stream_results(reference, %{})
end
defp stream_results(reference, -1), do: stream_results(reference, %{}, nil)
defp stream_results(reference, num_rows), do: stream_results(reference, %{}, num_rows)

defp stream_results(reference, acc) do
defp stream_results(reference, acc, num_rows) do
case Adbc.Nif.adbc_arrow_array_stream_next(reference) do
{:ok, results, done} ->
acc = Map.merge(acc, Map.new(results), fn _k, v1, v2 -> v1 ++ v2 end)

case done do
0 -> stream_results(reference, acc)
1 -> {:ok, acc}
0 -> stream_results(reference, acc, num_rows)
1 -> {:ok, %Adbc.Result{data: acc, num_rows: num_rows}}
end

{:error, reason} ->
Expand Down Expand Up @@ -301,9 +301,9 @@ defmodule Adbc.Connection do
{pid, _} = from

case handle_command(command, state.conn) do
{:ok, stream_ref} when is_reference(stream_ref) ->
{:ok, stream_ref, rows_affected} when is_reference(stream_ref) ->
unlock_ref = Process.monitor(pid)
GenServer.reply(from, {:ok, self(), unlock_ref, stream_ref})
GenServer.reply(from, {:ok, self(), unlock_ref, stream_ref, rows_affected})
%{state | lock: {unlock_ref, stream_ref}, queue: queue}

{:error, error} ->
Expand All @@ -318,14 +318,15 @@ defmodule Adbc.Connection do
defp handle_command({:query, query, params}, conn) do
with {:ok, stmt} <- Adbc.Nif.adbc_statement_new(conn),
:ok <- Adbc.Nif.adbc_statement_set_sql_query(stmt, query),
:ok <- maybe_bind(stmt, params),
{:ok, stream_ref, _rows_affected} <- Adbc.Nif.adbc_statement_execute_query(stmt) do
{:ok, stream_ref}
:ok <- maybe_bind(stmt, params) do
Adbc.Nif.adbc_statement_execute_query(stmt)
end
end

defp handle_command({name, args}, conn) do
apply(Adbc.Nif, name, [conn | args])
with {:ok, stream_ref} <- apply(Adbc.Nif, name, [conn | args]) do
{:ok, stream_ref, nil}
end
end

defp maybe_bind(_stmt, []), do: :ok
Expand Down
8 changes: 8 additions & 0 deletions lib/adbc_result.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defmodule Adbc.Result do
defstruct [:num_rows, :data]

@type t :: %Adbc.Result{
num_rows: non_neg_integer() | nil,
data: %{optional(binary) => list(term)}
}
end
17 changes: 11 additions & 6 deletions test/adbc_connection_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,28 @@ defmodule Adbc.Connection.Test do
describe "get_table_types" do
test "get table types from a connection", %{db: db} do
conn = start_supervised!({Connection, database: db})
assert {:ok, %{"table_type" => ["table", "view"]}} = Connection.get_table_types(conn)

assert {:ok, %Adbc.Result{data: %{"table_type" => ["table", "view"]}}} =
Connection.get_table_types(conn)
end
end

describe "query" do
test "select", %{db: db} do
conn = start_supervised!({Connection, database: db})

assert {:ok, %{"num" => [123]}} = Connection.query(conn, "SELECT 123 as num")
assert {:ok, %Adbc.Result{data: %{"num" => [123]}}} =
Connection.query(conn, "SELECT 123 as num")

assert {:ok, %{"num" => [123], "bool" => [1]}} =
assert {:ok, %Adbc.Result{data: %{"num" => [123], "bool" => [1]}}} =
Connection.query(conn, "SELECT 123 as num, true as bool")
end

test "select with parameters", %{db: db} do
conn = start_supervised!({Connection, database: db})
assert {:ok, %{"num" => [579]}} = Connection.query(conn, "SELECT 123 + ? as num", [456])

assert {:ok, %Adbc.Result{data: %{"num" => [579]}}} =
Connection.query(conn, "SELECT 123 + ? as num", [456])
end

test "fails on invalid query", %{db: db} do
Expand All @@ -105,7 +110,7 @@ defmodule Adbc.Connection.Test do
conn = start_supervised!({Connection, database: db})

assert_raise RuntimeError, fn ->
Connection.query_pointer(conn, "SELECT 1", fn _pointer ->
Connection.query_pointer(conn, "SELECT 1", fn _pointer, _num_rows ->
raise "oops"
end)
end
Expand All @@ -119,7 +124,7 @@ defmodule Adbc.Connection.Test do

child =
spawn(fn ->
Connection.query_pointer(conn, "SELECT 1", fn _pointer ->
Connection.query_pointer(conn, "SELECT 1", fn _pointer, _num_rows ->
send(parent, :ready)
Process.sleep(:infinity)
end)
Expand Down

0 comments on commit 4bf6dd6

Please sign in to comment.