Skip to content

Commit

Permalink
feat: cacheable PathDataFrameBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov committed Mar 5, 2025
1 parent ebcdab6 commit 2dc5c1d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rbyte"
version = "0.14.1"
version = "0.14.2"
description = "Multimodal PyTorch dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
Expand All @@ -16,7 +16,7 @@ dependencies = [
"cachetools>=5.5.1",
"structlog>=25.1.0",
"tqdm>=4.67.1",
"pipefunc[autodoc]>=0.57.0",
"pipefunc[autodoc]>=0.57.1",
"xxhash>=3.5.0",
]
readme = "README.md"
Expand Down
54 changes: 38 additions & 16 deletions src/rbyte/io/path/dataframe_builder.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import os
from collections.abc import Iterator, Mapping
from functools import cached_property
from typing import final
from typing import Self, final

import polars as pl
from optree import tree_map
from polars._typing import PolarsDataType # noqa: PLC2701
from polars.datatypes import (
DataType, # pyright: ignore[reportUnusedImport] # noqa: F401
DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401
)
from pydantic import ConfigDict, DirectoryPath, validate_call
from polars.polars import dtype_str_repr # pyright: ignore[reportUnknownVariableType]
from pydantic import DirectoryPath, field_serializer, model_validator, validate_call
from structlog import get_logger
from structlog.contextvars import bound_contextvars
from xxhash import xxh3_64_hexdigest as digest

from rbyte.config.base import BaseModel

logger = get_logger(__name__)

Expand All @@ -27,25 +32,40 @@ def scantree(path: str) -> Iterator[str]:
yield entry.path


@final
class PathDataFrameBuilder:
__name__ = __qualname__
class Config(BaseModel):
fields: Fields
pattern: str

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(self, *, fields: Fields, pattern: str) -> None:
if set(fields) != set(
pl.Series(dtype=pl.String).str.extract_groups(pattern).struct.fields
@model_validator(mode="after")
def _validate_model(self) -> Self:
if set(self.fields) != set(
pl.Series(dtype=pl.String).str.extract_groups(self.pattern).struct.fields
):
logger.error(
msg := "field keys don't match pattern groups",
fields=fields,
pattern=pattern,
fields=self.fields,
pattern=self.pattern,
)

raise ValueError(msg)

self._fields = fields
self._pattern = pattern
return self

@field_serializer("fields", when_used="json")
@staticmethod
def _serialize_fields(fields: Fields) -> Mapping[str, str | None]:
return tree_map(dtype_str_repr, fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownVariableType, reportReturnType]


@final
class PathDataFrameBuilder:
__name__ = __qualname__

def __init__(self, *, fields: Fields, pattern: str) -> None:
self._config = Config(fields=fields, pattern=pattern)

def __pipefunc_hash__(self) -> str: # noqa: PLW3201
return digest(self._config.model_dump_json())

@validate_call
def __call__(self, path: DirectoryPath) -> pl.DataFrame:
Expand All @@ -63,18 +83,20 @@ def _build(self, path: str) -> pl.DataFrame:
pl.col("path")
.str.strip_prefix(path)
.str.strip_prefix("/")
.str.extract_groups(self._pattern)
.str.extract_groups(self._config.pattern)
.alias("groups")
)
.unnest("groups")
.drop_nulls()
.select(self._fields)
.select(self._config.fields)
.cast(self._schema, strict=True) # pyright: ignore[reportArgumentType]
.collect()
)

@cached_property
def _schema(self) -> Mapping[str, PolarsDataType]:
return {
name: dtype for name, dtype in self._fields.items() if dtype is not None
name: dtype
for name, dtype in self._config.fields.items()
if dtype is not None
}

0 comments on commit 2dc5c1d

Please sign in to comment.