From 37d81af2dd25b58d1749da4dd334bc8c30444dbf Mon Sep 17 00:00:00 2001 From: Yilin Xia Date: Wed, 8 May 2024 11:35:49 -0500 Subject: [PATCH 1/5] init duckdb --- colab_logica.py | 14 ++- common/duckdb_logica.py | 94 ++++++++++++++++++++ compiler/dialect_libraries/duckdb_library.py | 60 +++++++++++++ compiler/dialects.py | 65 +++++++++++++- 4 files changed, 230 insertions(+), 3 deletions(-) create mode 100644 common/duckdb_logica.py create mode 100644 compiler/dialect_libraries/duckdb_library.py diff --git a/colab_logica.py b/colab_logica.py index eff76ac3..0575e725 100755 --- a/colab_logica.py +++ b/colab_logica.py @@ -42,6 +42,7 @@ from .parser_py import parse from .common import sqlite3_logica +from .common import duckdb_logica BQ_READY = True # By default. @@ -172,6 +173,10 @@ def RunSQL(sql, engine, connection=None, is_final=False): return df else: psql_logica.PostgresExecute(sql, connection) + elif engine == 'duckdb': + import duckdb + print("\n"+sql) + duckdb.sql(sql).show() elif engine == 'sqlite': try: if is_final: @@ -207,6 +212,12 @@ def __init__(self): def __call__(self, sql, engine, is_final): return RunSQL(sql, engine, self.connection, is_final) +class DuckdbRunner(object): + def __init__(self): + self.connection = sqlite3_logica.SqliteConnect() + def __call__(self, sql, engine, is_final): + return RunSQL(sql, engine, self.connection, is_final) + class PostgresRunner(object): def __init__(self): @@ -278,7 +289,6 @@ def Logica(line, cell, run_query): bar = TabBar(predicates + ['(Log)']) logs_idx = len(predicates) - executions = [] sub_bars = [] ip = IPython.get_ipython() @@ -313,6 +323,8 @@ def Logica(line, cell, run_query): sql_runner = SqliteRunner() elif engine == 'psql': sql_runner = PostgresRunner() + elif engine == 'duckdb': + sql_runner = DuckdbRunner() elif engine == 'bigquery': EnsureAuthenticatedUser() sql_runner = RunSQL diff --git a/common/duckdb_logica.py b/common/duckdb_logica.py new file mode 100644 index 00000000..443aeda6 --- /dev/null +++ b/common/duckdb_logica.py @@ -0,0 +1,94 @@ +#!/usr/bin/python +# +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import getpass +import json +import os +import re +from decimal import Decimal + +if '.' not in __package__: + from type_inference.research import infer +else: + from ..type_inference.research import infer + + +def PostgresExecute(sql, connection): + import psycopg2 + import psycopg2.extras + cursor = connection.cursor() + try: + cursor.execute(sql) + # Make connection aware of the used types. + types = re.findall(r'-- Logica type: (\w*)', sql) + for t in types: + if t != 'logicarecord893574736': # Empty record. + psycopg2.extras.register_composite(t, cursor, globally=True) + except psycopg2.errors.UndefinedTable as e: + raise infer.TypeErrorCaughtException( + infer.ContextualizedError.BuildNiceMessage( + 'Running SQL.', 'Undefined table used: ' + str(e))) + except psycopg2.Error as e: + connection.rollback() + raise e + return cursor + + +def DigestPsqlType(x): + if isinstance(x, tuple): + return PsqlTypeAsDictionary(x) + if isinstance(x, list) and len(x) > 0: + return PsqlTypeAsList(x) + if isinstance(x, Decimal): + if x.as_integer_ratio()[1] == 1: + return int(x) + else: + return float(x) + return x + + +def PsqlTypeAsDictionary(record): + result = {} + for f in record._asdict(): + a = getattr(record, f) + result[f] = DigestPsqlType(a) + return result + + +def PsqlTypeAsList(a): + return list(map(DigestPsqlType, a)) + + +def ConnectToPostgres(mode): + import psycopg2 + if mode == 'interactive': + print('Please enter PostgreSQL URL, or config in JSON format with fields host, database, user and password.') + connection_str = getpass.getpass() + elif mode == 'environment': + connection_str = os.environ.get('LOGICA_PSQL_CONNECTION') + assert connection_str, ( + 'Please provide PSQL connection parameters ' + 'in LOGICA_PSQL_CONNECTION.') + else: + assert False, 'Unknown mode:' + mode + if connection_str.startswith('postgres'): + connection = psycopg2.connect(connection_str) + else: + connection_json = json.loads(connection_str) + connection = psycopg2.connect(**connection_json) + + connection.autocommit = True + return connection \ No newline at end of file diff --git a/compiler/dialect_libraries/duckdb_library.py b/compiler/dialect_libraries/duckdb_library.py new file mode 100644 index 00000000..284a94d7 --- /dev/null +++ b/compiler/dialect_libraries/duckdb_library.py @@ -0,0 +1,60 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +library = """ +->(left:, right:) = {arg: left, value: right}; +`=`(left:, right:) = right :- left == right; + +Arrow(left, right) = arrow :- + left == arrow.arg, + right == arrow.value; + +PrintToConsole(message) :- 1 == SqlExpr("PrintToConsole({message})", {message:}); + +ArgMin(arr) = Element( + SqlExpr("ArgMin({a}, {v}, 1)", {a:, v:}), 0) :- Arrow(a, v) == arr; + +ArgMax(arr) = Element( + SqlExpr("ArgMax({a}, {v}, 1)", {a:, v:}), 0) :- Arrow(a, v) == arr; + +ArgMinK(arr, k) = + SqlExpr("ArgMin({a}, {v}, {k})", {a:, v:, k:}) :- + Arrow(a, v) == arr; + +ArgMaxK(arr, k) = + SqlExpr("ArgMax({a}, {v}, {k})", {a:, v:, k:}) :- Arrow(a, v) == arr; + +Array(arr) = + SqlExpr("ArgMin({v}, {a}, null)", {a:, v:}) :- Arrow(a, v) == arr; + +ReadFile(filename) = SqlExpr("ReadFile({filename})", {filename:}); + +ReadJson(filename) = ReadFile(filename); + +WriteFile(filename, content:) = SqlExpr("WriteFile({filename}, {content})", + {filename:, content:}); + +Fingerprint(s) = SqlExpr("Fingerprint({s})", {s:}); + +Intelligence(command) = SqlExpr("Intelligence({command})", {command:}); + +AssembleRecord(field_values) = SqlExpr("AssembleRecord({field_values})", {field_values:}); + +DisassembleRecord(record) = SqlExpr("DisassembleRecord({record})", {record:}); + +Char(code) = SqlExpr("CHAR({code})", {code:}); + +""" diff --git a/compiler/dialects.py b/compiler/dialects.py index ba16f61f..45918f3b 100755 --- a/compiler/dialects.py +++ b/compiler/dialects.py @@ -25,6 +25,7 @@ from compiler.dialect_libraries import trino_library from compiler.dialect_libraries import presto_library from compiler.dialect_libraries import databricks_library + from compiler.dialect_libraries import duckdb_library else: from ..compiler.dialect_libraries import bq_library from ..compiler.dialect_libraries import psql_library @@ -32,7 +33,7 @@ from ..compiler.dialect_libraries import trino_library from ..compiler.dialect_libraries import presto_library from ..compiler.dialect_libraries import databricks_library - + from ..compiler.dialect_libraries import duckdb_library def Get(engine): return DIALECTS[engine]() @@ -387,12 +388,72 @@ def GroupBySpecBy(self): def DecorateCombineRule(self, rule, var): return rule +class DuckDB(Dialect): + """DuckDB dialect""" + + def Name(self): + return 'DuckDB' + + def BuiltInFunctions(self): + return { + 'Set': 'DistinctListAgg({0})', + 'Element': "JSON_EXTRACT({0}, '$[' || {1} || ']')", + 'Range': ('(select json_group_array(n) from (with recursive t as' + '(select 0 as n union all ' + 'select n + 1 as n from t where n + 1 < {0}) ' + 'select n from t) where n < {0})'), + 'ValueOfUnnested': '{0}.value', + 'List': 'JSON_GROUP_ARRAY({0})', + 'Size': 'JSON_ARRAY_LENGTH({0})', + 'Join': 'JOIN_STRINGS({0}, {1})', + 'Count': 'COUNT(DISTINCT {0})', + 'StringAgg': 'GROUP_CONCAT(%s)', + 'Sort': 'SortList({0})', + 'MagicalEntangle': 'MagicalEntangle({0}, {1})', + 'Format': 'Printf(%s)', + 'Least': 'MIN(%s)', + 'Greatest': 'MAX(%s)', + 'ToString': 'CAST(%s AS TEXT)', + 'DateAddDay': "DATE({0}, {1} || ' days')", + 'DateDiffDay': "CAST(JULIANDAY({0}) - JULIANDAY({1}) AS INT64)" + } + + def DecorateCombineRule(self, rule, var): + return DecorateCombineRule(rule, var) + + def InfixOperators(self): + return { + '++': '(%s) || (%s)', + '%' : '(%s) %% (%s)', + 'in': 'IN_LIST(%s, %s)' + } + + def Subscript(self, record, subscript, record_is_table): + if record_is_table: + return '%s.%s' % (record, subscript) + else: + return 'JSON_EXTRACT(%s, "$.%s")' % (record, subscript) + + def LibraryProgram(self): + return duckdb_library.library + + def UnnestPhrase(self): + return 'JSON_EACH({0}) as {1}' + + def ArrayPhrase(self): + return 'JSON_ARRAY(%s)' + + def GroupBySpecBy(self): + return 'expr' + + DIALECTS = { 'bigquery': BigQueryDialect, 'sqlite': SqLiteDialect, 'psql': PostgreSQL, 'presto': Presto, 'trino': Trino, - 'databricks': Databricks + 'databricks': Databricks, + 'duckdb': DuckDB, } From a7a99325daa53d8c2b3e09ed2aba07327b3be4d5 Mon Sep 17 00:00:00 2001 From: Yilin Xia Date: Wed, 19 Jun 2024 20:18:30 -0500 Subject: [PATCH 2/5] fix connect --- colab_logica.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/colab_logica.py b/colab_logica.py index 0575e725..c1ca8c4c 100755 --- a/colab_logica.py +++ b/colab_logica.py @@ -39,6 +39,7 @@ import os import pandas +import duckdb from .parser_py import parse from .common import sqlite3_logica @@ -174,9 +175,7 @@ def RunSQL(sql, engine, connection=None, is_final=False): else: psql_logica.PostgresExecute(sql, connection) elif engine == 'duckdb': - import duckdb - print("\n"+sql) - duckdb.sql(sql).show() + return duckdb.sql(sql).df() elif engine == 'sqlite': try: if is_final: @@ -214,7 +213,7 @@ def __call__(self, sql, engine, is_final): class DuckdbRunner(object): def __init__(self): - self.connection = sqlite3_logica.SqliteConnect() + self.connection = duckdb_logica.SqliteConnect() def __call__(self, sql, engine, is_final): return RunSQL(sql, engine, self.connection, is_final) @@ -274,7 +273,6 @@ def Logica(line, cell, run_query): except infer.TypeErrorCaughtException as e: e.ShowMessage() return - engine = program.annotations.Engine() if engine == 'bigquery' and not BQ_READY: @@ -319,6 +317,8 @@ def Logica(line, cell, run_query): color.Warn(predicate + '_sql')) with bar.output_to(logs_idx): + print(sql) + # print(program) if engine == 'sqlite': sql_runner = SqliteRunner() elif engine == 'psql': @@ -331,7 +331,7 @@ def Logica(line, cell, run_query): else: raise Exception('Logica only supports BigQuery, PostgreSQL and SQLite ' 'for now.') - try: + try: result_map = concertina_lib.ExecuteLogicaProgram( executions, sql_runner=sql_runner, sql_engine=engine, display_mode=DISPLAY_MODE) From 9a68b7eb7eb12f9f75c782d8713585eed1854d1a Mon Sep 17 00:00:00 2001 From: Yilin Xia Date: Wed, 19 Jun 2024 20:19:12 -0500 Subject: [PATCH 3/5] argmin argmax --- common/duckdb_logica.py | 8 +++- compiler/dialect_libraries/duckdb_library.py | 40 +++++++++----------- compiler/dialects.py | 16 ++++---- compiler/expr_translate.py | 8 +++- 4 files changed, 40 insertions(+), 32 deletions(-) diff --git a/common/duckdb_logica.py b/common/duckdb_logica.py index 443aeda6..dc7b72b0 100644 --- a/common/duckdb_logica.py +++ b/common/duckdb_logica.py @@ -19,6 +19,7 @@ import os import re from decimal import Decimal +import sqlite3 if '.' not in __package__: from type_inference.research import infer @@ -91,4 +92,9 @@ def ConnectToPostgres(mode): connection = psycopg2.connect(**connection_json) connection.autocommit = True - return connection \ No newline at end of file + return connection + + +def SqliteConnect(): + con = sqlite3.connect(':memory:') + return con \ No newline at end of file diff --git a/compiler/dialect_libraries/duckdb_library.py b/compiler/dialect_libraries/duckdb_library.py index 284a94d7..3f28d91a 100644 --- a/compiler/dialect_libraries/duckdb_library.py +++ b/compiler/dialect_libraries/duckdb_library.py @@ -24,37 +24,33 @@ PrintToConsole(message) :- 1 == SqlExpr("PrintToConsole({message})", {message:}); -ArgMin(arr) = Element( - SqlExpr("ArgMin({a}, {v}, 1)", {a:, v:}), 0) :- Arrow(a, v) == arr; +ArgMin(arr) = SqlExpr( + "argmin({a}, {v})", {a:, v:}) :- Arrow(a, v) == arr; -ArgMax(arr) = Element( - SqlExpr("ArgMax({a}, {v}, 1)", {a:, v:}), 0) :- Arrow(a, v) == arr; +ArgMax(arr) = SqlExpr( + "argmax({a}, {v})", {a:, v:}) :- Arrow(a, v) == arr; -ArgMinK(arr, k) = - SqlExpr("ArgMin({a}, {v}, {k})", {a:, v:, k:}) :- - Arrow(a, v) == arr; +ArgMaxK(a, l) = SqlExpr( + "(array_agg({arg} order by {value} desc))[1:{lim}]", + {arg: a.arg, value: a.value, lim: l}); -ArgMaxK(arr, k) = - SqlExpr("ArgMax({a}, {v}, {k})", {a:, v:, k:}) :- Arrow(a, v) == arr; +ArgMinK(a, l) = SqlExpr( + "(array_agg({arg} order by {value}))[1:{lim}]", + {arg: a.arg, value: a.value, lim: l}); Array(arr) = - SqlExpr("ArgMin({v}, {a}, null)", {a:, v:}) :- Arrow(a, v) == arr; + SqlExpr("ArgMin({v}, {a})", {a:, v:}) :- Arrow(a, v) == arr; -ReadFile(filename) = SqlExpr("ReadFile({filename})", {filename:}); +RecordAsJson(r) = SqlExpr( + "ROW_TO_JSON({r})", {r:}); -ReadJson(filename) = ReadFile(filename); +Fingerprint(s) = SqlExpr("('x' || substr(md5({s}), 1, 16))::bit(64)::bigint", {s:}); -WriteFile(filename, content:) = SqlExpr("WriteFile({filename}, {content})", - {filename:, content:}); +ReadFile(filename) = SqlExpr("pg_read_file({filename})", {filename:}); -Fingerprint(s) = SqlExpr("Fingerprint({s})", {s:}); +Chr(x) = SqlExpr("Chr({x})", {x:}); -Intelligence(command) = SqlExpr("Intelligence({command})", {command:}); - -AssembleRecord(field_values) = SqlExpr("AssembleRecord({field_values})", {field_values:}); - -DisassembleRecord(record) = SqlExpr("DisassembleRecord({record})", {record:}); - -Char(code) = SqlExpr("CHAR({code})", {code:}); +Num(a) = a; +Str(a) = a; """ diff --git a/compiler/dialects.py b/compiler/dialects.py index 45918f3b..380a5eb3 100755 --- a/compiler/dialects.py +++ b/compiler/dialects.py @@ -397,19 +397,19 @@ def Name(self): def BuiltInFunctions(self): return { 'Set': 'DistinctListAgg({0})', - 'Element': "JSON_EXTRACT({0}, '$[' || {1} || ']')", + 'Element': "array_extract({0}, {1}+1)", 'Range': ('(select json_group_array(n) from (with recursive t as' '(select 0 as n union all ' 'select n + 1 as n from t where n + 1 < {0}) ' 'select n from t) where n < {0})'), - 'ValueOfUnnested': '{0}.value', - 'List': 'JSON_GROUP_ARRAY({0})', + 'ValueOfUnnested': '{0}', + 'List': '[{0}]', 'Size': 'JSON_ARRAY_LENGTH({0})', 'Join': 'JOIN_STRINGS({0}, {1})', 'Count': 'COUNT(DISTINCT {0})', 'StringAgg': 'GROUP_CONCAT(%s)', 'Sort': 'SortList({0})', - 'MagicalEntangle': 'MagicalEntangle({0}, {1})', + 'MagicalEntangle': '{0}', 'Format': 'Printf(%s)', 'Least': 'MIN(%s)', 'Greatest': 'MAX(%s)', @@ -432,16 +432,16 @@ def Subscript(self, record, subscript, record_is_table): if record_is_table: return '%s.%s' % (record, subscript) else: - return 'JSON_EXTRACT(%s, "$.%s")' % (record, subscript) - + return '(%s)' % (record) + def LibraryProgram(self): return duckdb_library.library def UnnestPhrase(self): - return 'JSON_EACH({0}) as {1}' + return 'unnest({0}) as {1}' def ArrayPhrase(self): - return 'JSON_ARRAY(%s)' + return '[%s]' def GroupBySpecBy(self): return 'expr' diff --git a/compiler/expr_translate.py b/compiler/expr_translate.py index 88279f1a..f84f30dd 100755 --- a/compiler/expr_translate.py +++ b/compiler/expr_translate.py @@ -246,7 +246,7 @@ def IntLiteral(self, literal): return str(literal['number']) def StrLiteral(self, literal): - if self.dialect.Name() in ["PostgreSQL", "Presto", "Trino", "SqLite"]: + if self.dialect.Name() in ["PostgreSQL", "Presto", "Trino", "SqLite", "DuckDB"]: # TODO: Do this safely. return '\'%s\'' % (literal['the_string'].replace("'", "''")) @@ -350,6 +350,12 @@ def Record(self, record, record_type=None): for f_v in sorted(record['field_value'], key=lambda x: StrIntKey(x['field']))) return 'ROW(%s)::%s' % (args, record_type) + if self.dialect.Name() == 'DuckDB': + arguments_str = ', '.join( + "%s: %s" % (f_v['field'], + self.ConvertToSql(f_v['value']['expression']) ) + for f_v in record['field_value']) + return '{%s}' % arguments_str return 'STRUCT(%s)' % arguments_str def GenericSqlExpression(self, record): From 3e7adeaf507970a8ce0368b9b540fc8af95bbeb8 Mon Sep 17 00:00:00 2001 From: Yilin Xia Date: Thu, 20 Jun 2024 21:31:09 -0500 Subject: [PATCH 4/5] fix duckdb list --- colab_logica.py | 5 ++++- compiler/dialect_libraries/duckdb_library.py | 8 ++++---- compiler/dialects.py | 13 +++++-------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/colab_logica.py b/colab_logica.py index c1ca8c4c..c390cc52 100755 --- a/colab_logica.py +++ b/colab_logica.py @@ -175,7 +175,10 @@ def RunSQL(sql, engine, connection=None, is_final=False): else: psql_logica.PostgresExecute(sql, connection) elif engine == 'duckdb': - return duckdb.sql(sql).df() + if is_final: + return duckdb.sql(sql).df() + else: + duckdb.sql(sql) elif engine == 'sqlite': try: if is_final: diff --git a/compiler/dialect_libraries/duckdb_library.py b/compiler/dialect_libraries/duckdb_library.py index 3f28d91a..9b9e1adb 100644 --- a/compiler/dialect_libraries/duckdb_library.py +++ b/compiler/dialect_libraries/duckdb_library.py @@ -31,12 +31,12 @@ "argmax({a}, {v})", {a:, v:}) :- Arrow(a, v) == arr; ArgMaxK(a, l) = SqlExpr( - "(array_agg({arg} order by {value} desc))[1:{lim}]", - {arg: a.arg, value: a.value, lim: l}); + "(array_agg({arg_1} order by {value_1} desc))[1:{lim}]", + {arg_1: a.arg, value_1: a.value, lim: l}); ArgMinK(a, l) = SqlExpr( - "(array_agg({arg} order by {value}))[1:{lim}]", - {arg: a.arg, value: a.value, lim: l}); + "(array_agg({arg_1} order by {value_1}))[1:{lim}]", + {arg_1: a.arg, value_1: a.value, lim: l}); Array(arr) = SqlExpr("ArgMin({v}, {a})", {a:, v:}) :- Arrow(a, v) == arr; diff --git a/compiler/dialects.py b/compiler/dialects.py index 380a5eb3..6491fb95 100755 --- a/compiler/dialects.py +++ b/compiler/dialects.py @@ -398,18 +398,18 @@ def BuiltInFunctions(self): return { 'Set': 'DistinctListAgg({0})', 'Element': "array_extract({0}, {1}+1)", - 'Range': ('(select json_group_array(n) from (with recursive t as' + 'Range': ('(select [n] from (with recursive t as' '(select 0 as n union all ' 'select n + 1 as n from t where n + 1 < {0}) ' 'select n from t) where n < {0})'), - 'ValueOfUnnested': '{0}', + 'ValueOfUnnested': '{0}.unnested_pod', 'List': '[{0}]', 'Size': 'JSON_ARRAY_LENGTH({0})', 'Join': 'JOIN_STRINGS({0}, {1})', 'Count': 'COUNT(DISTINCT {0})', 'StringAgg': 'GROUP_CONCAT(%s)', 'Sort': 'SortList({0})', - 'MagicalEntangle': '{0}', + 'MagicalEntangle': '(CASE WHEN {1} = 0 THEN {0} ELSE NULL END)', 'Format': 'Printf(%s)', 'Least': 'MIN(%s)', 'Greatest': 'MAX(%s)', @@ -429,16 +429,13 @@ def InfixOperators(self): } def Subscript(self, record, subscript, record_is_table): - if record_is_table: - return '%s.%s' % (record, subscript) - else: - return '(%s)' % (record) + return '%s.%s' % (record, subscript) def LibraryProgram(self): return duckdb_library.library def UnnestPhrase(self): - return 'unnest({0}) as {1}' + return '(select unnest({0}) as unnested_pod) as {1}' def ArrayPhrase(self): return '[%s]' From 5a412333e5827ec9bea9ecfaad8faf4d51f5e419 Mon Sep 17 00:00:00 2001 From: Yilin Xia Date: Thu, 20 Jun 2024 21:42:03 -0500 Subject: [PATCH 5/5] remove redundant scripts --- colab_logica.py | 2 -- common/duckdb_logica.py | 78 ----------------------------------------- 2 files changed, 80 deletions(-) diff --git a/colab_logica.py b/colab_logica.py index c390cc52..b0f0aa84 100755 --- a/colab_logica.py +++ b/colab_logica.py @@ -320,8 +320,6 @@ def Logica(line, cell, run_query): color.Warn(predicate + '_sql')) with bar.output_to(logs_idx): - print(sql) - # print(program) if engine == 'sqlite': sql_runner = SqliteRunner() elif engine == 'psql': diff --git a/common/duckdb_logica.py b/common/duckdb_logica.py index dc7b72b0..1bd87eee 100644 --- a/common/duckdb_logica.py +++ b/common/duckdb_logica.py @@ -14,87 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import getpass -import json -import os -import re from decimal import Decimal import sqlite3 -if '.' not in __package__: - from type_inference.research import infer -else: - from ..type_inference.research import infer - - -def PostgresExecute(sql, connection): - import psycopg2 - import psycopg2.extras - cursor = connection.cursor() - try: - cursor.execute(sql) - # Make connection aware of the used types. - types = re.findall(r'-- Logica type: (\w*)', sql) - for t in types: - if t != 'logicarecord893574736': # Empty record. - psycopg2.extras.register_composite(t, cursor, globally=True) - except psycopg2.errors.UndefinedTable as e: - raise infer.TypeErrorCaughtException( - infer.ContextualizedError.BuildNiceMessage( - 'Running SQL.', 'Undefined table used: ' + str(e))) - except psycopg2.Error as e: - connection.rollback() - raise e - return cursor - - -def DigestPsqlType(x): - if isinstance(x, tuple): - return PsqlTypeAsDictionary(x) - if isinstance(x, list) and len(x) > 0: - return PsqlTypeAsList(x) - if isinstance(x, Decimal): - if x.as_integer_ratio()[1] == 1: - return int(x) - else: - return float(x) - return x - - -def PsqlTypeAsDictionary(record): - result = {} - for f in record._asdict(): - a = getattr(record, f) - result[f] = DigestPsqlType(a) - return result - - -def PsqlTypeAsList(a): - return list(map(DigestPsqlType, a)) - - -def ConnectToPostgres(mode): - import psycopg2 - if mode == 'interactive': - print('Please enter PostgreSQL URL, or config in JSON format with fields host, database, user and password.') - connection_str = getpass.getpass() - elif mode == 'environment': - connection_str = os.environ.get('LOGICA_PSQL_CONNECTION') - assert connection_str, ( - 'Please provide PSQL connection parameters ' - 'in LOGICA_PSQL_CONNECTION.') - else: - assert False, 'Unknown mode:' + mode - if connection_str.startswith('postgres'): - connection = psycopg2.connect(connection_str) - else: - connection_json = json.loads(connection_str) - connection = psycopg2.connect(**connection_json) - - connection.autocommit = True - return connection - - def SqliteConnect(): con = sqlite3.connect(':memory:') return con \ No newline at end of file