diff --git a/include/restapi/common.h b/include/restapi/common.h index ec321ec9a..b8339635a 100644 --- a/include/restapi/common.h +++ b/include/restapi/common.h @@ -6,14 +6,12 @@ #include #include -namespace diskann { +namespace diskann +{ // Constants -static const std::string VECTOR_KEY = "query", K_KEY = "k", - INDICES_KEY = "indices", DISTANCES_KEY = "distances", - TAGS_KEY = "tags", QUERY_ID_KEY = "query_id", - ERROR_MESSAGE_KEY = "error", L_KEY = "Ls", - TIME_TAKEN_KEY = "time_taken_in_us", - PARTITION_KEY = "partition", +static const std::string VECTOR_KEY = "query", K_KEY = "k", INDICES_KEY = "indices", DISTANCES_KEY = "distances", + TAGS_KEY = "tags", QUERY_ID_KEY = "query_id", ERROR_MESSAGE_KEY = "error", L_KEY = "Ls", + TIME_TAKEN_KEY = "time_taken_in_us", PARTITION_KEY = "partition", UNKNOWN_ERROR = "unknown_error"; const unsigned int DEFAULT_L = 100; diff --git a/include/restapi/search_wrapper.h b/include/restapi/search_wrapper.h index e7ed1725e..d41b2b7cd 100644 --- a/include/restapi/search_wrapper.h +++ b/include/restapi/search_wrapper.h @@ -10,106 +10,131 @@ #include #include -namespace diskann { -class SearchResult { -public: - SearchResult(unsigned int K, unsigned int elapsed_time_in_ms, - const unsigned *const indices, const float *const distances, - const std::string *const tags = nullptr, - const unsigned *const partitions = nullptr); - - const std::vector &get_indices() const { return _indices; } - const std::vector &get_distances() const { return _distances; } - bool tags_enabled() const { return _tags_enabled; } - const std::vector &get_tags() const { return _tags; } - bool partitions_enabled() const { return _partitions_enabled; } - const std::vector &get_partitions() const { return _partitions; } - unsigned get_time() const { return _search_time_in_ms; } - -private: - unsigned int _K; - unsigned int _search_time_in_ms; - std::vector _indices; - std::vector _distances; - - bool _tags_enabled; - std::vector _tags; - - bool _partitions_enabled; - std::vector _partitions; +namespace diskann +{ +class SearchResult +{ + public: + SearchResult(unsigned int K, unsigned int elapsed_time_in_ms, const unsigned *const indices, + const float *const distances, const std::string *const tags = nullptr, + const unsigned *const partitions = nullptr); + + const std::vector &get_indices() const + { + return _indices; + } + const std::vector &get_distances() const + { + return _distances; + } + bool tags_enabled() const + { + return _tags_enabled; + } + const std::vector &get_tags() const + { + return _tags; + } + bool partitions_enabled() const + { + return _partitions_enabled; + } + const std::vector &get_partitions() const + { + return _partitions; + } + unsigned get_time() const + { + return _search_time_in_ms; + } + + private: + unsigned int _K; + unsigned int _search_time_in_ms; + std::vector _indices; + std::vector _distances; + + bool _tags_enabled; + std::vector _tags; + + bool _partitions_enabled; + std::vector _partitions; }; -class SearchNotImplementedException : public std::logic_error { -private: - std::string _errormsg; - -public: - SearchNotImplementedException(const char *type) - : std::logic_error("Not Implemented") { - _errormsg = "Search with data type "; - _errormsg += std::string(type); - _errormsg += " not implemented : "; - _errormsg += __FUNCTION__; - } - - virtual const char *what() const throw() { return _errormsg.c_str(); } +class SearchNotImplementedException : public std::logic_error +{ + private: + std::string _errormsg; + + public: + SearchNotImplementedException(const char *type) : std::logic_error("Not Implemented") + { + _errormsg = "Search with data type "; + _errormsg += std::string(type); + _errormsg += " not implemented : "; + _errormsg += __FUNCTION__; + } + + virtual const char *what() const throw() + { + return _errormsg.c_str(); + } }; -class BaseSearch { -public: - BaseSearch(const std::string &tagsFile = nullptr); - virtual SearchResult search(const float *query, const unsigned int dimensions, - const unsigned int K, const unsigned int Ls) { - throw SearchNotImplementedException("float"); - } - virtual SearchResult search(const int8_t *query, - const unsigned int dimensions, - const unsigned int K, const unsigned int Ls) { - throw SearchNotImplementedException("int8_t"); - } - - virtual SearchResult search(const uint8_t *query, - const unsigned int dimensions, - const unsigned int K, const unsigned int Ls) { - throw SearchNotImplementedException("uint8_t"); - } - - void lookup_tags(const unsigned K, const unsigned *indices, - std::string *ret_tags); - -protected: - bool _tags_enabled; - std::vector _tags_str; +class BaseSearch +{ + public: + BaseSearch(const std::string &tagsFile = nullptr); + virtual SearchResult search(const float *query, const unsigned int dimensions, const unsigned int K, + const unsigned int Ls) + { + throw SearchNotImplementedException("float"); + } + virtual SearchResult search(const int8_t *query, const unsigned int dimensions, const unsigned int K, + const unsigned int Ls) + { + throw SearchNotImplementedException("int8_t"); + } + + virtual SearchResult search(const uint8_t *query, const unsigned int dimensions, const unsigned int K, + const unsigned int Ls) + { + throw SearchNotImplementedException("uint8_t"); + } + + void lookup_tags(const unsigned K, const unsigned *indices, std::string *ret_tags); + + protected: + bool _tags_enabled; + std::vector _tags_str; }; -template class InMemorySearch : public BaseSearch { -public: - InMemorySearch(const std::string &baseFile, const std::string &indexFile, - const std::string &tagsFile, Metric m, uint32_t num_threads, - uint32_t search_l); - virtual ~InMemorySearch(); +template class InMemorySearch : public BaseSearch +{ + public: + InMemorySearch(const std::string &baseFile, const std::string &indexFile, const std::string &tagsFile, Metric m, + uint32_t num_threads, uint32_t search_l); + virtual ~InMemorySearch(); - SearchResult search(const T *query, const unsigned int dimensions, - const unsigned int K, const unsigned int Ls); + SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls); -private: - unsigned int _dimensions, _numPoints; - std::unique_ptr> _index; + private: + unsigned int _dimensions, _numPoints; + std::unique_ptr> _index; }; -template class PQFlashSearch : public BaseSearch { -public: - PQFlashSearch(const std::string &indexPrefix, - const unsigned num_nodes_to_cache, const unsigned num_threads, - const std::string &tagsFile, Metric m); - virtual ~PQFlashSearch(); +template class PQFlashSearch : public BaseSearch +{ + public: + PQFlashSearch(const std::string &indexPrefix, const unsigned num_nodes_to_cache, const unsigned num_threads, + const std::string &tagsFile, Metric m); + virtual ~PQFlashSearch(); - SearchResult search(const T *query, const unsigned int dimensions, - const unsigned int K, const unsigned int Ls); + SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls); -private: - unsigned int _dimensions, _numPoints; - std::unique_ptr> _index; - std::shared_ptr reader; + private: + unsigned int _dimensions, _numPoints; + std::unique_ptr> _index; + std::shared_ptr reader; }; } // namespace diskann diff --git a/include/restapi/server.h b/include/restapi/server.h index 9cb9449da..ddb19d17a 100644 --- a/include/restapi/server.h +++ b/include/restapi/server.h @@ -6,44 +6,40 @@ #include #include -namespace diskann { -class Server { -public: - Server(web::uri &url, - std::vector> &multi_searcher, - const std::string &typestring); - virtual ~Server(); - - pplx::task open(); - pplx::task close(); - -protected: - template void handle_post(web::http::http_request message); - - template - web::json::value - toJsonArray(const std::vector &v, - std::function valConverter); - web::json::value prepareResponse(const int64_t &queryId, const int k); - - template - void parseJson(const utility::string_t &body, unsigned int &k, - int64_t &queryId, T *&queryVector, unsigned int &dimensions, - unsigned &Ls); - - web::json::value idsToJsonArray(const diskann::SearchResult &result); - web::json::value distancesToJsonArray(const diskann::SearchResult &result); - web::json::value tagsToJsonArray(const diskann::SearchResult &result); - web::json::value partitionsToJsonArray(const diskann::SearchResult &result); - - SearchResult - aggregate_results(const unsigned K, - const std::vector &results); - -private: - bool _isDebug; - std::unique_ptr _listener; - const bool _multi_search; - std::vector> _multi_searcher; +namespace diskann +{ +class Server +{ + public: + Server(web::uri &url, std::vector> &multi_searcher, + const std::string &typestring); + virtual ~Server(); + + pplx::task open(); + pplx::task close(); + + protected: + template void handle_post(web::http::http_request message); + + template + web::json::value toJsonArray(const std::vector &v, std::function valConverter); + web::json::value prepareResponse(const int64_t &queryId, const int k); + + template + void parseJson(const utility::string_t &body, unsigned int &k, int64_t &queryId, T *&queryVector, + unsigned int &dimensions, unsigned &Ls); + + web::json::value idsToJsonArray(const diskann::SearchResult &result); + web::json::value distancesToJsonArray(const diskann::SearchResult &result); + web::json::value tagsToJsonArray(const diskann::SearchResult &result); + web::json::value partitionsToJsonArray(const diskann::SearchResult &result); + + SearchResult aggregate_results(const unsigned K, const std::vector &results); + + private: + bool _isDebug; + std::unique_ptr _listener; + const bool _multi_search; + std::vector> _multi_searcher; }; } // namespace diskann