Skip to content

Commit

Permalink
ARROW-176 Nested extension objects are not handled in auto schema (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp authored Jul 13, 2023
1 parent bbdfef0 commit 1cde144
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
33 changes: 29 additions & 4 deletions bindings/python/pymongoarrow/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ cdef class ObjectIdBuilder(_ArrayBuilderBase):
cdef shared_ptr[CFixedSizeBinaryBuilder] unwrap(self):
return self.builder


cdef class Int32Builder(_ArrayBuilderBase):
cdef:
shared_ptr[CInt32Builder] builder
Expand Down Expand Up @@ -722,6 +721,8 @@ cdef object get_field_builder(object field, object tzinfo):
field_builder = Decimal128Builder()
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.binary:
field_builder = BinaryBuilder(field_type.subtype)
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.code:
field_builder = CodeBuilder()
else:
field_builder = StringBuilder()
return field_builder
Expand All @@ -732,6 +733,7 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
shared_ptr[CStructBuilder] builder
object dtype
object context
object builder_map

def __cinit__(self, StructType dtype, tzinfo=None, MemoryPool memory_pool=None):
cdef StringBuilder field_builder
Expand All @@ -744,11 +746,11 @@ cdef class DocumentBuilder(_ArrayBuilderBase):

self.context = context = PyMongoArrowContext(None, {})
context.tzinfo = tzinfo
builder_map = context.builder_map
self.builder_map = context.builder_map

for field in dtype:
field_builder = <StringBuilder>get_field_builder(field, tzinfo)
builder_map[field.name.encode('utf-8')] = field_builder
self.builder_map[field.name.encode('utf-8')] = field_builder
c_field_builders.push_back(<shared_ptr[CArrayBuilder]>field_builder.builder)

self.builder.reset(new CStructBuilder(pyarrow_unwrap_data_type(dtype), pool, c_field_builders))
Expand Down Expand Up @@ -781,7 +783,30 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
cdef shared_ptr[CArray] out
with nogil:
self.builder.get().Finish(&out)
return pyarrow_wrap_array(out)

struct_array = pyarrow_wrap_array(out)
for struct_def in struct_array:
new_types = []
new_names = list(struct_def.keys())
for fname, ftype in struct_def.items():
builder_instance = self.builder_map[fname.encode('utf-8')]
if isinstance(builder_instance, ObjectIdBuilder): # ObjectIdType
new_ftype = ObjectIdType()
new_types.append(new_ftype)
elif isinstance(builder_instance, Decimal128Builder): # Decimal128Type
new_ftype = Decimal128Type_()
new_types.append(new_ftype)
elif isinstance(builder_instance, BinaryBuilder): # BinaryType
new_ftype = BinaryType(self.dtype.field(fname).type.subtype)
new_types.append(new_ftype)
elif isinstance(builder_instance, CodeBuilder): # CodeType
new_ftype = CodeType()
new_types.append(new_ftype)
else:
new_types.append(ftype.type)

new_dtype = struct(zip(new_names, new_types))
return struct_array.cast(new_dtype)

cdef shared_ptr[CStructBuilder] unwrap(self):
return self.builder
Expand Down
22 changes: 21 additions & 1 deletion bindings/python/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pyarrow
import pymongo
from bson import Binary, CodecOptions, Decimal128, ObjectId
from bson import Binary, Code, CodecOptions, Decimal128, ObjectId
from pyarrow import Table, bool_, csv, decimal256, field, int32, int64, list_
from pyarrow import schema as ArrowSchema
from pyarrow import string, struct, timestamp
Expand Down Expand Up @@ -650,6 +650,26 @@ def test_nested_contradicting_unused_schema(self):
out = func(self.coll, {} if func == find_arrow_all else [], schema=schema)
self.assertEqual(out["obj"].to_pylist(), [{"a": 1}, {"a": 2}])

def test_nested_bson_extension_types(self):
data = {
"obj": {
"obj_id": ObjectId(),
"dec_128": Decimal128("0.0005"),
"binary": Binary(b"123"),
"code": Code(""),
}
}

self.coll.drop()
self.coll.insert_one(data)
out = find_arrow_all(self.coll, {})
obj_schema_type = out.field("obj").type

self.assertIsInstance(obj_schema_type.field("obj_id").type, ObjectIdType)
self.assertIsInstance(obj_schema_type.field("dec_128").type, Decimal128Type)
self.assertIsInstance(obj_schema_type.field("binary").type, BinaryType)
self.assertIsInstance(obj_schema_type.field("code").type, CodeType)


class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
def run_find(self, *args, **kwargs):
Expand Down

0 comments on commit 1cde144

Please sign in to comment.