diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index cccbca09..c61303c0 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -19,10 +19,10 @@ jobs: - "3.10" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: pip @@ -50,12 +50,12 @@ jobs: strategy: matrix: include: + # - python-version: "3.8" + # test-db-env: "postgres" + # pip-extra: "sqlalchemy <2" - python-version: "3.8" test-db-env: "postgres" - pip-extra: "sqlalchemy <2" - - python-version: "3.8" - test-db-env: "postgres" - pip-extra: "sqlalchemy >2" + pip-extra: "'sqlalchemy>2'" # - python-version: "3.8" # test-db-env: "sqlite" # - python-version: "3.9" @@ -68,13 +68,13 @@ jobs: # test-db-env: "sqlite" - python-version: "3.11" test-db-env: "postgres" - pip-extra: "sqlalchemy <2" + pip-extra: '"sqlalchemy<2" "pandas<2.2"' - python-version: "3.11" test-db-env: "postgres" - pip-extra: "sqlalchemy >2" + pip-extra: '"sqlalchemy>2"' - python-version: "3.11" test-db-env: "sqlite" - pip-extra: "sqlalchemy >2" + pip-extra: '"sqlalchemy>2"' services: # Label used to access the service container @@ -107,17 +107,17 @@ jobs: - 6333:6333 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: pip - name: Install dependencies run: | - pip install "${{ matrix.pip-extra }}" ".[sqlite,excel,milvus,gcsfs,s3fs,redis,qdrant,gcp]" "pytest<8" "pytest_cases" + pip install ${{ matrix.pip-extra }} ".[sqlite,excel,milvus,gcsfs,s3fs,redis,qdrant,gcp]" "pytest<8" "pytest_cases" - name: Test with pytest run: | diff --git a/.github/workflows/test_examples.yaml b/.github/workflows/test_examples.yaml index c16799c0..96176804 100644 --- a/.github/workflows/test_examples.yaml +++ b/.github/workflows/test_examples.yaml @@ -28,10 +28,10 @@ jobs: - RayExecutor steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: pip diff --git a/CHANGELOG.md b/CHANGELOG.md index 33ac0eda..f6a6e11f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,23 @@ +# 0.13.13 + +* Add `ComputeStep.get_status` method +* Remove restriction for Pandas < 2.2 + +# 0.13.12 + +* Add processing of an empty response in `QdrantStore` +* Add optional `index_schema` to `QdrantStore` +* Add redis cluster mode support in `RedisStore` + +# 0.13.11 + +* Remove logging to database (`datapipe_events` table) from `EventLogger` + # 0.13.10 * Fix compatibility with SQLalchemy < 2 (ColumnClause in typing) * Fix compatibility with Ray and SQLalchemy > 2 (serialization of Table) +* (post.1) Fix dependencies for MacOS; deprecate Python 3.8 # 0.13.9 diff --git a/README.md b/README.md index 480a4e3d..3c02d8ec 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ # Datapipe -`datapipe` is a real-time, incremental ETL library for Python with record-level dependency tracking. +[Datapipe](https://datapipe.dev/) is a real-time, incremental ETL library for Python with record-level dependency tracking. The library is designed for describing data processing pipelines and is capable of tracking dependencies for each record in the pipeline. This ensures that tasks within the pipeline receive only the data that has been modified, thereby improving the overall efficiency of data handling. +https://datapipe.dev/ + # Development At the moment these branches are active: diff --git a/datapipe/cli.py b/datapipe/cli.py index a420e450..753b2fbe 100644 --- a/datapipe/cli.py +++ b/datapipe/cli.py @@ -348,19 +348,16 @@ def list(ctx: click.Context, status: bool) -> None: # noqa extra_args = {} if status: - if len(step.input_dts) > 0: - try: - if isinstance(step, BaseBatchTransformStep): - changed_idx_count = step.get_changed_idx_count(ds=app.ds) - - if changed_idx_count > 0: - extra_args[ - "changed_idx_count" - ] = f"[red]{changed_idx_count}[/red]" - - except NotImplementedError: - # Currently we do not support empty join_keys - extra_args["changed_idx_count"] = "[red]N/A[/red]" + try: + step_status = step.get_status(ds=app.ds) + extra_args["total_idx_count"] = str(step_status.total_idx_count) + extra_args["changed_idx_count"] = ( + f"[red]{step_status.changed_idx_count}[/red]" + ) + except NotImplementedError: + # Currently we do not support empty join_keys + extra_args["total_idx_count"] = "[red]N/A[/red]" + extra_args["changed_idx_count"] = "[red]N/A[/red]" rprint(to_human_repr(step, extra_args=extra_args)) rprint("") diff --git a/datapipe/compute.py b/datapipe/compute.py index 7255867c..8c0ff0c1 100644 --- a/datapipe/compute.py +++ b/datapipe/compute.py @@ -40,6 +40,13 @@ def get_datatable(self, ds: DataStore, name: str) -> DataTable: return ds.get_or_create_table(name=name, table_store=self.catalog[name].store) +@dataclass +class StepStatus: + name: str + total_idx_count: int + changed_idx_count: int + + class ComputeStep: """ Шаг вычислений в графе вычислений. @@ -91,6 +98,9 @@ def name(self) -> str: def labels(self) -> Labels: return self._labels if self._labels else [] + def get_status(self, ds: DataStore) -> StepStatus: + raise NotImplementedError + # TODO: move to lints def validate(self) -> None: inp_p_keys_arr = [set(inp.primary_keys) for inp in self.input_dts if inp] diff --git a/datapipe/datatable.py b/datapipe/datatable.py index 6c42e9bd..6f34215a 100644 --- a/datapipe/datatable.py +++ b/datapipe/datatable.py @@ -609,9 +609,7 @@ def __init__( create_meta_table: bool = False, ) -> None: self.meta_dbconn = meta_dbconn - self.event_logger = EventLogger( - self.meta_dbconn, create_table=create_meta_table - ) + self.event_logger = EventLogger() self.tables: Dict[str, DataTable] = {} self.create_meta_table = create_meta_table diff --git a/datapipe/event_logger.py b/datapipe/event_logger.py index a678a591..5cc763c9 100644 --- a/datapipe/event_logger.py +++ b/datapipe/event_logger.py @@ -1,60 +1,14 @@ import logging -from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import Optional -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.sql import func -from sqlalchemy.sql.schema import Column, Table -from sqlalchemy.sql.sqltypes import JSON, DateTime, Integer, String from traceback_with_variables import format_exc from datapipe.run_config import RunConfig logger = logging.getLogger("datapipe.event_logger") -if TYPE_CHECKING: - from datapipe.store.database import DBConn - - -class EventTypes(Enum): - STATE = "state" - ERROR = "error" - - -class StepEventTypes(Enum): - RUN_FULL_COMPLETE = "run_full_complete" - class EventLogger: - def __init__(self, dbconn: "DBConn", create_table: bool = False): - self.dbconn = dbconn - - self.events_table = Table( - "datapipe_events", - dbconn.sqla_metadata, - Column("id", Integer, primary_key=True, autoincrement=True), - Column("event_ts", DateTime, server_default=func.now()), - Column("type", String(100)), - Column("event", JSON if dbconn.con.name == "sqlite" else JSONB), - ) - - self.step_events_table = Table( - "datapipe_step_events", - dbconn.sqla_metadata, - Column("id", Integer, primary_key=True, autoincrement=True), - Column("step", String(100)), - Column("event_ts", DateTime, server_default=func.now()), - Column("event", String(100)), - Column("event_payload", JSON if dbconn.con.name == "sqlite" else JSONB), - ) - - if create_table: - self.events_table.create(self.dbconn.con, checkfirst=True) - self.step_events_table.create(self.dbconn.con, checkfirst=True) - - def __reduce__(self) -> Tuple[Any, ...]: - return self.__class__, (self.dbconn,) - def log_state( self, table_name, @@ -66,34 +20,9 @@ def log_state( ): logger.debug( f'Table "{table_name}": added = {added_count}; updated = {updated_count}; ' - f"deleted = {deleted_count}, processed_count = {deleted_count}" + f"deleted = {deleted_count}, processed_count = {processed_count}" ) - if run_config is not None: - meta = { - "labels": run_config.labels, - "filters": run_config.filters, - } - else: - meta = {} - - ins = self.events_table.insert().values( - type=EventTypes.STATE.value, - event={ - "meta": meta, - "data": { - "table_name": table_name, - "added_count": added_count, - "updated_count": updated_count, - "deleted_count": deleted_count, - "processed_count": processed_count, - }, - }, - ) - - with self.dbconn.con.begin() as con: - con.execute(ins) - def log_error( self, type, @@ -106,29 +35,8 @@ def log_error( logger.error( f'Error in step {run_config.labels.get("step_name")}: {type} {message}\n{description}' ) - meta = { - "labels": run_config.labels, - "filters": run_config.filters, - } else: logger.error(f"Error: {type} {message}\n{description}") - meta = {} - - ins = self.events_table.insert().values( - type=EventTypes.ERROR.value, - event={ - "meta": meta, - "data": { - "type": type, - "message": message, - "description": description, - "params": params, - }, - }, - ) - - with self.dbconn.con.begin() as con: - con.execute(ins) def log_exception( self, @@ -148,11 +56,3 @@ def log_step_full_complete( step_name: str, ) -> None: logger.debug(f"Step {step_name} is marked complete") - - ins = self.step_events_table.insert().values( - step=step_name, - event=StepEventTypes.RUN_FULL_COMPLETE.value, - ) - - with self.dbconn.con.begin() as con: - con.execute(ins) diff --git a/datapipe/step/batch_transform.py b/datapipe/step/batch_transform.py index 8a1fd2f6..6ed989ba 100644 --- a/datapipe/step/batch_transform.py +++ b/datapipe/step/batch_transform.py @@ -42,7 +42,7 @@ from sqlalchemy.sql.expression import select from tqdm_loggable.auto import tqdm -from datapipe.compute import Catalog, ComputeStep, PipelineStep +from datapipe.compute import Catalog, ComputeStep, PipelineStep, StepStatus from datapipe.datatable import DataStore, DataTable, MetaTable from datapipe.executor import Executor, ExecutorConfig, SingleThreadExecutor from datapipe.run_config import LabelDict, RunConfig @@ -77,8 +77,7 @@ def __call__( input_dts: List[DataTable], run_config: Optional[RunConfig] = None, kwargs: Optional[Dict[str, Any]] = None, - ) -> TransformResult: - ... + ) -> TransformResult: ... BatchTransformFunc = Callable[..., TransformResult] @@ -517,6 +516,13 @@ def _apply_filters_to_run_config( run_config.filters = filters return run_config + def get_status(self, ds: DataStore) -> StepStatus: + return StepStatus( + name=self.name, + total_idx_count=self.meta_table.get_metadata_size(), + changed_idx_count=self.get_changed_idx_count(ds), + ) + def get_changed_idx_count( self, ds: DataStore, diff --git a/datapipe/store/qdrant.py b/datapipe/store/qdrant.py index 18ffc879..8d39d154 100644 --- a/datapipe/store/qdrant.py +++ b/datapipe/store/qdrant.py @@ -22,12 +22,27 @@ class QdrantStore(TableStore): Args: name (str): name of the Qdrant collection - url (str): url of the Qdrant server (if using with api_key, - you should explicitly specify port 443, by default qdrant uses 6333) - schema (DataSchema): Describes data that will be stored in the Qdrant collection - pk_field (str): name of the primary key field in the schema, used to identify records - embedding_field (str): name of the field in the schema that contains the vector representation of the record - collection_params (CollectionParams): parameters for creating a collection in Qdrant + + url (str): url of the Qdrant server (if using with api_key, you should + explicitly specify port 443, by default qdrant uses 6333) + + schema (DataSchema): Describes data that will be stored in the Qdrant + collection + + pk_field (str): name of the primary key field in the schema, used to + identify records + + embedding_field (str): name of the field in the schema that contains the + vector representation of the record + + collection_params (CollectionParams): parameters for creating a + collection in Qdrant + + index_schema (dict): {field_name: field_schema} - field(s) in payload + that will be used to create an index on. For data types and field + schema, check + https://qdrant.tech/documentation/concepts/indexing/#payload-index + api_key (Optional[str]): api_key for Qdrant server """ @@ -39,6 +54,7 @@ def __init__( pk_field: str, embedding_field: str, collection_params: CollectionParams, + index_schema: Optional[dict] = None, api_key: Optional[str] = None, ): super().__init__() @@ -55,13 +71,23 @@ def __init__( pk_columns = [column for column in self.schema if column.primary_key] if len(pk_columns) != 1 and pk_columns[0].name != pk_field: - raise ValueError("Incorrect prymary key columns in schema") + raise ValueError("Incorrect primary key columns in schema") - self.paylods_filelds = [ + self.payloads_filelds = [ column.name for column in self.schema if column.name != self.embedding_field ] - def __init(self): + self.index_field = {} + if index_schema: + # check if index field is present in schema + for field, field_schema in index_schema.items(): + if field not in self.payloads_filelds: + raise ValueError( + f"Index field `{field}` ({field_schema}) not found in payload schema" + ) + self.index_field = index_schema + + def __init_collection(self): self.client = QdrantClient(url=self.url, api_key=self._api_key) try: self.client.get_collection(self.name) @@ -71,9 +97,25 @@ def __init(self): collection_name=self.name, create_collection=self.collection_params ) + def __init_indexes(self): + """ + Checks on collection's payload indexes and adds them from index_field, if necessary. + Schema checks are not performed. + """ + payload_schema = self.client.get_collection(self.name).payload_schema + for field, field_schema in self.index_field.items(): + if field not in payload_schema.keys(): + self.client.create_payload_index( + collection_name=self.name, + field_name=field, + field_schema=field_schema, + ) + def __check_init(self): if not self.inited: - self.__init() + self.__init_collection() + if self.index_field: + self.__init_indexes() self.inited = True def __get_ids(self, df): @@ -107,7 +149,7 @@ def insert_rows(self, df: DataDF) -> None: vectors=df[self.embedding_field].apply(list).to_list(), payloads=cast( List[Dict[str, Any]], - df[self.paylods_filelds].to_dict(orient="records"), + df[self.payloads_filelds].to_dict(orient="records"), ), ), wait=True, @@ -146,6 +188,9 @@ def read_rows(self, idx: Optional[IndexDF] = None) -> DataDF: records = [] assert response.result is not None + if len(response.result) == 0: + return pd.DataFrame(columns=[column.name for column in self.schema]) + for point in response.result: record = point.payload @@ -169,6 +214,8 @@ class QdrantShardedStore(TableStore): schema (DataSchema): Describes data that will be stored in the Qdrant collection embedding_field (str): name of the field in the schema that contains the vector representation of the record collection_params (CollectionParams): parameters for creating a collection in Qdrant + index_schema (dict): {field_name: field_schema} - field(s) in payload that will be used to create an index on. + For data types and field schema, check https://qdrant.tech/documentation/concepts/indexing/#payload-index api_key (Optional[str]): api_key for Qdrant server """ @@ -179,6 +226,7 @@ def __init__( schema: DataSchema, embedding_field: str, collection_params: CollectionParams, + index_schema: Optional[dict] = None, api_key: Optional[str] = None, ): super().__init__() @@ -193,9 +241,20 @@ def __init__( self.client: Optional[QdrantClient] = None self.pk_fields = [column.name for column in self.schema if column.primary_key] - self.paylods_filelds = [ + self.payloads_filelds = [ column.name for column in self.schema if column.name != self.embedding_field ] + + self.index_field = {} + if index_schema: + # check if index field is present in schema + for field, field_schema in index_schema.items(): + if field not in self.payloads_filelds: + raise ValueError( + f"Index field `{field}` ({field_schema}) not found in payload schema" + ) + self.index_field = index_schema + self.name_params = re.findall(r"\{([^/]+?)\}", self.name_pattern) if not len(self.pk_fields): @@ -213,12 +272,28 @@ def __init_collection(self, name): collection_name=name, create_collection=self.collection_params ) + def __init_indexes(self, name): + """ + Checks on collection's payload indexes and adds them from index_field, if necessary. + Schema checks are not performed. + """ + payload_schema = self.client.get_collection(name).payload_schema + for field, field_schema in self.index_field.items(): + if field not in payload_schema.keys(): + self.client.create_payload_index( + collection_name=name, + field_name=field, + field_schema=field_schema, + ) + def __check_init(self, name): if not self.client: self.client = QdrantClient(url=self.url, api_key=self._api_key) if name not in self.inited_collections: self.__init_collection(name) + if self.index_field: + self.__init_indexes(name) self.inited_collections.add(name) def __get_ids(self, df): @@ -264,7 +339,7 @@ def insert_rows(self, df: DataDF) -> None: vectors=gdf[self.embedding_field].apply(list).to_list(), payloads=cast( List[Dict[str, Any]], - df[self.paylods_filelds].to_dict(orient="records"), + df[self.payloads_filelds].to_dict(orient="records"), ), ), wait=True, diff --git a/datapipe/store/redis.py b/datapipe/store/redis.py index 7b2c8fb6..d3113afc 100644 --- a/datapipe/store/redis.py +++ b/datapipe/store/redis.py @@ -3,6 +3,7 @@ import pandas as pd from redis.client import Redis +from redis.cluster import RedisCluster from sqlalchemy import Column from datapipe.store.database import MetaKey @@ -24,10 +25,13 @@ def _to_itertuples(df: DataDF, colnames): class RedisStore(TableStore): def __init__( - self, connection: str, name: str, data_sql_schema: List[Column] + self, connection: str, name: str, data_sql_schema: List[Column], cluster_mode: bool = False ) -> None: self.connection = connection - self.redis_connection = Redis.from_url(connection, decode_responses=True) + if not cluster_mode: + self.redis_connection: Union[Redis, RedisCluster] = Redis.from_url(connection, decode_responses=True) + else: + self.redis_connection = RedisCluster.from_url(connection, decode_responses=True) self.name = name self.data_sql_schema = data_sql_schema diff --git a/pyproject.toml b/pyproject.toml index 8dbd946a..6762650d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "datapipe-core" -version = "0.13.10" +version = "0.13.13" description = "`datapipe` is a realtime incremental ETL library for Python application" readme = "README.md" repository = "https://github.com/epoch8/datapipe" @@ -10,14 +10,14 @@ packages = [ ] [tool.poetry.dependencies] -python = ">=3.8,<3.12" +python = ">3.8,<3.12" fsspec = ">=2021.11.1" gcsfs = {version=">=2021.11.1", optional=true} s3fs = {version=">=2021.11.1", optional=true} # TODO Fix incompatibility between sqlalchemy < 2 and pandas 2.2 -pandas = ">=1.2.0, <2.2" +pandas = ">=1.2.0" numpy = ">=1.21.0, <2.0" SQLAlchemy = ">=1.4.25, <3.0.0" @@ -30,7 +30,7 @@ cityhash = "^0.4.2" # TODO 0.14: make it optional Pillow = "^10.0.0" -epoch8-tqdm-loggable = "^0.1.4" +tqdm-loggable = "^0.2" traceback-with-variables = "^2.0.4" pymilvus = {version="^2.0.2", optional=true} @@ -44,8 +44,8 @@ xlrd = {version=">=2.0.1", optional=true} openpyxl = {version=">=3.0.7", optional=true} redis = {version="^4.3.4", optional=true} -pysqlite3-binary = {version="^0.5.0", optional=true} -sqlalchemy-pysqlite3-binary = {version="^0.0.4", optional=true} +pysqlite3-binary = {version="^0.5.0", optional=true, markers="sys_platform != 'darwin'"} +sqlalchemy-pysqlite3-binary = {version="^0.0.4", optional=true, markers="sys_platform != 'darwin'"} qdrant-client = {version="^1.1.7", optional=true} click = ">=7.1.2" diff --git a/tests/test_qdrant_store.py b/tests/test_qdrant_store.py index 17113469..f922dd7b 100644 --- a/tests/test_qdrant_store.py +++ b/tests/test_qdrant_store.py @@ -3,7 +3,7 @@ import pandas as pd from qdrant_client.models import Distance, VectorParams -from sqlalchemy import ARRAY, Float, Integer +from sqlalchemy import ARRAY, Float, Integer, String from sqlalchemy.sql.schema import Column from datapipe.compute import Catalog, Pipeline, Table, build_compute, run_steps @@ -20,7 +20,9 @@ def extract_id(df: pd.DataFrame) -> pd.DataFrame: def generate_data() -> Generator[pd.DataFrame, None, None]: - yield pd.DataFrame({"id": [1], "embedding": [[0.1]]}) + yield pd.DataFrame( + {"id": [1], "embedding": [[0.1]], "str_payload": ["foo"], "int_payload": [42]} + ) def test_qdrant_table_to_json(dbconn: DBConn, tmp_dir: Path) -> None: @@ -34,6 +36,8 @@ def test_qdrant_table_to_json(dbconn: DBConn, tmp_dir: Path) -> None: schema=[ Column("id", Integer, primary_key=True), Column("embedding", ARRAY(Float, dimensions=1)), + Column("str_payload", String), + Column("int_payload", Integer), ], collection_params=CollectionParams( vectors=VectorParams( @@ -43,6 +47,14 @@ def test_qdrant_table_to_json(dbconn: DBConn, tmp_dir: Path) -> None: ), pk_field="id", embedding_field="embedding", + index_schema={ + "str_payload": "keyword", + "int_payload": { + "type": "integer", + "lookup": False, + "range": True, + }, + }, ) ), "output": Table(