Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-243 Handle column of fields with "null" values only #241

Merged
merged 3 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions bindings/python/pymongoarrow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs):
against which to run the ``find`` operation.
- `query`: A mapping containing the query to use for the find operation.
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the first
document in the result set.
If the schema is not given, it will be inferred using the data in the
result set.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.
Expand Down Expand Up @@ -122,8 +122,8 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
against which to run the ``aggregate`` operation.
- `pipeline`: A list of aggregation pipeline stages.
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the first
document in the result set.
If the schema is not given, it will be inferred using the data in the
result set.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.
Expand Down Expand Up @@ -177,8 +177,8 @@ def find_pandas_all(collection, query, *, schema=None, **kwargs):
against which to run the ``find`` operation.
- `query`: A mapping containing the query to use for the find operation.
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the first
document in the result set.
If the schema is not given, it will be inferred using the data in the
result set.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.
Expand All @@ -198,8 +198,8 @@ def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs):
against which to run the ``find`` operation.
- `pipeline`: A list of aggregation pipeline stages.
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the first
document in the result set.
If the schema is not given, it will be inferred using the data in the
result set.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.
Expand Down Expand Up @@ -240,8 +240,8 @@ def find_numpy_all(collection, query, *, schema=None, **kwargs):
against which to run the ``find`` operation.
- `query`: A mapping containing the query to use for the find operation.
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the first
document in the result set.
If the schema is not given, it will be inferred using the data in the
result set.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.
Expand Down Expand Up @@ -271,8 +271,8 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
against which to run the ``find`` operation.
- `query`: A mapping containing the query to use for the find operation.
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the first
document in the result set.
If the schema is not given, it will be inferred using the data in the
result set.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.
Expand Down Expand Up @@ -338,8 +338,8 @@ def find_polars_all(collection, query, *, schema=None, **kwargs):
against which to run the ``find`` operation.
- `query`: A mapping containing the query to use for the find operation.
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the first
document in the result set.
If the schema is not given, it will be inferred using the data in the
result set.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.
Expand All @@ -361,8 +361,8 @@ def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs):
against which to run the ``find`` operation.
- `pipeline`: A list of aggregation pipeline stages.
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the first
document in the result set.
If the schema is not given, it will be inferred using the data in the
result set.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.
Expand Down
2 changes: 2 additions & 0 deletions bindings/python/pymongoarrow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Int32Builder,
Int64Builder,
ListBuilder,
NullBuilder,
ObjectIdBuilder,
StringBuilder,
)
Expand All @@ -49,6 +50,7 @@
_BsonArrowTypes.code: CodeBuilder,
_BsonArrowTypes.date32: Date32Builder,
_BsonArrowTypes.date64: Date64Builder,
_BsonArrowTypes.null: NullBuilder,
}
except ImportError:
pass
Expand Down
62 changes: 58 additions & 4 deletions bindings/python/pymongoarrow/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ cdef const bson_t* bson_reader_read_safe(bson_reader_t* stream_reader) except? N
# Placeholder numbers for the date types.
cdef uint8_t ARROW_TYPE_DATE32 = 100
cdef uint8_t ARROW_TYPE_DATE64 = 101
cdef uint8_t ARROW_TYPE_NULL = 102

_builder_type_map = {
BSON_TYPE_INT32: Int32Builder,
Expand All @@ -80,6 +81,7 @@ _builder_type_map = {
BSON_TYPE_CODE: CodeBuilder,
ARROW_TYPE_DATE32: Date32Builder,
ARROW_TYPE_DATE64: Date64Builder,
ARROW_TYPE_NULL: NullBuilder
}

_field_type_map = {
Expand Down Expand Up @@ -177,6 +179,7 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
cdef Py_ssize_t count = 0
cdef uint8_t byte_order_status = 0
cdef map[cstring, void *] builder_map
cdef map[cstring, void *] missing_builders
cdef map[cstring, void*].iterator it
cdef bson_subtype_t subtype
cdef int32_t val32
Expand All @@ -197,6 +200,7 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
cdef DocumentBuilder doc_builder
cdef Date32Builder date32_builder
cdef Date64Builder date64_builder
cdef NullBuilder null_builder

# Build up a map of the builders.
for key, value in context.builder_map.items():
Expand All @@ -219,10 +223,6 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
builder = None
if arr_value_builder is not None:
builder = arr_value_builder
else:
it = builder_map.find(key)
if it != builder_map.end():
builder = <_ArrayBuilderBase>builder_map[key]

if builder is None:
it = builder_map.find(key)
Expand All @@ -233,9 +233,16 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
# Get the appropriate builder for the current field.
value_t = bson_iter_type(&doc_iter)
builder_type = _builder_type_map.get(value_t)

# Keep the key in missing builders until we find it.
if builder_type is None:
missing_builders[key] = <void *>None
continue

it = missing_builders.find(key)
if it != builder_map.end():
missing_builders.erase(key)

# Handle the parameterized builders.
if builder_type == DatetimeBuilder and context.tzinfo is not None:
arrow_type = timestamp('ms', tz=context.tzinfo)
Expand Down Expand Up @@ -410,6 +417,9 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
binary_builder.append_null()
else:
binary_builder.append_raw(<char*>val_buf, val_buf_len)
elif ftype == ARROW_TYPE_NULL:
null_builder = builder
null_builder.append_null()
else:
raise PyMongoArrowError('unknown ftype {}'.format(ftype))

Expand All @@ -422,6 +432,17 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
if len(builder) != count:
builder.append_null()
preincrement(it)

# Any missing fields that are left must be null fields.
it = missing_builders.begin()
while it != missing_builders.end():
builder = NullBuilder()
context.builder_map[key] = builder
null_builder = builder
for _ in range(count):
null_builder.append_null()
preincrement(it)

finally:
bson_reader_destroy(stream_reader)

Expand Down Expand Up @@ -724,6 +745,37 @@ cdef class Date32Builder(_ArrayBuilderBase):
return self.builder


cdef class NullBuilder(_ArrayBuilderBase):
cdef:
shared_ptr[CNullBuilder] builder

def __cinit__(self, MemoryPool memory_pool=None):
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
self.builder.reset(new CNullBuilder(pool))
self.type_marker = ARROW_TYPE_NULL

cdef append_raw(self, void* value):
self.builder.get().AppendNull()

cpdef append(self, value):
self.builder.get().AppendNull()

cpdef append_null(self):
self.builder.get().AppendNull()

def __len__(self):
return self.builder.get().length()

cpdef finish(self):
cdef shared_ptr[CArray] out
with nogil:
self.builder.get().Finish(&out)
return pyarrow_wrap_array(out)

cdef shared_ptr[CNullBuilder] unwrap(self):
return self.builder


cdef class BoolBuilder(_ArrayBuilderBase):
cdef:
shared_ptr[CBooleanBuilder] builder
Expand Down Expand Up @@ -817,6 +869,8 @@ cdef object get_field_builder(object field, object tzinfo):
field_builder = ListBuilder(field_type, tzinfo)
elif _atypes.is_large_list(field_type):
field_builder = ListBuilder(field_type, tzinfo)
elif _atypes.is_null(field_type):
field_builder = NullBuilder()
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.objectid:
field_builder = ObjectIdBuilder()
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.decimal128:
Expand Down
3 changes: 3 additions & 0 deletions bindings/python/pymongoarrow/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ cdef extern from "arrow/builder.h" namespace "arrow" nogil:
int32_t num_values()
shared_ptr[CDataType] type()

cdef cppclass CNullBuilder" arrow::NullBuilder"(CArrayBuilder):
CNullBuilder(CMemoryPool* pool)


cdef extern from "arrow/type_fwd.h" namespace "arrow" nogil:
shared_ptr[CDataType] fixed_size_binary(int32_t byte_width)
10 changes: 8 additions & 2 deletions bindings/python/pymongoarrow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
float64,
int64,
list_,
null,
string,
struct,
timestamp,
Expand Down Expand Up @@ -55,6 +56,7 @@ class _BsonArrowTypes(enum.Enum):
code = 12
date32 = 13
date64 = 14
null = 15


# Custom Extension Types.
Expand Down Expand Up @@ -260,6 +262,7 @@ def get_numpy_type(type):
_atypes.is_int64: _BsonArrowTypes.int64,
_atypes.is_float64: _BsonArrowTypes.double,
_atypes.is_timestamp: _BsonArrowTypes.datetime,
_atypes.is_null: _BsonArrowTypes.null,
_is_objectid: _BsonArrowTypes.objectid,
_is_decimal128: _BsonArrowTypes.decimal128,
_is_binary: _BsonArrowTypes.binary,
Expand All @@ -276,7 +279,7 @@ def get_numpy_type(type):


def _is_typeid_supported(typeid):
return typeid in _TYPE_NORMALIZER_FACTORY
return typeid in _TYPE_NORMALIZER_FACTORY or typeid is None


def _normalize_typeid(typeid, field_name):
Expand All @@ -293,7 +296,10 @@ def _normalize_typeid(typeid, field_name):
raise ValueError(msg)
return list_(_normalize_typeid(typeid[0], "0"))
if _is_typeid_supported(typeid):
normalizer = _TYPE_NORMALIZER_FACTORY[typeid]
if typeid is None: # noqa: SIM108
normalizer = lambda _: null() # noqa: E731
else:
normalizer = _TYPE_NORMALIZER_FACTORY[typeid]
return normalizer(typeid)
msg = f"Unsupported type identifier {typeid} for field {field_name}"
raise ValueError(msg)
Expand Down
48 changes: 34 additions & 14 deletions bindings/python/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,23 +295,26 @@ def test_pymongo_error(self):

def _create_data(self):
schema = {k.__name__: v(True) for k, v in _TYPE_NORMALIZER_FACTORY.items()}
schema["null"] = pa.null()
schema["Binary"] = BinaryType(10)
schema["ObjectId"] = ObjectIdType()
schema["Decimal128"] = Decimal128Type()
schema["Code"] = CodeType()
pydict = {
"Int64": [i for i in range(2)],
"float": [i for i in range(2)],
"datetime": [i for i in range(2)],
"str": [str(i) for i in range(2)],
"int": [i for i in range(2)],
"bool": [True, False],
"null": [None for _ in range(2)],
"Binary": [b"1", b"23"],
"ObjectId": [ObjectId().binary, ObjectId().binary],
"Decimal128": [Decimal128(str(i)).bid for i in range(2)],
"Code": [str(i) for i in range(2)],
}
data = Table.from_pydict(
{
"Int64": [i for i in range(2)],
"float": [i for i in range(2)],
"datetime": [i for i in range(2)],
"str": [str(i) for i in range(2)],
"int": [i for i in range(2)],
"bool": [True, False],
"Binary": [b"1", b"23"],
"ObjectId": [ObjectId().binary, ObjectId().binary],
"Decimal128": [Decimal128(str(i)).bid for i in range(2)],
"Code": [str(i) for i in range(2)],
},
pydict,
ArrowSchema(schema),
)
return schema, data
Expand Down Expand Up @@ -355,8 +358,10 @@ def test_write_batching(self, mock):
self.round_trip(data, Schema(schema), coll=self.coll)
self.assertEqual(mock.call_count, 2)

def _create_nested_data(self, nested_elem=None):
def _create_nested_data(self, nested_elem=None, use_none=False):
schema = {k.__name__: v(0) for k, v in _TYPE_NORMALIZER_FACTORY.items()}
if use_none:
schema["null"] = pa.null()
if nested_elem:
schem_ent, nested_elem = nested_elem
schema["list"] = list_(schem_ent)
Expand All @@ -379,10 +384,12 @@ def _create_nested_data(self, nested_elem=None):
"date32": [date(2012, 1, 1) for i in range(3)],
"date64": [date(2012, 1, 1) for i in range(3)],
}
if use_none:
raw_data["null"] = [None for _ in range(3)]

def inner(i):
inner_dict = dict(
str=str(i),
str=None if use_none and i == 0 else str(i),
bool=bool(i),
float=i + 0.1,
Int64=i,
Expand All @@ -395,6 +402,8 @@ def inner(i):
date32=date(2012, 1, 1),
date64=date(2014, 1, 1),
)
if use_none:
inner_dict["null"] = None
if nested_elem:
inner_dict["list"] = [nested_elem]
return inner_dict
Expand Down Expand Up @@ -471,6 +480,17 @@ def test_auto_schema_nested(self):
for name in out.column_names:
self.assertEqual(data[name], out[name].cast(data[name].type))

def test_schema_nested_null(self):
schema, data = self._create_nested_data(use_none=True)

self.coll.drop()
res = write(self.coll, data)
self.assertEqual(len(data), res.raw_result["insertedCount"])
for func in [find_arrow_all, aggregate_arrow_all]:
out = func(self.coll, {} if func == find_arrow_all else [], schema=Schema(schema))
for name in out.column_names:
self.assertEqual(data[name], out[name].cast(data[name].type))

def test_auto_schema(self):
_, data = self._create_data()
self.coll.drop()
Expand Down
Loading