Skip to content

Commit

Permalink
build: refactor tidydataframe with contextmanager
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-nelson-uiuc committed Nov 30, 2024
1 parent a4843e8 commit 9b73b67
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 62 deletions.
6 changes: 3 additions & 3 deletions src/tidy_tools/core/_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
TidyDataFrame = "TidyDataFrame"

from pyspark.sql import Column, DataFrame
from pyspark.sql import Column, DataFrame, GroupedData


ColumnReference = str | Column
DataFrameReference = DataFrame | Optional[DataFrame]
DataFrameReference = DataFrame | GroupedData
64 changes: 64 additions & 0 deletions src/tidy_tools/frame/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import datetime
import difflib
from contextlib import contextmanager

from attrs import define
from attrs import field
from loguru import logger
from pyspark.sql import types as T


logger.level("ENTER", color="<green>")
logger.level("EXIT", color="<red>")


@define
class TidySnapshot:
operation: str
message: str
schema: T.StructType
dimensions: tuple[int, int]
timestamp: datetime.datetime = field(default=datetime.datetime.now())


@contextmanager
def tidy_context():
"""Define context manager for handling tidy operations."""
context = {"operation_log": [], "snapshots": []}
try:
logger.log("ENTER", ">> Converting data to TidyDataFrame")
yield context
logger.log("EXIT", "<< Returning data as DataFrame")
finally:
for log in context["operation_log"]:
print(log)


def compute_delta(snapshot1: TidySnapshot, snapshot2: TidySnapshot):
# Get schema differences using difflib
schema_diff = compare_schemas(snapshot1.schema, snapshot2.schema)
print("Schema Changes:")
print("\n".join(schema_diff))

# Get dimension (row/column count) differences using difflib
dimension_diff = compare_dimensions(snapshot1.dimensions, snapshot2.dimensions)
print("Dimension Changes:")
print("\n".join(dimension_diff))


def compare_schemas(schema1, schema2):
# Extract column names and types for comparison
cols1 = [f"{field.name}: {field.dataType}" for field in schema1.fields]
cols2 = [f"{field.name}: {field.dataType}" for field in schema2.fields]

# Use difflib to compare column lists
return list(difflib.ndiff(cols1, cols2))


def compare_dimensions(dim1, dim2):
# Compare row and column counts as text for difflib
row_diff = f"Rows: {dim1[0]} -> {dim2[0]}"
col_diff = f"Columns: {dim1[1]} -> {dim2[1]}"

# Using difflib to show dimension changes
return list(difflib.ndiff([row_diff], [col_diff]))
159 changes: 138 additions & 21 deletions src/tidy_tools/frame/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,149 @@
import functools
import inspect
import warnings
from typing import Callable
from typing import Optional

from attrs import define
from attrs import field
from attrs import validators
from loguru import logger
from pyspark.sql import DataFrame
from pyspark.sql import GroupedData
from tidy_tools.core.selector import ColumnSelector
from tidy_tools.frame.context import TidySnapshot


@define
class TidyDataFrame:
_data: str
_data: DataFrame = field(validator=validators.instance_of((DataFrame, GroupedData)))
config: dict = field(factory=dict)
context: Optional[dict] = field(default=None)

def __attrs_post_init__(self):
self.config.setdefault("name", self.__class__.__name__)
self.config.setdefault("count", True)
self.config.setdefault("display", True)
self.config.setdefault("verbose", False)
self.config.setdefault("register_tidy", True)
if self.config.get("register_tidy"):
# TODO: add feature to selectively register modules
# self.__class__.register(tidy_filter)
# self.__class__.register(tidy_select)
# self.__class__.register(tidy_summarize)
pass

def __repr__(self):
return f"{self.config.get('name')} [{self.count():,} rows x {len(self.columns)} cols]"

@classmethod
def register(cls, module):
"""Register external functions as methods of TidyDataFrame."""
for name, func in inspect.getmembers(module, inspect.isfunction):
setattr(cls, name, func)

def _logger(self, operation: Callable, message: str, level: str = "info") -> None:
def _snapshot(self, operation: str, message: str, dimensions: tuple[int, int]):
"""Captures a snapshot of the DataFrame"""
snapshot = TidySnapshot(
operation=operation,
message=message,
schema=self._data.schema,
dimensions=dimensions,
)
if self.context is not None:
self.context["snapshots"].append(snapshot)

def _log(
self,
operation: str = "comment",
message: str = "no message provided",
level: str = "info",
) -> None:
getattr(logger, level)(f"#> {operation:<12}: {message}")
return self

def _record(message: str) -> None:
def decorator(func: Callable):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if hasattr(self, func.__name__):
result = func(self, *args, **kwargs)
if self.context:
self._snapshot(
operation=func.__name__,
message=eval(f"f'{message}'"),
dimensions=(self.count(), len(self._data.columns)),
)
self._log(
operation=func.__name__,
message=eval(f"f'{message}'"),
)
return result

return wrapper

return decorator

@property
def columns(self):
"""Returns the raw Spark DataFrame"""
return self._data.columns

@property
def dtypes(self):
"""Return all column names and data types as a list"""
return self._data.dtypes

@property
def describe(self, *cols):
"""Compute basic statistics for numeric and string columns."""
return self._data.describe(*cols)

@property
def schema(self):
return self._data.schema

@property
def columns(self):
return self._data.columns
def data(self):
"""Returns the raw Spark DataFrame"""
logger.info(">> exit: TidyDataFrame context ending.")
return self._data

def display(self, limit: int = 10):
"""
Control execution of display method
This method masks the `pyspark.sql.DataFrame.display` method. This method does not
mask the native PySpark display function.
Often, the `.display()` method will need to be disabled for logging purposes. Similar
to toggling the `.count()` method, users can temporarily disable a DataFrame's
ability to display to the console by passing `toggle_display = True`.
"""
if not self.config.get("display"):
self._log(
operation="display", message="display is toggled off", level="warning"
)
else:
self._data.limit(limit).display()
return self

def select(self, *selectors: ColumnSelector, strict: bool = True):
def show(self, limit: int = 10):
if not self.config.get("display"):
self._log(
operation="show", message="display is toggled off", level="warning"
)
else:
self._data.limit(limit).show()
return self

def count(self, result: Optional[DataFrame] = None) -> int:
"""Retrieve number of rows in DataFrame."""
if not self.config.get("count"):
return 0
if not self.context["snapshots"]:
return self._data.count()
if result:
return result._data.count()
return self.context["snapshots"][-1].dimensions[0]

@_record(message="selected {len(result._data.columns)} columns")
def select(
self, *selectors: ColumnSelector, strict: bool = True, invert: bool = False
):
compare_operator = all if strict else any
selected = set(
[
Expand All @@ -53,17 +154,32 @@ def select(self, *selectors: ColumnSelector, strict: bool = True):
)
]
)
if len(selected) < 1:
warnings.warn("No columns matched the selector(s).")
self._data = self._data.select(*selected)
return self

if invert:
result = self._data.drop(*selected)
else:
result = self._data.select(*selected)
return TidyDataFrame(result, config=self.config, context=self.context)

def drop(self, *selectors: ColumnSelector, strict: bool = True) -> "TidyDataFrame":
return self.select(*selectors, strict=strict, invert=True)

@_record(message="removed {self.count() - self.count(result):,} rows")
def filter(self, condition):
result = self._data.filter(condition)
return TidyDataFrame(result, config=self.config, context=self.context)

@_record(message='added column {args[0] if args else kwargs.get("colName")}')
def withColumn(self, colName, col):
result = self._data.withColumn(colName, col)
return TidyDataFrame(result, config=self.config, context=self.context)

@_record(message="calling pipe operator!!!")
def pipe(self, *funcs: Callable):
"""Chain multiple custom transformation functions to be applied iteratively."""
self._data = functools.reduce(
result = functools.reduce(
lambda init, func: init.transform(func), funcs, self._data
)
return self
return TidyDataFrame(result, config=self.config, context=self.context)

def __getattr__(self, attr):
"""
Expand All @@ -88,11 +204,12 @@ def __getattr__(self, attr):
def wrapper(*args, **kwargs):
result = getattr(self._data, attr)(*args, **kwargs)
if isinstance(result, DataFrame):
self._data = result
self._logger(
operation=attr, message="not yet implemented", level="warning"
)
return self
return TidyDataFrame(
result, config=self.config, context=self.context
)
else:
return self

Expand Down
13 changes: 8 additions & 5 deletions src/tidy_tools/frame/workflow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from attrs import define
from tidy_tools.frame._types import Functions
from tidy_tools.frame._types import Objects


@define
class TidyWorkFlow:
input: Objects
funcs: Functions
class tidyworkflow:
def __enter__(self):
print("Starting")
return self

def __exit__(self, *exc):
print("Finishing")
return False
64 changes: 31 additions & 33 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,40 @@
# TODO: enable pyspark.testing at some point
# from pyspark.testing import assertDataFrameEqual
from tidy_tools.core.filters import filter_nulls
from tidy_tools.frame import TidyDataFrame


class TestFilters:
def test_filter_nulls(self, eits_data):
tidy_data = TidyDataFrame(eits_data)
eits_data.filter_nulls = filter_nulls
tidy_data.filter_nulls = filter_nulls

# test `filter_nulls` is equivalent to `DataFrame.na.drop`
assert eits_data.na.drop(how="any").count() == tidy_data.filter_nulls().count()

assert (
eits_data.na.drop(how="all").count()
== tidy_data.filter_nulls(strict=True).count()
)

columns = [
"title",
"release_year",
"release_date",
"recorded_at",
"tracks",
"duration_minutes",
"rating",
]
assert (
eits_data.na.drop(subset=[columns]).count()
== tidy_data.filter_nulls(*columns).count()
)

columns = ["formats", "producer", "ceritifed_gold", "comments"]
assert (
eits_data.na.drop(subset=[columns]).count()
== tidy_data.filter_nulls(*columns).count()
)
# tidy_data = TidyDataFrame(eits_data)
# tidy_data.filter_nulls = filter_nulls

# # test `filter_nulls` is equivalent to `DataFrame.na.drop`
# assert eits_data.na.drop(how="any").count() == tidy_data.filter_nulls().count()

# assert (
# eits_data.na.drop(how="all").count()
# == tidy_data.filter_nulls(strict=True).count()
# )

# columns = [
# "title",
# "release_year",
# "release_date",
# "recorded_at",
# "tracks",
# "duration_minutes",
# "rating",
# ]
# assert (
# eits_data.na.drop(subset=[columns]).count()
# == tidy_data.filter_nulls(*columns).count()
# )

# columns = ["formats", "producer", "ceritifed_gold", "comments"]
# assert (
# eits_data.na.drop(subset=[columns]).count()
# == tidy_data.filter_nulls(*columns).count()
# )
assert True

def test_filter_regex(self, eits_data):
# tidy_data = TidyDataFrame(eits_data)
Expand Down

0 comments on commit 9b73b67

Please sign in to comment.