Skip to content

Commit 81ec86e

Browse files
committed
Make the roundtripping for a categorical column work
1 parent 1b6ef4e commit 81ec86e

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

protocol/dataframe_protocol.py

-1
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,6 @@ def get_mask(self) -> Buffer:
313313
"""
314314
pass
315315

316-
# # NOTE: not needed unless one considers nested dtypes
317316
# def get_children(self) -> Iterable[Column]:
318317
# """
319318
# Children columns underneath the column, each object in this iterator

protocol/pandas_implementation.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,12 @@ def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray:
9797
raise NotImplementedError("Null values represented as masks or "
9898
"sentinel values not handled yet")
9999

100+
_buffer, _dtype = col.get_data_buffer()
101+
return buffer_to_ndarray(_buffer, _dtype)
102+
103+
104+
def buffer_to_ndarray(_buffer, _dtype) -> np.ndarray:
100105
# Handle the dtype
101-
_dtype = col.dtype
102106
kind = _dtype[0]
103107
bitwidth = _dtype[1]
104108
_k = _DtypeKind
@@ -113,7 +117,6 @@ def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray:
113117

114118
# No DLPack yet, so need to construct a new ndarray from the data pointer
115119
# and size in the buffer plus the dtype on the column
116-
_buffer = col.get_data_buffer()
117120
ctypes_type = np.ctypeslib.as_ctypes_type(column_dtype)
118121
data_pointer = ctypes.cast(_buffer.ptr, ctypes.POINTER(ctypes_type))
119122

@@ -134,11 +137,12 @@ def convert_categorical_column(col : ColumnObject) -> pd.Series:
134137
if not is_dict:
135138
raise NotImplementedError('Non-dictionary categoricals not supported yet')
136139

137-
# FIXME: this is cheating, can't use `_col` (just testing now)
140+
# If you want to cheat for testing (can't use `_col` in real-world code):
138141
# categories = col._col.values.categories.values
139142
# codes = col._col.values.codes
140143
categories = np.asarray(list(mapping.values()))
141-
codes = col.get_data_buffer() # this is broken; don't have dtype info for buffer
144+
codes_buffer, codes_dtype = col.get_data_buffer()
145+
codes = buffer_to_ndarray(codes_buffer, codes_dtype)
142146
values = categories[codes]
143147

144148
# Seems like Pandas can only construct with non-null values, so need to
@@ -314,6 +318,12 @@ def dtype(self) -> Tuple[enum.IntEnum, int, str, str]:
314318
and nested (list, struct, map, union) dtypes.
315319
"""
316320
dtype = self._col.dtype
321+
return self._dtype_from_pandasdtype(dtype)
322+
323+
def _dtype_from_pandasdtype(self, dtype) -> Tuple[enum.IntEnum, int, str, str]:
324+
"""
325+
See `self.dtype` for details
326+
"""
317327
# Note: 'c' (complex) not handled yet (not in array spec v1).
318328
# 'b', 'B' (bytes), 'S', 'a', (old-style string) 'V' (void) not handled
319329
# datetime and timedelta both map to datetime (is timedelta handled?)
@@ -430,20 +440,22 @@ def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable['_PandasColumn
430440
"""
431441
return (self,)
432442

433-
def get_data_buffer(self) -> _PandasBuffer:
443+
def get_data_buffer(self) -> Tuple[_PandasBuffer, Any]: # Any is for self.dtype tuple
434444
"""
435445
Return the buffer containing the data.
436446
"""
437447
_k = _DtypeKind
438448
if self.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL):
439449
buffer = _PandasBuffer(self._col.to_numpy())
450+
dtype = self.dtype
440451
elif self.dtype[0] == _k.CATEGORICAL:
441-
# FIXME: losing the dtype info here - see `convert_categorical_column`
442-
buffer = _PandasBuffer(self._col.values.codes)
452+
codes = self._col.values.codes
453+
buffer = _PandasBuffer(codes)
454+
dtype = self._dtype_from_pandasdtype(codes.dtype)
443455
else:
444456
raise NotImplementedError(f"Data type {self._col.dtype} not handled yet")
445457

446-
return buffer
458+
return buffer, dtype
447459

448460
def get_mask(self) -> _PandasBuffer:
449461
"""

0 commit comments

Comments
 (0)