Skip to content

Commit

Permalink
Use common add_df method, use Protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
veghdev committed Aug 15, 2023
1 parent 98b05fd commit 8926117
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 67 deletions.
4 changes: 2 additions & 2 deletions docs/tutorial/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ data.add_df(df)
### Using `pyspark` DataFrame

Use
[`add_spark_df`](../reference/ipyvizzu/animation.md#ipyvizzu.animation.Data.add_spark_df)
[`add_df`](../reference/ipyvizzu/animation.md#ipyvizzu.animation.Data.add_df)
method for adding `pyspark` DataFrame to
[`Data`](../reference/ipyvizzu/animation.md#ipyvizzu.animation.Data).

Expand Down Expand Up @@ -363,7 +363,7 @@ spark_data = [
df = spark.createDataFrame(spark_data, spark_schema)

data = Data()
data.add_spark_df(df)
data.add_df(df)
```

!!! note
Expand Down
4 changes: 1 addition & 3 deletions src/ipyvizzu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@
]


if sys.version_info >= (3, 7):
pass
else:
if sys.version_info < (3, 7):
# TODO: remove once support for Python 3.6 is dropped
warnings.warn(
"Python 3.6 support will be dropped in future versions.",
Expand Down
82 changes: 35 additions & 47 deletions src/ipyvizzu/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
import abc
import json
from os import PathLike
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Type, Union
import warnings

import jsonschema # type: ignore

from ipyvizzu.data.converters.defaults import NAN_DIMENSION, NAN_MEASURE
from ipyvizzu.data.converters.df.defaults import MAX_ROWS
from ipyvizzu.data.converters.numpy.converter import NumpyArrayConverter
from ipyvizzu.data.converters.pandas.converter import PandasDataFrameConverter
from ipyvizzu.data.converters.spark.converter import SparkDataFrameConverter
from ipyvizzu.data.converters.numpy.type_alias import ColumnName, ColumnDtype
from ipyvizzu.data.converters.numpy import ColumnDtype, ColumnName, NumpyArrayConverter
from ipyvizzu.data.converters.pandas import PandasDataFrameConverter
from ipyvizzu.data.converters.spark import SparkDataFrame, SparkDataFrameConverter
from ipyvizzu.data.type_alias import (
DimensionValue,
NestedMeasureValues,
Expand Down Expand Up @@ -277,28 +276,35 @@ def add_measure(

def add_df(
self,
df: Optional[Union["pandas.DataFrame", "pandas.Series"]], # type: ignore
df: Optional[ # type: ignore
Union[
"pandas.DataFrame",
"pandas.Series",
"pyspark.sql.DataFrame",
]
],
default_measure_value: MeasureValue = NAN_MEASURE,
default_dimension_value: DimensionValue = NAN_DIMENSION,
max_rows: int = MAX_ROWS,
include_index: Optional[str] = None,
) -> None:
"""
Add a `pandas` `DataFrame` or `Series` to an existing
[Data][ipyvizzu.animation.Data] class instance.
Add a `pandas` `DataFrame`, `Series` or a `pyspark` `DataFrame`
to an existing [Data][ipyvizzu.animation.Data] class instance.
Args:
df:
The `pandas` `DataFrame` or `Series` to add.
The `pandas` `DataFrame`, `Series` or the `pyspark` `DataFrame`to add.
default_measure_value:
The default measure value to fill empty values. Defaults to 0.
default_dimension_value:
The default dimension value to fill empty values. Defaults to an empty string.
max_rows: The maximum number of rows to include in the converted series list.
If the `df` contains more rows,
a random sample of the given number of rows will be taken.
a random sample of the given number of rows (approximately) will be taken.
include_index:
Add the data frame's index as a column with the given name. Defaults to `None`.
(Cannot be used with `pyspark` `DataFrame`.)
Example:
Adding a data frame to a [Data][ipyvizzu.animation.Data] class instance:
Expand All @@ -317,13 +323,25 @@ def add_df(
# pylint: disable=too-many-arguments

if not isinstance(df, type(None)):
converter = PandasDataFrameConverter(
df,
default_measure_value,
default_dimension_value,
max_rows,
include_index,
)
arguments = {
"df": df,
"default_measure_value": default_measure_value,
"default_dimension_value": default_dimension_value,
"max_rows": max_rows,
"include_index": include_index,
}
Converter: Union[
Type[PandasDataFrameConverter], Type[SparkDataFrameConverter]
] = PandasDataFrameConverter
if isinstance(df, SparkDataFrame):
Converter = SparkDataFrameConverter
if arguments["include_index"] is not None:
raise ValueError(
"`include_index` cannot be used with `pyspark` `DataFrame`"
)
del arguments["include_index"]

converter = Converter(**arguments) # type: ignore
series_list = converter.get_series_list()
self.add_series_list(series_list)

Expand Down Expand Up @@ -469,36 +487,6 @@ def add_np_array(
series_list = converter.get_series_list()
self.add_series_list(series_list)

def add_spark_df(
self,
df: Optional["pyspark.sql.DataFrame"], # type: ignore
default_measure_value: MeasureValue = NAN_MEASURE,
default_dimension_value: DimensionValue = NAN_DIMENSION,
max_rows: int = MAX_ROWS,
) -> None:
"""
Add a `pyspark` `DataFrame` to an existing
[Data][ipyvizzu.animation.Data] class instance.
Args:
df:
The `pyspark` `DataFrame` to add.
default_measure_value:
The default measure value to fill empty values. Defaults to 0.
default_dimension_value:
The default dimension value to fill empty values. Defaults to an empty string.
max_rows: The maximum number of rows to include in the converted series list.
If the `df` contains more rows,
a random sample of the given number of rows (approximately) will be taken.
"""

if not isinstance(df, type(None)):
converter = SparkDataFrameConverter(
df, default_measure_value, default_dimension_value, max_rows
)
series_list = converter.get_series_list()
self.add_series_list(series_list)

def _add_named_value(
self,
dest: str,
Expand Down
10 changes: 10 additions & 0 deletions src/ipyvizzu/data/converters/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
This module provides modules for numpy converter.
"""

from ipyvizzu.data.converters.numpy.converter import NumpyArrayConverter
from ipyvizzu.data.converters.numpy.type_alias import (
Index,
Name,
DType,
ColumnName,
ColumnDtype,
ColumnConfig,
)
3 changes: 3 additions & 0 deletions src/ipyvizzu/data/converters/pandas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
This module provides modules for pandas converter.
"""

from ipyvizzu.data.converters.pandas.converter import PandasDataFrameConverter
from ipyvizzu.data.converters.pandas.protocol import PandasDataFrame, PandasSeries
3 changes: 2 additions & 1 deletion src/ipyvizzu/data/converters/pandas/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ipyvizzu.data.converters.defaults import NAN_DIMENSION, NAN_MEASURE
from ipyvizzu.data.converters.df.defaults import MAX_ROWS
from ipyvizzu.data.converters.df.converter import DataFrameConverter
from ipyvizzu.data.converters.pandas.protocol import PandasSeries
from ipyvizzu.data.infer_type import InferType
from ipyvizzu.data.type_alias import (
DimensionValue,
Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(
super().__init__(default_measure_value, default_dimension_value, max_rows)
self._pd = self._get_pandas()
self._df = self._get_sampled_df(
self._convert_to_df(df) if isinstance(df, self._pd.Series) else df
self._convert_to_df(df) if isinstance(df, PandasSeries) else df
)
self._include_index = include_index

Expand Down
36 changes: 36 additions & 0 deletions src/ipyvizzu/data/converters/pandas/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
This module provides protocol classes for pandas data frame converter.
"""

from typing import Any, Callable, Sequence
from typing_extensions import Protocol, runtime_checkable


@runtime_checkable
class PandasDataFrame(Protocol):
"""
Represents a pandas DataFrame Protocol.
"""

# pylint: disable=too-few-public-methods

index: Any
columns: Sequence[str]
sample: Callable[..., Any]
__len__: Callable[[], int]
__getitem__: Callable[[Any], Any]


@runtime_checkable
class PandasSeries(Protocol):
"""
Represents a pandas Series Protocol.
"""

# pylint: disable=too-few-public-methods

index: Any
values: Any
dtype: Any
__len__: Callable[[], int]
__getitem__: Callable[[Any], Any]
3 changes: 3 additions & 0 deletions src/ipyvizzu/data/converters/spark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
This module provides modules for pyspark converter.
"""

from ipyvizzu.data.converters.spark.converter import SparkDataFrameConverter
from ipyvizzu.data.converters.spark.protocol import SparkDataFrame
23 changes: 23 additions & 0 deletions src/ipyvizzu/data/converters/spark/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
This module provides protocol classes for pandas data frame converter.
"""

from typing import Any, Callable, Sequence
from typing_extensions import Protocol, runtime_checkable


@runtime_checkable
class SparkDataFrame(Protocol):
"""
Represents a pyspark DataFrame Protocol.
"""

# pylint: disable=too-few-public-methods

columns: Sequence[str]
count: Callable[..., int]
sample: Callable[..., Any]
limit: Callable[..., Any]
select: Callable[..., Any]
withColumn: Callable[..., Any]
rdd: Any
29 changes: 17 additions & 12 deletions tests/test_data/test_pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,26 @@ def tearDownClass(cls) -> None:
super().tearDownClass()
cls.spark.stop()

def test_add_spark_df_if_pyspark_not_installed(self) -> None:
def test_add_df_if_pyspark_not_installed(self) -> None:
with RaiseImportError.module_name("pyspark"):
with self.assertRaises(ImportError):
self.data.add_spark_df(self.spark.createDataFrame([], StructType([])))
self.data.add_df(self.spark.createDataFrame([], StructType([])))

def test_add_spark_df_with_none(self) -> None:
self.data.add_spark_df(None)
def test_add_df_with_none(self) -> None:
self.data.add_df(None)
self.assertEqual(
{"data": {}},
self.data.build(),
)

def test_add_spark_df_with_empty_df(self) -> None:
self.data.add_spark_df(self.spark.createDataFrame([], StructType([])))
def test_add_df_with_empty_df(self) -> None:
self.data.add_df(self.spark.createDataFrame([], StructType([])))
self.assertEqual(
{"data": {}},
self.data.build(),
)

def test_add_spark_df_with_df(self) -> None:
def test_add_df_with_df(self) -> None:
schema = StructType(
[
StructField("DimensionSeries", StringType(), True),
Expand All @@ -51,13 +51,13 @@ def test_add_spark_df_with_df(self) -> None:
("2", 4),
]
df = self.spark.createDataFrame(df_data, schema)
self.data.add_spark_df(df)
self.data.add_df(df)
self.assertEqual(
self.ref_pd_series,
self.data.build(),
)

def test_add_spark_df_with_df_contains_na(self) -> None:
def test_add_df_with_df_contains_na(self) -> None:
schema = StructType(
[
StructField("DimensionSeries", StringType(), True),
Expand All @@ -69,13 +69,18 @@ def test_add_spark_df_with_df_contains_na(self) -> None:
(None, None),
]
df = self.spark.createDataFrame(df_data, schema)
self.data.add_spark_df(df)
self.data.add_df(df)
self.assertEqual(
self.ref_pd_series_with_nan,
self.data.build(),
)

def test_add_spark_df_with_df_and_max_rows(self) -> None:
def test_add_df_with_df_and_with_include_index(self) -> None:
df = self.spark.createDataFrame([], StructType([]))
with self.assertRaises(ValueError):
self.data.add_df(df, include_index="Index")

def test_add_df_with_df_and_max_rows(self) -> None:
max_rows = 2

dimension_data = ["0", "1", "2", "3", "4"]
Expand All @@ -91,7 +96,7 @@ def test_add_spark_df_with_df_and_max_rows(self) -> None:
]
)
df = self.spark.createDataFrame(df_data, schema)
self.data.add_spark_df(df, max_rows=max_rows)
self.data.add_df(df, max_rows=max_rows)

data_series = self.data.build()["data"]["series"]

Expand Down
4 changes: 2 additions & 2 deletions tests/test_docs/tutorial/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def tearDownClass(cls) -> None:
super().tearDownClass()
cls.spark.stop()

def test_add_spark_df(self) -> None:
def test_add_df(self) -> None:
spark_schema = StructType(
[
StructField("Genres", StringType(), True),
Expand All @@ -158,7 +158,7 @@ def test_add_spark_df(self) -> None:
("Metal", "Experimental", 58),
]
df = self.spark.createDataFrame(spark_data, spark_schema)
self.data.add_spark_df(df)
self.data.add_df(df)
self.assertEqual(
self.ref_pd_df_by_series,
self.data.build(),
Expand Down

0 comments on commit 8926117

Please sign in to comment.