diff --git a/crates/polars-ops/src/frame/mod.rs b/crates/polars-ops/src/frame/mod.rs index d72ff4488251..9b0d71c017dc 100644 --- a/crates/polars-ops/src/frame/mod.rs +++ b/crates/polars-ops/src/frame/mod.rs @@ -70,8 +70,13 @@ pub trait DataFrameOps: IntoDf { /// +------+------+------+--------+--------+--------+---------+---------+---------+ /// ``` #[cfg(feature = "to_dummies")] - fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult { - self._to_dummies(None, separator, drop_first) + fn to_dummies( + &self, + separator: Option<&str>, + drop_first: bool, + keep_columns: bool, + ) -> PolarsResult { + self._to_dummies(None, separator, drop_first, keep_columns) } #[cfg(feature = "to_dummies")] @@ -80,8 +85,9 @@ pub trait DataFrameOps: IntoDf { columns: Vec<&str>, separator: Option<&str>, drop_first: bool, + keep_columns: bool, ) -> PolarsResult { - self._to_dummies(Some(columns), separator, drop_first) + self._to_dummies(Some(columns), separator, drop_first, keep_columns) } #[cfg(feature = "to_dummies")] @@ -90,6 +96,7 @@ pub trait DataFrameOps: IntoDf { columns: Option>, separator: Option<&str>, drop_first: bool, + keep_columns: bool, ) -> PolarsResult { use crate::series::ToDummies; @@ -105,7 +112,10 @@ pub trait DataFrameOps: IntoDf { df.get_columns() .par_iter() .map(|s| match set.contains(s.name().as_str()) { - true => s.as_materialized_series().to_dummies(separator, drop_first), + true => { + s.as_materialized_series() + .to_dummies(separator, drop_first, keep_columns) + }, false => Ok(s.clone().into_frame()), }) .collect::>>() diff --git a/crates/polars-ops/src/series/ops/to_dummies.rs b/crates/polars-ops/src/series/ops/to_dummies.rs index eb2cf3a228c1..3a0b1940f62d 100644 --- a/crates/polars-ops/src/series/ops/to_dummies.rs +++ b/crates/polars-ops/src/series/ops/to_dummies.rs @@ -13,11 +13,21 @@ type DummyType = i32; type DummyCa = Int32Chunked; pub trait ToDummies { - fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult; + fn to_dummies( + &self, + separator: Option<&str>, + drop_first: bool, + keep_column: bool, + ) -> PolarsResult; } impl ToDummies for Series { - fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult { + fn to_dummies( + &self, + separator: Option<&str>, + drop_first: bool, + keep_column: bool, + ) -> PolarsResult { let sep = separator.unwrap_or("_"); let col_name = self.name(); let groups = self.group_tuples(true, drop_first)?; @@ -25,26 +35,32 @@ impl ToDummies for Series { // SAFETY: groups are in bounds let columns = unsafe { self.agg_first(&groups) }; let columns = columns.iter().zip(groups.iter()).skip(drop_first as usize); - let columns = columns - .map(|(av, group)| { - // strings are formatted with extra \" \" in polars, so we - // extract the string - let name = if let Some(s) = av.get_str() { - format_pl_smallstr!("{col_name}{sep}{s}") - } else { - // other types don't have this formatting issue - format_pl_smallstr!("{col_name}{sep}{av}") - }; - - let ca = match group { - GroupsIndicator::Idx((_, group)) => dummies_helper_idx(group, self.len(), name), - GroupsIndicator::Slice([offset, len]) => { - dummies_helper_slice(offset, len, self.len(), name) - }, - }; - ca.into_column() - }) - .collect::>(); + let columns = columns.map(|(av, group)| { + // strings are formatted with extra \" \" in polars, so we + // extract the string + let name = if let Some(s) = av.get_str() { + format_pl_smallstr!("{col_name}{sep}{s}") + } else { + // other types don't have this formatting issue + format_pl_smallstr!("{col_name}{sep}{av}") + }; + + let ca = match group { + GroupsIndicator::Idx((_, group)) => dummies_helper_idx(group, self.len(), name), + GroupsIndicator::Slice([offset, len]) => { + dummies_helper_slice(offset, len, self.len(), name) + }, + }; + ca.into_column() + }); + + let columns = if keep_column { + std::iter::once(self.clone().into_column()) + .chain(columns) + .collect() + } else { + columns.collect() + }; // SAFETY: `dummies_helper` functions preserve `self.len()` length unsafe { DataFrame::new_no_length_checks(sort_columns(columns)) } diff --git a/crates/polars-python/src/dataframe/general.rs b/crates/polars-python/src/dataframe/general.rs index e866e7db1004..cda8ddfb557a 100644 --- a/crates/polars-python/src/dataframe/general.rs +++ b/crates/polars-python/src/dataframe/general.rs @@ -550,13 +550,14 @@ impl PyDataFrame { Ok(s.map(|s| s.into())) } - #[pyo3(signature = (columns, separator, drop_first=false))] + #[pyo3(signature = (columns, separator, drop_first=false, keep_columns=false))] pub fn to_dummies( &self, py: Python, columns: Option>, separator: Option<&str>, drop_first: bool, + keep_columns: bool, ) -> PyResult { let df = py .allow_threads(|| match columns { @@ -564,8 +565,9 @@ impl PyDataFrame { cols.iter().map(|x| x as &str).collect(), separator, drop_first, + keep_columns, ), - None => self.df.to_dummies(separator, drop_first), + None => self.df.to_dummies(separator, drop_first, keep_columns), }) .map_err(PyPolarsErr::from)?; Ok(df.into()) diff --git a/crates/polars-python/src/series/general.rs b/crates/polars-python/src/series/general.rs index 3134f5354f09..6426e9786d1c 100644 --- a/crates/polars-python/src/series/general.rs +++ b/crates/polars-python/src/series/general.rs @@ -336,15 +336,16 @@ impl PySeries { Ok(s.into()) } - #[pyo3(signature = (separator, drop_first=false))] + #[pyo3(signature = (separator, drop_first=false, keep_column=false))] fn to_dummies( &self, py: Python, separator: Option<&str>, drop_first: bool, + keep_column: bool, ) -> PyResult { let df = py - .allow_threads(|| self.series.to_dummies(separator, drop_first)) + .allow_threads(|| self.series.to_dummies(separator, drop_first, keep_column)) .map_err(PyPolarsErr::from)?; Ok(df.into()) } diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 662d09e83d5b..fe5bd409aa89 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -9682,6 +9682,7 @@ def to_dummies( *, separator: str = "_", drop_first: bool = False, + keep_columns: bool = False, ) -> DataFrame: """ Convert categorical variables into dummy/indicator variables. @@ -9695,6 +9696,8 @@ def to_dummies( Separator/delimiter used when generating column names. drop_first Remove the first category from the variables being encoded. + keep_columns + Retain columns used to generated dummy columns. Examples -------- @@ -9727,6 +9730,17 @@ def to_dummies( │ 1 ┆ 1 ┆ 1 │ └───────┴───────┴───────┘ + >>> df.to_dummies(keep_columns=True) + shape: (2, 9) + ┌─────┬───────┬───────┬─────┬───┬───────┬─────┬───────┬───────┐ + │ foo ┆ foo_1 ┆ foo_2 ┆ bar ┆ … ┆ bar_4 ┆ ham ┆ ham_a ┆ ham_b │ + │ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ u8 ┆ u8 ┆ i64 ┆ ┆ u8 ┆ str ┆ u8 ┆ u8 │ + ╞═════╪═══════╪═══════╪═════╪═══╪═══════╪═════╪═══════╪═══════╡ + │ 1 ┆ 1 ┆ 0 ┆ 3 ┆ … ┆ 0 ┆ a ┆ 1 ┆ 0 │ + │ 2 ┆ 0 ┆ 1 ┆ 4 ┆ … ┆ 1 ┆ b ┆ 0 ┆ 1 │ + └─────┴───────┴───────┴─────┴───┴───────┴─────┴───────┴───────┘ + >>> import polars.selectors as cs >>> df.to_dummies(cs.integer(), separator=":") shape: (2, 5) @@ -9752,7 +9766,9 @@ def to_dummies( """ if columns is not None: columns = _expand_selectors(self, columns) - return self._from_pydf(self._df.to_dummies(columns, separator, drop_first)) + return self._from_pydf( + self._df.to_dummies(columns, separator, drop_first, keep_columns) + ) def unique( self, diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index d86c9d29cd0f..37fd5748aaa1 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -2148,7 +2148,11 @@ def quantile( return self._s.quantile(quantile, interpolation) def to_dummies( - self, *, separator: str = "_", drop_first: bool = False + self, + *, + separator: str = "_", + drop_first: bool = False, + keep_column: bool = False, ) -> DataFrame: """ Get dummy/indicator variables. @@ -2159,6 +2163,8 @@ def to_dummies( Separator/delimiter used when generating column names. drop_first Remove the first category from the variable being encoded. + keep_column + Retain column used to generated dummy columns. Examples -------- @@ -2186,8 +2192,19 @@ def to_dummies( │ 1 ┆ 0 │ │ 0 ┆ 1 │ └─────┴─────┘ - """ - return wrap_df(self._s.to_dummies(separator, drop_first)) + >>> s.to_dummies(keep_column=True) + shape: (3, 4) + ┌─────┬─────┬─────┬─────┐ + │ a ┆ a_1 ┆ a_2 ┆ a_3 │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ u8 ┆ u8 ┆ u8 │ + ╞═════╪═════╪═════╪═════╡ + │ 1 ┆ 1 ┆ 0 ┆ 0 │ + │ 2 ┆ 0 ┆ 1 ┆ 0 │ + │ 3 ┆ 0 ┆ 0 ┆ 1 │ + └─────┴─────┴─────┴─────┘ + """ + return wrap_df(self._s.to_dummies(separator, drop_first, keep_column)) @unstable() def cut( diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index c375e1952347..31d753a3c4a7 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -806,6 +806,25 @@ def test_to_dummies_drop_first() -> None: ] +def test_to_dummies_keep_columns() -> None: + df = pl.DataFrame({"A": ["a", "b", "c"], "B": [1, 3, 5]}) + dummies = df.to_dummies(keep_columns=True) + + expected = pl.DataFrame( + { + "A": ["a", "b", "c"], + "A_a": pl.Series([1, 0, 0], dtype=pl.UInt8), + "A_b": pl.Series([0, 1, 0], dtype=pl.UInt8), + "A_c": pl.Series([0, 0, 1], dtype=pl.UInt8), + "B": [1, 3, 5], + "B_1": pl.Series([1, 0, 0], dtype=pl.UInt8), + "B_3": pl.Series([0, 1, 0], dtype=pl.UInt8), + "B_5": pl.Series([0, 0, 1], dtype=pl.UInt8), + } + ) + assert_frame_equal(dummies, expected) + + def test_to_pandas(df: pl.DataFrame) -> None: # pyarrow cannot deal with unsigned dictionary integer yet. # pyarrow cannot convert a time64 w/ non-zero nanoseconds diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 0ed478b9aa83..aa34c9d09b78 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1357,6 +1357,16 @@ def test_to_dummies_drop_first() -> None: assert_frame_equal(result, expected) +def test_to_dummies_keep_column() -> None: + s = pl.Series("a", [1, 2, 3]) + result = s.to_dummies(keep_column=True) + expected = pl.DataFrame( + {"a": [1, 2, 3], "a_1": [1, 0, 0], "a_2": [0, 1, 0], "a_3": [0, 0, 1]}, + schema={"a": pl.Int64, "a_1": pl.UInt8, "a_2": pl.UInt8, "a_3": pl.UInt8}, + ) + assert_frame_equal(result, expected) + + def test_to_dummies_null_clash_19096() -> None: with pytest.raises( DuplicateError, match="column with name '_null' has more than one occurrence"