Skip to content

Commit

Permalink
[python] Replace SOMAArray.write with ManagedQuery.write
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenv committed Jan 30, 2025
1 parent a42d6b5 commit 87af109
Show file tree
Hide file tree
Showing 14 changed files with 206 additions and 659 deletions.
11 changes: 5 additions & 6 deletions apis/python/src/tiledbsoma/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from . import pytiledbsoma as clib
from ._constants import SOMA_JOINID
from ._exception import SOMAError, map_exception_for_create
from ._read_iters import TableReadIter
from ._read_iters import ManagedQuery, TableReadIter
from ._soma_array import SOMAArray
from ._tdb_handles import DataFrameWrapper
from ._types import (
Expand Down Expand Up @@ -770,7 +770,6 @@ def write(
_util.check_type("values", values, (pa.Table,))

write_options: Union[TileDBCreateOptions, TileDBWriteOptions]
sort_coords = None
if isinstance(platform_config, TileDBCreateOptions):
raise ValueError(
"As of TileDB-SOMA 1.13, the write method takes "
Expand All @@ -779,13 +778,13 @@ def write(
write_options = TileDBWriteOptions.from_platform_config(platform_config)
sort_coords = write_options.sort_coords

clib_dataframe = self._handle._handle

for batch in values.to_batches():
clib_dataframe.write(batch, sort_coords or False)
mq = ManagedQuery(self)
mq._handle.set_array_data(batch)
mq._handle.submit_write(sort_coords or False)

if write_options.consolidate_and_vacuum:
clib_dataframe.consolidate_and_vacuum()
self._handle._handle.consolidate_and_vacuum()

return self

Expand Down
2 changes: 1 addition & 1 deletion apis/python/src/tiledbsoma/_dense_nd_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def write(
mq = ManagedQuery(self, platform_config)
mq._handle.set_layout(order)
_util._set_coords(mq, new_coords)
mq._handle.set_soma_data(input)
mq._handle.set_column_data("soma_data", input)
mq._handle.submit_write()

tiledb_write_options = TileDBWriteOptions.from_platform_config(platform_config)
Expand Down
6 changes: 4 additions & 2 deletions apis/python/src/tiledbsoma/_point_cloud_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_revise_domain_for_extent,
)
from ._exception import SOMAError, map_exception_for_create
from ._read_iters import TableReadIter
from ._read_iters import ManagedQuery, TableReadIter
from ._spatial_dataframe import SpatialDataFrame
from ._spatial_util import (
coordinate_space_from_json,
Expand Down Expand Up @@ -484,7 +484,9 @@ def write(
clib_dataframe = self._handle._handle

for batch in values.to_batches():
clib_dataframe.write(batch, sort_coords or False)
mq = ManagedQuery(self, None)
mq._handle.set_array_data(batch)
mq._handle.submit_write(sort_coords or False)

if write_options.consolidate_and_vacuum:
clib_dataframe.consolidate_and_vacuum()
Expand Down
2 changes: 1 addition & 1 deletion apis/python/src/tiledbsoma/_read_iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ class ManagedQuery:
"""Keep the lifetime of the SOMAArray tethered to ManagedQuery."""

_array: SOMAArray
_platform_config: options.PlatformConfig | None
_platform_config: options.PlatformConfig | None = None
_handle: clib.ManagedQuery = attrs.field(init=False)

def __attrs_post_init__(self) -> None:
Expand Down
39 changes: 25 additions & 14 deletions apis/python/src/tiledbsoma/_sparse_nd_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._read_iters import (
BlockwiseScipyReadIter,
BlockwiseTableReadIter,
ManagedQuery,
SparseCOOTensorReadIter,
TableReadIter,
)
Expand Down Expand Up @@ -324,19 +325,23 @@ def write(
if isinstance(values, pa.SparseCOOTensor):
# Write bulk data
data, coords = values.to_numpy()
clib_sparse_array.write_coords(
[

mq = ManagedQuery(self, platform_config)
for i, c in enumerate(coords.T):
mq._handle.set_column_data(
f"soma_dim_{i}",
np.array(
c,
dtype=self.schema.field(f"soma_dim_{i}").type.to_pandas_dtype(),
)
for i, c in enumerate(coords.T)
],
),
)
mq._handle.set_column_data(
"soma_data",
np.array(
data, dtype=self.schema.field("soma_data").type.to_pandas_dtype()
),
sort_coords or True,
)
mq._handle.submit_write(sort_coords or True)

# Write bounding-box metadata. Note COO can be N-dimensional.
maxes = [e - 1 for e in values.shape]
Expand All @@ -356,19 +361,23 @@ def write(
# Write bulk data
# TODO: the ``to_scipy`` function is not zero copy. Need to explore zero-copy options.
sp = values.to_scipy().tocoo()
clib_sparse_array.write_coords(
[

mq = ManagedQuery(self, platform_config)
for i, c in enumerate([sp.row, sp.col]):
mq._handle.set_column_data(
f"soma_dim_{i}",
np.array(
c,
dtype=self.schema.field(f"soma_dim_{i}").type.to_pandas_dtype(),
)
for i, c in enumerate([sp.row, sp.col])
],
),
)
mq._handle.set_column_data(
"soma_data",
np.array(
sp.data, dtype=self.schema.field("soma_data").type.to_pandas_dtype()
),
sort_coords or True,
)
mq._handle.submit_write(sort_coords or True)

# Write bounding-box metadata. Note CSR and CSC are necessarily 2-dimensional.
nr, nc = values.shape
Expand All @@ -382,9 +391,11 @@ def write(

if isinstance(values, pa.Table):
# Write bulk data
values = _util.cast_values_to_target_schema(values, self.schema)
for batch in values.to_batches():
clib_sparse_array.write(batch, sort_coords or False)
# clib_sparse_array.write(batch, sort_coords or False)
mq = ManagedQuery(self, None)
mq._handle.set_array_data(batch)
mq._handle.submit_write(sort_coords or False)

# Write bounding-box metadata
maxes = []
Expand Down
14 changes: 0 additions & 14 deletions apis/python/src/tiledbsoma/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,20 +314,6 @@ def pa_types_is_string_or_bytes(dtype: pa.DataType) -> bool:
)


def cast_values_to_target_schema(values: pa.Table, schema: pa.Schema) -> pa.Table:
"""
When writing data to a SOMAArray, the values that the user passes in may not
match the schema on disk. Cast the values to the correct dtypes.
"""
# Ensure fields are in the correct order
# target_schema = []
# for input_field in values.schema:
# target_schema.append(schema.field(input_field.name))

# return values.cast(pa.schema(target_schema, values.schema.metadata))
return values


def build_clib_platform_config(
platform_config: options.PlatformConfig | None,
) -> clib.PlatformConfig:
Expand Down
31 changes: 22 additions & 9 deletions apis/python/src/tiledbsoma/managed_query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,28 +142,41 @@ void load_managed_query(py::module& m) {
arrow_array.release(&arrow_array);
})
.def(
"set_soma_data",
[](ManagedQuery& mq, py::array data) {
"set_column_data",
[](ManagedQuery& mq, std::string name, py::array data) {
py::buffer_info data_info = data.request();

py::gil_scoped_release release;
mq.setup_write_column(
"soma_data",
data.size(),
(const void*)data_info.ptr,
static_cast<uint64_t*>(nullptr),
std::nullopt);
try {
mq.setup_write_column(
name,
data.size(),
(const void*)data_info.ptr,
static_cast<uint64_t*>(nullptr),
std::nullopt);
} catch (const std::exception& e) {
TPY_ERROR_LOC(e.what());
}
py::gil_scoped_acquire acquire;
})
.def(
"submit_write",
&ManagedQuery::submit_write,
[](ManagedQuery& mq, bool sort_coords) {
try {
mq.submit_write(sort_coords);
} catch (const std::exception& e) {
TPY_ERROR_LOC(e.what());
}
},
"sort_coords"_a = false,
py::call_guard<py::gil_scoped_release>())

.def("reset", &ManagedQuery::reset)
.def("close", &ManagedQuery::close)

.def_property_readonly("result_order", &ManagedQuery::result_order)
.def_property_readonly("column_names", &ManagedQuery::column_names)

// The following short functions are expected to be invoked when the
// coords are Python list/tuple, or NumPy arrays. Arrow arrays are in
// the long if-else-if function above.
Expand Down
Loading

0 comments on commit 87af109

Please sign in to comment.