Skip to content

Commit

Permalink
Incorporating CR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gopal-msr committed Oct 23, 2023
1 parent 84f1ee4 commit 74330ca
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
26 changes: 2 additions & 24 deletions include/pq_scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,8 @@ template <typename T> struct PQScratch
float *rotated_query = nullptr;
float *aligned_query_float = nullptr;

PQScratch(size_t graph_degree, size_t aligned_dim)
{
diskann::alloc_aligned((void **)&aligned_pq_coord_scratch,
(size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256);
diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float),
256);
diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256);
diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float));
diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float));

memset(aligned_query_float, 0, aligned_dim * sizeof(float));
memset(rotated_query, 0, aligned_dim * sizeof(float));
}

void initialize(size_t dim, const T *query, const float norm = 1.0f)
{
for (size_t d = 0; d < dim; ++d)
{
if (norm != 1.0f)
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]) / norm;
else
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]);
}
}
PQScratch(size_t graph_degree, size_t aligned_dim);
void initialize(size_t dim, const T *query, const float norm = 1.0f);
};

} // namespace diskann
27 changes: 27 additions & 0 deletions src/scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,33 @@ template <typename T> void SSDThreadData<T>::clear()
scratch.reset();
}

template<typename T>
PQScratch<T>::PQScratch(size_t graph_degree, size_t aligned_dim)
{
diskann::alloc_aligned((void **)&aligned_pq_coord_scratch,
(size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256);
diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float), 256);
diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256);
diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float));
diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float));

memset(aligned_query_float, 0, aligned_dim * sizeof(float));
memset(rotated_query, 0, aligned_dim * sizeof(float));
}

template<typename T>
void PQScratch<T>::initialize(size_t dim, const T *query, const float norm = 1.0f)
{
for (size_t d = 0; d < dim; ++d)
{
if (norm != 1.0f)
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]) / norm;
else
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]);
}
}


template DISKANN_DLLEXPORT class InMemQueryScratch<int8_t>;
template DISKANN_DLLEXPORT class InMemQueryScratch<uint8_t>;
template DISKANN_DLLEXPORT class InMemQueryScratch<float>;
Expand Down

0 comments on commit 74330ca

Please sign in to comment.