Skip to content

Commit

Permalink
ARROW-176 Nested extension objects are not handled in auto schema
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp committed Jul 12, 2023
1 parent bbdfef0 commit fec6998
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
14 changes: 12 additions & 2 deletions bindings/python/pymongoarrow/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ from math import isnan
# Python imports
import bson
import numpy as np
from pyarrow import timestamp, struct, field
from pyarrow import timestamp, struct, field, scalar, FixedSizeBinaryScalar, StructScalar
from pyarrow.lib import (
tobytes, StructType, int32, int64, float64, string, bool_, list_
)
Expand Down Expand Up @@ -485,6 +485,9 @@ cdef class ObjectIdBuilder(_ArrayBuilderBase):
cdef shared_ptr[CArray] out
with nogil:
self.builder.get().Finish(&out)
result = pyarrow_wrap_array(out)
for x in result:
print("CORRECT: ", x)
return pyarrow_wrap_array(out).cast(ObjectIdType())

cdef shared_ptr[CFixedSizeBinaryBuilder] unwrap(self):
Expand Down Expand Up @@ -781,7 +784,14 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
cdef shared_ptr[CArray] out
with nogil:
self.builder.get().Finish(&out)
return pyarrow_wrap_array(out)
wrapped = pyarrow_wrap_array(out)
python_out = []
for original in wrapped:
for fname, ftype in original.items():
if isinstance(ftype, FixedSizeBinaryScalar) and ftype.type.byte_width == 12: # ObjectIdType
print("BEFORE: ", fname, ftype)
print("AFTER: ", fname, ftype.cast(ObjectIdType()))
return wrapped

cdef shared_ptr[CStructBuilder] unwrap(self):
return self.builder
Expand Down
9 changes: 9 additions & 0 deletions bindings/python/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,15 @@ def test_nested_contradicting_unused_schema(self):
out = func(self.coll, {} if func == find_arrow_all else [], schema=schema)
self.assertEqual(out["obj"].to_pylist(), [{"a": 1}, {"a": 2}])

def test_nested_bson_objectId(self):
object_id = ObjectId()
data = {'_id': object_id, 'id1': object_id, 'obj': {'id2': object_id, 'id3': object_id}}

self.coll.drop()
self.coll.insert_one(data)
out = find_arrow_all(self.coll, {})
for schema_field in out.schema:
self.assertIsInstance(schema_field.type, ObjectIdType)

class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
def run_find(self, *args, **kwargs):
Expand Down

0 comments on commit fec6998

Please sign in to comment.