From 4d3b9e4ad474b58507a3c27c1d51d6c2d66de956 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Tue, 30 Jul 2024 12:49:43 +0200 Subject: [PATCH] Add PivotTable Transform (#630) --- lumen/transforms/base.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/lumen/transforms/base.py b/lumen/transforms/base.py index 41f53c50a..21a2c9df6 100644 --- a/lumen/transforms/base.py +++ b/lumen/transforms/base.py @@ -544,6 +544,42 @@ def apply(self, table: DataFrame) -> DataFrame: return pivot_table +class PivotTable(Transform): + """ + `PivotTable` applies pandas.pivot_table` to the data. + """ + + values = param.ListSelector(default=[], doc=""" + Column or columns to aggregate.""") + + index = param.ListSelector(default=[], doc=""" + Column, Grouper, array, or list of the previous + Keys to group by on the pivot table index. If a list is passed, + it can contain any of the other types (except list). If an array is + passed, it must be the same length as the data and will be used in + the same manner as column values.""") + + columns = param.ListSelector(default=[], doc=""" + Column, Grouper, array, or list of the previous + Keys to group by on the pivot table column. If a list is passed, + it can contain any of the other types (except list). If an array is + passed, it must be the same length as the data and will be used in + the same manner as column values.""") + + aggfunc = param.String(default="mean", doc=""" + Function, list of functions, dict, default 'mean'""") + + _field_params: ClassVar[List[str]] = ['values', 'index', 'columns'] + + def apply(self, table: DataFrame) -> DataFrame: + values = self.values if len(self.values) > 1 else self.values[0] + columns = self.columns if len(self.columns) > 1 else self.columns[0] + return pd.pivot_table( + table, values=values, index=self.index, columns=columns, + aggfunc=self.aggfunc + ) + + class Melt(Transform): """ `Melt` applies the `pandas.melt` operation given the `id_vars` and `value_vars`.