From cbb46b6d863a465a19a06020ec3020078e529d02 Mon Sep 17 00:00:00 2001 From: David Vegh Date: Tue, 22 Aug 2023 09:37:25 +0200 Subject: [PATCH] Fixed: add missing warning when `max_rows` is exceeded --- src/ipyvizzu/data/converters/df/converter.py | 12 ++++++++++++ src/ipyvizzu/data/converters/pandas/converter.py | 2 +- src/ipyvizzu/data/converters/spark/converter.py | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/ipyvizzu/data/converters/df/converter.py b/src/ipyvizzu/data/converters/df/converter.py index 0579911a..9f5c2709 100644 --- a/src/ipyvizzu/data/converters/df/converter.py +++ b/src/ipyvizzu/data/converters/df/converter.py @@ -4,6 +4,7 @@ from abc import abstractmethod from typing import List +import warnings from ipyvizzu.data.converters.converter import ToSeriesListConverter from ipyvizzu.data.converters.df.type_alias import DataFrame @@ -49,6 +50,17 @@ def _get_series_from_column(self, column_name: str) -> Series: values, infer_type = self._convert_to_series_values_and_type(column_name) return self._convert_to_series(column_name, values, infer_type) + def _is_max_rows_exceeded(self, row_number: int) -> bool: + if row_number > self._max_rows: + warnings.warn( + "The number of rows of the dataframe exceeds the set `max_rows`, " + f"the dataframe is randomly sampled to the set value ({self._max_rows}).", + UserWarning, + stacklevel=2, + ) + return True + return False + @abstractmethod def _get_sampled_df(self, df: DataFrame) -> DataFrame: """ diff --git a/src/ipyvizzu/data/converters/pandas/converter.py b/src/ipyvizzu/data/converters/pandas/converter.py index 0a06ca15..1f80f3b4 100644 --- a/src/ipyvizzu/data/converters/pandas/converter.py +++ b/src/ipyvizzu/data/converters/pandas/converter.py @@ -110,7 +110,7 @@ def _convert_to_df(self, series: "pandas.Series") -> "pandas.Dataframe": # type def _get_sampled_df(self, df: "pandas.DataFrame") -> "pandas.DataFrame": # type: ignore row_number = len(df) - if row_number > self._max_rows: + if self._is_max_rows_exceeded(row_number): frac = self._max_rows / row_number sampled_df = df.sample( replace=False, diff --git a/src/ipyvizzu/data/converters/spark/converter.py b/src/ipyvizzu/data/converters/spark/converter.py index 2580de8b..37755fd7 100644 --- a/src/ipyvizzu/data/converters/spark/converter.py +++ b/src/ipyvizzu/data/converters/spark/converter.py @@ -68,7 +68,7 @@ def _get_sampled_df( self, df: "pyspark.sql.DataFrame" # type: ignore ) -> "pyspark.sql.DataFrame": # type: ignore row_number = df.count() - if row_number > self._max_rows: + if self._is_max_rows_exceeded(row_number): fraction = self._max_rows / row_number sample_df = df.sample(withReplacement=False, fraction=fraction, seed=42) return sample_df.limit(self._max_rows)