From 22a0e76e652fa1d1bc2b8fc276fe343d6b81f6c1 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Sun, 28 Jul 2024 17:41:32 +0800 Subject: [PATCH] copy graph to a continuous buffer --- include/abstract_graph_store.h | 3 +- include/in_mem_graph_store.h | 2 +- include/in_mem_static_graph_store.h | 77 +++++++++++ include/index_config.h | 5 +- include/neighbor_list.h | 48 +++++++ src/in_mem_graph_store.cpp | 6 +- src/in_mem_static_graph_store.cpp | 193 ++++++++++++++++++++++++++++ src/index.cpp | 41 +++--- src/index_factory.cpp | 3 + src/neighbor_list.cpp | 72 +++++++++++ 10 files changed, 428 insertions(+), 22 deletions(-) create mode 100644 include/in_mem_static_graph_store.h create mode 100644 include/neighbor_list.h create mode 100644 src/in_mem_static_graph_store.cpp create mode 100644 src/neighbor_list.cpp diff --git a/include/abstract_graph_store.h b/include/abstract_graph_store.h index 4d6906ca4..115d9ed1c 100644 --- a/include/abstract_graph_store.h +++ b/include/abstract_graph_store.h @@ -6,6 +6,7 @@ #include #include #include "types.h" +#include "neighbor_list.h" namespace diskann { @@ -27,7 +28,7 @@ class AbstractGraphStore const uint32_t start) = 0; // not synchronised, user should use lock when necvessary. - virtual const std::vector &get_neighbours(const location_t i) const = 0; + virtual const NeighborList get_neighbours(const location_t i) const = 0; virtual void add_neighbour(const location_t i, location_t neighbour_id) = 0; virtual void clear_neighbours(const location_t i) = 0; virtual void swap_neighbours(const location_t a, location_t b) = 0; diff --git a/include/in_mem_graph_store.h b/include/in_mem_graph_store.h index d0206a7d6..fe1efd1b3 100644 --- a/include/in_mem_graph_store.h +++ b/include/in_mem_graph_store.h @@ -19,7 +19,7 @@ class InMemGraphStore : public AbstractGraphStore virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_frozen_points, const uint32_t start) override; - virtual const std::vector &get_neighbours(const location_t i) const override; + virtual const NeighborList get_neighbours(const location_t i) const override; virtual void add_neighbour(const location_t i, location_t neighbour_id) override; virtual void clear_neighbours(const location_t i) override; virtual void swap_neighbours(const location_t a, location_t b) override; diff --git a/include/in_mem_static_graph_store.h b/include/in_mem_static_graph_store.h new file mode 100644 index 000000000..fa5b8c2cd --- /dev/null +++ b/include/in_mem_static_graph_store.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "abstract_graph_store.h" + +namespace diskann +{ + +class InMemStaticGraphStore : public AbstractGraphStore +{ +public: + InMemStaticGraphStore(const size_t total_pts, const size_t reserve_graph_degree); + + // returns tuple of + virtual std::tuple load(const std::string& index_path_prefix, + const size_t num_points) override; + + virtual int store(const std::string& /*index_path_prefix*/, const size_t /*num_points*/, const size_t /*num_frozen_points*/, + const uint32_t /*start*/) override + { + throw std::runtime_error("static memory graph only use for searching"); + } + + virtual const NeighborList get_neighbours(const location_t i) const override; + + virtual void add_neighbour(const location_t /*i*/, location_t /*neighbour_id*/) override + { + throw std::runtime_error("static memory graph only use for searching"); + } + + virtual void clear_neighbours(const location_t /*i*/) override + { + throw std::runtime_error("static memory graph only use for searching"); + } + + virtual void swap_neighbours(const location_t /*a*/, location_t /*b*/) override + { + throw std::runtime_error("static memory graph only use for searching"); + } + + virtual void set_neighbours(const location_t /*i*/, std::vector& /*neighbors*/) override + { + throw std::runtime_error("static memory graph only use for searching"); + } + + virtual size_t resize_graph(const size_t /*new_size*/) override + { + throw std::runtime_error("static memory graph only use for searching"); + } + + virtual void clear_graph() override + { + throw std::runtime_error("static memory graph only use for searching"); + } + + virtual size_t get_max_range_of_graph() override; + virtual uint32_t get_max_observed_degree() override; + +protected: + virtual std::tuple load_impl(const std::string& filename, size_t expected_num_points); +#ifdef EXEC_ENV_OLS + virtual std::tuple load_impl(AlignedFileReader& reader, size_t expected_num_points); +#endif + + +private: + size_t _max_range_of_graph = 0; + uint32_t _max_observed_degree = 0; + + std::vector _node_index; + std::vector _graph; +// std::vector> _graph; +}; + +} // namespace diskann diff --git a/include/index_config.h b/include/index_config.h index 452498b01..549f8265a 100644 --- a/include/index_config.h +++ b/include/index_config.h @@ -10,7 +10,8 @@ enum class DataStoreStrategy enum class GraphStoreStrategy { - MEMORY + MEMORY, + STATICMEMORY }; struct IndexConfig @@ -228,7 +229,7 @@ class IndexConfigBuilder private: DataStoreStrategy _data_strategy; - GraphStoreStrategy _graph_strategy; + GraphStoreStrategy _graph_strategy = GraphStoreStrategy::MEMORY; Metric _metric; size_t _dimension; diff --git a/include/neighbor_list.h b/include/neighbor_list.h new file mode 100644 index 000000000..0f404e8ac --- /dev/null +++ b/include/neighbor_list.h @@ -0,0 +1,48 @@ +#pragma once +#include +#include +#include "types.h" + +namespace diskann +{ + +class NeighborList +{ +public: + NeighborList(const location_t* data, size_t size); + + const location_t* data() const; + size_t size() const; + bool empty() const; + + // compatable with current interface, need deprecate later + void convert_to_vector(std::vector& vector_copy) const; + + class Iterator + { + public: + Iterator(const location_t* index); + + const location_t& operator*() const; + + const Iterator& operator++(); + + bool operator==(const Iterator& other) const; + + bool operator!=(const Iterator& other) const; + + private: + const location_t* _index; + }; + + // Iterator begin() = 0; + Iterator begin() const; + // Iterator end() = 0; + Iterator end() const; + +private: + const location_t* _data; + size_t _size; +}; + +} diff --git a/src/in_mem_graph_store.cpp b/src/in_mem_graph_store.cpp index c12b2514e..6ba41b148 100644 --- a/src/in_mem_graph_store.cpp +++ b/src/in_mem_graph_store.cpp @@ -6,6 +6,7 @@ namespace diskann { + InMemGraphStore::InMemGraphStore(const size_t total_pts, const size_t reserve_graph_degree) : AbstractGraphStore(total_pts, reserve_graph_degree) { @@ -26,9 +27,10 @@ int InMemGraphStore::store(const std::string &index_path_prefix, const size_t nu { return save_graph(index_path_prefix, num_points, num_frozen_points, start); } -const std::vector &InMemGraphStore::get_neighbours(const location_t i) const +const NeighborList InMemGraphStore::get_neighbours(const location_t i) const { - return _graph.at(i); + auto& neighbor_vector = _graph.at(i); + return NeighborList(neighbor_vector.data(), neighbor_vector.size()); } void InMemGraphStore::add_neighbour(const location_t i, location_t neighbour_id) diff --git a/src/in_mem_static_graph_store.cpp b/src/in_mem_static_graph_store.cpp new file mode 100644 index 000000000..e54f1a491 --- /dev/null +++ b/src/in_mem_static_graph_store.cpp @@ -0,0 +1,193 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "in_mem_static_graph_store.h" +#include "utils.h" + +namespace diskann +{ + +InMemStaticGraphStore::InMemStaticGraphStore(const size_t total_pts, const size_t reserve_graph_degree) + : AbstractGraphStore(total_pts, reserve_graph_degree) +{ +} + +std::tuple InMemStaticGraphStore::load(const std::string& index_path_prefix, + const size_t num_points) +{ + return load_impl(index_path_prefix, num_points); +} + +const NeighborList InMemStaticGraphStore::get_neighbours(const location_t i) const +{ + assert(i < _node_index.size() - 1); + size_t start_index = _node_index[i]; + size_t end_index = _node_index[i + 1]; + size_t size = end_index - start_index; + const location_t* neighbor_start = _graph.data() + start_index; + return NeighborList(neighbor_start, size); +} + +#ifdef EXEC_ENV_OLS +std::tuple InMemGraphStore::load_impl(AlignedFileReader& reader, size_t expected_num_points) +{ + size_t expected_file_size; + size_t file_frozen_pts; + uint32_t start; + + auto max_points = get_max_points(); + int header_size = 2 * sizeof(size_t) + 2 * sizeof(uint32_t); + std::unique_ptr header = std::make_unique(header_size); + read_array(reader, header.get(), header_size); + + expected_file_size = *((size_t*)header.get()); + _max_observed_degree = *((uint32_t*)(header.get() + sizeof(size_t))); + start = *((uint32_t*)(header.get() + sizeof(size_t) + sizeof(uint32_t))); + file_frozen_pts = *((size_t*)(header.get() + sizeof(size_t) + sizeof(uint32_t) + sizeof(uint32_t))); + + diskann::cout << "From graph header, expected_file_size: " << expected_file_size + << ", _max_observed_degree: " << _max_observed_degree << ", _start: " << start + << ", file_frozen_pts: " << file_frozen_pts << std::endl; + + diskann::cout << "Loading vamana graph from reader..." << std::flush; + + // If user provides more points than max_points + // resize the _graph to the larger size. + if (get_total_points() < expected_num_points) + { + diskann::cout << "resizing graph to " << expected_num_points << std::endl; + this->resize_graph(expected_num_points); + } + + uint32_t nodes_read = 0; + size_t cc = 0; + size_t graph_offset = header_size; + while (nodes_read < expected_num_points) + { + uint32_t k; + read_value(reader, k, graph_offset); + graph_offset += sizeof(uint32_t); + std::vector tmp(k); + tmp.reserve(k); + read_array(reader, tmp.data(), k, graph_offset); + graph_offset += k * sizeof(uint32_t); + cc += k; + _graph[nodes_read].swap(tmp); + nodes_read++; + if (nodes_read % 1000000 == 0) + { + diskann::cout << "." << std::flush; + } + if (k > _max_range_of_graph) + { + _max_range_of_graph = k; + } + } + + diskann::cout << "done. Index has " << nodes_read << " nodes and " << cc << " out-edges, _start is set to " << start + << std::endl; + return std::make_tuple(nodes_read, start, file_frozen_pts); +} +#endif + +std::tuple InMemStaticGraphStore::load_impl(const std::string& filename, + size_t expected_num_points) +{ + size_t expected_file_size; + size_t file_frozen_pts; + uint32_t start; + size_t file_offset = 0; // will need this for single file format support + + std::ifstream in; + in.exceptions(std::ios::badbit | std::ios::failbit); + in.open(filename, std::ios::binary); + in.seekg(file_offset, in.beg); + in.read((char*)&expected_file_size, sizeof(size_t)); + in.read((char*)&_max_observed_degree, sizeof(uint32_t)); + in.read((char*)&start, sizeof(uint32_t)); + in.read((char*)&file_frozen_pts, sizeof(size_t)); + size_t vamana_metadata_size = sizeof(size_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(size_t); + + diskann::cout << "From graph header, expected_file_size: " << expected_file_size + << ", _max_observed_degree: " << _max_observed_degree << ", _start: " << start + << ", file_frozen_pts: " << file_frozen_pts << std::endl; + + diskann::cout << "Loading vamana graph " << filename << "..." << std::flush; + + std::vector buffer; + size_t graph_size = expected_file_size - vamana_metadata_size; + buffer.resize(graph_size); + + in.read(buffer.data(), graph_size); + in.close(); + + size_t cc = 0; + uint32_t nodes_read = 0; + + // first round to calculate memory size needed. + size_t cur_index = 0; + while (cur_index + sizeof(uint32_t) < graph_size) + { + uint32_t k; + memcpy((char*)&k, buffer.data() + cur_index, sizeof(uint32_t)); + cur_index += sizeof(uint32_t); + size_t neighbor_size = k * sizeof(uint32_t); + if (cur_index + neighbor_size > graph_size) + { + break; + } + cur_index += neighbor_size; + + cc += k; + ++nodes_read; + } + + // resize graph + _node_index.resize(nodes_read + 1); + _node_index[0] = 0; + _graph.resize(cc); + + // second round to insert graph data + nodes_read = 0; + cur_index = 0; + while (cur_index + sizeof(uint32_t) < graph_size) + { + uint32_t k; + memcpy((char*)&k, buffer.data() + cur_index, sizeof(uint32_t)); + cur_index += sizeof(uint32_t); + size_t neighbor_size = k * sizeof(uint32_t); + if (cur_index + neighbor_size > graph_size) + { + break; + } + + size_t offset = _node_index[nodes_read]; + std::uint32_t* neighborPtr = &_graph[offset]; + + memcpy(neighborPtr, buffer.data() + cur_index, neighbor_size); + _node_index[nodes_read + 1] = offset + k; + + cur_index += neighbor_size; + + if (nodes_read % 10000000 == 0) + std::cout << "." << std::flush; + + ++nodes_read; + } + + diskann::cout << "done. Index has " << nodes_read << " nodes and " << cc << " out-edges, _start is set to " << start + << std::endl; + return std::make_tuple(nodes_read, start, file_frozen_pts); +} + +size_t InMemStaticGraphStore::get_max_range_of_graph() +{ + return _max_range_of_graph; +} + +uint32_t InMemStaticGraphStore::get_max_observed_degree() +{ + return _max_observed_degree; +} + +} // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index 3d4ae2619..0b01afa20 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -948,7 +948,8 @@ std::pair Index::iterate_to_fixed_point( if (_dynamic_index) { LockGuard guard(_locks[n]); - for (auto id : _graph_store->get_neighbours(n)) + auto neighbour_list = _graph_store->get_neighbours(n); + for (auto id : neighbour_list) { assert(id < _max_points + _num_frozen_pts); @@ -976,7 +977,7 @@ std::pair Index::iterate_to_fixed_point( { tmp_neighbor_list.clear(); _locks[n].lock_shared(); - auto& nbrs = _graph_store->get_neighbours(n); + auto nbrs = _graph_store->get_neighbours(n); tmp_neighbor_list.resize(nbrs.size()); memcpy(tmp_neighbor_list.data(), nbrs.data(), nbrs.size() * sizeof(location_t)); _locks[n].unlock_shared(); @@ -1272,7 +1273,7 @@ void Index::inter_insert(uint32_t n, std::vector &pru copy_of_neighbors.clear(); // LockGuard guard(_locks[des]); _locks[des].lock_shared(); - auto &des_pool = _graph_store->get_neighbours(des); + auto des_pool = _graph_store->get_neighbours(des); copy_of_neighbors.reserve(des_pool.size() + 1); for (auto& des_n : des_pool) { @@ -1416,7 +1417,8 @@ template void Index dummy_pool(0); std::vector new_out_neighbors; - for (auto cur_nbr : _graph_store->get_neighbours((location_t)node)) + auto neighbour_list = _graph_store->get_neighbours((location_t)node); + for (auto cur_nbr : neighbour_list) { if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node) { @@ -1461,7 +1463,8 @@ void Index::prune_all_neighbors(const uint32_t max_degree, cons ScratchStoreManager> manager(_query_scratch); auto scratch = manager.scratch_space(); - for (auto cur_nbr : _graph_store->get_neighbours((location_t)node)) + auto neighbour_list = _graph_store->get_neighbours((location_t)node); + for (auto cur_nbr : neighbour_list) { if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node) { @@ -1484,7 +1487,7 @@ void Index::prune_all_neighbors(const uint32_t max_degree, cons { if (i < _nd || i >= _max_points) { - const std::vector &pool = _graph_store->get_neighbours((location_t)i); + const auto pool = _graph_store->get_neighbours((location_t)i); max = (std::max)(max, pool.size()); min = (std::min)(min, pool.size()); total += pool.size(); @@ -1614,7 +1617,7 @@ void Index::build_with_data_populated(const std::vector & size_t max = 0, min = SIZE_MAX, total = 0, cnt = 0; for (size_t i = 0; i < _nd; i++) { - auto &pool = _graph_store->get_neighbours((location_t)i); + auto pool = _graph_store->get_neighbours((location_t)i); max = std::max(max, pool.size()); min = std::min(min, pool.size()); total += pool.size(); @@ -2536,7 +2539,8 @@ inline void Index::process_delete(const tsl::robin_set adj_list_lock; if (_conc_consolidate) adj_list_lock = std::unique_lock(_locks[loc]); - adj_list = _graph_store->get_neighbours((location_t)loc); + auto adj_neighbor_list = _graph_store->get_neighbours((location_t)loc); + adj_neighbor_list.convert_to_vector(adj_list); } bool modify = false; @@ -2553,7 +2557,8 @@ inline void Index::process_delete(const tsl::robin_set ngh_lock; if (_conc_consolidate) ngh_lock = std::unique_lock(_locks[ngh]); - for (auto j : _graph_store->get_neighbours((location_t)ngh)) + auto neighbour_list = _graph_store->get_neighbours((location_t)ngh); + for (auto j : neighbour_list) if (j != loc && old_delete_set.find(j) == old_delete_set.end()) expanded_nodes_set.insert(j); } @@ -2768,8 +2773,9 @@ template void Index= _max_points && old < _max_points + _num_frozen_pts)) { - new_adj_list.reserve(_graph_store->get_neighbours((location_t)old).size()); - for (auto ngh_iter : _graph_store->get_neighbours((location_t)old)) + auto neighbour_list = _graph_store->get_neighbours((location_t)old); + new_adj_list.reserve(neighbour_list.size()); + for (auto ngh_iter : neighbour_list) { if (empty_locations.find(ngh_iter) != empty_locations.end()) { @@ -2918,8 +2924,9 @@ void Index::reposition_points(uint32_t old_location_start, uint std::vector updated_neighbours_location; for (uint32_t i = 0; i < _max_points + _num_frozen_pts; i++) { - auto &i_neighbours = _graph_store->get_neighbours((location_t)i); - std::vector i_neighbours_copy(i_neighbours.begin(), i_neighbours.end()); + auto i_neighbours = _graph_store->get_neighbours((location_t)i); + std::vector i_neighbours_copy; + i_neighbours.convert_to_vector(i_neighbours_copy); for (auto &loc : i_neighbours_copy) { if (loc >= old_location_start && loc < old_location_start + num_locations) @@ -3389,7 +3396,8 @@ template void Indexget_neighbours((location_t)node)) + auto neighbour_list = _graph_store->get_neighbours((location_t)node); + for (auto nghbr : neighbour_list) { if (!visited.test(nghbr)) { @@ -3428,9 +3436,10 @@ template void Indexget_neighbours(i).size(); + auto neighbour_list = _graph_store->get_neighbours(i); + uint32_t k = (uint32_t)neighbour_list.size(); std::memcpy(cur_node_offset, &k, sizeof(uint32_t)); - std::memcpy(cur_node_offset + sizeof(uint32_t), _graph_store->get_neighbours(i).data(), k * sizeof(uint32_t)); + std::memcpy(cur_node_offset + sizeof(uint32_t), neighbour_list.data(), k * sizeof(uint32_t)); // std::vector().swap(_graph_store->get_neighbours(i)); _graph_store->clear_neighbours(i); } diff --git a/src/index_factory.cpp b/src/index_factory.cpp index 616b2a6b6..5c7dbee6b 100644 --- a/src/index_factory.cpp +++ b/src/index_factory.cpp @@ -1,6 +1,7 @@ #include "index_factory.h" #include "tag_uint128.h" #include "pq_l2_distance.h" +#include "in_mem_static_graph_store.h" namespace diskann { @@ -89,6 +90,8 @@ std::unique_ptr IndexFactory::construct_graphstore(const Gra { case GraphStoreStrategy::MEMORY: return std::make_unique(size, reserve_graph_degree); + case GraphStoreStrategy::STATICMEMORY: + return std::make_unique(size, reserve_graph_degree); default: throw ANNException("Error : Current GraphStoreStratagy is not supported.", -1); } diff --git a/src/neighbor_list.cpp b/src/neighbor_list.cpp new file mode 100644 index 000000000..cdac9c4d4 --- /dev/null +++ b/src/neighbor_list.cpp @@ -0,0 +1,72 @@ +#include "neighbor_list.h" + +namespace diskann +{ + +NeighborList::NeighborList(const location_t* data, size_t size) + : _data(data) + , _size(size) +{ +} + +const location_t* NeighborList::data() const +{ + return _data; +} + +size_t NeighborList::size() const +{ + return _size; +} + +bool NeighborList::empty() const +{ + return _size == 0; +} + +void NeighborList::convert_to_vector(std::vector& vector_copy) const +{ + vector_copy.reserve(_size); + for (size_t i = 0; i < _size; i++) + { + vector_copy.push_back(_data[i]); + } +} + +NeighborList::Iterator::Iterator(const location_t* index) + : _index(index) +{ +} + +const location_t& NeighborList::Iterator::operator*() const +{ + return *_index; +} + +const NeighborList::Iterator& NeighborList::Iterator::operator++() +{ + _index++; + return *this; +} + +bool NeighborList::Iterator::operator==(const NeighborList::Iterator& other) const +{ + return _index == other._index; +} + +bool NeighborList::Iterator::operator!=(const NeighborList::Iterator& other) const +{ + return !(*this == other); +} + +NeighborList::Iterator NeighborList::begin() const +{ + return Iterator(_data); +} + +NeighborList::Iterator NeighborList::end() const +{ + return Iterator(_data + _size); +} + +} \ No newline at end of file