Skip to content

Commit 9eb563a

Browse files
committed
ARROW-176 Nested extension objects are not handled in auto schema
1 parent ec5ccb4 commit 9eb563a

File tree

6 files changed

+49
-32
lines changed

6 files changed

+49
-32
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@
2626
from pymongo.common import MAX_WRITE_BATCH_SIZE
2727
from pymongoarrow.context import PyMongoArrowContext
2828
from pymongoarrow.errors import ArrowWriteError
29-
from pymongoarrow.lib import process_bson_stream
3029
from pymongoarrow.result import ArrowWriteResult
3130
from pymongoarrow.schema import Schema
3231
from pymongoarrow.types import _validate_schema, get_numpy_type
3332

33+
from pymongoarrow.lib import process_bson_stream
34+
3435
__all__ = [
3536
"aggregate_arrow_all",
3637
"find_arrow_all",

bindings/python/pymongoarrow/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
from bson.codec_options import DEFAULT_CODEC_OPTIONS
1515
from pyarrow import Table, timestamp
16+
from pymongoarrow.types import _BsonArrowTypes, _get_internal_typemap
17+
1618
from pymongoarrow.lib import (
1719
BinaryBuilder,
1820
BoolBuilder,
@@ -27,7 +29,6 @@
2729
ObjectIdBuilder,
2830
StringBuilder,
2931
)
30-
from pymongoarrow.types import _BsonArrowTypes, _get_internal_typemap
3132

3233
_TYPE_TO_BUILDER_CLS = {
3334
_BsonArrowTypes.int32: Int32Builder,

bindings/python/pymongoarrow/lib.pyx

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ from math import isnan
2626
# Python imports
2727
import bson
2828
import numpy as np
29-
from pyarrow import timestamp, struct, field, scalar, FixedSizeBinaryScalar, StructScalar, array
29+
from pyarrow import timestamp, struct, field
3030
from pyarrow.lib import (
3131
tobytes, StructType, int32, int64, float64, string, bool_, list_
3232
)
@@ -485,15 +485,11 @@ cdef class ObjectIdBuilder(_ArrayBuilderBase):
485485
cdef shared_ptr[CArray] out
486486
with nogil:
487487
self.builder.get().Finish(&out)
488-
result = pyarrow_wrap_array(out)
489-
for x in result:
490-
print("CORRECT: ", result.type, type(result), x.type, type(x))
491488
return pyarrow_wrap_array(out).cast(ObjectIdType())
492489

493490
cdef shared_ptr[CFixedSizeBinaryBuilder] unwrap(self):
494491
return self.builder
495492

496-
497493
cdef class Int32Builder(_ArrayBuilderBase):
498494
cdef:
499495
shared_ptr[CInt32Builder] builder
@@ -725,6 +721,8 @@ cdef object get_field_builder(object field, object tzinfo):
725721
field_builder = Decimal128Builder()
726722
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.binary:
727723
field_builder = BinaryBuilder(field_type.subtype)
724+
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.code:
725+
field_builder = CodeBuilder()
728726
else:
729727
field_builder = StringBuilder()
730728
return field_builder
@@ -735,6 +733,7 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
735733
shared_ptr[CStructBuilder] builder
736734
object dtype
737735
object context
736+
object builder_map
738737

739738
def __cinit__(self, StructType dtype, tzinfo=None, MemoryPool memory_pool=None):
740739
cdef StringBuilder field_builder
@@ -747,11 +746,11 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
747746

748747
self.context = context = PyMongoArrowContext(None, {})
749748
context.tzinfo = tzinfo
750-
builder_map = context.builder_map
749+
self.builder_map = context.builder_map
751750

752751
for field in dtype:
753752
field_builder = <StringBuilder>get_field_builder(field, tzinfo)
754-
builder_map[field.name.encode('utf-8')] = field_builder
753+
self.builder_map[field.name.encode('utf-8')] = field_builder
755754
c_field_builders.push_back(<shared_ptr[CArrayBuilder]>field_builder.builder)
756755

757756
self.builder.reset(new CStructBuilder(pyarrow_unwrap_data_type(dtype), pool, c_field_builders))
@@ -784,26 +783,29 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
784783
cdef shared_ptr[CArray] out
785784
with nogil:
786785
self.builder.get().Finish(&out)
787-
wrapped = pyarrow_wrap_array(out)
788-
python_out = []
789-
for original in wrapped:
786+
787+
struct_array = pyarrow_wrap_array(out)
788+
for struct_def in struct_array:
790789
new_types = []
791-
new_names = list(original.keys())
792-
for fname, ftype in original.items():
793-
#new_names.append(fname)
794-
if isinstance(ftype, FixedSizeBinaryScalar) and ftype.type.byte_width == 12: # ObjectIdType
795-
print("TYPE: ", ftype, ftype.type, type(ftype))
790+
new_names = list(struct_def.keys())
791+
for fname, ftype in struct_def.items():
792+
if type(self.builder_map[fname.encode('utf-8')]).__name__ == ObjectIdBuilder.__name__: # ObjectIdType
796793
new_ftype = ObjectIdType()
797-
#print("TYPE: ", new_ftype, new_ftype.storage_type, type(new_ftype))
798-
#print("ARRAY: ", array([(fname, ftype)]))
794+
new_types.append(new_ftype)
795+
elif type(self.builder_map[fname.encode('utf-8')]).__name__ == Decimal128Builder.__name__: # Decimal128Type
796+
new_ftype = Decimal128Type_()
797+
new_types.append(new_ftype)
798+
elif type(self.builder_map[fname.encode('utf-8')]).__name__ == BinaryBuilder.__name__: # BinaryType
799+
new_ftype = BinaryType(self.dtype.field(fname).type.subtype)
800+
new_types.append(new_ftype)
801+
elif type(self.builder_map[fname.encode('utf-8')]).__name__ == CodeBuilder.__name__: # CodeType
802+
new_ftype = CodeType()
799803
new_types.append(new_ftype)
800804
else:
801805
new_types.append(ftype.type)
802-
python_out.append(struct(zip(new_names, new_types)))
803-
print("BEFORE: ", wrapped, wrapped.type, type(wrapped))
804-
print("AFTER: ", python_out[0], python_out, type(python_out[0]))
805-
print("AFTER AFTER: ", array([], type=python_out[0]))
806-
return wrapped
806+
807+
new_dtype = struct(dict(zip(new_names, new_types)))
808+
return struct_array.cast(new_dtype)
807809

808810
cdef shared_ptr[CStructBuilder] unwrap(self):
809811
return self.builder

bindings/python/test/test_arrow.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pyarrow
2323
import pymongo
24-
from bson import Binary, CodecOptions, Decimal128, ObjectId
24+
from bson import Binary, Code, CodecOptions, Decimal128, ObjectId
2525
from pyarrow import Table, bool_, csv, decimal256, field, int32, int64, list_
2626
from pyarrow import schema as ArrowSchema
2727
from pyarrow import string, struct, timestamp
@@ -650,15 +650,26 @@ def test_nested_contradicting_unused_schema(self):
650650
out = func(self.coll, {} if func == find_arrow_all else [], schema=schema)
651651
self.assertEqual(out["obj"].to_pylist(), [{"a": 1}, {"a": 2}])
652652

653-
def test_nested_bson_objectId(self):
654-
object_id = ObjectId()
655-
data = {'_id': object_id, 'id1': object_id, 'obj': {'id2': object_id, 'id3': object_id}}
653+
def test_nested_bson_extension_types(self):
654+
data = {
655+
"obj": {
656+
"obj_id": ObjectId(),
657+
"dec_128": Decimal128("0.0005"),
658+
"binary": Binary(b"123"),
659+
"code": Code(""),
660+
}
661+
}
656662

657663
self.coll.drop()
658664
self.coll.insert_one(data)
659665
out = find_arrow_all(self.coll, {})
660-
for schema_field in out.schema:
661-
self.assertIsInstance(schema_field.type, ObjectIdType)
666+
obj_schema_type = out.field("obj").type
667+
668+
self.assertIsInstance(obj_schema_type.field("obj_id").type, ObjectIdType)
669+
self.assertIsInstance(obj_schema_type.field("dec_128").type, Decimal128Type)
670+
self.assertIsInstance(obj_schema_type.field("binary").type, BinaryType)
671+
self.assertIsInstance(obj_schema_type.field("code").type, CodeType)
672+
662673

663674
class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
664675
def run_find(self, *args, **kwargs):

bindings/python/test/test_bson.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import pyarrow as pa
1717
from bson import Decimal128, Int64, InvalidBSON, encode
1818
from pymongoarrow.context import PyMongoArrowContext
19-
from pymongoarrow.lib import process_bson_stream
2019
from pymongoarrow.schema import Schema
2120
from pymongoarrow.types import ObjectId, ObjectIdType, int64, string
2221

22+
from pymongoarrow.lib import process_bson_stream
23+
2324

2425
class TestBsonToArrowConversionBase(TestCase):
2526
def setUp(self):

bindings/python/test/test_builders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from bson import Binary, Code, Decimal128, ObjectId
1919
from pyarrow import Array, bool_, field, int32, int64, list_, struct, timestamp
20+
from pymongoarrow.types import ObjectIdType
21+
2022
from pymongoarrow.lib import (
2123
BinaryBuilder,
2224
BoolBuilder,
@@ -31,7 +33,6 @@
3133
ObjectIdBuilder,
3234
StringBuilder,
3335
)
34-
from pymongoarrow.types import ObjectIdType
3536

3637

3738
class IntBuildersTestMixin:

0 commit comments

Comments
 (0)