diff --git a/colab_logica.py b/colab_logica.py index 458e90a..0703e92 100755 --- a/colab_logica.py +++ b/colab_logica.py @@ -39,9 +39,11 @@ import os import pandas +import duckdb from .parser_py import parse from .common import sqlite3_logica +from .common import duckdb_logica BQ_READY = True # By default. @@ -185,6 +187,11 @@ def RunSQL(sql, engine, connection=None, is_final=False): return df else: psql_logica.PostgresExecute(sql, connection) + elif engine == 'duckdb': + if is_final: + return duckdb.sql(sql).df() + else: + duckdb.sql(sql) elif engine == 'sqlite': try: if is_final: @@ -220,6 +227,12 @@ def __init__(self): def __call__(self, sql, engine, is_final): return RunSQL(sql, engine, self.connection, is_final) +class DuckdbRunner(object): + def __init__(self): + self.connection = duckdb_logica.SqliteConnect() + def __call__(self, sql, engine, is_final): + return RunSQL(sql, engine, self.connection, is_final) + class PostgresRunner(object): def __init__(self): @@ -280,7 +293,6 @@ def Logica(line, cell, run_query): except infer.TypeErrorCaughtException as e: e.ShowMessage() return - engine = program.annotations.Engine() if engine == 'bigquery' and not BQ_READY: @@ -295,7 +307,6 @@ def Logica(line, cell, run_query): bar = TabBar(predicates + ['(Log)']) logs_idx = len(predicates) - executions = [] sub_bars = [] ip = IPython.get_ipython() @@ -334,13 +345,15 @@ def Logica(line, cell, run_query): sql_runner = SqliteRunner() elif engine == 'psql': sql_runner = PostgresRunner() + elif engine == 'duckdb': + sql_runner = DuckdbRunner() elif engine == 'bigquery': EnsureAuthenticatedUser() sql_runner = RunSQL else: raise Exception('Logica only supports BigQuery, PostgreSQL and SQLite ' 'for now.') - try: + try: result_map = concertina_lib.ExecuteLogicaProgram( executions, sql_runner=sql_runner, sql_engine=engine, display_mode=DISPLAY_MODE) diff --git a/common/duckdb_logica.py b/common/duckdb_logica.py new file mode 100644 index 0000000..1bd87ee --- /dev/null +++ b/common/duckdb_logica.py @@ -0,0 +1,22 @@ +#!/usr/bin/python +# +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from decimal import Decimal +import sqlite3 + +def SqliteConnect(): + con = sqlite3.connect(':memory:') + return con \ No newline at end of file diff --git a/compiler/dialect_libraries/duckdb_library.py b/compiler/dialect_libraries/duckdb_library.py new file mode 100644 index 0000000..9b9e1ad --- /dev/null +++ b/compiler/dialect_libraries/duckdb_library.py @@ -0,0 +1,56 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +library = """ +->(left:, right:) = {arg: left, value: right}; +`=`(left:, right:) = right :- left == right; + +Arrow(left, right) = arrow :- + left == arrow.arg, + right == arrow.value; + +PrintToConsole(message) :- 1 == SqlExpr("PrintToConsole({message})", {message:}); + +ArgMin(arr) = SqlExpr( + "argmin({a}, {v})", {a:, v:}) :- Arrow(a, v) == arr; + +ArgMax(arr) = SqlExpr( + "argmax({a}, {v})", {a:, v:}) :- Arrow(a, v) == arr; + +ArgMaxK(a, l) = SqlExpr( + "(array_agg({arg_1} order by {value_1} desc))[1:{lim}]", + {arg_1: a.arg, value_1: a.value, lim: l}); + +ArgMinK(a, l) = SqlExpr( + "(array_agg({arg_1} order by {value_1}))[1:{lim}]", + {arg_1: a.arg, value_1: a.value, lim: l}); + +Array(arr) = + SqlExpr("ArgMin({v}, {a})", {a:, v:}) :- Arrow(a, v) == arr; + +RecordAsJson(r) = SqlExpr( + "ROW_TO_JSON({r})", {r:}); + +Fingerprint(s) = SqlExpr("('x' || substr(md5({s}), 1, 16))::bit(64)::bigint", {s:}); + +ReadFile(filename) = SqlExpr("pg_read_file({filename})", {filename:}); + +Chr(x) = SqlExpr("Chr({x})", {x:}); + +Num(a) = a; +Str(a) = a; + +""" diff --git a/compiler/dialects.py b/compiler/dialects.py index 45f9a43..3e16598 100755 --- a/compiler/dialects.py +++ b/compiler/dialects.py @@ -25,6 +25,7 @@ from compiler.dialect_libraries import trino_library from compiler.dialect_libraries import presto_library from compiler.dialect_libraries import databricks_library + from compiler.dialect_libraries import duckdb_library else: from ..compiler.dialect_libraries import bq_library from ..compiler.dialect_libraries import psql_library @@ -32,7 +33,7 @@ from ..compiler.dialect_libraries import trino_library from ..compiler.dialect_libraries import presto_library from ..compiler.dialect_libraries import databricks_library - + from ..compiler.dialect_libraries import duckdb_library def Get(engine): return DIALECTS[engine]() @@ -388,12 +389,69 @@ def GroupBySpecBy(self): def DecorateCombineRule(self, rule, var): return rule +class DuckDB(Dialect): + """DuckDB dialect""" + + def Name(self): + return 'DuckDB' + + def BuiltInFunctions(self): + return { + 'Set': 'DistinctListAgg({0})', + 'Element': "array_extract({0}, {1}+1)", + 'Range': ('(select [n] from (with recursive t as' + '(select 0 as n union all ' + 'select n + 1 as n from t where n + 1 < {0}) ' + 'select n from t) where n < {0})'), + 'ValueOfUnnested': '{0}.unnested_pod', + 'List': '[{0}]', + 'Size': 'JSON_ARRAY_LENGTH({0})', + 'Join': 'JOIN_STRINGS({0}, {1})', + 'Count': 'COUNT(DISTINCT {0})', + 'StringAgg': 'GROUP_CONCAT(%s)', + 'Sort': 'SortList({0})', + 'MagicalEntangle': '(CASE WHEN {1} = 0 THEN {0} ELSE NULL END)', + 'Format': 'Printf(%s)', + 'Least': 'MIN(%s)', + 'Greatest': 'MAX(%s)', + 'ToString': 'CAST(%s AS TEXT)', + 'DateAddDay': "DATE({0}, {1} || ' days')", + 'DateDiffDay': "CAST(JULIANDAY({0}) - JULIANDAY({1}) AS INT64)" + } + + def DecorateCombineRule(self, rule, var): + return DecorateCombineRule(rule, var) + + def InfixOperators(self): + return { + '++': '(%s) || (%s)', + '%' : '(%s) %% (%s)', + 'in': 'IN_LIST(%s, %s)' + } + + def Subscript(self, record, subscript, record_is_table): + return '%s.%s' % (record, subscript) + + def LibraryProgram(self): + return duckdb_library.library + + def UnnestPhrase(self): + return '(select unnest({0}) as unnested_pod) as {1}' + + def ArrayPhrase(self): + return '[%s]' + + def GroupBySpecBy(self): + return 'expr' + + DIALECTS = { 'bigquery': BigQueryDialect, 'sqlite': SqLiteDialect, 'psql': PostgreSQL, 'presto': Presto, 'trino': Trino, - 'databricks': Databricks + 'databricks': Databricks, + 'duckdb': DuckDB, } diff --git a/compiler/expr_translate.py b/compiler/expr_translate.py index a2da2f6..d2ddc47 100755 --- a/compiler/expr_translate.py +++ b/compiler/expr_translate.py @@ -246,7 +246,7 @@ def IntLiteral(self, literal): return str(literal['number']) def StrLiteral(self, literal): - if self.dialect.Name() in ["PostgreSQL", "Presto", "Trino", "SqLite"]: + if self.dialect.Name() in ["PostgreSQL", "Presto", "Trino", "SqLite", "DuckDB"]: # TODO: Do this safely. return '\'%s\'' % (literal['the_string'].replace("'", "''")) @@ -354,6 +354,12 @@ def Record(self, record, record_type=None): for f_v in sorted(record['field_value'], key=lambda x: StrIntKey(x['field']))) return 'ROW(%s)::%s' % (args, record_type) + if self.dialect.Name() == 'DuckDB': + arguments_str = ', '.join( + "%s: %s" % (f_v['field'], + self.ConvertToSql(f_v['value']['expression']) ) + for f_v in record['field_value']) + return '{%s}' % arguments_str return 'STRUCT(%s)' % arguments_str def GenericSqlExpression(self, record):