diff --git a/clickhouse_sqlalchemy/drivers/native/base.py b/clickhouse_sqlalchemy/drivers/native/base.py index b8cc37d..1e455c6 100644 --- a/clickhouse_sqlalchemy/drivers/native/base.py +++ b/clickhouse_sqlalchemy/drivers/native/base.py @@ -7,6 +7,7 @@ from ..base import ( ClickHouseDialect, ClickHouseExecutionContextBase, ClickHouseSQLCompiler, ) +from ...sql.dml import Insert from sqlalchemy.engine.interfaces import ExecuteStyle from sqlalchemy import __version__ as sqlalchemy_version @@ -20,10 +21,14 @@ class ClickHouseExecutionContext(ClickHouseExecutionContextBase): def pre_exec(self): + if not self.isinsert: + return # Always do executemany on INSERT with VALUES clause. - if (self.isinsert and self.compiled.statement.select is None and - self.parameters != [{}]): + if self.compiled.statement.select is None and self.parameters != [{}]: self.execute_style = ExecuteStyle.EXECUTEMANY + if (isinstance(self.compiled.statement, Insert) and + self.compiled.statement._values_iterator): + self.parameters = self.compiled.statement._values_iterator class ClickHouseNativeSQLCompiler(ClickHouseSQLCompiler): diff --git a/clickhouse_sqlalchemy/sql/__init__.py b/clickhouse_sqlalchemy/sql/__init__.py index 7516b59..9ba0a30 100644 --- a/clickhouse_sqlalchemy/sql/__init__.py +++ b/clickhouse_sqlalchemy/sql/__init__.py @@ -1,6 +1,6 @@ from .schema import Table, MaterializedView from .selectable import Select, select +from .dml import Insert, insert - -__all__ = ('Table', 'MaterializedView', 'Select', 'select') +__all__ = ('Table', 'MaterializedView', 'Select', 'select', 'Insert', 'insert') diff --git a/clickhouse_sqlalchemy/sql/dml.py b/clickhouse_sqlalchemy/sql/dml.py new file mode 100644 index 0000000..b633af8 --- /dev/null +++ b/clickhouse_sqlalchemy/sql/dml.py @@ -0,0 +1,15 @@ +from sqlalchemy.sql.dml import Insert as BaseInsert + +__all__ = ('Insert', 'insert') + + +class Insert(BaseInsert): + _values_iterator: None + + def values_iterator(self, columns, iterator): + self._values_iterator = iterator + self._multi_values = ([{column: None for column in columns}],) + return self + + +insert = Insert diff --git a/docs/features.rst b/docs/features.rst index 20dce7e..d98103a 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -694,6 +694,23 @@ INSERT FROM SELECT statement: .from_select(['day', 'value'], select_query) ) +Streaming insert: + + .. code-block:: python + from datetime import datetime + from clickhouse_sqlalchemy import sql + + def generator(): + for i in range(100): + yield [datetime.now(), 1, i] + + session.execute( + sql.insert(Statistics).values_iterator( + [Statistics.date, Statistics.sign, Statistics.grouping], + generator() + ) + ) + UPDATE and DELETE ----------------- diff --git a/tests/sql/test_insert.py b/tests/sql/test_insert.py index 26d2fdc..65bd96f 100644 --- a/tests/sql/test_insert.py +++ b/tests/sql/test_insert.py @@ -1,6 +1,6 @@ from sqlalchemy import Column, literal_column, select -from clickhouse_sqlalchemy import types, Table, engines +from clickhouse_sqlalchemy import types, Table, engines, sql from tests.testcase import NativeSessionTestCase from tests.util import require_server_version @@ -39,3 +39,40 @@ def test_insert_map(self): rv = self.session.execute(select(table.c.x)).scalar() self.assertEqual(rv, dict_map) + + @require_server_version(19, 3, 3) + def test_insert_iterator(self): + table = Table( + 't', self.metadata(), + Column('x', types.String, primary_key=True), + engines.Log() + ) + + def generator(): + yield ["foo"] + yield ["bar"] + + with self.create_table(table): + query = sql.insert(table).values_iterator([table.c.x], generator()) + self.session.execute(query) + + result = list(self.session.execute(select(table.c.x))) + self.assertListEqual(result, [('foo',), ('bar',)]) + + @require_server_version(19, 3, 3) + def test_insert_iterator_list(self): + table = Table( + 't', self.metadata(), + Column('x', types.String, primary_key=True), + engines.Log() + ) + + with self.create_table(table): + query = sql.insert(table).values_iterator( + [table.c.x], + [["foo"], ["bar"]] + ) + self.session.execute(query) + + result = list(self.session.execute(select(table.c.x))) + self.assertListEqual(result, [('foo',), ('bar',)])