Skip to content

Commit

Permalink
Merge pull request #333 from yilinxia/duckdb
Browse files Browse the repository at this point in the history
DuckDB Implementation
  • Loading branch information
EvgSkv authored Jun 21, 2024
2 parents 248c9a1 + 5a41233 commit 1da5a0d
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 6 deletions.
19 changes: 16 additions & 3 deletions colab_logica.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
import os

import pandas
import duckdb

from .parser_py import parse
from .common import sqlite3_logica
from .common import duckdb_logica

BQ_READY = True # By default.

Expand Down Expand Up @@ -185,6 +187,11 @@ def RunSQL(sql, engine, connection=None, is_final=False):
return df
else:
psql_logica.PostgresExecute(sql, connection)
elif engine == 'duckdb':
if is_final:
return duckdb.sql(sql).df()
else:
duckdb.sql(sql)
elif engine == 'sqlite':
try:
if is_final:
Expand Down Expand Up @@ -220,6 +227,12 @@ def __init__(self):
def __call__(self, sql, engine, is_final):
return RunSQL(sql, engine, self.connection, is_final)

class DuckdbRunner(object):
def __init__(self):
self.connection = duckdb_logica.SqliteConnect()
def __call__(self, sql, engine, is_final):
return RunSQL(sql, engine, self.connection, is_final)


class PostgresRunner(object):
def __init__(self):
Expand Down Expand Up @@ -280,7 +293,6 @@ def Logica(line, cell, run_query):
except infer.TypeErrorCaughtException as e:
e.ShowMessage()
return

engine = program.annotations.Engine()

if engine == 'bigquery' and not BQ_READY:
Expand All @@ -295,7 +307,6 @@ def Logica(line, cell, run_query):

bar = TabBar(predicates + ['(Log)'])
logs_idx = len(predicates)

executions = []
sub_bars = []
ip = IPython.get_ipython()
Expand Down Expand Up @@ -334,13 +345,15 @@ def Logica(line, cell, run_query):
sql_runner = SqliteRunner()
elif engine == 'psql':
sql_runner = PostgresRunner()
elif engine == 'duckdb':
sql_runner = DuckdbRunner()
elif engine == 'bigquery':
EnsureAuthenticatedUser()
sql_runner = RunSQL
else:
raise Exception('Logica only supports BigQuery, PostgreSQL and SQLite '
'for now.')
try:
try:
result_map = concertina_lib.ExecuteLogicaProgram(
executions, sql_runner=sql_runner, sql_engine=engine,
display_mode=DISPLAY_MODE)
Expand Down
22 changes: 22 additions & 0 deletions common/duckdb_logica.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/python
#
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from decimal import Decimal
import sqlite3

def SqliteConnect():
con = sqlite3.connect(':memory:')
return con
56 changes: 56 additions & 0 deletions compiler/dialect_libraries/duckdb_library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

library = """
->(left:, right:) = {arg: left, value: right};
`=`(left:, right:) = right :- left == right;
Arrow(left, right) = arrow :-
left == arrow.arg,
right == arrow.value;
PrintToConsole(message) :- 1 == SqlExpr("PrintToConsole({message})", {message:});
ArgMin(arr) = SqlExpr(
"argmin({a}, {v})", {a:, v:}) :- Arrow(a, v) == arr;
ArgMax(arr) = SqlExpr(
"argmax({a}, {v})", {a:, v:}) :- Arrow(a, v) == arr;
ArgMaxK(a, l) = SqlExpr(
"(array_agg({arg_1} order by {value_1} desc))[1:{lim}]",
{arg_1: a.arg, value_1: a.value, lim: l});
ArgMinK(a, l) = SqlExpr(
"(array_agg({arg_1} order by {value_1}))[1:{lim}]",
{arg_1: a.arg, value_1: a.value, lim: l});
Array(arr) =
SqlExpr("ArgMin({v}, {a})", {a:, v:}) :- Arrow(a, v) == arr;
RecordAsJson(r) = SqlExpr(
"ROW_TO_JSON({r})", {r:});
Fingerprint(s) = SqlExpr("('x' || substr(md5({s}), 1, 16))::bit(64)::bigint", {s:});
ReadFile(filename) = SqlExpr("pg_read_file({filename})", {filename:});
Chr(x) = SqlExpr("Chr({x})", {x:});
Num(a) = a;
Str(a) = a;
"""
62 changes: 60 additions & 2 deletions compiler/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
from compiler.dialect_libraries import trino_library
from compiler.dialect_libraries import presto_library
from compiler.dialect_libraries import databricks_library
from compiler.dialect_libraries import duckdb_library
else:
from ..compiler.dialect_libraries import bq_library
from ..compiler.dialect_libraries import psql_library
from ..compiler.dialect_libraries import sqlite_library
from ..compiler.dialect_libraries import trino_library
from ..compiler.dialect_libraries import presto_library
from ..compiler.dialect_libraries import databricks_library

from ..compiler.dialect_libraries import duckdb_library
def Get(engine):
return DIALECTS[engine]()

Expand Down Expand Up @@ -388,12 +389,69 @@ def GroupBySpecBy(self):
def DecorateCombineRule(self, rule, var):
return rule

class DuckDB(Dialect):
"""DuckDB dialect"""

def Name(self):
return 'DuckDB'

def BuiltInFunctions(self):
return {
'Set': 'DistinctListAgg({0})',
'Element': "array_extract({0}, {1}+1)",
'Range': ('(select [n] from (with recursive t as'
'(select 0 as n union all '
'select n + 1 as n from t where n + 1 < {0}) '
'select n from t) where n < {0})'),
'ValueOfUnnested': '{0}.unnested_pod',
'List': '[{0}]',
'Size': 'JSON_ARRAY_LENGTH({0})',
'Join': 'JOIN_STRINGS({0}, {1})',
'Count': 'COUNT(DISTINCT {0})',
'StringAgg': 'GROUP_CONCAT(%s)',
'Sort': 'SortList({0})',
'MagicalEntangle': '(CASE WHEN {1} = 0 THEN {0} ELSE NULL END)',
'Format': 'Printf(%s)',
'Least': 'MIN(%s)',
'Greatest': 'MAX(%s)',
'ToString': 'CAST(%s AS TEXT)',
'DateAddDay': "DATE({0}, {1} || ' days')",
'DateDiffDay': "CAST(JULIANDAY({0}) - JULIANDAY({1}) AS INT64)"
}

def DecorateCombineRule(self, rule, var):
return DecorateCombineRule(rule, var)

def InfixOperators(self):
return {
'++': '(%s) || (%s)',
'%' : '(%s) %% (%s)',
'in': 'IN_LIST(%s, %s)'
}

def Subscript(self, record, subscript, record_is_table):
return '%s.%s' % (record, subscript)

def LibraryProgram(self):
return duckdb_library.library

def UnnestPhrase(self):
return '(select unnest({0}) as unnested_pod) as {1}'

def ArrayPhrase(self):
return '[%s]'

def GroupBySpecBy(self):
return 'expr'


DIALECTS = {
'bigquery': BigQueryDialect,
'sqlite': SqLiteDialect,
'psql': PostgreSQL,
'presto': Presto,
'trino': Trino,
'databricks': Databricks
'databricks': Databricks,
'duckdb': DuckDB,
}

8 changes: 7 additions & 1 deletion compiler/expr_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def IntLiteral(self, literal):
return str(literal['number'])

def StrLiteral(self, literal):
if self.dialect.Name() in ["PostgreSQL", "Presto", "Trino", "SqLite"]:
if self.dialect.Name() in ["PostgreSQL", "Presto", "Trino", "SqLite", "DuckDB"]:
# TODO: Do this safely.
return '\'%s\'' % (literal['the_string'].replace("'", "''"))

Expand Down Expand Up @@ -354,6 +354,12 @@ def Record(self, record, record_type=None):
for f_v in sorted(record['field_value'],
key=lambda x: StrIntKey(x['field'])))
return 'ROW(%s)::%s' % (args, record_type)
if self.dialect.Name() == 'DuckDB':
arguments_str = ', '.join(
"%s: %s" % (f_v['field'],
self.ConvertToSql(f_v['value']['expression']) )
for f_v in record['field_value'])
return '{%s}' % arguments_str
return 'STRUCT(%s)' % arguments_str

def GenericSqlExpression(self, record):
Expand Down

0 comments on commit 1da5a0d

Please sign in to comment.