Skip to content

Commit

Permalink
test: update tests with config file
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-nelson-uiuc committed Nov 28, 2024
1 parent 6be5ac8 commit 07f2736
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 62 deletions.
11 changes: 11 additions & 0 deletions src/tidy_tools/tidy/_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sys
from loguru import logger

logger.remove()
logger.add(sys.stderr, format="{time:HH:mm} | <level>{level}</level> | {message}")


def _logger(message: str, level: str = "info") -> None:
if not hasattr(logger, level):
raise ValueError(f"Logger does not have {level=}")
getattr(logger, level)(message)
5 changes: 5 additions & 0 deletions src/tidy_tools/tidy/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Iterable, Callable


Functions = Callable | Iterable[Callable]
Objects = object | Iterable[object]
1 change: 1 addition & 0 deletions src/tidy_tools/tidy/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def select(self, *selectors: ColumnSelector, strict: bool = True):
return self

def pipe(self, *funcs: Callable):
"""Chain multiple custom transformation functions to be applied iteratively."""
self._data = functools.reduce(
lambda init, func: init.transform(func), funcs, self._data
)
Expand Down
47 changes: 43 additions & 4 deletions src/tidy_tools/tidy/workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,45 @@
from attrs import define
# from typing import Any
# import inspect
# import functools

# from attrs import define, field

@define
class TidyWorkFlow:
pass
# from tidy_tools.tidy._logger import _logger
# from tidy_tools.tidy._types import Functions, Objects


# def identity(obj: Any) -> Any:
# """Return input object as is."""
# return obj


# def metadata_factory() -> dict:
# return dict(
# name="No name provided",
# description="No description provided"
# )

# @define
# class TidyWorkFlow:
# input: Objects
# funcs: Functions
# preprocess: Functions = field(default=identity)
# postprocess: Functions = field(default=identity)
# metadata: dict = field(factory=metadata_factory)

# def run(self):
# input = map(self.preprocess, self.input)
# result = functools.reduce(
# lambda init, func: transform(init, func),
# self.funcs,
# self.input
# )
# output = self.preprocess(result)
# return output


# def metadata(self):
# return {
# func.__name__: inspect.getdoc(func)
# for func in self.funcs
# }
58 changes: 58 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import datetime
import pytest

from pyspark.sql import SparkSession, types as T

from tidy_tools.tidy import TidyDataFrame


@pytest.fixture
def spark_fixture():
spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()
yield spark


@pytest.fixture
def sample_data(spark_fixture):
data = spark_fixture.createDataFrame(
[
{
"name": "Homer",
"birth_date": datetime.date(1956, 5, 12),
"original_air_date": datetime.datetime(1987, 4, 19, 20, 0, 0),
"seasons": 36,
"instrument": None,
},
{
"name": "Marge",
"birth_date": datetime.date(1956, 10, 1),
"original_air_date": datetime.datetime(1987, 4, 19, 20, 0, 0),
"seasons": 36,
"instrument": None,
},
{
"name": "Bart",
"birth_date": datetime.date(1979, 4, 1),
"original_air_date": datetime.datetime(1987, 4, 19, 20, 0, 0),
"seasons": 36,
"instrument": None,
},
{
"name": "Lisa",
"birth_date": datetime.date(1981, 5, 9),
"original_air_date": datetime.datetime(1987, 4, 19, 20, 0, 0),
"seasons": 36,
"instrument": "Saxophone",
},
],
schema=T.StructType(
[
T.StructField("name", T.StringType(), nullable=False),
T.StructField("birth_date", T.DateType(), nullable=False),
T.StructField("original_air_date", T.TimestampType(), nullable=False),
T.StructField("seasons", T.IntegerType(), nullable=False),
T.StructField("instrument", T.StringType(), nullable=True),
]
),
)
yield TidyDataFrame(data)
Empty file added tests/test_filters.py
Empty file.
58 changes: 0 additions & 58 deletions tests/test_selector.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,4 @@
import pytest
import datetime

from pyspark.sql import SparkSession, types as T

from tidy_tools.core import selector as cs
from tidy_tools.tidy import TidyDataFrame


@pytest.fixture
def spark_fixture():
spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()
yield spark


@pytest.fixture
def sample_data(spark_fixture):
data = spark_fixture.createDataFrame(
[
{
"name": "Homer",
"birth_date": datetime.date(1956, 5, 12),
"original_air_date": datetime.datetime(1987, 4, 19, 20, 0, 0),
"seasons": 36,
"instrument": None,
},
{
"name": "Marge",
"birth_date": datetime.date(1956, 10, 1),
"original_air_date": datetime.datetime(1987, 4, 19, 20, 0, 0),
"seasons": 36,
"instrument": None,
},
{
"name": "Bart",
"birth_date": datetime.date(1979, 4, 1),
"original_air_date": datetime.datetime(1987, 4, 19, 20, 0, 0),
"seasons": 36,
"instrument": None,
},
{
"name": "Lisa",
"birth_date": datetime.date(1981, 5, 9),
"original_air_date": datetime.datetime(1987, 4, 19, 20, 0, 0),
"seasons": 36,
"instrument": "Saxophone",
},
],
schema=T.StructType(
[
T.StructField("name", T.StringType(), nullable=False),
T.StructField("birth_date", T.DateType(), nullable=False),
T.StructField("original_air_date", T.TimestampType(), nullable=False),
T.StructField("seasons", T.IntegerType(), nullable=False),
T.StructField("instrument", T.StringType(), nullable=True),
]
),
)
yield TidyDataFrame(data)


class TestColumnSelector:
Expand Down

0 comments on commit 07f2736

Please sign in to comment.