From 9eb563a60cc075f373fb2b0dc00d9980fbdac484 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 12 Jul 2023 14:21:32 -0700 Subject: [PATCH] ARROW-176 Nested extension objects are not handled in auto schema --- bindings/python/pymongoarrow/api.py | 3 +- bindings/python/pymongoarrow/context.py | 3 +- bindings/python/pymongoarrow/lib.pyx | 46 +++++++++++++------------ bindings/python/test/test_arrow.py | 23 +++++++++---- bindings/python/test/test_bson.py | 3 +- bindings/python/test/test_builders.py | 3 +- 6 files changed, 49 insertions(+), 32 deletions(-) diff --git a/bindings/python/pymongoarrow/api.py b/bindings/python/pymongoarrow/api.py index 8b753845..e14d78e1 100644 --- a/bindings/python/pymongoarrow/api.py +++ b/bindings/python/pymongoarrow/api.py @@ -26,11 +26,12 @@ from pymongo.common import MAX_WRITE_BATCH_SIZE from pymongoarrow.context import PyMongoArrowContext from pymongoarrow.errors import ArrowWriteError -from pymongoarrow.lib import process_bson_stream from pymongoarrow.result import ArrowWriteResult from pymongoarrow.schema import Schema from pymongoarrow.types import _validate_schema, get_numpy_type +from pymongoarrow.lib import process_bson_stream + __all__ = [ "aggregate_arrow_all", "find_arrow_all", diff --git a/bindings/python/pymongoarrow/context.py b/bindings/python/pymongoarrow/context.py index f6eb82d8..f8c60747 100644 --- a/bindings/python/pymongoarrow/context.py +++ b/bindings/python/pymongoarrow/context.py @@ -13,6 +13,8 @@ # limitations under the License. from bson.codec_options import DEFAULT_CODEC_OPTIONS from pyarrow import Table, timestamp +from pymongoarrow.types import _BsonArrowTypes, _get_internal_typemap + from pymongoarrow.lib import ( BinaryBuilder, BoolBuilder, @@ -27,7 +29,6 @@ ObjectIdBuilder, StringBuilder, ) -from pymongoarrow.types import _BsonArrowTypes, _get_internal_typemap _TYPE_TO_BUILDER_CLS = { _BsonArrowTypes.int32: Int32Builder, diff --git a/bindings/python/pymongoarrow/lib.pyx b/bindings/python/pymongoarrow/lib.pyx index c407be47..704078fb 100644 --- a/bindings/python/pymongoarrow/lib.pyx +++ b/bindings/python/pymongoarrow/lib.pyx @@ -26,7 +26,7 @@ from math import isnan # Python imports import bson import numpy as np -from pyarrow import timestamp, struct, field, scalar, FixedSizeBinaryScalar, StructScalar, array +from pyarrow import timestamp, struct, field from pyarrow.lib import ( tobytes, StructType, int32, int64, float64, string, bool_, list_ ) @@ -485,15 +485,11 @@ cdef class ObjectIdBuilder(_ArrayBuilderBase): cdef shared_ptr[CArray] out with nogil: self.builder.get().Finish(&out) - result = pyarrow_wrap_array(out) - for x in result: - print("CORRECT: ", result.type, type(result), x.type, type(x)) return pyarrow_wrap_array(out).cast(ObjectIdType()) cdef shared_ptr[CFixedSizeBinaryBuilder] unwrap(self): return self.builder - cdef class Int32Builder(_ArrayBuilderBase): cdef: shared_ptr[CInt32Builder] builder @@ -725,6 +721,8 @@ cdef object get_field_builder(object field, object tzinfo): field_builder = Decimal128Builder() elif getattr(field_type, '_type_marker') == _BsonArrowTypes.binary: field_builder = BinaryBuilder(field_type.subtype) + elif getattr(field_type, '_type_marker') == _BsonArrowTypes.code: + field_builder = CodeBuilder() else: field_builder = StringBuilder() return field_builder @@ -735,6 +733,7 @@ cdef class DocumentBuilder(_ArrayBuilderBase): shared_ptr[CStructBuilder] builder object dtype object context + object builder_map def __cinit__(self, StructType dtype, tzinfo=None, MemoryPool memory_pool=None): cdef StringBuilder field_builder @@ -747,11 +746,11 @@ cdef class DocumentBuilder(_ArrayBuilderBase): self.context = context = PyMongoArrowContext(None, {}) context.tzinfo = tzinfo - builder_map = context.builder_map + self.builder_map = context.builder_map for field in dtype: field_builder = get_field_builder(field, tzinfo) - builder_map[field.name.encode('utf-8')] = field_builder + self.builder_map[field.name.encode('utf-8')] = field_builder c_field_builders.push_back(field_builder.builder) self.builder.reset(new CStructBuilder(pyarrow_unwrap_data_type(dtype), pool, c_field_builders)) @@ -784,26 +783,29 @@ cdef class DocumentBuilder(_ArrayBuilderBase): cdef shared_ptr[CArray] out with nogil: self.builder.get().Finish(&out) - wrapped = pyarrow_wrap_array(out) - python_out = [] - for original in wrapped: + + struct_array = pyarrow_wrap_array(out) + for struct_def in struct_array: new_types = [] - new_names = list(original.keys()) - for fname, ftype in original.items(): - #new_names.append(fname) - if isinstance(ftype, FixedSizeBinaryScalar) and ftype.type.byte_width == 12: # ObjectIdType - print("TYPE: ", ftype, ftype.type, type(ftype)) + new_names = list(struct_def.keys()) + for fname, ftype in struct_def.items(): + if type(self.builder_map[fname.encode('utf-8')]).__name__ == ObjectIdBuilder.__name__: # ObjectIdType new_ftype = ObjectIdType() - #print("TYPE: ", new_ftype, new_ftype.storage_type, type(new_ftype)) - #print("ARRAY: ", array([(fname, ftype)])) + new_types.append(new_ftype) + elif type(self.builder_map[fname.encode('utf-8')]).__name__ == Decimal128Builder.__name__: # Decimal128Type + new_ftype = Decimal128Type_() + new_types.append(new_ftype) + elif type(self.builder_map[fname.encode('utf-8')]).__name__ == BinaryBuilder.__name__: # BinaryType + new_ftype = BinaryType(self.dtype.field(fname).type.subtype) + new_types.append(new_ftype) + elif type(self.builder_map[fname.encode('utf-8')]).__name__ == CodeBuilder.__name__: # CodeType + new_ftype = CodeType() new_types.append(new_ftype) else: new_types.append(ftype.type) - python_out.append(struct(zip(new_names, new_types))) - print("BEFORE: ", wrapped, wrapped.type, type(wrapped)) - print("AFTER: ", python_out[0], python_out, type(python_out[0])) - print("AFTER AFTER: ", array([], type=python_out[0])) - return wrapped + + new_dtype = struct(dict(zip(new_names, new_types))) + return struct_array.cast(new_dtype) cdef shared_ptr[CStructBuilder] unwrap(self): return self.builder diff --git a/bindings/python/test/test_arrow.py b/bindings/python/test/test_arrow.py index b9878fc9..f7678147 100644 --- a/bindings/python/test/test_arrow.py +++ b/bindings/python/test/test_arrow.py @@ -21,7 +21,7 @@ import pyarrow import pymongo -from bson import Binary, CodecOptions, Decimal128, ObjectId +from bson import Binary, Code, CodecOptions, Decimal128, ObjectId from pyarrow import Table, bool_, csv, decimal256, field, int32, int64, list_ from pyarrow import schema as ArrowSchema from pyarrow import string, struct, timestamp @@ -650,15 +650,26 @@ def test_nested_contradicting_unused_schema(self): out = func(self.coll, {} if func == find_arrow_all else [], schema=schema) self.assertEqual(out["obj"].to_pylist(), [{"a": 1}, {"a": 2}]) - def test_nested_bson_objectId(self): - object_id = ObjectId() - data = {'_id': object_id, 'id1': object_id, 'obj': {'id2': object_id, 'id3': object_id}} + def test_nested_bson_extension_types(self): + data = { + "obj": { + "obj_id": ObjectId(), + "dec_128": Decimal128("0.0005"), + "binary": Binary(b"123"), + "code": Code(""), + } + } self.coll.drop() self.coll.insert_one(data) out = find_arrow_all(self.coll, {}) - for schema_field in out.schema: - self.assertIsInstance(schema_field.type, ObjectIdType) + obj_schema_type = out.field("obj").type + + self.assertIsInstance(obj_schema_type.field("obj_id").type, ObjectIdType) + self.assertIsInstance(obj_schema_type.field("dec_128").type, Decimal128Type) + self.assertIsInstance(obj_schema_type.field("binary").type, BinaryType) + self.assertIsInstance(obj_schema_type.field("code").type, CodeType) + class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase): def run_find(self, *args, **kwargs): diff --git a/bindings/python/test/test_bson.py b/bindings/python/test/test_bson.py index de6ae76c..64977b24 100644 --- a/bindings/python/test/test_bson.py +++ b/bindings/python/test/test_bson.py @@ -16,10 +16,11 @@ import pyarrow as pa from bson import Decimal128, Int64, InvalidBSON, encode from pymongoarrow.context import PyMongoArrowContext -from pymongoarrow.lib import process_bson_stream from pymongoarrow.schema import Schema from pymongoarrow.types import ObjectId, ObjectIdType, int64, string +from pymongoarrow.lib import process_bson_stream + class TestBsonToArrowConversionBase(TestCase): def setUp(self): diff --git a/bindings/python/test/test_builders.py b/bindings/python/test/test_builders.py index 4cb96748..fa987038 100644 --- a/bindings/python/test/test_builders.py +++ b/bindings/python/test/test_builders.py @@ -17,6 +17,8 @@ from bson import Binary, Code, Decimal128, ObjectId from pyarrow import Array, bool_, field, int32, int64, list_, struct, timestamp +from pymongoarrow.types import ObjectIdType + from pymongoarrow.lib import ( BinaryBuilder, BoolBuilder, @@ -31,7 +33,6 @@ ObjectIdBuilder, StringBuilder, ) -from pymongoarrow.types import ObjectIdType class IntBuildersTestMixin: