Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add keep_column(s) params to to_dummies #14844

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions crates/polars-ops/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ pub trait DataFrameOps: IntoDf {
/// +------+------+------+--------+--------+--------+---------+---------+---------+
/// ```
#[cfg(feature = "to_dummies")]
fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult<DataFrame> {
self._to_dummies(None, separator, drop_first)
fn to_dummies(
&self,
separator: Option<&str>,
drop_first: bool,
keep_columns: bool,
) -> PolarsResult<DataFrame> {
self._to_dummies(None, separator, drop_first, keep_columns)
}

#[cfg(feature = "to_dummies")]
Expand All @@ -80,8 +85,9 @@ pub trait DataFrameOps: IntoDf {
columns: Vec<&str>,
separator: Option<&str>,
drop_first: bool,
keep_columns: bool,
) -> PolarsResult<DataFrame> {
self._to_dummies(Some(columns), separator, drop_first)
self._to_dummies(Some(columns), separator, drop_first, keep_columns)
}

#[cfg(feature = "to_dummies")]
Expand All @@ -90,6 +96,7 @@ pub trait DataFrameOps: IntoDf {
columns: Option<Vec<&str>>,
separator: Option<&str>,
drop_first: bool,
keep_columns: bool,
) -> PolarsResult<DataFrame> {
use crate::series::ToDummies;

Expand All @@ -105,7 +112,10 @@ pub trait DataFrameOps: IntoDf {
df.get_columns()
.par_iter()
.map(|s| match set.contains(s.name().as_str()) {
true => s.as_materialized_series().to_dummies(separator, drop_first),
true => {
s.as_materialized_series()
.to_dummies(separator, drop_first, keep_columns)
},
false => Ok(s.clone().into_frame()),
})
.collect::<PolarsResult<Vec<_>>>()
Expand Down
60 changes: 38 additions & 22 deletions crates/polars-ops/src/series/ops/to_dummies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,54 @@ type DummyType = i32;
type DummyCa = Int32Chunked;

pub trait ToDummies {
fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult<DataFrame>;
fn to_dummies(
&self,
separator: Option<&str>,
drop_first: bool,
keep_column: bool,
) -> PolarsResult<DataFrame>;
}

impl ToDummies for Series {
fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult<DataFrame> {
fn to_dummies(
&self,
separator: Option<&str>,
drop_first: bool,
keep_column: bool,
) -> PolarsResult<DataFrame> {
let sep = separator.unwrap_or("_");
let col_name = self.name();
let groups = self.group_tuples(true, drop_first)?;

// SAFETY: groups are in bounds
let columns = unsafe { self.agg_first(&groups) };
let columns = columns.iter().zip(groups.iter()).skip(drop_first as usize);
let columns = columns
.map(|(av, group)| {
// strings are formatted with extra \" \" in polars, so we
// extract the string
let name = if let Some(s) = av.get_str() {
format_pl_smallstr!("{col_name}{sep}{s}")
} else {
// other types don't have this formatting issue
format_pl_smallstr!("{col_name}{sep}{av}")
};

let ca = match group {
GroupsIndicator::Idx((_, group)) => dummies_helper_idx(group, self.len(), name),
GroupsIndicator::Slice([offset, len]) => {
dummies_helper_slice(offset, len, self.len(), name)
},
};
ca.into_column()
})
.collect::<Vec<_>>();
let columns = columns.map(|(av, group)| {
// strings are formatted with extra \" \" in polars, so we
// extract the string
let name = if let Some(s) = av.get_str() {
format_pl_smallstr!("{col_name}{sep}{s}")
} else {
// other types don't have this formatting issue
format_pl_smallstr!("{col_name}{sep}{av}")
};

let ca = match group {
GroupsIndicator::Idx((_, group)) => dummies_helper_idx(group, self.len(), name),
GroupsIndicator::Slice([offset, len]) => {
dummies_helper_slice(offset, len, self.len(), name)
},
};
ca.into_column()
});

let columns = if keep_column {
std::iter::once(self.clone().into_column())
.chain(columns)
.collect()
} else {
columns.collect()
};

// SAFETY: `dummies_helper` functions preserve `self.len()` length
unsafe { DataFrame::new_no_length_checks(sort_columns(columns)) }
Expand Down
6 changes: 4 additions & 2 deletions crates/polars-python/src/dataframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,22 +550,24 @@ impl PyDataFrame {
Ok(s.map(|s| s.into()))
}

#[pyo3(signature = (columns, separator, drop_first=false))]
#[pyo3(signature = (columns, separator, drop_first=false, keep_columns=false))]
pub fn to_dummies(
&self,
py: Python,
columns: Option<Vec<String>>,
separator: Option<&str>,
drop_first: bool,
keep_columns: bool,
) -> PyResult<Self> {
let df = py
.allow_threads(|| match columns {
Some(cols) => self.df.columns_to_dummies(
cols.iter().map(|x| x as &str).collect(),
separator,
drop_first,
keep_columns,
),
None => self.df.to_dummies(separator, drop_first),
None => self.df.to_dummies(separator, drop_first, keep_columns),
})
.map_err(PyPolarsErr::from)?;
Ok(df.into())
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-python/src/series/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,16 @@ impl PySeries {
Ok(s.into())
}

#[pyo3(signature = (separator, drop_first=false))]
#[pyo3(signature = (separator, drop_first=false, keep_column=false))]
fn to_dummies(
&self,
py: Python,
separator: Option<&str>,
drop_first: bool,
keep_column: bool,
) -> PyResult<PyDataFrame> {
let df = py
.allow_threads(|| self.series.to_dummies(separator, drop_first))
.allow_threads(|| self.series.to_dummies(separator, drop_first, keep_column))
.map_err(PyPolarsErr::from)?;
Ok(df.into())
}
Expand Down
18 changes: 17 additions & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9682,6 +9682,7 @@ def to_dummies(
*,
separator: str = "_",
drop_first: bool = False,
keep_columns: bool = False,
) -> DataFrame:
"""
Convert categorical variables into dummy/indicator variables.
Expand All @@ -9695,6 +9696,8 @@ def to_dummies(
Separator/delimiter used when generating column names.
drop_first
Remove the first category from the variables being encoded.
keep_columns
Retain columns used to generated dummy columns.

Examples
--------
Expand Down Expand Up @@ -9727,6 +9730,17 @@ def to_dummies(
│ 1 ┆ 1 ┆ 1 │
└───────┴───────┴───────┘

>>> df.to_dummies(keep_columns=True)
shape: (2, 9)
┌─────┬───────┬───────┬─────┬───┬───────┬─────┬───────┬───────┐
│ foo ┆ foo_1 ┆ foo_2 ┆ bar ┆ … ┆ bar_4 ┆ ham ┆ ham_a ┆ ham_b │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ u8 ┆ u8 ┆ i64 ┆ ┆ u8 ┆ str ┆ u8 ┆ u8 │
╞═════╪═══════╪═══════╪═════╪═══╪═══════╪═════╪═══════╪═══════╡
│ 1 ┆ 1 ┆ 0 ┆ 3 ┆ … ┆ 0 ┆ a ┆ 1 ┆ 0 │
│ 2 ┆ 0 ┆ 1 ┆ 4 ┆ … ┆ 1 ┆ b ┆ 0 ┆ 1 │
└─────┴───────┴───────┴─────┴───┴───────┴─────┴───────┴───────┘

>>> import polars.selectors as cs
>>> df.to_dummies(cs.integer(), separator=":")
shape: (2, 5)
Expand All @@ -9752,7 +9766,9 @@ def to_dummies(
"""
if columns is not None:
columns = _expand_selectors(self, columns)
return self._from_pydf(self._df.to_dummies(columns, separator, drop_first))
return self._from_pydf(
self._df.to_dummies(columns, separator, drop_first, keep_columns)
)

def unique(
self,
Expand Down
23 changes: 20 additions & 3 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,7 +2148,11 @@ def quantile(
return self._s.quantile(quantile, interpolation)

def to_dummies(
self, *, separator: str = "_", drop_first: bool = False
self,
*,
separator: str = "_",
drop_first: bool = False,
keep_column: bool = False,
) -> DataFrame:
"""
Get dummy/indicator variables.
Expand All @@ -2159,6 +2163,8 @@ def to_dummies(
Separator/delimiter used when generating column names.
drop_first
Remove the first category from the variable being encoded.
keep_column
Retain column used to generated dummy columns.

Examples
--------
Expand Down Expand Up @@ -2186,8 +2192,19 @@ def to_dummies(
│ 1 ┆ 0 │
│ 0 ┆ 1 │
└─────┴─────┘
"""
return wrap_df(self._s.to_dummies(separator, drop_first))
>>> s.to_dummies(keep_column=True)
shape: (3, 4)
┌─────┬─────┬─────┬─────┐
│ a ┆ a_1 ┆ a_2 ┆ a_3 │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ u8 ┆ u8 ┆ u8 │
╞═════╪═════╪═════╪═════╡
│ 1 ┆ 1 ┆ 0 ┆ 0 │
│ 2 ┆ 0 ┆ 1 ┆ 0 │
│ 3 ┆ 0 ┆ 0 ┆ 1 │
└─────┴─────┴─────┴─────┘
"""
return wrap_df(self._s.to_dummies(separator, drop_first, keep_column))

@unstable()
def cut(
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,25 @@ def test_to_dummies_drop_first() -> None:
]


def test_to_dummies_keep_columns() -> None:
df = pl.DataFrame({"A": ["a", "b", "c"], "B": [1, 3, 5]})
dummies = df.to_dummies(keep_columns=True)

expected = pl.DataFrame(
{
"A": ["a", "b", "c"],
"A_a": pl.Series([1, 0, 0], dtype=pl.UInt8),
"A_b": pl.Series([0, 1, 0], dtype=pl.UInt8),
"A_c": pl.Series([0, 0, 1], dtype=pl.UInt8),
"B": [1, 3, 5],
"B_1": pl.Series([1, 0, 0], dtype=pl.UInt8),
"B_3": pl.Series([0, 1, 0], dtype=pl.UInt8),
"B_5": pl.Series([0, 0, 1], dtype=pl.UInt8),
}
)
assert_frame_equal(dummies, expected)


def test_to_pandas(df: pl.DataFrame) -> None:
# pyarrow cannot deal with unsigned dictionary integer yet.
# pyarrow cannot convert a time64 w/ non-zero nanoseconds
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,16 @@ def test_to_dummies_drop_first() -> None:
assert_frame_equal(result, expected)


def test_to_dummies_keep_column() -> None:
s = pl.Series("a", [1, 2, 3])
result = s.to_dummies(keep_column=True)
expected = pl.DataFrame(
{"a": [1, 2, 3], "a_1": [1, 0, 0], "a_2": [0, 1, 0], "a_3": [0, 0, 1]},
schema={"a": pl.Int64, "a_1": pl.UInt8, "a_2": pl.UInt8, "a_3": pl.UInt8},
)
assert_frame_equal(result, expected)


def test_to_dummies_null_clash_19096() -> None:
with pytest.raises(
DuplicateError, match="column with name '_null' has more than one occurrence"
Expand Down