Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add new Enum categorical data type which allows a fixed set of categories #11822

Merged
merged 22 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ pub fn create_categorical_chunked_listbuilder(
rev_map: Arc<RevMapping>,
) -> Box<dyn ListBuilderTrait> {
match &*rev_map {
RevMapping::Enum(_, h) => Box::new(ListEnumCategoricalChunkedBuilder::new(
name,
capacity,
values_capacity,
(*rev_map).clone(),
*h,
)),
RevMapping::Local(_, h) => Box::new(ListLocalCategoricalChunkedBuilder::new(
name,
capacity,
Expand All @@ -24,6 +31,57 @@ pub fn create_categorical_chunked_listbuilder(
}
}

struct ListEnumCategoricalChunkedBuilder {
inner: ListPrimitiveChunkedBuilder<UInt32Type>,
rev_map: RevMapping,
hash: u128,
}

impl ListEnumCategoricalChunkedBuilder {
pub(super) fn new(
name: &str,
capacity: usize,
values_capacity: usize,
rev_map: RevMapping,
hash: u128,
) -> Self {
Self {
inner: ListPrimitiveChunkedBuilder::new(
name,
capacity,
values_capacity,
DataType::UInt32,
),
rev_map,
hash,
}
}
}

impl ListBuilderTrait for ListEnumCategoricalChunkedBuilder {
fn append_series(&mut self, s: &Series) -> PolarsResult<()> {
let DataType::Categorical(Some(rev_map)) = s.dtype() else {
polars_bail!(ComputeError: "expected categorical type")
};
let RevMapping::Enum(_, new_hash) = &**rev_map else {
polars_bail!(ComputeError: "Can not combine enum with categorical, consider casting to one of the two")
};
polars_ensure!(*new_hash == self.hash,ComputeError: "Can not combine enums with different variants");
self.inner.append_series(s)
}

fn append_null(&mut self) {
self.inner.append_null()
}

fn finish(&mut self) -> ListChunked {
let inner_dtype = DataType::Categorical(Some(Arc::new(self.rev_map.clone())));
let mut ca = self.inner.finish();
unsafe { ca.set_dtype(DataType::List(Box::new(inner_dtype))) }
ca
}
}

struct ListLocalCategoricalChunkedBuilder {
inner: ListPrimitiveChunkedBuilder<UInt32Type>,
idx_lookup: PlHashMap<KeyWrapper, ()>,
Expand Down
26 changes: 20 additions & 6 deletions crates/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,26 @@ impl ChunkCast for Utf8Chunked {
fn cast(&self, data_type: &DataType) -> PolarsResult<Series> {
match data_type {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => {
let iter = self.into_iter();
let mut builder = CategoricalChunkedBuilder::new(self.name(), self.len());
builder.drain_iter(iter);
let ca = builder.finish();
Ok(ca.into_series())
DataType::Categorical(rev_map) => match rev_map {
None => {
// Safety: length is correct
let iter =
unsafe { self.downcast_iter().flatten().trust_my_length(self.len()) };
let mut builder = CategoricalChunkedBuilder::new(self.name(), self.len());
builder.drain_iter(iter);
let ca = builder.finish();
Ok(ca.into_series())
},
Some(rev_map) => {
polars_ensure!(rev_map.is_enum(), InvalidOperation: "casting to a non-enum variant with rev map is not supported for the user");
CategoricalChunked::from_utf8_to_enum(self, rev_map.get_categories()).map(
|ca| {
let mut s = ca.into_series();
s.rename(self.name());
s
},
)
},
},
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields),
Expand Down
110 changes: 90 additions & 20 deletions crates/polars-core/src/chunked_array/logical/categorical/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use ahash::RandomState;
use arrow::array::*;
use arrow::legacy::trusted_len::TrustedLenPush;
use hashbrown::hash_map::{Entry, RawEntryMut};
use polars_utils::iter::EnumerateIdxTrait;

use crate::datatypes::PlHashMap;
use crate::hashing::_HASHMAP_INIT_SIZE;
Expand Down Expand Up @@ -54,6 +55,8 @@ pub enum RevMapping {
Global(PlHashMap<u32, u32>, Utf8Array<i64>, u32),
/// Utf8Array: caches the string values and a hash of all values for quick comparison
Local(Utf8Array<i64>, u128),
/// Utf8Array: fixed user defined array of categories which caches the string values
Enum(Utf8Array<i64>, u128),
}

impl Debug for RevMapping {
Expand All @@ -65,6 +68,9 @@ impl Debug for RevMapping {
RevMapping::Local(_, _) => {
write!(f, "local")
},
RevMapping::Enum(_, _) => {
write!(f, "enum")
},
}
}
}
Expand All @@ -90,23 +96,36 @@ impl RevMapping {
}

pub fn is_local(&self) -> bool {
!self.is_global()
matches!(self, Self::Local(_, _))
}

pub fn is_enum(&self) -> bool {
matches!(self, Self::Enum(_, _))
}

/// Get the categories in this [`RevMapping`]
pub fn get_categories(&self) -> &Utf8Array<i64> {
match self {
Self::Global(_, a, _) => a,
Self::Local(a, _) => a,
Self::Local(a, _) | Self::Enum(a, _) => a,
}
}

pub fn build_local(categories: Utf8Array<i64>) -> RevMapping {
fn build_hash(categories: &Utf8Array<i64>) -> u128 {
let hash_builder = RandomState::with_seed(0);
let value_hash = hash_builder.hash_one(categories.values().as_slice());
let offset_hash = hash_builder.hash_one(categories.offsets().as_slice());
let combined = (value_hash as u128) << 64 | (offset_hash as u128);
RevMapping::Local(categories, combined)
(value_hash as u128) << 64 | (offset_hash as u128)
}

pub fn build_enum(categories: Utf8Array<i64>) -> Self {
let hash = Self::build_hash(&categories);
Self::Enum(categories, hash)
}

pub fn build_local(categories: Utf8Array<i64>) -> Self {
let hash = Self::build_hash(&categories);
Self::Local(categories, hash)
}

/// Get the length of the [`RevMapping`]
Expand All @@ -123,7 +142,7 @@ impl RevMapping {
let idx = *map.get(&idx).unwrap();
a.value(idx as usize)
},
Self::Local(a, _) => a.value(idx as usize),
Self::Local(a, _) | Self::Enum(a, _) => a.value(idx as usize),
}
}

Expand All @@ -133,7 +152,7 @@ impl RevMapping {
let idx = *map.get(&idx)?;
a.get(idx as usize)
},
Self::Local(a, _) => a.get(idx as usize),
Self::Local(a, _) | Self::Enum(a, _) => a.get(idx as usize),
}
}

Expand All @@ -149,7 +168,7 @@ impl RevMapping {
let idx = *map.get(&idx).unwrap();
a.value_unchecked(idx as usize)
},
Self::Local(a, _) => a.value_unchecked(idx as usize),
Self::Local(a, _) | Self::Enum(a, _) => a.value_unchecked(idx as usize),
}
}
/// Check if the categoricals have a compatible mapping
Expand All @@ -158,6 +177,7 @@ impl RevMapping {
match (self, other) {
(RevMapping::Global(_, _, l), RevMapping::Global(_, _, r)) => *l == *r,
(RevMapping::Local(_, l_hash), RevMapping::Local(_, r_hash)) => l_hash == r_hash,
(RevMapping::Enum(_, l_hash), RevMapping::Enum(_, r_hash)) => l_hash == r_hash,
_ => false,
}
}
Expand All @@ -183,7 +203,8 @@ impl RevMapping {
.find(|(_k, &v)| (unsafe { a.value_unchecked(v as usize) } == value))
.map(|(k, _v)| *k)
},
Self::Local(a, _) => {

Self::Local(a, _) | Self::Enum(a, _) => {
// Safety: within bounds
unsafe { (0..a.len()).find(|idx| a.value_unchecked(*idx) == value) }
.map(|idx| idx as u32)
Expand Down Expand Up @@ -315,7 +336,15 @@ impl<'a> CategoricalChunkedBuilder<'a> {
}

/// Build a global string cached [`CategoricalChunked`] from a local [`Dictionary`].
pub(super) fn global_map_from_local(&mut self, keys: &UInt32Array, values: Utf8Array<i64>) {
pub(super) fn global_map_from_local<I, J>(
&mut self,
keys: I,
capacity: usize,
values: Utf8Array<i64>,
) where
I: IntoIterator<Item = J> + Send + Sync,
J: IntoIterator<Item = Option<u32>>,
{
// locally we don't need a hashmap because we all categories are 1 integer apart
// so the index is local, and the values is global
let mut local_to_global: Vec<u32> = Vec::with_capacity(values.len());
Expand Down Expand Up @@ -346,14 +375,18 @@ impl<'a> CategoricalChunkedBuilder<'a> {
let mut global_to_local = PlHashMap::with_capacity(local_to_global.len());

let compute_cats = || {
keys.into_iter()
.map(|opt_k| {
opt_k.map(|cat| {
debug_assert!((*cat as usize) < local_to_global.len());
*unsafe { local_to_global.get_unchecked(*cat as usize) }
})
})
.collect::<UInt32Vec>()
let mut result = UInt32Vec::with_capacity(capacity);

let iters = keys.into_iter();
for iter in iters.into_iter() {
for opt_value in iter {
result.push(opt_value.map(|cat| {
debug_assert!((cat as usize) < local_to_global.len());
*unsafe { local_to_global.get_unchecked(cat as usize) }
}));
}
}
result
};

let (_, cats) = POOL.join(
Expand Down Expand Up @@ -448,8 +481,9 @@ impl<'a> CategoricalChunkedBuilder<'a> {
if using_string_cache() {
if let RevMappingBuilder::Local(ref mut mut_arr) = self.reverse_mapping {
let arr: Utf8Array<_> = std::mem::take(mut_arr).into();
let keys = std::mem::take(&mut self.cat_builder).into();
self.global_map_from_local(&keys, arr);
let keys: UInt32Array = std::mem::take(&mut self.cat_builder).into();
let capacity = keys.len();
self.global_map_from_local([keys.into_iter()], capacity, arr);
}
}

Expand Down Expand Up @@ -514,6 +548,42 @@ impl CategoricalChunked {

CategoricalChunked::from_cats_and_rev_map_unchecked(cats, Arc::new(rev_map))
}

/// Create a [`CategoricalChunked`] from a fixed list of categories and a List of strings.
/// This will error if a string is not in the fixed list of categories
pub fn from_utf8_to_enum(
values: &Utf8Chunked,
categories: &Utf8Array<i64>,
) -> PolarsResult<CategoricalChunked> {
polars_ensure!(categories.null_count() == 0, ComputeError: "categories can not contain null values");

// Build a mapping string -> idx
let mut map = PlHashMap::with_capacity(categories.len());
for (idx, cat) in categories.values_iter().enumerate_idx() {
#[allow(clippy::unnecessary_cast)]
map.insert(cat, idx as u32);
}
// Find idx of every value in the map
let ca_idx: UInt32Chunked = values
.into_iter()
.map(|opt_s: Option<&str>| {
opt_s
.map(|s| {
map.get(s).copied().ok_or_else(
|| polars_err!(ComputeError: "value '{}' is not present in Enum: {:?}",s,categories),
)
})
.transpose()
})
.collect::<Result<UInt32Chunked, PolarsError>>()?;
let rev_map = RevMapping::build_enum(categories.clone());
c-peters marked this conversation as resolved.
Show resolved Hide resolved
unsafe {
Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
ca_idx,
Arc::new(rev_map),
))
}
}
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl From<&CategoricalChunked> for DictionaryArray<u32> {
false,
);
match map {
RevMapping::Local(arr, _) => {
RevMapping::Local(arr, _) | RevMapping::Enum(arr, _) => {
// Safety:
// the keys are in bounds
unsafe {
Expand Down Expand Up @@ -53,7 +53,7 @@ impl From<&CategoricalChunked> for DictionaryArray<i64> {
match map {
// Safety:
// the keys are in bounds
RevMapping::Local(arr, _) => unsafe {
RevMapping::Local(arr, _) | RevMapping::Enum(arr, _) => unsafe {
DictionaryArray::try_new_unchecked(
dtype,
cast(keys, &ArrowDataType::Int64)
Expand Down Expand Up @@ -93,7 +93,12 @@ impl CategoricalChunked {
) -> Self {
if using_string_cache() {
let mut builder = CategoricalChunkedBuilder::new(name, keys.len());
builder.global_map_from_local(keys, values.clone());
let capacity = keys.len();
builder.global_map_from_local(
[keys.iter().map(|v| v.copied())],
capacity,
values.clone(),
);
builder.finish()
} else {
CategoricalChunked::from_chunks_original(
Expand Down
Loading