Skip to content

Commit

Permalink
ARROW-176 Nested extension objects are not handled in auto schema
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp committed Jul 12, 2023
1 parent ec5ccb4 commit 9eb563a
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 32 deletions.
3 changes: 2 additions & 1 deletion bindings/python/pymongoarrow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
from pymongo.common import MAX_WRITE_BATCH_SIZE
from pymongoarrow.context import PyMongoArrowContext
from pymongoarrow.errors import ArrowWriteError
from pymongoarrow.lib import process_bson_stream
from pymongoarrow.result import ArrowWriteResult
from pymongoarrow.schema import Schema
from pymongoarrow.types import _validate_schema, get_numpy_type

from pymongoarrow.lib import process_bson_stream

__all__ = [
"aggregate_arrow_all",
"find_arrow_all",
Expand Down
3 changes: 2 additions & 1 deletion bindings/python/pymongoarrow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from pyarrow import Table, timestamp
from pymongoarrow.types import _BsonArrowTypes, _get_internal_typemap

from pymongoarrow.lib import (
BinaryBuilder,
BoolBuilder,
Expand All @@ -27,7 +29,6 @@
ObjectIdBuilder,
StringBuilder,
)
from pymongoarrow.types import _BsonArrowTypes, _get_internal_typemap

_TYPE_TO_BUILDER_CLS = {
_BsonArrowTypes.int32: Int32Builder,
Expand Down
46 changes: 24 additions & 22 deletions bindings/python/pymongoarrow/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ from math import isnan
# Python imports
import bson
import numpy as np
from pyarrow import timestamp, struct, field, scalar, FixedSizeBinaryScalar, StructScalar, array
from pyarrow import timestamp, struct, field
from pyarrow.lib import (
tobytes, StructType, int32, int64, float64, string, bool_, list_
)
Expand Down Expand Up @@ -485,15 +485,11 @@ cdef class ObjectIdBuilder(_ArrayBuilderBase):
cdef shared_ptr[CArray] out
with nogil:
self.builder.get().Finish(&out)
result = pyarrow_wrap_array(out)
for x in result:
print("CORRECT: ", result.type, type(result), x.type, type(x))
return pyarrow_wrap_array(out).cast(ObjectIdType())

cdef shared_ptr[CFixedSizeBinaryBuilder] unwrap(self):
return self.builder


cdef class Int32Builder(_ArrayBuilderBase):
cdef:
shared_ptr[CInt32Builder] builder
Expand Down Expand Up @@ -725,6 +721,8 @@ cdef object get_field_builder(object field, object tzinfo):
field_builder = Decimal128Builder()
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.binary:
field_builder = BinaryBuilder(field_type.subtype)
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.code:
field_builder = CodeBuilder()
else:
field_builder = StringBuilder()
return field_builder
Expand All @@ -735,6 +733,7 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
shared_ptr[CStructBuilder] builder
object dtype
object context
object builder_map

def __cinit__(self, StructType dtype, tzinfo=None, MemoryPool memory_pool=None):
cdef StringBuilder field_builder
Expand All @@ -747,11 +746,11 @@ cdef class DocumentBuilder(_ArrayBuilderBase):

self.context = context = PyMongoArrowContext(None, {})
context.tzinfo = tzinfo
builder_map = context.builder_map
self.builder_map = context.builder_map

for field in dtype:
field_builder = <StringBuilder>get_field_builder(field, tzinfo)
builder_map[field.name.encode('utf-8')] = field_builder
self.builder_map[field.name.encode('utf-8')] = field_builder
c_field_builders.push_back(<shared_ptr[CArrayBuilder]>field_builder.builder)

self.builder.reset(new CStructBuilder(pyarrow_unwrap_data_type(dtype), pool, c_field_builders))
Expand Down Expand Up @@ -784,26 +783,29 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
cdef shared_ptr[CArray] out
with nogil:
self.builder.get().Finish(&out)
wrapped = pyarrow_wrap_array(out)
python_out = []
for original in wrapped:

struct_array = pyarrow_wrap_array(out)
for struct_def in struct_array:
new_types = []
new_names = list(original.keys())
for fname, ftype in original.items():
#new_names.append(fname)
if isinstance(ftype, FixedSizeBinaryScalar) and ftype.type.byte_width == 12: # ObjectIdType
print("TYPE: ", ftype, ftype.type, type(ftype))
new_names = list(struct_def.keys())
for fname, ftype in struct_def.items():
if type(self.builder_map[fname.encode('utf-8')]).__name__ == ObjectIdBuilder.__name__: # ObjectIdType
new_ftype = ObjectIdType()
#print("TYPE: ", new_ftype, new_ftype.storage_type, type(new_ftype))
#print("ARRAY: ", array([(fname, ftype)]))
new_types.append(new_ftype)
elif type(self.builder_map[fname.encode('utf-8')]).__name__ == Decimal128Builder.__name__: # Decimal128Type
new_ftype = Decimal128Type_()
new_types.append(new_ftype)
elif type(self.builder_map[fname.encode('utf-8')]).__name__ == BinaryBuilder.__name__: # BinaryType
new_ftype = BinaryType(self.dtype.field(fname).type.subtype)
new_types.append(new_ftype)
elif type(self.builder_map[fname.encode('utf-8')]).__name__ == CodeBuilder.__name__: # CodeType
new_ftype = CodeType()
new_types.append(new_ftype)
else:
new_types.append(ftype.type)
python_out.append(struct(zip(new_names, new_types)))
print("BEFORE: ", wrapped, wrapped.type, type(wrapped))
print("AFTER: ", python_out[0], python_out, type(python_out[0]))
print("AFTER AFTER: ", array([], type=python_out[0]))
return wrapped

new_dtype = struct(dict(zip(new_names, new_types)))
return struct_array.cast(new_dtype)

cdef shared_ptr[CStructBuilder] unwrap(self):
return self.builder
Expand Down
23 changes: 17 additions & 6 deletions bindings/python/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pyarrow
import pymongo
from bson import Binary, CodecOptions, Decimal128, ObjectId
from bson import Binary, Code, CodecOptions, Decimal128, ObjectId
from pyarrow import Table, bool_, csv, decimal256, field, int32, int64, list_
from pyarrow import schema as ArrowSchema
from pyarrow import string, struct, timestamp
Expand Down Expand Up @@ -650,15 +650,26 @@ def test_nested_contradicting_unused_schema(self):
out = func(self.coll, {} if func == find_arrow_all else [], schema=schema)
self.assertEqual(out["obj"].to_pylist(), [{"a": 1}, {"a": 2}])

def test_nested_bson_objectId(self):
object_id = ObjectId()
data = {'_id': object_id, 'id1': object_id, 'obj': {'id2': object_id, 'id3': object_id}}
def test_nested_bson_extension_types(self):
data = {
"obj": {
"obj_id": ObjectId(),
"dec_128": Decimal128("0.0005"),
"binary": Binary(b"123"),
"code": Code(""),
}
}

self.coll.drop()
self.coll.insert_one(data)
out = find_arrow_all(self.coll, {})
for schema_field in out.schema:
self.assertIsInstance(schema_field.type, ObjectIdType)
obj_schema_type = out.field("obj").type

self.assertIsInstance(obj_schema_type.field("obj_id").type, ObjectIdType)
self.assertIsInstance(obj_schema_type.field("dec_128").type, Decimal128Type)
self.assertIsInstance(obj_schema_type.field("binary").type, BinaryType)
self.assertIsInstance(obj_schema_type.field("code").type, CodeType)


class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
def run_find(self, *args, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion bindings/python/test/test_bson.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import pyarrow as pa
from bson import Decimal128, Int64, InvalidBSON, encode
from pymongoarrow.context import PyMongoArrowContext
from pymongoarrow.lib import process_bson_stream
from pymongoarrow.schema import Schema
from pymongoarrow.types import ObjectId, ObjectIdType, int64, string

from pymongoarrow.lib import process_bson_stream


class TestBsonToArrowConversionBase(TestCase):
def setUp(self):
Expand Down
3 changes: 2 additions & 1 deletion bindings/python/test/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from bson import Binary, Code, Decimal128, ObjectId
from pyarrow import Array, bool_, field, int32, int64, list_, struct, timestamp
from pymongoarrow.types import ObjectIdType

from pymongoarrow.lib import (
BinaryBuilder,
BoolBuilder,
Expand All @@ -31,7 +33,6 @@
ObjectIdBuilder,
StringBuilder,
)
from pymongoarrow.types import ObjectIdType


class IntBuildersTestMixin:
Expand Down

0 comments on commit 9eb563a

Please sign in to comment.