Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change #441

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions compression/compress-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ struct CompressTraits<NuqStream> {
size_t num, CompressPerThread& tls,
const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs);
NuqCodec::EncInterleaved(df, raw, num, tls.buf, packed, packed_ofs);

if (COMPRESS_STATS) {
for (size_t i = 0; i < num; ++i) {
Expand All @@ -396,8 +396,8 @@ struct CompressTraits<NuqStream> {
const hn::Repartition<BF16, DF> dbf;
const size_t N16 = hn::Lanes(dbf);
auto distorted = hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, N16));
NuqCodec::DecompressAndZeroPad(dbf, MakeConst(packed), packed_ofs,
distorted.get(), num);
NuqCodec::DecompressAndZeroPadInterleaved(
dbf, MakeConst(packed), packed_ofs, distorted.get(), num);
DistortionStats stats;
for (size_t i = 0; i < num; ++i) {
stats.Notify(raw[i], hwy::F32FromBF16(distorted[i]));
Expand All @@ -410,7 +410,7 @@ struct CompressTraits<NuqStream> {
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<D>& raw0,
hn::Vec<D>& raw1) {
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
NuqCodec::Dec2Interleaved(d, packed, packed_ofs, raw0, raw1);
}

// Store2 is not yet implemented.
Expand All @@ -419,7 +419,7 @@ struct CompressTraits<NuqStream> {
static HWY_INLINE void DecompressAndZeroPad(
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
Raw* raw, const size_t num) {
NuqCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
NuqCodec::DecompressAndZeroPadInterleaved(d, packed, packed_ofs, raw, num);
}
};

Expand Down
20 changes: 15 additions & 5 deletions compression/compress.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_

#include "hwy/base.h"
#define COMPRESS_STATS 0

#include <stddef.h>
Expand Down Expand Up @@ -134,7 +135,7 @@ class MatPtr {
size_t NumElements() const { return num_elements_; }

// Returns the number of bytes in the array.
size_t SizeBytes() const { return num_elements_ * element_size_; }
virtual size_t SizeBytes() const { return num_elements_ * element_size_; }

// Returns the number of rows in the 2-d array (outer dimension).
size_t Rows() const { return rows_; }
Expand Down Expand Up @@ -240,10 +241,13 @@ class MatPtrT : public MatPtr {
return name;
}

// Sets the number of elements in the array. For use when the number of
// elements is != rows * cols ONLY.
void SetNumElements(size_t num_elements) {
num_elements_ = CompressedArrayElements<MatT>(num_elements);
// Returns the number of bytes in the array. Overrides MatPtr::SizeBytes()
// to account for NUQ's differing packed size.
size_t SizeBytes() const override {
if (hwy::IsSame<hwy::RemoveCvRef<MatT>, NuqStream>()) {
return NuqStream::PackedEnd(num_elements_);
}
return num_elements_ * element_size_;
}

// 2-d Accessor for a specific type but with a dynamic inner dimension.
Expand Down Expand Up @@ -333,6 +337,12 @@ class MatStorageT : public MatPtrT<MatT> {
// from the current num_elements_ which was set by the constructor from the
// rows and cols.
void Allocate(size_t num_elements = 0) {
// size_t num_elements = 0;
// TODO: optimize this check or obviate it.
if (hwy::IsSame<hwy::RemoveCvRef<MatT>, NuqStream>()) {
HWY_DASSERT(num_elements == 0);
}

if (num_elements == 0) {
num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT));
} else {
Expand Down
189 changes: 189 additions & 0 deletions compression/nuq-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
// After highway.h
#include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/profiler.h" // uses SIMD

HWY_BEFORE_NAMESPACE();
namespace gcpp {
Expand Down Expand Up @@ -529,12 +530,21 @@ class NuqCodec {
return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
}

// Offset (in bytes) of a group's table for packed_ofs (in elements) within a
// set of groups.
static constexpr size_t TableByteOffset(size_t packed_ofs) {
const size_t kBytesPerGroup =
(kClusters * sizeof(SfpStream)) + kGroupSize / 2;
return (packed_ofs / kGroupSize) * kBytesPerGroup;
}

// Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
// for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
// not be available for bf16.
template <class DU, HWY_IF_U16_D(DU)>
static HWY_INLINE hn::Vec<DU> LoadTable(DU du, const uint8_t* centers,
hn::Vec<DU>* HWY_RESTRICT tbl1) {
PROFILER_FUNC;
// Cap to the table size (kClusters) for decoding SFP - sufficient, and may
// be faster than a large vector.
const hn::CappedTag<BF16, kClusters> d_table;
Expand Down Expand Up @@ -606,6 +616,81 @@ class NuqCodec {
}

public:
// Encodes `num` floats from `raw` into `packed`. `packed` points to
// compressed storage and `packed_ofs` indicates the destination offset within
// it, in number of elements. Tables are interleaved with indices (clustered
// elements) to allow for easier unpacking. Returns the total number of
// unused clusters, which is typically zero.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE size_t EncInterleaved(DF df, const float* HWY_RESTRICT raw,
const size_t num,
NuqStream::ClusterBuf& buf,
const PackedSpan<NuqStream>& packed,
size_t packed_ofs) {
const hn::Repartition<uint16_t, DF> d16;
const hn::Repartition<uint8_t, DF> d8;
using V16 = hn::Vec<decltype(d16)>;
using V8 = hn::Vec<decltype(d8)>;
const size_t N16 = hn::Lanes(d16);

HWY_ASSERT(packed_ofs % kGroupSize == 0);

const size_t num_groups = hwy::DivCeil(num, kGroupSize);
// TODO: dynamic resize should be removed; it is no longer necessary as
// interleaved encoding uses only a single buffer of the same size.
buf.Resize(1);

size_t unused_clusters = 0;
size_t current_offset = packed_ofs;
for (size_t g = 0; g < num_groups; ++g) {
const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize);
const float* HWY_RESTRICT g_in = raw + g * kGroupSize;

float* HWY_RESTRICT g_centers = buf.centers.get();
uint16_t* HWY_RESTRICT g_idx = buf.idx.get();

unused_clusters +=
NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx);

uint8_t* centers = &packed.ptr->byte + TableByteOffset(current_offset);
SfpCodec::Enc(df, buf.centers.get(), kClusters,
reinterpret_cast<SfpStream*>(centers));
uint8_t* packed_start = centers + kClusters;

current_offset += g_num;

HWY_DASSERT(g_num % (4 * N16) == 0);

size_t i = 0;
HWY_UNROLL(1)
for (; i < g_num; i += 4 * N16) {
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
const V8 nibbles =
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
hn::StoreU(nibbles, d8, packed_start + i / 2);
}

const size_t remaining = g_num - i;

HWY_DASSERT(remaining < 4 * N16);
if (HWY_UNLIKELY(remaining != 0)) {
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
const V8 nibbles =
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
// i is even, but remaining might not be.
hn::StoreN(nibbles, d8, packed_start + i / 2,
hwy::DivCeil(remaining, 2));
}
}
return unused_clusters;
}

// Encodes `num` floats from `raw`. `packed` points to compressed storage and
// `packed_ofs` indicates the destination offset within it, in units of float
// values, for parallel encoding by multiple threads. Returns the total
Expand Down Expand Up @@ -733,6 +818,8 @@ class NuqCodec {
raw1 = BitCast(dbf, c1);
}

// TODO(philculliton): Remove non-interleaved function versions now that
// interleaved is working / the default.
// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
// vectors so that we only have to load one group's table.
template <class DF, HWY_IF_F32_D(DF)>
Expand Down Expand Up @@ -765,6 +852,107 @@ class NuqCodec {
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
}

// Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two
// vectors so that we only have to load one group's table.
template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Dec2Interleaved(
DBF dbf, const PackedSpan<const NuqStream>& packed,
const size_t packed_ofs, hn::Vec<DBF>& raw0, hn::Vec<DBF>& raw1) {
PROFILER_FUNC;
const hn::RebindToUnsigned<decltype(dbf)> d16;
const D8HFromD16<DBF> d8h;
using V16 = hn::Vec<decltype(d16)>;
using V8H = hn::Vec<decltype(d8h)>;

const size_t within_group = packed_ofs % kGroupSize;
HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0);
const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs);
const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);

V16 tbl1 = Zero(d16);
const V16 tbl0 = LoadTable(d16, table, &tbl1);

const V8H nibbles = hn::LoadU(d8h, indices);

V16 c0, c1;
TableLookups(d16, tbl0, tbl1, nibbles, c0, c1);
raw0 = BitCast(dbf, c0);
raw1 = BitCast(dbf, c1);
}

// Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two
// vectors so that we only have to load one group's table.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dec2Interleaved(
DF df, const PackedSpan<const NuqStream>& packed, const size_t packed_ofs,
hn::Vec<DF>& raw0, hn::Vec<DF>& raw1) {
const hn::Repartition<BF16, decltype(df)> dbf;
const hn::RebindToUnsigned<decltype(dbf)> d16;
const hn::Half<D8HFromD16<decltype(d16)>> d8q;
using V8Q = hn::Vec<decltype(d8q)>;
using V16 = hn::Vec<decltype(d16)>;

const size_t within_group = packed_ofs % kGroupSize;
HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0);
const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs);
const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);

V16 tbl1 = Zero(d16);
const V16 tbl0 = LoadTable(d16, table, &tbl1);

// The single-vector TableLookups overload only calls OrderedUnpackU16<0>,
// which expects a quarter vector of bytes.
const V8Q nibbles = hn::LoadU(d8q, indices);

// TODO(janwas): From janwas: on AVX-512 I imagine we can get a
// bit more speed for this function by changing LoadTable to return floats,
// then we could have a single lookup here instead of PromoteUpperTo which
// is not cheap.
const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles);
raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0));
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
}

template <class D, typename Raw = hn::TFromD<D>>
static HWY_INLINE void DecompressAndZeroPadInterleaved(
D d, const PackedSpan<const NuqStream>& packed, size_t packed_ofs,
Raw* HWY_RESTRICT raw, size_t num) {
// If unaligned, load elements from the first group and update the args,
// from which we compute new tables/indices below.
size_t current_offset = packed_ofs;
if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) {
const uint8_t* tables =
&packed.ptr->byte + TableByteOffset(current_offset);
const uint8_t* indices = tables + kClusters;
const size_t remaining = HWY_MIN(num, kGroupSize - within_group);

DecPartialGroup(d, tables, indices, raw, remaining);
packed_ofs += remaining;
current_offset += remaining;
raw += remaining;
num -= remaining;
if (num == 0) return;
}

HWY_DASSERT(packed_ofs % kGroupSize == 0);

const size_t num_groups = hwy::DivCeil(num, kGroupSize);
HWY_UNROLL(1)
for (size_t g = 0; g < num_groups - 1; ++g) {
const uint8_t* tables =
&packed.ptr->byte + TableByteOffset(current_offset);
const uint8_t* indices = tables + kClusters;
DecWholeGroup(d, tables, indices, raw + g * kGroupSize);
current_offset += kGroupSize;
}

const size_t g = num_groups - 1;
const uint8_t* tables = &packed.ptr->byte + TableByteOffset(current_offset);
const uint8_t* indices = tables + kClusters;
DecPartialGroup(d, tables, indices, raw + g * kGroupSize,
num - g * kGroupSize);
}

// Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num`
// elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to
// round `num` up to one vector, if it is not already.
Expand Down Expand Up @@ -955,6 +1143,7 @@ class NuqCodec {
}

const size_t remaining = num - i;

HWY_DASSERT(remaining < 4 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
// i is even, but remaining might not be.
Expand Down
Loading
Loading