Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start duckdb implementation of dremio #683

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions lumen/sources/dremio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from functools import reduce

from pyarrow import flight


class DremioClientAuthMiddleware(flight.ClientMiddleware):
"""
A ClientMiddleware that extracts the bearer token from
the authorization header returned by the Dremio
Flight Server Endpoint.

Parameters
----------
factory : ClientHeaderAuthMiddlewareFactory
The factory to set call credentials if an
authorization header with bearer token is
returned by the Dremio server.
"""

def __init__(self, factory):
self.factory = factory

def received_headers(self, headers):
if self.factory.call_credential:
return

auth_header_key = "authorization"

authorization_header = reduce(
lambda result, header: (
header[1] if header[0] == auth_header_key else result
),
headers.items(),
)
if not authorization_header:
raise Exception("Did not receive authorization header back from server.")
bearer_token = authorization_header[1][0]
self.factory.set_call_credential(
[b"authorization", bearer_token.encode("utf-8")]
)


class DremioClientAuthMiddlewareFactory(flight.ClientMiddlewareFactory):
"""A factory that creates DremioClientAuthMiddleware(s)."""

def __init__(self):
self.call_credential = []

def start_call(self, info):
return DremioClientAuthMiddleware(self)

def set_call_credential(self, call_credential):
self.call_credential = call_credential


class HttpDremioClientAuthHandler(flight.ClientAuthHandler):

def __init__(self, username, password):
super(flight.ClientAuthHandler, self).__init__()
self.basic_auth = flight.BasicAuth(username, password)
self.token = None

def authenticate(self, outgoing, incoming):
auth = self.basic_auth.serialize()
outgoing.write(auth)
self.token = incoming.read()

def get_token(self):
return self.token
156 changes: 154 additions & 2 deletions lumen/sources/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def get_tables(self):
def get_sql_expr(self, table: str):
if isinstance(self.tables, dict):
table = self.tables[table]
if '(' not in table and ')' not in table:
table = f'"{table}"'
if '(' not in table and ')' not in table and '"' not in table:
table = f"'{table}'"
if 'select ' in table.lower():
sql_expr = table
else:
Expand All @@ -250,6 +250,10 @@ def get_sql_expr(self, table: str):
@cached
def get(self, table, **query):
query.pop('__dask', None)

# duckdb does not support "Ahierachy"."Btable"
if '."' in table:
table = table.replace('"', '')
sql_expr = self.get_sql_expr(table)
sql_transforms = query.pop('sql_transforms', [])
conditions = list(query.items())
Expand All @@ -274,6 +278,11 @@ def get_schema(
schemas = {}
sql_limit = SQLLimit(limit=limit or 1)
for entry in tables:

# duckdb does not support "Ahierachy"."Btable"
if '."' in entry:
entry = entry.replace('"', '')

if not self.load_schema:
schemas[entry] = {}
continue
Expand Down Expand Up @@ -314,3 +323,146 @@ def get_schema(
schema[col]['inclusiveMinimum'] = cast(minmax_data[f'{col}_min'].iloc[0])
schema[col]['inclusiveMaximum'] = cast(minmax_data[f'{col}_max'].iloc[0])
return schemas if table is None else schemas[table]


class DremioDuckDBSource(DuckDBSource):
"""
DremioDuckDBSource provides a simple wrapper around the DuckDB SQL
connector, extended to connect to a Dremio server via Apache Arrow Flight.
"""

cert = param.String(default="Path to certificate file", doc="Path to the certificate file.")

dremio_uri = param.String(doc="URI of the Dremio server.")

tls = param.Boolean(default=True, doc="Enable encryption (TLS).")

username = param.String(default=None, doc="Dremio username.")

password = param.String(default=None, doc="Dremio password or token.")

dialect = 'dremio'

def __init__(self, **params):
from pyarrow import flight

from lumen.sources.dremio_utils import (
DremioClientAuthMiddlewareFactory, HttpDremioClientAuthHandler,
)

super().__init__(**params)

protocol, hostname, username, password = self._process_uri(
tls=self.tls, username=self.username, password=self.password)

dremio_client_auth_middleware = DremioClientAuthMiddlewareFactory()
connection_args = {'middleware': [dremio_client_auth_middleware]}

if self.tls:
with open(self.cert) as f:
certs = f.read()
connection_args["tls_root_certs"] = certs

dremio_client_auth_middleware = DremioClientAuthMiddlewareFactory()
connection_args = {'middleware': [dremio_client_auth_middleware]}
if self.tls:
connection_args["tls_root_certs"] = certs
self._dremio_client = flight.FlightClient(f'{protocol}://{hostname}', **connection_args)
auth_options = flight.FlightCallOptions()
try:
bearer_token = self._dremio_client.authenticate_basic_token(username, password)
self._headers = [bearer_token]
except Exception as e:
if self.tls:
raise e
handler = HttpDremioClientAuthHandler(username, password)
self._dremio_client.authenticate(handler, options=auth_options)
self._headers = []


def _process_uri(self, tls=False, username=None, password=None):
"""
Extracts hostname, protocol, username and passworrd from URI

Parameters
----------
uri: str or None
Connection string in the form username:password@hostname:port
tls: boolean
Whether TLS is enabled
username: str or None
Username if not supplied as part of the URI
password: str or None
Password if not supplied as part of the URI
"""
uri = self.dremio_uri
if "://" in uri:
protocol, uri = uri.split("://")
else:
protocol = "grpc+tls" if tls else "grpc+tcp"
if "@" in uri:
if username or password:
raise ValueError(
"Dremio URI must not include username and password "
"if they were supplied explicitly."
)
userinfo, hostname = uri.split("@")
username, password = userinfo.split(":")
elif not (username and password):
raise ValueError(
"Dremio URI must include username and password "
"or they must be provided explicitly."
)
else:
hostname = uri
return protocol, hostname, username, password

def _execute_dremio_sql(self, sql_expr, as_pandas=False):
"""
Executes a SQL expression on Dremio via Apache Arrow Flight.
"""
from pyarrow import flight

flight_desc = flight.FlightDescriptor.for_command(sql_expr)
options = flight.FlightCallOptions(headers=self._headers)
flight_info = self._dremio_client.get_flight_info(flight_desc, options)
reader = self._dremio_client.do_get(flight_info.endpoints[0].ticket, options)
data_table = reader.read_all() if not as_pandas else reader.read_pandas()
return data_table

def _ingest_table(self, table):
"""
Ingests a table from Dremio to DuckDB.
"""
from pyarrow.dataset import dataset as arrow_dataset

sql_expr = self.get_sql_expr(table)
data_table = self._execute_dremio_sql(sql_expr)
arrow_ds = arrow_dataset(source=[data_table])
self._connection.from_arrow(arrow_ds).to_view(table.replace('"', ''))
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved

def get_tables(self):
if isinstance(self.tables, (dict, list)):
return list(self.tables)

databases = self._execute_dremio_sql("SHOW DATABASES", as_pandas=True)["SCHEMA_NAME"]

all_tables = []
for database in databases:
if not database[0].isalpha():
continue
try:
sql_expr = f"SHOW TABLES IN {database}"
reader = self._execute_dremio_sql(sql_expr, as_pandas=True)
except Exception:
pass
tables = reader.read_pandas()["TABLE_NAME"]
for table in tables:
all_tables.append(f'"{database}"."{table}"')
return all_tables

def get(self, table, **query):
ingested_tables = self._connection.execute("SHOW TABLES").fetchdf()["name"]
if table.replace('"', "") not in ingested_tables:
self._ingest_table(table)
return super().get(table, **query)
Loading