diff --git a/chdb/dbapi/connections.py b/chdb/dbapi/connections.py index 5fca50dc209..0d719894fe3 100644 --- a/chdb/dbapi/connections.py +++ b/chdb/dbapi/connections.py @@ -23,6 +23,7 @@ class Connection(object): """ _closed = False + _session = None def __init__(self, cursorclass=Cursor): @@ -39,6 +40,8 @@ def __init__(self, cursorclass=Cursor): self.connect() def connect(self): + from chdb import session as chs + self._session = chs.Session() self._closed = False self._execute_command("select 1;") self._read_query_result() @@ -55,6 +58,7 @@ def close(self): if self._closed: raise err.Error("Already closed") self._closed = True + self._session = None @property def open(self): @@ -119,8 +123,7 @@ def _execute_command(self, sql): if DEBUG: print("DEBUG: query:", sql) try: - import chdb - self._resp = chdb.query(sql, output_format="JSON").data() + self._resp = self._session.query(sql, fmt="JSON").data() except Exception as error: raise err.InterfaceError("query err: %s" % error) @@ -181,6 +184,10 @@ def __init__(self, connection): self.has_next = None def read(self): + # Handle empty responses (for instance from CREATE TABLE) + if self.connection.resp is None: + return + try: data = json.loads(self.connection.resp) except Exception as error: diff --git a/tests/test_dbapi.py b/tests/test_dbapi.py index fbc4cb98158..898cfaf2ea8 100644 --- a/tests/test_dbapi.py +++ b/tests/test_dbapi.py @@ -24,6 +24,26 @@ def test_select_version(self): print(data) self.assertRegex(data[0], expected_clickhouse_version_pattern) + def test_insert_and_read_data(self): + conn = dbapi.connect() + cur = conn.cursor() + cur.execute("CREATE DATABASE IF NOT EXISTS test_db ENGINE = Atomic") + cur.execute("USE test_db") + cur.execute(""" + CREATE TABLE rate ( + day Date, + value Int32 + ) ENGINE = Log""") + + # Insert values + cur.execute("INSERT INTO rate VALUES ('2024-01-01', 24)") + cur.execute("INSERT INTO rate VALUES ('2024-01-02', 72)") + + # Read values + cur.execute("SELECT value FROM rate ORDER BY day DESC") + rows = cur.fetchall() + self.assertEqual(rows, ((72,), (24,))) + def test_select_chdb_version(self): ver = dbapi.get_client_info() # chDB version liek '0.12.0' ver_tuple = dbapi.chdb_version # chDB version tuple like ('0', '12', '0')