Skip to content

Commit

Permalink
Update Python API to send PointCloudDataFrame metadata to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
jp-dark committed Jan 24, 2025
1 parent 854a823 commit 94ceb98
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
28 changes: 14 additions & 14 deletions apis/python/src/tiledbsoma/_point_cloud_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from ._constants import (
SOMA_COORDINATE_SPACE_METADATA_KEY,
SOMA_JOINID,
SOMA_SPATIAL_ENCODING_VERSION,
SOMA_SPATIAL_VERSION_METADATA_KEY,
SPATIAL_DISCLAIMER,
)
from ._dataframe import (
Expand Down Expand Up @@ -123,9 +121,16 @@ def create(
warnings.warn(SPATIAL_DISCLAIMER)

axis_dtype: pa.DataType | None = None
if not isinstance(coordinate_space, CoordinateSpace):
coordinate_space = CoordinateSpace.from_axis_names(coordinate_space)
for column_name in coordinate_space.axis_names:

# Get coordinate space axis data.
if isinstance(coordinate_space, CoordinateSpace):
axis_names = tuple(axis.name for axis in coordinate_space)
axis_units = tuple(axis.unit for axis in coordinate_space)
else:
axis_names = tuple(coordinate_space)
axis_units = tuple(len(axis_names) * [None])

for column_name in axis_names:
# Check axis column type is valid and all axis columns have the same type.
if axis_dtype is None:
try:
Expand All @@ -152,7 +157,7 @@ def create(
) from ke
if column_dtype != axis_dtype:
raise ValueError("All spatial axes must have the same datatype.")
index_column_names = coordinate_space.axis_names + (SOMA_JOINID,)
index_column_names = axis_names + (SOMA_JOINID,)

context = _validate_soma_tiledb_context(context)
schema = _canonicalize_schema(schema, index_column_names)
Expand Down Expand Up @@ -251,22 +256,17 @@ def create(
uri,
schema=schema,
index_column_info=index_column_info,
axis_names=axis_names,
axis_units=axis_units,
ctx=context.native_context,
platform_config=plt_cfg,
timestamp=(0, timestamp_ms),
)
except SOMAError as e:
raise map_exception_for_create(e, uri) from None

handle = cls._wrapper_type.open(uri, "w", context, tiledb_timestamp)
handle.metadata[SOMA_SPATIAL_VERSION_METADATA_KEY] = (
SOMA_SPATIAL_ENCODING_VERSION
)
handle.meta[SOMA_COORDINATE_SPACE_METADATA_KEY] = coordinate_space_to_json(
coordinate_space
)
return cls(
handle,
cls._wrapper_type.open(uri, "w", context, tiledb_timestamp),
_dont_call_this_use_create_or_open_instead="tiledbsoma-internal-code",
)

Expand Down
7 changes: 7 additions & 0 deletions apis/python/src/tiledbsoma/soma_point_cloud_dataframe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ void load_soma_point_cloud_dataframe(py::module& m) {
[](std::string_view uri,
py::object py_schema,
py::object index_column_info,
std::vector<std::string> axis_names,
std::vector<std::optional<std::string>> axis_units,
std::shared_ptr<SOMAContext> context,
PlatformConfig platform_config,
std::optional<std::pair<uint64_t, uint64_t>> timestamp) {
Expand Down Expand Up @@ -80,13 +82,16 @@ void load_soma_point_cloud_dataframe(py::module& m) {
index_column_info.attr("_export_to_c")(
index_column_array_ptr, index_column_schema_ptr);

SOMACoordinateSpace coord_space{axis_names, axis_units};

try {
SOMAPointCloudDataFrame::create(
uri,
std::make_unique<ArrowSchema>(schema),
ArrowTable(
std::make_unique<ArrowArray>(index_column_array),
std::make_unique<ArrowSchema>(index_column_schema)),
coord_space,
context,
platform_config,
timestamp);
Expand All @@ -101,6 +106,8 @@ void load_soma_point_cloud_dataframe(py::module& m) {
py::kw_only(),
"schema"_a,
"index_column_info"_a,
"axis_names"_a,
"axis_units"_a,
"ctx"_a,
"platform_config"_a,
"timestamp"_a = py::none())
Expand Down

0 comments on commit 94ceb98

Please sign in to comment.