Skip to content

Commit

Permalink
Wrap the FlightSQL server native class in a Python class that manages…
Browse files Browse the repository at this point in the history
… its own thread (#480)

* Create wrapper functions for FlightSqlServer methods, add pybindings to server

* Add gil release call guard in binding, create async function for serve

* Wrap FlightSQL native class in Python class with own thread

* Remove unused import

* Address formatting and pylint errors

* Add Union from typing, reformat flight_sql_server.py

* Address comments: add types, move server start, use Optional

---------

Co-authored-by: Sophie Zhang <[email protected]>
  • Loading branch information
sopzha and Sophie Zhang authored Mar 16, 2024
1 parent 238c230 commit 5c0106d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 10 deletions.
7 changes: 6 additions & 1 deletion cpp/pybind/brad_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,10 @@ PYBIND11_MODULE(pybind_brad_server, m) {

brad_server
.def(py::init<>())
.def("create", &brad::BradFlightSqlServer::Create);
.def("create", &brad::BradFlightSqlServer::Create)
.def("init", &brad::BradFlightSqlServer::InitWrapper)
.def("serve",
&brad::BradFlightSqlServer::ServeWrapper,
py::call_guard<py::gil_scoped_release>())
.def("shutdown", &brad::BradFlightSqlServer::ShutdownWrapper);
}
26 changes: 20 additions & 6 deletions cpp/server/brad_server_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,26 @@ BradFlightSqlServer::~BradFlightSqlServer() = default;
std::shared_ptr<BradFlightSqlServer>
BradFlightSqlServer::Create() {
// std::shared_ptr<BradFlightSqlServer> result(new BradFlightSqlServer());
std::shared_ptr<BradFlightSqlServer> result =
std::make_shared<BradFlightSqlServer>();
for (const auto &id_to_result : GetSqlInfoResultMap()) {
result->RegisterSqlInfo(id_to_result.first, id_to_result.second);
}
return result;
std::shared_ptr<BradFlightSqlServer> result =
std::make_shared<BradFlightSqlServer>();
for (const auto &id_to_result : GetSqlInfoResultMap()) {
result->RegisterSqlInfo(id_to_result.first, id_to_result.second);
}
return result;
}

void BradFlightSqlServer::InitWrapper(const std::string &host, int port) {
auto location = arrow::flight::Location::ForGrpcTcp(host, port).ValueOrDie();
arrow::flight::FlightServerOptions options(location);
this->Init(options);
}

void BradFlightSqlServer::ServeWrapper() {
this->Serve();
}

void BradFlightSqlServer::ShutdownWrapper() {
this->Shutdown(nullptr);
}

arrow::Result<std::unique_ptr<FlightInfo>>
Expand Down
6 changes: 6 additions & 0 deletions cpp/server/brad_server_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {

static std::shared_ptr<BradFlightSqlServer> Create();

void InitWrapper(const std::string &host, int port);

void ServeWrapper();

void ShutdownWrapper();

arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>>
GetFlightInfoStatement(
const arrow::flight::ServerCallContext &context,
Expand Down
26 changes: 26 additions & 0 deletions src/brad/front_end/flight_sql_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import logging
import threading

# pylint: disable-next=import-error,no-name-in-module,unused-import
import brad.native.pybind_brad_server as brad_server

logger = logging.getLogger(__name__)


class BradFlightSqlServer:
def __init__(self, host: str, port: int) -> None:
self._flight_sql_server = brad_server.BradFlightSqlServer()
self._flight_sql_server.init(host, port)
self._thread = threading.Thread(name="BradFlightSqlServer", target=self._serve)

def start(self) -> None:
self._thread.start()

def stop(self) -> None:
logger.info("BRAD FlightSQL server stopping...")
self._flight_sql_server.shutdown()
self._thread.join()
logger.info("BRAD FlightSQL server stopped.")

def _serve(self) -> None:
self._flight_sql_server.serve()
16 changes: 13 additions & 3 deletions src/brad/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ def __init__(
output_queue: mp.Queue,
):
if BradFrontEnd.native_server_is_supported():
# pylint: disable-next=import-error,no-name-in-module
import brad.native.pybind_brad_server as brad_server
from brad.front_end.flight_sql_server import BradFlightSqlServer

self._flight_sql_server = brad_server.BradFlightSqlServer.create()
self._flight_sql_server: Optional[BradFlightSqlServer] = (
BradFlightSqlServer(host="0.0.0.0", port=31337)
)
else:
self._flight_sql_server = None

Expand Down Expand Up @@ -191,6 +192,10 @@ def __init__(

async def serve_forever(self):
await self._run_setup()

# Start FlightSQL server
self._flight_sql_server.start()

try:
grpc_server = grpc.aio.server()
brad_grpc.add_BradServicer_to_server(BradGrpc(self), grpc_server)
Expand Down Expand Up @@ -281,6 +286,11 @@ async def _set_up_router(self) -> None:

async def _run_teardown(self):
logger.debug("Starting BRAD front end _run_teardown()")

# Shutdown FlightSQL server
if self._flight_sql_server:
self._flight_sql_server.stop()

await self._sessions.end_all_sessions()

# Important for unblocking our message reader thread.
Expand Down

0 comments on commit 5c0106d

Please sign in to comment.