From 2dc5c1d2e43945213325e9543fc987f1ea97b6a7 Mon Sep 17 00:00:00 2001 From: Evgenii Gorchakov Date: Wed, 5 Mar 2025 12:53:14 +0100 Subject: [PATCH] feat: cacheable PathDataFrameBuilder --- pyproject.toml | 4 +- src/rbyte/io/path/dataframe_builder.py | 54 ++++++++++++++++++-------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 96945d6..f89f126 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rbyte" -version = "0.14.1" +version = "0.14.2" description = "Multimodal PyTorch dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] @@ -16,7 +16,7 @@ dependencies = [ "cachetools>=5.5.1", "structlog>=25.1.0", "tqdm>=4.67.1", - "pipefunc[autodoc]>=0.57.0", + "pipefunc[autodoc]>=0.57.1", "xxhash>=3.5.0", ] readme = "README.md" diff --git a/src/rbyte/io/path/dataframe_builder.py b/src/rbyte/io/path/dataframe_builder.py index 3d48649..e87880f 100644 --- a/src/rbyte/io/path/dataframe_builder.py +++ b/src/rbyte/io/path/dataframe_builder.py @@ -1,17 +1,22 @@ import os from collections.abc import Iterator, Mapping from functools import cached_property -from typing import final +from typing import Self, final import polars as pl +from optree import tree_map from polars._typing import PolarsDataType # noqa: PLC2701 from polars.datatypes import ( DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 ) -from pydantic import ConfigDict, DirectoryPath, validate_call +from polars.polars import dtype_str_repr # pyright: ignore[reportUnknownVariableType] +from pydantic import DirectoryPath, field_serializer, model_validator, validate_call from structlog import get_logger from structlog.contextvars import bound_contextvars +from xxhash import xxh3_64_hexdigest as digest + +from rbyte.config.base import BaseModel logger = get_logger(__name__) @@ -27,25 +32,40 @@ def scantree(path: str) -> Iterator[str]: yield entry.path -@final -class PathDataFrameBuilder: - __name__ = __qualname__ +class Config(BaseModel): + fields: Fields + pattern: str - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) - def __init__(self, *, fields: Fields, pattern: str) -> None: - if set(fields) != set( - pl.Series(dtype=pl.String).str.extract_groups(pattern).struct.fields + @model_validator(mode="after") + def _validate_model(self) -> Self: + if set(self.fields) != set( + pl.Series(dtype=pl.String).str.extract_groups(self.pattern).struct.fields ): logger.error( msg := "field keys don't match pattern groups", - fields=fields, - pattern=pattern, + fields=self.fields, + pattern=self.pattern, ) raise ValueError(msg) - self._fields = fields - self._pattern = pattern + return self + + @field_serializer("fields", when_used="json") + @staticmethod + def _serialize_fields(fields: Fields) -> Mapping[str, str | None]: + return tree_map(dtype_str_repr, fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownVariableType, reportReturnType] + + +@final +class PathDataFrameBuilder: + __name__ = __qualname__ + + def __init__(self, *, fields: Fields, pattern: str) -> None: + self._config = Config(fields=fields, pattern=pattern) + + def __pipefunc_hash__(self) -> str: # noqa: PLW3201 + return digest(self._config.model_dump_json()) @validate_call def __call__(self, path: DirectoryPath) -> pl.DataFrame: @@ -63,12 +83,12 @@ def _build(self, path: str) -> pl.DataFrame: pl.col("path") .str.strip_prefix(path) .str.strip_prefix("/") - .str.extract_groups(self._pattern) + .str.extract_groups(self._config.pattern) .alias("groups") ) .unnest("groups") .drop_nulls() - .select(self._fields) + .select(self._config.fields) .cast(self._schema, strict=True) # pyright: ignore[reportArgumentType] .collect() ) @@ -76,5 +96,7 @@ def _build(self, path: str) -> pl.DataFrame: @cached_property def _schema(self) -> Mapping[str, PolarsDataType]: return { - name: dtype for name, dtype in self._fields.items() if dtype is not None + name: dtype + for name, dtype in self._config.fields.items() + if dtype is not None }