@@ -84,7 +84,7 @@ def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray:
84
84
if col .offset != 0 :
85
85
raise NotImplementedError ("column.offset > 0 not handled yet" )
86
86
87
- if col .describe_null not in (0 , 1 ):
87
+ if col .describe_null [ 0 ] not in (0 , 1 ):
88
88
raise NotImplementedError ("Null values represented as masks or "
89
89
"sentinel values not handled yet" )
90
90
@@ -230,19 +230,19 @@ def offset(self) -> int:
230
230
return 0
231
231
232
232
@property
233
- def dtype (self ) -> Tuple [int , int , str , str ]:
233
+ def dtype (self ) -> Tuple [enum . IntEnum , int , str , str ]:
234
234
"""
235
235
Dtype description as a tuple ``(kind, bit-width, format string, endianness)``
236
236
237
237
Kind :
238
238
239
- - 0 : signed integer
240
- - 1 : unsigned integer
241
- - 2 : IEEE floating point
242
- - 20 : boolean
243
- - 21 : string ( UTF-8)
244
- - 22 : datetime
245
- - 23 : categorical
239
+ - INT = 0
240
+ - UINT = 1
241
+ - FLOAT = 2
242
+ - BOOL = 20
243
+ - STRING = 21 # UTF-8
244
+ - DATETIME = 22
245
+ - CATEGORICAL = 23
246
246
247
247
Bit-width : the number of bits as an integer
248
248
Format string : data type description format string in Apache Arrow C
@@ -273,15 +273,25 @@ def dtype(self) -> Tuple[int, int, str, str]:
273
273
# Note: 'c' (complex) not handled yet (not in array spec v1).
274
274
# 'b', 'B' (bytes), 'S', 'a', (old-style string) 'V' (void) not handled
275
275
# datetime and timedelta both map to datetime (is timedelta handled?)
276
- _np_kinds = {'i' : 0 , 'u' : 1 , 'f' : 2 , 'b' : 20 , 'O' : 21 , 'U' : 21 ,
277
- 'M' : 22 , 'm' : 22 }
276
+ _k = _DtypeKind
277
+ _np_kinds = {'i' : _k .INT , 'u' : _k .UINT , 'f' : _k .FLOAT , 'b' : _k .BOOL ,
278
+ 'U' : _k .STRING ,
279
+ 'M' : _k .DATETIME , 'm' : _k .DATETIME }
278
280
kind = _np_kinds .get (dtype .kind , None )
279
281
if kind is None :
280
- raise NotImplementedError ("Data type {} not handled" .format (dtype ))
282
+ # Not a NumPy dtype. Check if it's a categorical maybe
283
+ if isinstance (dtype , pd .CategoricalDtype ):
284
+ kind = 23
285
+ else :
286
+ raise ValueError (f"Data type { dtype } not supported by exchange"
287
+ "protocol" )
288
+
289
+ if kind not in (_k .INT , _k .UINT , _k .FLOAT , _k .BOOL , _k .CATEGORICAL ):
290
+ raise NotImplementedError (f"Data type { dtype } not handled yet" )
281
291
282
292
bitwidth = dtype .itemsize * 8
283
293
format_str = dtype .str
284
- endianness = dtype .byteorder
294
+ endianness = dtype .byteorder if not kind == _k . CATEGORICAL else '='
285
295
return (kind , bitwidth , format_str , endianness )
286
296
287
297
@@ -324,19 +334,26 @@ def describe_null(self) -> Tuple[int, Any]:
324
334
325
335
Value : if kind is "sentinel value", the actual value. None otherwise.
326
336
"""
337
+ _k = _DtypeKind
327
338
kind = self .dtype [0 ]
328
- if kind == 2 :
339
+ value = None
340
+ if kind == _k .FLOAT :
329
341
null = 1 # np.nan
330
- elif kind == 22 :
342
+ elif kind == _k . DATETIME :
331
343
null = 1 # np.datetime64('NaT')
332
- elif kind in (0 , 1 , 20 ):
344
+ elif kind in (_k . INT , _k . UINT , _k . BOOL ):
333
345
# TODO: check if extension dtypes are used once support for them is
334
346
# implemented in this procotol code
335
347
null = 0 # integer and boolean dtypes are non-nullable
348
+ elif kind == _k .CATEGORICAL :
349
+ # Null values for categoricals are stored as `-1` sentinel values
350
+ # in the category date (e.g., `col.values.codes` is int8 np.ndarray)
351
+ null = 2
352
+ value = - 1
336
353
else :
337
- raise NotImplementedError ('TODO ' )
354
+ raise NotImplementedError (f'Data type { self . dtype } not yet supported ' )
338
355
339
- return null
356
+ return null , value
340
357
341
358
@property
342
359
def null_count (self ) -> int :
@@ -469,8 +486,16 @@ def test_noncontiguous_columns():
469
486
#df2 = from_dataframe(df)
470
487
471
488
489
+ def test_categorical_dtype ():
490
+ df = pd .DataFrame ({"A" : [1 , 2 , 3 , 1 ]})
491
+ df ["B" ] = df ["A" ].astype ("category" )
492
+ df .at [1 , 'B' ] = np .nan # Set one item to null
493
+ df2 = from_dataframe (df )
494
+
495
+
472
496
if __name__ == '__main__' :
473
497
test_float_only ()
474
498
test_mixed_intfloat ()
475
499
test_noncontiguous_columns ()
500
+ test_categorical_dtype ()
476
501
0 commit comments