From 1cde144af09c7b53b47ff4087ce349ec44979a3b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 13 Jul 2023 15:23:31 -0700 Subject: [PATCH] ARROW-176 Nested extension objects are not handled in auto schema (#166) --- bindings/python/pymongoarrow/lib.pyx | 33 ++++++++++++++++++++++++---- bindings/python/test/test_arrow.py | 22 ++++++++++++++++++- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/bindings/python/pymongoarrow/lib.pyx b/bindings/python/pymongoarrow/lib.pyx index 1e4e516b..33cfac95 100644 --- a/bindings/python/pymongoarrow/lib.pyx +++ b/bindings/python/pymongoarrow/lib.pyx @@ -490,7 +490,6 @@ cdef class ObjectIdBuilder(_ArrayBuilderBase): cdef shared_ptr[CFixedSizeBinaryBuilder] unwrap(self): return self.builder - cdef class Int32Builder(_ArrayBuilderBase): cdef: shared_ptr[CInt32Builder] builder @@ -722,6 +721,8 @@ cdef object get_field_builder(object field, object tzinfo): field_builder = Decimal128Builder() elif getattr(field_type, '_type_marker') == _BsonArrowTypes.binary: field_builder = BinaryBuilder(field_type.subtype) + elif getattr(field_type, '_type_marker') == _BsonArrowTypes.code: + field_builder = CodeBuilder() else: field_builder = StringBuilder() return field_builder @@ -732,6 +733,7 @@ cdef class DocumentBuilder(_ArrayBuilderBase): shared_ptr[CStructBuilder] builder object dtype object context + object builder_map def __cinit__(self, StructType dtype, tzinfo=None, MemoryPool memory_pool=None): cdef StringBuilder field_builder @@ -744,11 +746,11 @@ cdef class DocumentBuilder(_ArrayBuilderBase): self.context = context = PyMongoArrowContext(None, {}) context.tzinfo = tzinfo - builder_map = context.builder_map + self.builder_map = context.builder_map for field in dtype: field_builder = get_field_builder(field, tzinfo) - builder_map[field.name.encode('utf-8')] = field_builder + self.builder_map[field.name.encode('utf-8')] = field_builder c_field_builders.push_back(field_builder.builder) self.builder.reset(new CStructBuilder(pyarrow_unwrap_data_type(dtype), pool, c_field_builders)) @@ -781,7 +783,30 @@ cdef class DocumentBuilder(_ArrayBuilderBase): cdef shared_ptr[CArray] out with nogil: self.builder.get().Finish(&out) - return pyarrow_wrap_array(out) + + struct_array = pyarrow_wrap_array(out) + for struct_def in struct_array: + new_types = [] + new_names = list(struct_def.keys()) + for fname, ftype in struct_def.items(): + builder_instance = self.builder_map[fname.encode('utf-8')] + if isinstance(builder_instance, ObjectIdBuilder): # ObjectIdType + new_ftype = ObjectIdType() + new_types.append(new_ftype) + elif isinstance(builder_instance, Decimal128Builder): # Decimal128Type + new_ftype = Decimal128Type_() + new_types.append(new_ftype) + elif isinstance(builder_instance, BinaryBuilder): # BinaryType + new_ftype = BinaryType(self.dtype.field(fname).type.subtype) + new_types.append(new_ftype) + elif isinstance(builder_instance, CodeBuilder): # CodeType + new_ftype = CodeType() + new_types.append(new_ftype) + else: + new_types.append(ftype.type) + + new_dtype = struct(zip(new_names, new_types)) + return struct_array.cast(new_dtype) cdef shared_ptr[CStructBuilder] unwrap(self): return self.builder diff --git a/bindings/python/test/test_arrow.py b/bindings/python/test/test_arrow.py index b6315087..f7678147 100644 --- a/bindings/python/test/test_arrow.py +++ b/bindings/python/test/test_arrow.py @@ -21,7 +21,7 @@ import pyarrow import pymongo -from bson import Binary, CodecOptions, Decimal128, ObjectId +from bson import Binary, Code, CodecOptions, Decimal128, ObjectId from pyarrow import Table, bool_, csv, decimal256, field, int32, int64, list_ from pyarrow import schema as ArrowSchema from pyarrow import string, struct, timestamp @@ -650,6 +650,26 @@ def test_nested_contradicting_unused_schema(self): out = func(self.coll, {} if func == find_arrow_all else [], schema=schema) self.assertEqual(out["obj"].to_pylist(), [{"a": 1}, {"a": 2}]) + def test_nested_bson_extension_types(self): + data = { + "obj": { + "obj_id": ObjectId(), + "dec_128": Decimal128("0.0005"), + "binary": Binary(b"123"), + "code": Code(""), + } + } + + self.coll.drop() + self.coll.insert_one(data) + out = find_arrow_all(self.coll, {}) + obj_schema_type = out.field("obj").type + + self.assertIsInstance(obj_schema_type.field("obj_id").type, ObjectIdType) + self.assertIsInstance(obj_schema_type.field("dec_128").type, Decimal128Type) + self.assertIsInstance(obj_schema_type.field("binary").type, BinaryType) + self.assertIsInstance(obj_schema_type.field("code").type, CodeType) + class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase): def run_find(self, *args, **kwargs):