From a6c54a40e1a3a528f5a3ca405f2c4da6830ab7f6 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 10 Mar 2023 00:17:30 +0100 Subject: [PATCH] Add serialization overload --- .../neighbors/detail/ivf_flat_serialize.cuh | 92 ++++++----- cpp/include/raft/neighbors/ivf_flat.cuh | 8 +- .../raft/neighbors/ivf_flat_serialize.cuh | 156 ++++++++++++++++++ cpp/test/neighbors/ann_ivf_flat.cuh | 1 + 4 files changed, 212 insertions(+), 45 deletions(-) create mode 100644 cpp/include/raft/neighbors/ivf_flat_serialize.cuh diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh index 1195197fa3..18e8c2862e 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -55,31 +55,26 @@ template struct check_index_layout), 368>; * */ template -void serialize(raft::device_resources const& handle, - const std::string& filename, - const index& index_) +void serialize(raft::device_resources const& handle, std::ostream& os, const index& index_) { - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open %s", filename.c_str()); } - RAFT_LOG_DEBUG( "Saving IVF-Flat index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); - serialize_scalar(handle, of, serialization_version); - serialize_scalar(handle, of, index_.size()); - serialize_scalar(handle, of, index_.dim()); - serialize_scalar(handle, of, index_.n_lists()); - serialize_scalar(handle, of, index_.metric()); - serialize_scalar(handle, of, index_.adaptive_centers()); - serialize_scalar(handle, of, index_.conservative_memory_allocation()); - serialize_mdspan(handle, of, index_.centers()); + serialize_scalar(handle, os, serialization_version); + serialize_scalar(handle, os, index_.size()); + serialize_scalar(handle, os, index_.dim()); + serialize_scalar(handle, os, index_.n_lists()); + serialize_scalar(handle, os, index_.metric()); + serialize_scalar(handle, os, index_.adaptive_centers()); + serialize_scalar(handle, os, index_.conservative_memory_allocation()); + serialize_mdspan(handle, os, index_.centers()); if (index_.center_norms()) { bool has_norms = true; - serialize_scalar(handle, of, has_norms); - serialize_mdspan(handle, of, *index_.center_norms()); + serialize_scalar(handle, os, has_norms); + serialize_mdspan(handle, os, *index_.center_norms()); } else { bool has_norms = false; - serialize_scalar(handle, of, has_norms); + serialize_scalar(handle, os, has_norms); } auto sizes_host = make_host_vector(index_.list_sizes().extent(0)); copy(sizes_host.data_handle(), @@ -87,13 +82,25 @@ void serialize(raft::device_resources const& handle, sizes_host.size(), handle.get_stream()); handle.sync_stream(); - serialize_mdspan(handle, of, sizes_host.view()); + serialize_mdspan(handle, os, sizes_host.view()); list_spec list_store_spec{index_.dim(), true}; for (uint32_t label = 0; label < index_.n_lists(); label++) { - ivf::serialize_list(handle, of, index_.lists()[label], list_store_spec, sizes_host(label)); + ivf::serialize_list(handle, os, index_.lists()[label], list_store_spec, sizes_host(label)); } handle.sync_stream(); +} + +template +void serialize(raft::device_resources const& handle, + const std::string& filename, + const index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + detail::serialize(handle, of, index_); + of.close(); if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } } @@ -108,49 +115,58 @@ void serialize(raft::device_resources const& handle, * */ template -auto deserialize(raft::device_resources const& handle, const std::string& filename) - -> index +auto deserialize(raft::device_resources const& handle, std::istream& is) -> index { - std::ifstream infile(filename, std::ios::in | std::ios::binary); - - if (!infile) { RAFT_FAIL("Cannot open %s", filename.c_str()); } - - auto ver = deserialize_scalar(handle, infile); + auto ver = deserialize_scalar(handle, is); if (ver != serialization_version) { RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); } - auto n_rows = deserialize_scalar(handle, infile); - auto dim = deserialize_scalar(handle, infile); - auto n_lists = deserialize_scalar(handle, infile); - auto metric = deserialize_scalar(handle, infile); - bool adaptive_centers = deserialize_scalar(handle, infile); - bool cma = deserialize_scalar(handle, infile); + auto n_rows = deserialize_scalar(handle, is); + auto dim = deserialize_scalar(handle, is); + auto n_lists = deserialize_scalar(handle, is); + auto metric = deserialize_scalar(handle, is); + bool adaptive_centers = deserialize_scalar(handle, is); + bool cma = deserialize_scalar(handle, is); index index_ = index(handle, metric, n_lists, adaptive_centers, cma, dim); - deserialize_mdspan(handle, infile, index_.centers()); - bool has_norms = deserialize_scalar(handle, infile); + deserialize_mdspan(handle, is, index_.centers()); + bool has_norms = deserialize_scalar(handle, is); if (has_norms) { index_.allocate_center_norms(handle); if (!index_.center_norms()) { RAFT_FAIL("Error inconsistent center norms"); } else { auto center_norms = index_.center_norms().value(); - deserialize_mdspan(handle, infile, center_norms); + deserialize_mdspan(handle, is, center_norms); } } - deserialize_mdspan(handle, infile, index_.list_sizes()); + deserialize_mdspan(handle, is, index_.list_sizes()); list_spec list_device_spec{index_.dim(), cma}; list_spec list_store_spec{index_.dim(), true}; for (uint32_t label = 0; label < index_.n_lists(); label++) { - ivf::deserialize_list(handle, infile, index_.lists()[label], list_store_spec, list_device_spec); + ivf::deserialize_list(handle, is, index_.lists()[label], list_store_spec, list_device_spec); } handle.sync_stream(); - infile.close(); index_.recompute_internal_state(handle); return index_; } + +template +auto deserialize(raft::device_resources const& handle, const std::string& filename) + -> index +{ + std::ifstream is(filename, std::ios::in | std::ios::binary); + + if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + auto index = detail::deserialize(handle, is); + + is.close(); + + return index; +} } // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index 4a5e2a3641..f42bfe66c7 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -18,7 +18,7 @@ #include #include -#include +#include #include #include @@ -100,7 +100,6 @@ auto build(raft::device_resources const& handle, * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle * @param[in] params configure the index building @@ -191,8 +190,6 @@ auto extend(raft::device_resources const& handle, * * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle * @param[in] orig_index original index @@ -278,8 +275,6 @@ void extend(raft::device_resources const& handle, * * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle * @param[inout] index @@ -389,7 +384,6 @@ void search(raft::device_resources const& handle, * @tparam value_t data element type * @tparam idx_t type of the indices * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle * @param[in] index ivf-flat constructed index diff --git a/cpp/include/raft/neighbors/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/ivf_flat_serialize.cuh new file mode 100644 index 0000000000..d2ec9a39bd --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat_serialize.cuh @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/ivf_flat_serialize.cuh" + +namespace raft::neighbors::ivf_flat { + +/** + * \defgroup ivf_flat_serialize IVF-Flat Serialize + * @{ + */ + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = ivf_flat::build(...);` + * raft::serialize(handle, os, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index IVF-Flat index + * + * @return raft::neighbors::ivf_flat::index + */ +template +void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) +{ + detail::serialize(handle, os, index); +} + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = ivf_flat::build(...);` + * raft::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-Flat index + * + * @return raft::neighbors::ivf_flat::index + */ +template +void serialize(raft::device_resources const& handle, + const std::string& filename, + const index& index) +{ + detail::serialize(handle, filename, index); +} + +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using T = float; // data element type + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, is); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] is input stream + * + * @return raft::neighbors::ivf_flat::index + */ +template +index deserialize(raft::device_resources const& handle, std::istream& is) +{ + return detail::deserialize(handle, is); +} + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, filename); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * + * @return raft::neighbors::ivf_flat::index + */ +template +index deserialize(raft::device_resources const& handle, const std::string& filename) +{ + return detail::deserialize(handle, filename); +} + +/**@}*/ + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 34531ddb89..bc4de89319 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -230,6 +230,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { list_indices.data(), index_2.inds_ptrs().data_handle(), index_2.n_lists(), stream_); handle_.sync_stream(stream_); for (uint32_t l = 0; l < index_2.n_lists(); l++) { + if (list_sizes[l] == 0) continue; rmm::device_uvector cluster_data(list_sizes[l] * ps.dim, stream_); raft::spatial::knn::detail::utils::copy_selected((IdxT)list_sizes[l], (IdxT)ps.dim,