Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686665933
  • Loading branch information
pculliton authored and copybara-github committed Oct 29, 2024
1 parent ed40919 commit d934fe3
Show file tree
Hide file tree
Showing 14 changed files with 590 additions and 9 deletions.
10 changes: 5 additions & 5 deletions compression/compress-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ struct CompressTraits<NuqStream> {
size_t num, CompressPerThread& tls,
const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs);
NuqCodec::EncInterleaved(df, raw, num, tls.buf, packed, packed_ofs);

if (COMPRESS_STATS) {
for (size_t i = 0; i < num; ++i) {
Expand All @@ -396,8 +396,8 @@ struct CompressTraits<NuqStream> {
const hn::Repartition<BF16, DF> dbf;
const size_t N16 = hn::Lanes(dbf);
auto distorted = hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, N16));
NuqCodec::DecompressAndZeroPad(dbf, MakeConst(packed), packed_ofs,
distorted.get(), num);
NuqCodec::DecompressAndZeroPadInterleaved(
dbf, MakeConst(packed), packed_ofs, distorted.get(), num);
DistortionStats stats;
for (size_t i = 0; i < num; ++i) {
stats.Notify(raw[i], hwy::F32FromBF16(distorted[i]));
Expand All @@ -410,7 +410,7 @@ struct CompressTraits<NuqStream> {
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<D>& raw0,
hn::Vec<D>& raw1) {
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
NuqCodec::Dec2Interleaved(d, packed, packed_ofs, raw0, raw1);
}

// Store2 is not yet implemented.
Expand All @@ -419,7 +419,7 @@ struct CompressTraits<NuqStream> {
static HWY_INLINE void DecompressAndZeroPad(
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
Raw* raw, const size_t num) {
NuqCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
NuqCodec::DecompressAndZeroPadInterleaved(d, packed, packed_ofs, raw, num);
}
};

Expand Down
181 changes: 181 additions & 0 deletions compression/nuq-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <stdint.h>
#include <stdio.h>

#include <cstdio>

#include "compression/shared.h"
#include "util/basics.h"
#include "hwy/base.h"
Expand Down Expand Up @@ -529,6 +531,12 @@ class NuqCodec {
return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
}

static constexpr size_t TableOffset(size_t packed_ofs) {
const size_t group_size =
(16) + kGroupSize / 2; // NuqStream::PackedEnd(kGroupSize);
return (packed_ofs / kGroupSize) * group_size;
}

// Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
// for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
// not be available for bf16.
Expand Down Expand Up @@ -606,6 +614,81 @@ class NuqCodec {
}

public:
// Encodes `num` floats from `raw` into `packed`. `packed` points to
// compressed storage and `packed_ofs` indicates the destination offset within
// it, in number of elements. Tables are interleaved with indices (clustered
// elements) to allow for easier unpacking. Returns the total number of
// unused clusters, which is typically zero.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE size_t EncInterleaved(DF df, const float* HWY_RESTRICT raw,
const size_t num,
NuqStream::ClusterBuf& buf,
const PackedSpan<NuqStream>& packed,
size_t packed_ofs) {
const hn::Repartition<uint16_t, DF> d16;
const hn::Repartition<uint8_t, DF> d8;
using V16 = hn::Vec<decltype(d16)>;
using V8 = hn::Vec<decltype(d8)>;
const size_t N16 = hn::Lanes(d16);

HWY_ASSERT(packed_ofs % kGroupSize == 0);

const size_t num_groups = hwy::DivCeil(num, kGroupSize);
// TODO: dynamic resize should be removed; it is no longer necessary as
// interleaved encoding uses only a single buffer of the same size.
buf.Resize(1);

size_t unused_clusters = 0;
size_t current_offset = packed_ofs;
for (size_t g = 0; g < num_groups; ++g) {
const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize);
const float* HWY_RESTRICT g_in = raw + g * kGroupSize;

float* HWY_RESTRICT g_centers = buf.centers.get();
uint16_t* HWY_RESTRICT g_idx = buf.idx.get();

unused_clusters +=
NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx);

uint8_t* centers = &packed.ptr->byte + TableOffset(current_offset);
SfpCodec::Enc(df, buf.centers.get(), kClusters,
reinterpret_cast<SfpStream*>(centers));
uint8_t* packed_start = centers + kClusters;

current_offset += g_num;

HWY_DASSERT(g_num % (4 * N16) == 0);

size_t i = 0;
HWY_UNROLL(1)
for (; i < g_num; i += 4 * N16) {
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
const V8 nibbles =
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
hn::StoreU(nibbles, d8, packed_start + i / 2);
}

const size_t remaining = g_num - i;

HWY_DASSERT(remaining < 4 * N16);
if (HWY_UNLIKELY(remaining != 0)) {
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
const V8 nibbles =
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
// i is even, but remaining might not be.
hn::StoreN(nibbles, d8, packed_start + i / 2,
hwy::DivCeil(remaining, 2));
}
}
return unused_clusters;
}

// Encodes `num` floats from `raw`. `packed` points to compressed storage and
// `packed_ofs` indicates the destination offset within it, in units of float
// values, for parallel encoding by multiple threads. Returns the total
Expand Down Expand Up @@ -765,6 +848,103 @@ class NuqCodec {
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
}

// Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two
// vectors so that we only have to load one group's table.
template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Dec2Interleaved(
DBF dbf, const PackedSpan<const NuqStream>& packed,
const size_t packed_ofs, hn::Vec<DBF>& raw0, hn::Vec<DBF>& raw1) {
const hn::RebindToUnsigned<decltype(dbf)> d16;
const D8HFromD16<DBF> d8h;
using V16 = hn::Vec<decltype(d16)>;
using V8H = hn::Vec<decltype(d8h)>;

const size_t within_group = packed_ofs % kGroupSize;
HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0);
// const size_t ofs_in_groups = packed_ofs / kGroupSize;
const uint8_t* table =
&packed.ptr->byte +
TableOffset(packed_ofs); // ofs_in_groups * kClusters;
const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);

V16 tbl1 = Zero(d16);
const V16 tbl0 = LoadTable(d16, table, &tbl1);

const V8H nibbles = hn::LoadU(d8h, indices);

V16 c0, c1;
TableLookups(d16, tbl0, tbl1, nibbles, c0, c1);
raw0 = BitCast(dbf, c0);
raw1 = BitCast(dbf, c1);
}

// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
// vectors so that we only have to load one group's table.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dec2Interleaved(
DF df, const PackedSpan<const NuqStream>& packed, const size_t packed_ofs,
hn::Vec<DF>& raw0, hn::Vec<DF>& raw1) {
const hn::Repartition<BF16, decltype(df)> dbf;
const hn::RebindToUnsigned<decltype(dbf)> d16;
const hn::Half<D8HFromD16<decltype(d16)>> d8q;
using V8Q = hn::Vec<decltype(d8q)>;
using V16 = hn::Vec<decltype(d16)>;

const size_t within_group = packed_ofs % kGroupSize;
HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0);
const uint8_t* table = &packed.ptr->byte + TableOffset(packed_ofs);
const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);

V16 tbl1 = Zero(d16);
const V16 tbl0 = LoadTable(d16, table, &tbl1);

// The single-vector TableLookups overload only calls OrderedUnpackU16<0>,
// which expects a quarter vector of bytes.
const V8Q nibbles = hn::LoadU(d8q, indices);

const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles);
raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0));
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
}

template <class D, typename Raw = hn::TFromD<D>>
static HWY_INLINE void DecompressAndZeroPadInterleaved(
D d, const PackedSpan<const NuqStream>& packed, size_t packed_ofs,
Raw* HWY_RESTRICT raw, size_t num) {
// If unaligned, load elements from the first group and update the args,
// from which we compute new tables/indices below.
size_t current_offset = packed_ofs;
if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) {
const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset);
const uint8_t* indices = tables + kClusters;
const size_t remaining = HWY_MIN(num, kGroupSize - within_group);

DecPartialGroup(d, tables, indices, raw, remaining);
packed_ofs += remaining;
current_offset += remaining;
raw += remaining;
num -= remaining;
if (num == 0) return;
}

HWY_DASSERT(packed_ofs % kGroupSize == 0);

const size_t num_groups = hwy::DivCeil(num, kGroupSize);
HWY_UNROLL(1)
for (size_t g = 0; g < num_groups - 1; ++g) {
const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset);
const uint8_t* indices = tables + kClusters;
DecWholeGroup(d, tables, indices, raw + g * kGroupSize);
current_offset += kGroupSize;
}

const size_t g = num_groups - 1;
const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset);
const uint8_t* indices = tables + kClusters;
DecPartialGroup(d, tables, indices, raw + g * kGroupSize,
num - g * kGroupSize);
}

// Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num`
// elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to
// round `num` up to one vector, if it is not already.
Expand Down Expand Up @@ -955,6 +1135,7 @@ class NuqCodec {
}

const size_t remaining = num - i;

HWY_DASSERT(remaining < 4 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
// i is even, but remaining might not be.
Expand Down
Loading

0 comments on commit d934fe3

Please sign in to comment.