Skip to content

Commit

Permalink
Add serialization overload
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Mar 9, 2023
1 parent f0a7031 commit a6c54a4
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 45 deletions.
92 changes: 54 additions & 38 deletions cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,45 +55,52 @@ template struct check_index_layout<sizeof(index<double, std::uint64_t>), 368>;
*
*/
template <typename T, typename IdxT>
void serialize(raft::device_resources const& handle,
const std::string& filename,
const index<T, IdxT>& index_)
void serialize(raft::device_resources const& handle, std::ostream& os, const index<T, IdxT>& index_)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open %s", filename.c_str()); }

RAFT_LOG_DEBUG(
"Saving IVF-Flat index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

serialize_scalar(handle, of, serialization_version);
serialize_scalar(handle, of, index_.size());
serialize_scalar(handle, of, index_.dim());
serialize_scalar(handle, of, index_.n_lists());
serialize_scalar(handle, of, index_.metric());
serialize_scalar(handle, of, index_.adaptive_centers());
serialize_scalar(handle, of, index_.conservative_memory_allocation());
serialize_mdspan(handle, of, index_.centers());
serialize_scalar(handle, os, serialization_version);
serialize_scalar(handle, os, index_.size());
serialize_scalar(handle, os, index_.dim());
serialize_scalar(handle, os, index_.n_lists());
serialize_scalar(handle, os, index_.metric());
serialize_scalar(handle, os, index_.adaptive_centers());
serialize_scalar(handle, os, index_.conservative_memory_allocation());
serialize_mdspan(handle, os, index_.centers());
if (index_.center_norms()) {
bool has_norms = true;
serialize_scalar(handle, of, has_norms);
serialize_mdspan(handle, of, *index_.center_norms());
serialize_scalar(handle, os, has_norms);
serialize_mdspan(handle, os, *index_.center_norms());
} else {
bool has_norms = false;
serialize_scalar(handle, of, has_norms);
serialize_scalar(handle, os, has_norms);
}
auto sizes_host = make_host_vector<uint32_t, uint32_t>(index_.list_sizes().extent(0));
copy(sizes_host.data_handle(),
index_.list_sizes().data_handle(),
sizes_host.size(),
handle.get_stream());
handle.sync_stream();
serialize_mdspan(handle, of, sizes_host.view());
serialize_mdspan(handle, os, sizes_host.view());

list_spec<uint32_t, T, IdxT> list_store_spec{index_.dim(), true};
for (uint32_t label = 0; label < index_.n_lists(); label++) {
ivf::serialize_list(handle, of, index_.lists()[label], list_store_spec, sizes_host(label));
ivf::serialize_list(handle, os, index_.lists()[label], list_store_spec, sizes_host(label));
}
handle.sync_stream();
}

template <typename T, typename IdxT>
void serialize(raft::device_resources const& handle,
const std::string& filename,
const index<T, IdxT>& index_)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

detail::serialize(handle, of, index_);

of.close();
if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); }
}
Expand All @@ -108,49 +115,58 @@ void serialize(raft::device_resources const& handle,
*
*/
template <typename T, typename IdxT>
auto deserialize(raft::device_resources const& handle, const std::string& filename)
-> index<T, IdxT>
auto deserialize(raft::device_resources const& handle, std::istream& is) -> index<T, IdxT>
{
std::ifstream infile(filename, std::ios::in | std::ios::binary);

if (!infile) { RAFT_FAIL("Cannot open %s", filename.c_str()); }

auto ver = deserialize_scalar<int>(handle, infile);
auto ver = deserialize_scalar<int>(handle, is);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
}
auto n_rows = deserialize_scalar<IdxT>(handle, infile);
auto dim = deserialize_scalar<std::uint32_t>(handle, infile);
auto n_lists = deserialize_scalar<std::uint32_t>(handle, infile);
auto metric = deserialize_scalar<raft::distance::DistanceType>(handle, infile);
bool adaptive_centers = deserialize_scalar<bool>(handle, infile);
bool cma = deserialize_scalar<bool>(handle, infile);
auto n_rows = deserialize_scalar<IdxT>(handle, is);
auto dim = deserialize_scalar<std::uint32_t>(handle, is);
auto n_lists = deserialize_scalar<std::uint32_t>(handle, is);
auto metric = deserialize_scalar<raft::distance::DistanceType>(handle, is);
bool adaptive_centers = deserialize_scalar<bool>(handle, is);
bool cma = deserialize_scalar<bool>(handle, is);

index<T, IdxT> index_ = index<T, IdxT>(handle, metric, n_lists, adaptive_centers, cma, dim);

deserialize_mdspan(handle, infile, index_.centers());
bool has_norms = deserialize_scalar<bool>(handle, infile);
deserialize_mdspan(handle, is, index_.centers());
bool has_norms = deserialize_scalar<bool>(handle, is);
if (has_norms) {
index_.allocate_center_norms(handle);
if (!index_.center_norms()) {
RAFT_FAIL("Error inconsistent center norms");
} else {
auto center_norms = index_.center_norms().value();
deserialize_mdspan(handle, infile, center_norms);
deserialize_mdspan(handle, is, center_norms);
}
}
deserialize_mdspan(handle, infile, index_.list_sizes());
deserialize_mdspan(handle, is, index_.list_sizes());

list_spec<uint32_t, T, IdxT> list_device_spec{index_.dim(), cma};
list_spec<uint32_t, T, IdxT> list_store_spec{index_.dim(), true};
for (uint32_t label = 0; label < index_.n_lists(); label++) {
ivf::deserialize_list(handle, infile, index_.lists()[label], list_store_spec, list_device_spec);
ivf::deserialize_list(handle, is, index_.lists()[label], list_store_spec, list_device_spec);
}
handle.sync_stream();
infile.close();

index_.recompute_internal_state(handle);

return index_;
}

template <typename T, typename IdxT>
auto deserialize(raft::device_resources const& handle, const std::string& filename)
-> index<T, IdxT>
{
std::ifstream is(filename, std::ios::in | std::ios::binary);

if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

auto index = detail::deserialize<T, IdxT>(handle, is);

is.close();

return index;
}
} // namespace raft::neighbors::ivf_flat::detail
8 changes: 1 addition & 7 deletions cpp/include/raft/neighbors/ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <raft/neighbors/detail/ivf_flat_build.cuh>
#include <raft/neighbors/detail/ivf_flat_search.cuh>
#include <raft/neighbors/detail/ivf_flat_serialize.cuh>
#include <raft/neighbors/ivf_flat_serialize.cuh>
#include <raft/neighbors/ivf_flat_types.hpp>

#include <raft/core/device_resources.hpp>
Expand Down Expand Up @@ -100,7 +100,6 @@ auto build(raft::device_resources const& handle,
* @tparam value_t data element type
* @tparam idx_t type of the indices in the source dataset
* @tparam int_t precision / type of integral arguments
* @tparam matrix_idx_t matrix indexing type
*
* @param[in] handle
* @param[in] params configure the index building
Expand Down Expand Up @@ -191,8 +190,6 @@ auto extend(raft::device_resources const& handle,
*
* @tparam value_t data element type
* @tparam idx_t type of the indices in the source dataset
* @tparam int_t precision / type of integral arguments
* @tparam matrix_idx_t matrix indexing type
*
* @param[in] handle
* @param[in] orig_index original index
Expand Down Expand Up @@ -278,8 +275,6 @@ void extend(raft::device_resources const& handle,
*
* @tparam value_t data element type
* @tparam idx_t type of the indices in the source dataset
* @tparam int_t precision / type of integral arguments
* @tparam matrix_idx_t matrix indexing type
*
* @param[in] handle
* @param[inout] index
Expand Down Expand Up @@ -389,7 +384,6 @@ void search(raft::device_resources const& handle,
* @tparam value_t data element type
* @tparam idx_t type of the indices
* @tparam int_t precision / type of integral arguments
* @tparam matrix_idx_t matrix indexing type
*
* @param[in] handle
* @param[in] index ivf-flat constructed index
Expand Down
156 changes: 156 additions & 0 deletions cpp/include/raft/neighbors/ivf_flat_serialize.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "detail/ivf_flat_serialize.cuh"

namespace raft::neighbors::ivf_flat {

/**
* \defgroup ivf_flat_serialize IVF-Flat Serialize
* @{
*/

/**
* Write the index to an output stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = ivf_flat::build(...);`
* raft::serialize(handle, os, index);
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index IVF-Flat index
*
* @return raft::neighbors::ivf_flat::index<T, IdxT>
*/
template <typename T, typename IdxT>
void serialize(raft::device_resources const& handle, std::ostream& os, const index<T, IdxT>& index)
{
detail::serialize(handle, os, index);
}

/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = ivf_flat::build(...);`
* raft::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index IVF-Flat index
*
* @return raft::neighbors::ivf_flat::index<T, IdxT>
*/
template <typename T, typename IdxT>
void serialize(raft::device_resources const& handle,
const std::string& filename,
const index<T, IdxT>& index)
{
detail::serialize(handle, filename, index);
}

/**
* Load index from input stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = float; // data element type
* using IdxT = int; // type of the index
* auto index = raft::deserialize<T, IdxT>(handle, is);
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] handle the raft handle
* @param[in] is input stream
*
* @return raft::neighbors::ivf_flat::index<T, IdxT>
*/
template <typename T, typename IdxT>
index<T, IdxT> deserialize(raft::device_resources const& handle, std::istream& is)
{
return detail::deserialize<T, IdxT>(handle, is);
}

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = float; // data element type
* using IdxT = int; // type of the index
* auto index = raft::deserialize<T, IdxT>(handle, filename);
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
*
* @return raft::neighbors::ivf_flat::index<T, IdxT>
*/
template <typename T, typename IdxT>
index<T, IdxT> deserialize(raft::device_resources const& handle, const std::string& filename)
{
return detail::deserialize<T, IdxT>(handle, filename);
}

/**@}*/

} // namespace raft::neighbors::ivf_flat
1 change: 1 addition & 0 deletions cpp/test/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {
list_indices.data(), index_2.inds_ptrs().data_handle(), index_2.n_lists(), stream_);
handle_.sync_stream(stream_);
for (uint32_t l = 0; l < index_2.n_lists(); l++) {
if (list_sizes[l] == 0) continue;
rmm::device_uvector<float> cluster_data(list_sizes[l] * ps.dim, stream_);
raft::spatial::knn::detail::utils::copy_selected<float>((IdxT)list_sizes[l],
(IdxT)ps.dim,
Expand Down

0 comments on commit a6c54a4

Please sign in to comment.