Skip to content

Commit

Permalink
#39 The alignment of the input mask has to be covered by the caller.
Browse files Browse the repository at this point in the history
  • Loading branch information
carljohnsen committed May 14, 2024
1 parent d2f1fa1 commit 67d0ba4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
3 changes: 0 additions & 3 deletions src/lib/cpp/gpu/bitpacking.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ namespace gpu {
const uint32_t *mask32 = (const uint32_t *) mask;
uint32_t local[buffer_size]; // Shared memory

// TODO handle block_size unalignment

#pragma acc data copyin(mask32[0:n/sizeof(uint32_t)]) copyout(packed[0:n/(uint64_t)T_bits])
#pragma acc parallel vector_length(vec_size) num_workers(worker_size)
{
Expand Down Expand Up @@ -129,7 +127,6 @@ namespace gpu {
for (uint64_t z = 0; z < sz; z++) {
for (uint64_t y = 0; y < sy; y++) {
for (uint64_t x = 0; x < sx; x++) {
// TODO Handle unalignment
uint64_t packed_offset = (oz+z)*Ny*Nx + (oy+y)*Nx + ox+x;
uint64_t slice_offset = z*sy*sx + y*sx + x;
slice[slice_offset] = packed[packed_offset];
Expand Down
21 changes: 11 additions & 10 deletions src/pybind/bitpacking-pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ namespace python_api {
auto mask_info = np_mask.request();
auto packed_info = np_packed.request();

assert(packed_info.size * T_bits >= (uint64_t) mask_info.size);
assert(packed_info.size * T_bits >= (uint64_t) mask_info.size && "Packed array is too small");
assert(mask_info.size % (T_bits*T_bits) == 0 && "Mask size must be a multiple of T_bits*T_bits");

const uint8_t *mask = static_cast<const uint8_t*>(mask_info.ptr);
//uint8_t *packed = static_cast<uint8_t*>(packed_info.ptr);
T *packed = static_cast<T*>(packed_info.ptr);

NS::encode(mask, mask_info.size, packed);
Expand All @@ -34,7 +34,6 @@ namespace python_api {

assert(packed_info.size * T_bits >= (uint64_t) mask_info.size);

//const uint8_t *packed = static_cast<const uint8_t*>(packed_info.ptr);
const T *packed = static_cast<const T*>(packed_info.ptr);
uint8_t *mask = static_cast<uint8_t*>(mask_info.ptr);

Expand All @@ -44,15 +43,17 @@ namespace python_api {
}

PYBIND11_MODULE(bitpacking, m) {
m.doc() = "Bitpacking functions for encoding and decoding. A bool should only take up 1 bit."; // optional module docstring
m.doc() = "Bitpacking functions for encoding and decoding. A bool should only take up 1 bit. Current implementations are built around the packed datatype being uint32_t."; // optional module docstring

// TODO Currently, the GPU implementation only supports uint32_t.

m.def("encode", &python_api::encode<uint8_t>, py::arg("np_mask").noconvert(), py::arg("np_packed").noconvert());
m.def("encode", &python_api::encode<uint16_t>, py::arg("np_mask").noconvert(), py::arg("np_packed").noconvert());
//m.def("encode", &python_api::encode<uint8_t>, py::arg("np_mask").noconvert(), py::arg("np_packed").noconvert());
//m.def("encode", &python_api::encode<uint16_t>, py::arg("np_mask").noconvert(), py::arg("np_packed").noconvert());
m.def("encode", &python_api::encode<uint32_t>, py::arg("np_mask").noconvert(), py::arg("np_packed").noconvert());
m.def("encode", &python_api::encode<uint64_t>, py::arg("np_mask").noconvert(), py::arg("np_packed").noconvert());
//m.def("encode", &python_api::encode<uint64_t>, py::arg("np_mask").noconvert(), py::arg("np_packed").noconvert());

m.def("decode", &python_api::decode<uint8_t>, py::arg("np_packed").noconvert(), py::arg("np_mask").noconvert());
m.def("decode", &python_api::decode<uint16_t>, py::arg("np_packed").noconvert(), py::arg("np_mask").noconvert());
//m.def("decode", &python_api::decode<uint8_t>, py::arg("np_packed").noconvert(), py::arg("np_mask").noconvert());
//m.def("decode", &python_api::decode<uint16_t>, py::arg("np_packed").noconvert(), py::arg("np_mask").noconvert());
m.def("decode", &python_api::decode<uint32_t>, py::arg("np_packed").noconvert(), py::arg("np_mask").noconvert());
m.def("decode", &python_api::decode<uint64_t>, py::arg("np_packed").noconvert(), py::arg("np_mask").noconvert());
//m.def("decode", &python_api::decode<uint64_t>, py::arg("np_packed").noconvert(), py::arg("np_mask").noconvert());
}

0 comments on commit 67d0ba4

Please sign in to comment.