diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h index 39db09be77..9377e75eed 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h @@ -9,6 +9,7 @@ #pragma once #include "kv_tensor_wrapper.h" +#include "ssd_table_batched_embeddings.h" namespace ssd { @@ -58,30 +59,34 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { enable_async_update)) {} void set_cuda( - Tensor indices, - Tensor weights, - Tensor count, + at::Tensor indices, + at::Tensor weights, + at::Tensor count, int64_t timestep, bool is_bwd) { return impl_->set_cuda(indices, weights, count, timestep, is_bwd); } - void get_cuda(Tensor indices, Tensor weights, Tensor count) { + void get_cuda(at::Tensor indices, at::Tensor weights, at::Tensor count) { return impl_->get_cuda(indices, weights, count); } - void set(Tensor indices, Tensor weights, Tensor count) { + void set(at::Tensor indices, at::Tensor weights, at::Tensor count) { return impl_->set(indices, weights, count); } void set_range_to_storage( - const Tensor& weights, + const at::Tensor& weights, const int64_t start, const int64_t length) { return impl_->set_range_to_storage(weights, start, length); } - void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) { + void get( + at::Tensor indices, + at::Tensor weights, + at::Tensor count, + int64_t sleep_ms) { return impl_->get(indices, weights, count, sleep_ms); }