Skip to content

Commit

Permalink
Add keep column(s) param
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Nov 17, 2024
1 parent 34ee4ee commit 2d6e17f
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 34 deletions.
18 changes: 14 additions & 4 deletions crates/polars-ops/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ pub trait DataFrameOps: IntoDf {
/// +------+------+------+--------+--------+--------+---------+---------+---------+
/// ```
#[cfg(feature = "to_dummies")]
fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult<DataFrame> {
self._to_dummies(None, separator, drop_first)
fn to_dummies(
&self,
separator: Option<&str>,
drop_first: bool,
keep_columns: bool,
) -> PolarsResult<DataFrame> {
self._to_dummies(None, separator, drop_first, keep_columns)
}

#[cfg(feature = "to_dummies")]
Expand All @@ -80,8 +85,9 @@ pub trait DataFrameOps: IntoDf {
columns: Vec<&str>,
separator: Option<&str>,
drop_first: bool,
keep_columns: bool,
) -> PolarsResult<DataFrame> {
self._to_dummies(Some(columns), separator, drop_first)
self._to_dummies(Some(columns), separator, drop_first, keep_columns)
}

#[cfg(feature = "to_dummies")]
Expand All @@ -90,6 +96,7 @@ pub trait DataFrameOps: IntoDf {
columns: Option<Vec<&str>>,
separator: Option<&str>,
drop_first: bool,
keep_columns: bool,
) -> PolarsResult<DataFrame> {
use crate::series::ToDummies;

Expand All @@ -105,7 +112,10 @@ pub trait DataFrameOps: IntoDf {
df.get_columns()
.par_iter()
.map(|s| match set.contains(s.name().as_str()) {
true => s.as_materialized_series().to_dummies(separator, drop_first),
true => {
s.as_materialized_series()
.to_dummies(separator, drop_first, keep_columns)
},
false => Ok(s.clone().into_frame()),
})
.collect::<PolarsResult<Vec<_>>>()
Expand Down
60 changes: 38 additions & 22 deletions crates/polars-ops/src/series/ops/to_dummies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,54 @@ type DummyType = i32;
type DummyCa = Int32Chunked;

pub trait ToDummies {
fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult<DataFrame>;
fn to_dummies(
&self,
separator: Option<&str>,
drop_first: bool,
keep_column: bool,
) -> PolarsResult<DataFrame>;
}

impl ToDummies for Series {
fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult<DataFrame> {
fn to_dummies(
&self,
separator: Option<&str>,
drop_first: bool,
keep_column: bool,
) -> PolarsResult<DataFrame> {
let sep = separator.unwrap_or("_");
let col_name = self.name();
let groups = self.group_tuples(true, drop_first)?;

// SAFETY: groups are in bounds
let columns = unsafe { self.agg_first(&groups) };
let columns = columns.iter().zip(groups.iter()).skip(drop_first as usize);
let columns = columns
.map(|(av, group)| {
// strings are formatted with extra \" \" in polars, so we
// extract the string
let name = if let Some(s) = av.get_str() {
format_pl_smallstr!("{col_name}{sep}{s}")
} else {
// other types don't have this formatting issue
format_pl_smallstr!("{col_name}{sep}{av}")
};

let ca = match group {
GroupsIndicator::Idx((_, group)) => dummies_helper_idx(group, self.len(), name),
GroupsIndicator::Slice([offset, len]) => {
dummies_helper_slice(offset, len, self.len(), name)
},
};
ca.into_column()
})
.collect::<Vec<_>>();
let columns = columns.map(|(av, group)| {
// strings are formatted with extra \" \" in polars, so we
// extract the string
let name = if let Some(s) = av.get_str() {
format_pl_smallstr!("{col_name}{sep}{s}")
} else {
// other types don't have this formatting issue
format_pl_smallstr!("{col_name}{sep}{av}")
};

let ca = match group {
GroupsIndicator::Idx((_, group)) => dummies_helper_idx(group, self.len(), name),
GroupsIndicator::Slice([offset, len]) => {
dummies_helper_slice(offset, len, self.len(), name)
},
};
ca.into_column()
});

let columns = if keep_column {
std::iter::once(self.clone().into_column())
.chain(columns)
.collect()
} else {
columns.collect()
};

// SAFETY: `dummies_helper` functions preserve `self.len()` length
unsafe { DataFrame::new_no_length_checks(sort_columns(columns)) }
Expand Down
6 changes: 4 additions & 2 deletions crates/polars-python/src/dataframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,22 +550,24 @@ impl PyDataFrame {
Ok(s.map(|s| s.into()))
}

#[pyo3(signature = (columns, separator, drop_first=false))]
#[pyo3(signature = (columns, separator, drop_first=false, keep_columns=false))]
pub fn to_dummies(
&self,
py: Python,
columns: Option<Vec<String>>,
separator: Option<&str>,
drop_first: bool,
keep_columns: bool,
) -> PyResult<Self> {
let df = py
.allow_threads(|| match columns {
Some(cols) => self.df.columns_to_dummies(
cols.iter().map(|x| x as &str).collect(),
separator,
drop_first,
keep_columns,
),
None => self.df.to_dummies(separator, drop_first),
None => self.df.to_dummies(separator, drop_first, keep_columns),
})
.map_err(PyPolarsErr::from)?;
Ok(df.into())
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-python/src/series/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,16 @@ impl PySeries {
Ok(s.into())
}

#[pyo3(signature = (separator, drop_first=false))]
#[pyo3(signature = (separator, drop_first=false, keep_column=false))]
fn to_dummies(
&self,
py: Python,
separator: Option<&str>,
drop_first: bool,
keep_column: bool,
) -> PyResult<PyDataFrame> {
let df = py
.allow_threads(|| self.series.to_dummies(separator, drop_first))
.allow_threads(|| self.series.to_dummies(separator, drop_first, keep_column))
.map_err(PyPolarsErr::from)?;
Ok(df.into())
}
Expand Down
18 changes: 17 additions & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9682,6 +9682,7 @@ def to_dummies(
*,
separator: str = "_",
drop_first: bool = False,
keep_columns: bool = False,
) -> DataFrame:
"""
Convert categorical variables into dummy/indicator variables.
Expand All @@ -9695,6 +9696,8 @@ def to_dummies(
Separator/delimiter used when generating column names.
drop_first
Remove the first category from the variables being encoded.
keep_columns
Retain columns used to generated dummy columns.
Examples
--------
Expand Down Expand Up @@ -9727,6 +9730,17 @@ def to_dummies(
│ 1 ┆ 1 ┆ 1 │
└───────┴───────┴───────┘
>>> df.to_dummies(keep_columns=True)
shape: (2, 9)
┌─────┬───────┬───────┬─────┬───┬───────┬─────┬───────┬───────┐
│ foo ┆ foo_1 ┆ foo_2 ┆ bar ┆ … ┆ bar_4 ┆ ham ┆ ham_a ┆ ham_b │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ u8 ┆ u8 ┆ i64 ┆ ┆ u8 ┆ str ┆ u8 ┆ u8 │
╞═════╪═══════╪═══════╪═════╪═══╪═══════╪═════╪═══════╪═══════╡
│ 1 ┆ 1 ┆ 0 ┆ 3 ┆ … ┆ 0 ┆ a ┆ 1 ┆ 0 │
│ 2 ┆ 0 ┆ 1 ┆ 4 ┆ … ┆ 1 ┆ b ┆ 0 ┆ 1 │
└─────┴───────┴───────┴─────┴───┴───────┴─────┴───────┴───────┘
>>> import polars.selectors as cs
>>> df.to_dummies(cs.integer(), separator=":")
shape: (2, 5)
Expand All @@ -9752,7 +9766,9 @@ def to_dummies(
"""
if columns is not None:
columns = _expand_selectors(self, columns)
return self._from_pydf(self._df.to_dummies(columns, separator, drop_first))
return self._from_pydf(
self._df.to_dummies(columns, separator, drop_first, keep_columns)
)

def unique(
self,
Expand Down
23 changes: 20 additions & 3 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,7 +2148,11 @@ def quantile(
return self._s.quantile(quantile, interpolation)

def to_dummies(
self, *, separator: str = "_", drop_first: bool = False
self,
*,
separator: str = "_",
drop_first: bool = False,
keep_column: bool = False,
) -> DataFrame:
"""
Get dummy/indicator variables.
Expand All @@ -2159,6 +2163,8 @@ def to_dummies(
Separator/delimiter used when generating column names.
drop_first
Remove the first category from the variable being encoded.
keep_column
Retain column used to generated dummy columns.
Examples
--------
Expand Down Expand Up @@ -2186,8 +2192,19 @@ def to_dummies(
│ 1 ┆ 0 │
│ 0 ┆ 1 │
└─────┴─────┘
"""
return wrap_df(self._s.to_dummies(separator, drop_first))
>>> s.to_dummies(keep_column=True)
shape: (3, 4)
┌─────┬─────┬─────┬─────┐
│ a ┆ a_1 ┆ a_2 ┆ a_3 │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ u8 ┆ u8 ┆ u8 │
╞═════╪═════╪═════╪═════╡
│ 1 ┆ 1 ┆ 0 ┆ 0 │
│ 2 ┆ 0 ┆ 1 ┆ 0 │
│ 3 ┆ 0 ┆ 0 ┆ 1 │
└─────┴─────┴─────┴─────┘
"""
return wrap_df(self._s.to_dummies(separator, drop_first, keep_column))

@unstable()
def cut(
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,25 @@ def test_to_dummies_drop_first() -> None:
]


def test_to_dummies_keep_columns() -> None:
df = pl.DataFrame({"A": ["a", "b", "c"], "B": [1, 3, 5]})
dummies = df.to_dummies(keep_columns=True)

expected = pl.DataFrame(
{
"A": ["a", "b", "c"],
"A_a": pl.Series([1, 0, 0], dtype=pl.UInt8),
"A_b": pl.Series([0, 1, 0], dtype=pl.UInt8),
"A_c": pl.Series([0, 0, 1], dtype=pl.UInt8),
"B": [1, 3, 5],
"B_1": pl.Series([1, 0, 0], dtype=pl.UInt8),
"B_3": pl.Series([0, 1, 0], dtype=pl.UInt8),
"B_5": pl.Series([0, 0, 1], dtype=pl.UInt8),
}
)
assert_frame_equal(dummies, expected)


def test_to_pandas(df: pl.DataFrame) -> None:
# pyarrow cannot deal with unsigned dictionary integer yet.
# pyarrow cannot convert a time64 w/ non-zero nanoseconds
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,16 @@ def test_to_dummies_drop_first() -> None:
assert_frame_equal(result, expected)


def test_to_dummies_keep_column() -> None:
s = pl.Series("a", [1, 2, 3])
result = s.to_dummies(keep_column=True)
expected = pl.DataFrame(
{"a": [1, 2, 3], "a_1": [1, 0, 0], "a_2": [0, 1, 0], "a_3": [0, 0, 1]},
schema={"a": pl.Int64, "a_1": pl.UInt8, "a_2": pl.UInt8, "a_3": pl.UInt8},
)
assert_frame_equal(result, expected)


def test_to_dummies_null_clash_19096() -> None:
with pytest.raises(
DuplicateError, match="column with name '_null' has more than one occurrence"
Expand Down

0 comments on commit 2d6e17f

Please sign in to comment.