Skip to content

Commit

Permalink
clang-formatted restapi files
Browse files Browse the repository at this point in the history
  • Loading branch information
rakri committed Nov 13, 2024
1 parent 505572f commit 7ac2891
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 135 deletions.
12 changes: 5 additions & 7 deletions include/restapi/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
#include <cpprest/base_uri.h>
#include <restapi/search_wrapper.h>

namespace diskann {
namespace diskann
{
// Constants
static const std::string VECTOR_KEY = "query", K_KEY = "k",
INDICES_KEY = "indices", DISTANCES_KEY = "distances",
TAGS_KEY = "tags", QUERY_ID_KEY = "query_id",
ERROR_MESSAGE_KEY = "error", L_KEY = "Ls",
TIME_TAKEN_KEY = "time_taken_in_us",
PARTITION_KEY = "partition",
static const std::string VECTOR_KEY = "query", K_KEY = "k", INDICES_KEY = "indices", DISTANCES_KEY = "distances",
TAGS_KEY = "tags", QUERY_ID_KEY = "query_id", ERROR_MESSAGE_KEY = "error", L_KEY = "Ls",
TIME_TAKEN_KEY = "time_taken_in_us", PARTITION_KEY = "partition",
UNKNOWN_ERROR = "unknown_error";
const unsigned int DEFAULT_L = 100;

Expand Down
203 changes: 114 additions & 89 deletions include/restapi/search_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,106 +10,131 @@
#include <index.h>
#include <pq_flash_index.h>

namespace diskann {
class SearchResult {
public:
SearchResult(unsigned int K, unsigned int elapsed_time_in_ms,
const unsigned *const indices, const float *const distances,
const std::string *const tags = nullptr,
const unsigned *const partitions = nullptr);

const std::vector<unsigned int> &get_indices() const { return _indices; }
const std::vector<float> &get_distances() const { return _distances; }
bool tags_enabled() const { return _tags_enabled; }
const std::vector<std::string> &get_tags() const { return _tags; }
bool partitions_enabled() const { return _partitions_enabled; }
const std::vector<unsigned> &get_partitions() const { return _partitions; }
unsigned get_time() const { return _search_time_in_ms; }

private:
unsigned int _K;
unsigned int _search_time_in_ms;
std::vector<unsigned int> _indices;
std::vector<float> _distances;

bool _tags_enabled;
std::vector<std::string> _tags;

bool _partitions_enabled;
std::vector<unsigned> _partitions;
namespace diskann
{
class SearchResult
{
public:
SearchResult(unsigned int K, unsigned int elapsed_time_in_ms, const unsigned *const indices,
const float *const distances, const std::string *const tags = nullptr,
const unsigned *const partitions = nullptr);

const std::vector<unsigned int> &get_indices() const
{
return _indices;
}
const std::vector<float> &get_distances() const
{
return _distances;
}
bool tags_enabled() const
{
return _tags_enabled;
}
const std::vector<std::string> &get_tags() const
{
return _tags;
}
bool partitions_enabled() const
{
return _partitions_enabled;
}
const std::vector<unsigned> &get_partitions() const
{
return _partitions;
}
unsigned get_time() const
{
return _search_time_in_ms;
}

private:
unsigned int _K;
unsigned int _search_time_in_ms;
std::vector<unsigned int> _indices;
std::vector<float> _distances;

bool _tags_enabled;
std::vector<std::string> _tags;

bool _partitions_enabled;
std::vector<unsigned> _partitions;
};

class SearchNotImplementedException : public std::logic_error {
private:
std::string _errormsg;

public:
SearchNotImplementedException(const char *type)
: std::logic_error("Not Implemented") {
_errormsg = "Search with data type ";
_errormsg += std::string(type);
_errormsg += " not implemented : ";
_errormsg += __FUNCTION__;
}

virtual const char *what() const throw() { return _errormsg.c_str(); }
class SearchNotImplementedException : public std::logic_error
{
private:
std::string _errormsg;

public:
SearchNotImplementedException(const char *type) : std::logic_error("Not Implemented")
{
_errormsg = "Search with data type ";
_errormsg += std::string(type);
_errormsg += " not implemented : ";
_errormsg += __FUNCTION__;
}

virtual const char *what() const throw()
{
return _errormsg.c_str();
}
};

class BaseSearch {
public:
BaseSearch(const std::string &tagsFile = nullptr);
virtual SearchResult search(const float *query, const unsigned int dimensions,
const unsigned int K, const unsigned int Ls) {
throw SearchNotImplementedException("float");
}
virtual SearchResult search(const int8_t *query,
const unsigned int dimensions,
const unsigned int K, const unsigned int Ls) {
throw SearchNotImplementedException("int8_t");
}

virtual SearchResult search(const uint8_t *query,
const unsigned int dimensions,
const unsigned int K, const unsigned int Ls) {
throw SearchNotImplementedException("uint8_t");
}

void lookup_tags(const unsigned K, const unsigned *indices,
std::string *ret_tags);

protected:
bool _tags_enabled;
std::vector<std::string> _tags_str;
class BaseSearch
{
public:
BaseSearch(const std::string &tagsFile = nullptr);
virtual SearchResult search(const float *query, const unsigned int dimensions, const unsigned int K,
const unsigned int Ls)
{
throw SearchNotImplementedException("float");
}
virtual SearchResult search(const int8_t *query, const unsigned int dimensions, const unsigned int K,
const unsigned int Ls)
{
throw SearchNotImplementedException("int8_t");
}

virtual SearchResult search(const uint8_t *query, const unsigned int dimensions, const unsigned int K,
const unsigned int Ls)
{
throw SearchNotImplementedException("uint8_t");
}

void lookup_tags(const unsigned K, const unsigned *indices, std::string *ret_tags);

protected:
bool _tags_enabled;
std::vector<std::string> _tags_str;
};

template <typename T> class InMemorySearch : public BaseSearch {
public:
InMemorySearch(const std::string &baseFile, const std::string &indexFile,
const std::string &tagsFile, Metric m, uint32_t num_threads,
uint32_t search_l);
virtual ~InMemorySearch();
template <typename T> class InMemorySearch : public BaseSearch
{
public:
InMemorySearch(const std::string &baseFile, const std::string &indexFile, const std::string &tagsFile, Metric m,
uint32_t num_threads, uint32_t search_l);
virtual ~InMemorySearch();

SearchResult search(const T *query, const unsigned int dimensions,
const unsigned int K, const unsigned int Ls);
SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls);

private:
unsigned int _dimensions, _numPoints;
std::unique_ptr<diskann::Index<T>> _index;
private:
unsigned int _dimensions, _numPoints;
std::unique_ptr<diskann::Index<T>> _index;
};

template <typename T> class PQFlashSearch : public BaseSearch {
public:
PQFlashSearch(const std::string &indexPrefix,
const unsigned num_nodes_to_cache, const unsigned num_threads,
const std::string &tagsFile, Metric m);
virtual ~PQFlashSearch();
template <typename T> class PQFlashSearch : public BaseSearch
{
public:
PQFlashSearch(const std::string &indexPrefix, const unsigned num_nodes_to_cache, const unsigned num_threads,
const std::string &tagsFile, Metric m);
virtual ~PQFlashSearch();

SearchResult search(const T *query, const unsigned int dimensions,
const unsigned int K, const unsigned int Ls);
SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls);

private:
unsigned int _dimensions, _numPoints;
std::unique_ptr<diskann::PQFlashIndex<T>> _index;
std::shared_ptr<AlignedFileReader> reader;
private:
unsigned int _dimensions, _numPoints;
std::unique_ptr<diskann::PQFlashIndex<T>> _index;
std::shared_ptr<AlignedFileReader> reader;
};
} // namespace diskann
74 changes: 35 additions & 39 deletions include/restapi/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,40 @@
#include <cpprest/http_listener.h>
#include <restapi/common.h>

namespace diskann {
class Server {
public:
Server(web::uri &url,
std::vector<std::unique_ptr<diskann::BaseSearch>> &multi_searcher,
const std::string &typestring);
virtual ~Server();

pplx::task<void> open();
pplx::task<void> close();

protected:
template <class T> void handle_post(web::http::http_request message);

template <typename T>
web::json::value
toJsonArray(const std::vector<T> &v,
std::function<web::json::value(const T &)> valConverter);
web::json::value prepareResponse(const int64_t &queryId, const int k);

template <class T>
void parseJson(const utility::string_t &body, unsigned int &k,
int64_t &queryId, T *&queryVector, unsigned int &dimensions,
unsigned &Ls);

web::json::value idsToJsonArray(const diskann::SearchResult &result);
web::json::value distancesToJsonArray(const diskann::SearchResult &result);
web::json::value tagsToJsonArray(const diskann::SearchResult &result);
web::json::value partitionsToJsonArray(const diskann::SearchResult &result);

SearchResult
aggregate_results(const unsigned K,
const std::vector<diskann::SearchResult> &results);

private:
bool _isDebug;
std::unique_ptr<web::http::experimental::listener::http_listener> _listener;
const bool _multi_search;
std::vector<std::unique_ptr<diskann::BaseSearch>> _multi_searcher;
namespace diskann
{
class Server
{
public:
Server(web::uri &url, std::vector<std::unique_ptr<diskann::BaseSearch>> &multi_searcher,
const std::string &typestring);
virtual ~Server();

pplx::task<void> open();
pplx::task<void> close();

protected:
template <class T> void handle_post(web::http::http_request message);

template <typename T>
web::json::value toJsonArray(const std::vector<T> &v, std::function<web::json::value(const T &)> valConverter);
web::json::value prepareResponse(const int64_t &queryId, const int k);

template <class T>
void parseJson(const utility::string_t &body, unsigned int &k, int64_t &queryId, T *&queryVector,
unsigned int &dimensions, unsigned &Ls);

web::json::value idsToJsonArray(const diskann::SearchResult &result);
web::json::value distancesToJsonArray(const diskann::SearchResult &result);
web::json::value tagsToJsonArray(const diskann::SearchResult &result);
web::json::value partitionsToJsonArray(const diskann::SearchResult &result);

SearchResult aggregate_results(const unsigned K, const std::vector<diskann::SearchResult> &results);

private:
bool _isDebug;
std::unique_ptr<web::http::experimental::listener::http_listener> _listener;
const bool _multi_search;
std::vector<std::unique_ptr<diskann::BaseSearch>> _multi_searcher;
};
} // namespace diskann

0 comments on commit 7ac2891

Please sign in to comment.