Skip to content

Commit cfabb9f

Browse files
committed
Make roundtripping with categorical dype work (with some cheating)
1 parent 552b794 commit cfabb9f

File tree

1 file changed

+46
-10
lines changed

1 file changed

+46
-10
lines changed

protocol/pandas_implementation.py

+46-10
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,16 @@ def _from_dataframe(df : DataFrameObject) -> pd.DataFrame:
6262
# We need a dict of columns here, with each column being a numpy array (at
6363
# least for now, deal with non-numpy dtypes later).
6464
columns = dict()
65+
_k = _DtypeKind
6566
for name in df.column_names():
66-
columns[name] = convert_column_to_ndarray(df.get_column_by_name(name))
67+
col = df.get_column_by_name(name)
68+
if col.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL):
69+
# Simple numerical or bool dtype, turn into numpy array
70+
columns[name] = convert_column_to_ndarray(col)
71+
elif col.dtype[0] == _k.CATEGORICAL:
72+
columns[name] = convert_categorical_column(col)
73+
else:
74+
raise NotImplementedError(f"Data type {col.dtype[0]} not handled yet")
6775

6876
return pd.DataFrame(columns)
6977

@@ -80,6 +88,7 @@ class _DtypeKind(enum.IntEnum):
8088

8189
def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray:
8290
"""
91+
Convert an int, uint, float or bool column to a numpy array
8392
"""
8493
if col.offset != 0:
8594
raise NotImplementedError("column.offset > 0 not handled yet")
@@ -117,6 +126,32 @@ def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray:
117126
return x
118127

119128

129+
def convert_categorical_column(col : ColumnObject) -> pd.Series:
130+
"""
131+
Convert a categorical column to a Series instance
132+
"""
133+
ordered, is_dict, mapping = col.describe_categorical
134+
if not is_dict:
135+
raise NotImplementedError('Non-dictionary categoricals not supported yet')
136+
137+
# FIXME: this is cheating, can't use `_col` (just testing now)
138+
categories = col._col.values.categories.values
139+
codes = col._col.values.codes
140+
values = categories[codes]
141+
142+
# Deal with null values
143+
null_kind = col.describe_null[0]
144+
if null_kind == 2: # sentinel value
145+
sentinel = col.describe_null[1]
146+
147+
# Seems like Pandas can only construct with non-null values, so need to
148+
# null out the nulls later
149+
cat = pd.Categorical(values, categories=categories, ordered=ordered)
150+
series = pd.Series(cat)
151+
series[codes == sentinel] = np.nan
152+
return series
153+
154+
120155
def __dataframe__(cls, nan_as_null : bool = False) -> dict:
121156
"""
122157
The public method to attach to pd.DataFrame
@@ -324,13 +359,14 @@ def describe_categorical(self) -> Dict[str, Any]:
324359
"categorical dtype!")
325360

326361
ordered = self._col.dtype.ordered
327-
is_dictionary = False
328-
# NOTE: this shows the children approach is better, transforming this
329-
# to a "mapping" dict would be inefficient
362+
is_dictionary = True
363+
# NOTE: this shows the children approach is better, transforming
364+
# `categories` to a "mapping" dict is inefficient
330365
codes = self._col.values.codes # ndarray, length `self.size`
331366
# categories.values is ndarray of length n_categories
332-
categories = self._col.values.categories
333-
return ordered, is_dictionary, None
367+
categories = self._col.values.categories.values
368+
mapping = {ix: val for ix, val in enumerate(categories)}
369+
return ordered, is_dictionary, mapping
334370

335371
@property
336372
def describe_null(self) -> Tuple[int, Any]:
@@ -402,7 +438,7 @@ def get_mask(self) -> _PandasBuffer:
402438
403439
Raises RuntimeError if null representation is not a bit or byte mask.
404440
"""
405-
null = self.describe_null()
441+
null, value = self.describe_null
406442
if null == 0:
407443
msg = "This column is non-nullable so does not have a mask"
408444
elif null == 1:
@@ -501,7 +537,7 @@ def test_noncontiguous_columns():
501537

502538

503539
def test_categorical_dtype():
504-
df = pd.DataFrame({"A": [1, 2, 3, 1]})
540+
df = pd.DataFrame({"A": [1, 2, 5, 1]})
505541
df["B"] = df["A"].astype("category")
506542
df.at[1, 'B'] = np.nan # Set one item to null
507543

@@ -511,15 +547,15 @@ def test_categorical_dtype():
511547
assert col.null_count == 1
512548
assert col.describe_null == (2, -1) # sentinel value -1
513549
assert col.num_chunks() == 1
514-
assert col.describe_categorical == (False, False, None)
550+
assert col.describe_categorical == (False, True, {0: 1, 1: 2, 2: 5})
515551

516552
df2 = from_dataframe(df)
517553
tm.assert_frame_equal(df, df2)
518554

519555

520556
if __name__ == '__main__':
557+
test_categorical_dtype()
521558
test_float_only()
522559
test_mixed_intfloat()
523560
test_noncontiguous_columns()
524-
test_categorical_dtype()
525561

0 commit comments

Comments
 (0)