diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 7fd097c..28eb497 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -386,7 +386,7 @@ struct CompressTraits { size_t num, CompressPerThread& tls, const PackedSpan& packed, const size_t packed_ofs) { - NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs); + NuqCodec::EncInterleaved(df, raw, num, tls.buf, packed, packed_ofs); if (COMPRESS_STATS) { for (size_t i = 0; i < num; ++i) { @@ -396,8 +396,8 @@ struct CompressTraits { const hn::Repartition dbf; const size_t N16 = hn::Lanes(dbf); auto distorted = hwy::AllocateAligned(hwy::RoundUpTo(num, N16)); - NuqCodec::DecompressAndZeroPad(dbf, MakeConst(packed), packed_ofs, - distorted.get(), num); + NuqCodec::DecompressAndZeroPadInterleaved( + dbf, MakeConst(packed), packed_ofs, distorted.get(), num); DistortionStats stats; for (size_t i = 0; i < num; ++i) { stats.Notify(raw[i], hwy::F32FromBF16(distorted[i])); @@ -410,7 +410,7 @@ struct CompressTraits { static HWY_INLINE void Load2(D d, const PackedSpan& packed, const size_t packed_ofs, hn::Vec& raw0, hn::Vec& raw1) { - NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1); + NuqCodec::Dec2Interleaved(d, packed, packed_ofs, raw0, raw1); } // Store2 is not yet implemented. @@ -419,7 +419,7 @@ struct CompressTraits { static HWY_INLINE void DecompressAndZeroPad( D d, const PackedSpan& packed, const size_t packed_ofs, Raw* raw, const size_t num) { - NuqCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num); + NuqCodec::DecompressAndZeroPadInterleaved(d, packed, packed_ofs, raw, num); } }; diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index 63c4255..50a27b0 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -21,6 +21,8 @@ #include #include +#include + #include "compression/shared.h" #include "util/basics.h" #include "hwy/base.h" @@ -529,6 +531,12 @@ class NuqCodec { return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2; } + static constexpr size_t TableOffset(size_t packed_ofs) { + const size_t group_size = + (16) + kGroupSize / 2; // NuqStream::PackedEnd(kGroupSize); + return (packed_ofs / kGroupSize) * group_size; + } + // Unpacks `centers` from SFP into bf16 and loads them into one or two vectors // for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might // not be available for bf16. @@ -606,6 +614,81 @@ class NuqCodec { } public: + // Encodes `num` floats from `raw` into `packed`. `packed` points to + // compressed storage and `packed_ofs` indicates the destination offset within + // it, in number of elements. Tables are interleaved with indices (clustered + // elements) to allow for easier unpacking. Returns the total number of + // unused clusters, which is typically zero. + template + static HWY_INLINE size_t EncInterleaved(DF df, const float* HWY_RESTRICT raw, + const size_t num, + NuqStream::ClusterBuf& buf, + const PackedSpan& packed, + size_t packed_ofs) { + const hn::Repartition d16; + const hn::Repartition d8; + using V16 = hn::Vec; + using V8 = hn::Vec; + const size_t N16 = hn::Lanes(d16); + + HWY_ASSERT(packed_ofs % kGroupSize == 0); + + const size_t num_groups = hwy::DivCeil(num, kGroupSize); + // TODO: dynamic resize should be removed; it is no longer necessary as + // interleaved encoding uses only a single buffer of the same size. + buf.Resize(1); + + size_t unused_clusters = 0; + size_t current_offset = packed_ofs; + for (size_t g = 0; g < num_groups; ++g) { + const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize); + const float* HWY_RESTRICT g_in = raw + g * kGroupSize; + + float* HWY_RESTRICT g_centers = buf.centers.get(); + uint16_t* HWY_RESTRICT g_idx = buf.idx.get(); + + unused_clusters += + NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx); + + uint8_t* centers = &packed.ptr->byte + TableOffset(current_offset); + SfpCodec::Enc(df, buf.centers.get(), kClusters, + reinterpret_cast(centers)); + uint8_t* packed_start = centers + kClusters; + + current_offset += g_num; + + HWY_DASSERT(g_num % (4 * N16) == 0); + + size_t i = 0; + HWY_UNROLL(1) + for (; i < g_num; i += 4 * N16) { + const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); + const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); + const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); + const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); + const V8 nibbles = + NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); + hn::StoreU(nibbles, d8, packed_start + i / 2); + } + + const size_t remaining = g_num - i; + + HWY_DASSERT(remaining < 4 * N16); + if (HWY_UNLIKELY(remaining != 0)) { + const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); + const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); + const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); + const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); + const V8 nibbles = + NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); + // i is even, but remaining might not be. + hn::StoreN(nibbles, d8, packed_start + i / 2, + hwy::DivCeil(remaining, 2)); + } + } + return unused_clusters; + } + // Encodes `num` floats from `raw`. `packed` points to compressed storage and // `packed_ofs` indicates the destination offset within it, in units of float // values, for parallel encoding by multiple threads. Returns the total @@ -765,6 +848,103 @@ class NuqCodec { raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); } + // Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two + // vectors so that we only have to load one group's table. + template + static HWY_INLINE void Dec2Interleaved( + DBF dbf, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, hn::Vec& raw1) { + const hn::RebindToUnsigned d16; + const D8HFromD16 d8h; + using V16 = hn::Vec; + using V8H = hn::Vec; + + const size_t within_group = packed_ofs % kGroupSize; + HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0); + // const size_t ofs_in_groups = packed_ofs / kGroupSize; + const uint8_t* table = + &packed.ptr->byte + + TableOffset(packed_ofs); // ofs_in_groups * kClusters; + const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2); + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, table, &tbl1); + + const V8H nibbles = hn::LoadU(d8h, indices); + + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, nibbles, c0, c1); + raw0 = BitCast(dbf, c0); + raw1 = BitCast(dbf, c1); + } + + // Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two + // vectors so that we only have to load one group's table. + template + static HWY_INLINE void Dec2Interleaved( + DF df, const PackedSpan& packed, const size_t packed_ofs, + hn::Vec& raw0, hn::Vec& raw1) { + const hn::Repartition dbf; + const hn::RebindToUnsigned d16; + const hn::Half> d8q; + using V8Q = hn::Vec; + using V16 = hn::Vec; + + const size_t within_group = packed_ofs % kGroupSize; + HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0); + const uint8_t* table = &packed.ptr->byte + TableOffset(packed_ofs); + const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2); + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, table, &tbl1); + + // The single-vector TableLookups overload only calls OrderedUnpackU16<0>, + // which expects a quarter vector of bytes. + const V8Q nibbles = hn::LoadU(d8q, indices); + + const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles); + raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); + raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); + } + + template > + static HWY_INLINE void DecompressAndZeroPadInterleaved( + D d, const PackedSpan& packed, size_t packed_ofs, + Raw* HWY_RESTRICT raw, size_t num) { + // If unaligned, load elements from the first group and update the args, + // from which we compute new tables/indices below. + size_t current_offset = packed_ofs; + if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) { + const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset); + const uint8_t* indices = tables + kClusters; + const size_t remaining = HWY_MIN(num, kGroupSize - within_group); + + DecPartialGroup(d, tables, indices, raw, remaining); + packed_ofs += remaining; + current_offset += remaining; + raw += remaining; + num -= remaining; + if (num == 0) return; + } + + HWY_DASSERT(packed_ofs % kGroupSize == 0); + + const size_t num_groups = hwy::DivCeil(num, kGroupSize); + HWY_UNROLL(1) + for (size_t g = 0; g < num_groups - 1; ++g) { + const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset); + const uint8_t* indices = tables + kClusters; + DecWholeGroup(d, tables, indices, raw + g * kGroupSize); + current_offset += kGroupSize; + } + + const size_t g = num_groups - 1; + const uint8_t* tables = &packed.ptr->byte + TableOffset(current_offset); + const uint8_t* indices = tables + kClusters; + DecPartialGroup(d, tables, indices, raw + g * kGroupSize, + num - g * kGroupSize); + } + // Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num` // elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to // round `num` up to one vector, if it is not already. @@ -955,6 +1135,7 @@ class NuqCodec { } const size_t remaining = num - i; + HWY_DASSERT(remaining < 4 * NF); if (HWY_UNLIKELY(remaining != 0)) { // i is even, but remaining might not be. diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index 8cbce6c..75d6575 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -277,6 +277,193 @@ struct TestOffset { void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); } void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); } +// Can encode and decode sub-regions. +struct TestUnalignedOffset { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = 10 * kGroupSize; // already padded + + const int num_unaligned_offsets = 4; + const std::array unaligned_offsets = { + 4, kGroupSize + 100, 2 * kGroupSize + 100, 3 * kGroupSize + 100}; + const std::array num = {4, 16, 32, 64}; + + for (int i = 0; i < num_unaligned_offsets; ++i) { + const size_t unaligned_offset = unaligned_offsets[i]; + const size_t num_decompressed = num[i]; + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + auto dec2 = hwy::AllocateAligned(num_decompressed); + HWY_ASSERT(in && dec1 && dec2 && nuq); + const auto nuq_span = MakeSpan(nuq.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // Encode + decode everything + NuqStream::ClusterBuf buf; + (void)NuqCodec::Enc(df, in.get(), total, buf, nuq_span, 0); + NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec1.get(), + total); + + NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), unaligned_offset, + dec2.get(), num_decompressed); + + for (size_t i = 0; i < num_decompressed; ++i) { + T f_should_be_correct = + hwy::ConvertScalarTo(dec1[unaligned_offset + i]); + T f_might_be_wrong = hwy::ConvertScalarTo(dec2[i]); + + HWY_ASSERT_EQ(f_should_be_correct, f_might_be_wrong); + } + } + } +}; + +void TestUnalignedOffsetBF16() { + hn::ForGEVectors<128, TestUnalignedOffset>()(BF16()); +} +void TestUnalignedOffsetF32() { + hn::ForGEVectors<128, TestUnalignedOffset>()(float()); +} + +// Can encode and decode sub-regions. +// Uses Dec2Interleaved to decode all elements in the packed buffer, then +// compares against the non-interleaved decode. +struct TestDec2Interleaved { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = + 1 * kGroupSize + + (kGroupSize / + 2); // adding a partial group to test... partial group handling! + const size_t kMidLen = 2 * kGroupSize; // length of middle piece + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec0 = hwy::AllocateAligned(total); + auto dec1 = hwy::AllocateAligned(total); + auto dec2 = hwy::AllocateAligned(kMidLen); + auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + HWY_ASSERT(in && dec1 && dec2 && nuq); + const auto nuq_span = MakeSpan(nuq.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // Non-interleaved encode + decode for comparison + NuqStream::ClusterBuf buf0; + (void)NuqCodec::Enc(df, in.get(), total, buf0, nuq_span, 0); + NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec0.get(), + total); + + // Encode + decode everything + NuqStream::ClusterBuf buf; + (void)NuqCodec::EncInterleaved(df, in.get(), total, buf, nuq_span, 0); + + using V = hn::Vec; + const size_t N = Lanes(d); + + for (size_t i = 0; i < total; i += 2 * N) { + V f0, f1; + NuqCodec::Dec2Interleaved(d, MakeConst(nuq_span), i, f0, f1); + + hn::StoreU(f0, d, dec1.get() + i + 0 * N); + hn::StoreU(f1, d, dec1.get() + i + 1 * N); + } + + for (size_t i = 0; i < total; ++i) { + if (dec0[i] != dec1[i]) { + fprintf(stderr, "dec0[%zu] = %g, dec1[%zu] = %g\n", i, (float)dec0[i], + i, (float)dec1[i]); + // stop_soon = i + 10; + } + + HWY_ASSERT(dec0[i] == dec1[i]); + } + } +}; + +void TestDec2BF16Interleaved() { + hn::ForGEVectors<128, TestDec2Interleaved>()(BF16()); +} +void TestDec2F32Interleaved() { + hn::ForGEVectors<128, TestDec2Interleaved>()(float()); +} + +// Can encode and decode sub-regions. +struct TestOffsetInterleaved { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = + 10 * kGroupSize + + (kGroupSize / + 2); // adding a partial group to test... partial group handling! + const size_t kMidLen = 2 * kGroupSize; // length of middle piece + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec0 = hwy::AllocateAligned(total); + auto dec1 = hwy::AllocateAligned(total); + auto dec2 = hwy::AllocateAligned(kMidLen); + auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + HWY_ASSERT(in && dec1 && dec2 && nuq); + const auto nuq_span = MakeSpan(nuq.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // Non-interleaved encode + decode for comparison + NuqStream::ClusterBuf buf0; + (void)NuqCodec::Enc(df, in.get(), total, buf0, nuq_span, 0); + NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec0.get(), + total); + + // Encode + decode everything + NuqStream::ClusterBuf buf; + (void)NuqCodec::EncInterleaved(df, in.get(), total, buf, nuq_span, 0); + NuqCodec::DecompressAndZeroPadInterleaved(d, MakeConst(nuq_span), 0, + dec1.get(), total); + + for (size_t i = 0; i < total; ++i) { + if (dec0[i] != dec1[i]) { + fprintf(stderr, "dec0[%zu] = %g, dec1[%zu] = %g\n", i, (float)dec0[i], + i, (float)dec1[i]); + } + + HWY_ASSERT(dec0[i] == dec1[i]); + } + + // Overwrite middle with first inputs + const size_t offset = 5 * kGroupSize; + (void)NuqCodec::EncInterleaved(df, in.get(), kMidLen, buf, nuq_span, + offset); + + // Decoded middle now matches previously decoded first + NuqCodec::DecompressAndZeroPadInterleaved(d, MakeConst(nuq_span), offset, + dec2.get(), kMidLen); + for (size_t i = 0; i < kMidLen; ++i) { + HWY_ASSERT(dec1[i] == dec2[i]); + } + } +}; + +void TestOffsetBF16Interleaved() { + hn::ForGEVectors<128, TestOffsetInterleaved>()(BF16()); +} +void TestOffsetF32Interleaved() { + hn::ForGEVectors<128, TestOffsetInterleaved>()(float()); +} + struct TestNibble { template HWY_INLINE void operator()(T /*unused*/, D d) { @@ -409,6 +596,12 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal); HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetF32); +HWY_EXPORT_AND_TEST_P(NuqTest, TestDec2BF16Interleaved); +HWY_EXPORT_AND_TEST_P(NuqTest, TestDec2F32Interleaved); +HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetBF16); +HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetF32); +HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetBF16Interleaved); +HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetF32Interleaved); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble); HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32); diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index a9d3894..130eb04 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -55,8 +55,7 @@ class SbsWriterImpl : public WriterInterface { template void AllocateAndCompress(const std::string& name, absl::Span weights) { - const size_t num_packed = CompressedArrayElements(weights.size()); - MatPtrT storage(name, 1, num_packed); + MatPtrT storage(name, 1, weights.size()); model_memory_.push_back(storage); model_memory_.back().Allocate(); storage.SetPtr(model_memory_.back()); diff --git a/compression/shared.h b/compression/shared.h index 74b7454..d257924 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -161,13 +161,16 @@ struct NuqStream { static constexpr size_t PackedStart(size_t capacity) { // Round up to avoid cache-line splits when loading indices. No effect on // size as long as capacity / kGroupSize is a multiple of 4. - return hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * kClusters, 64); + return kClusters; // hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * + // kClusters, 64); } // Returns number of NuqStream to allocate for the stream, which matches its // size in bytes. static constexpr size_t PackedEnd(size_t capacity) { - return PackedStart(capacity) + hwy::DivCeil(capacity, 2); // 2x 4-bit/byte + const size_t num_groups = hwy::DivCeil(capacity, kGroupSize); + return (kClusters * num_groups) + + hwy::DivCeil(capacity, 2); // 2x 4-bit/byte } uint8_t byte; diff --git a/gemma/instantiations/27b_nuq.cc b/gemma/instantiations/27b_nuq.cc new file mode 100644 index 0000000..91ccdcc --- /dev/null +++ b/gemma/instantiations/27b_nuq.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/gemma/instantiations/27b_nuq.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma2_27B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/2b_nuq.cc b/gemma/instantiations/2b_nuq.cc new file mode 100644 index 0000000..2e586bc --- /dev/null +++ b/gemma/instantiations/2b_nuq.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/gemma/instantiations/2b_nuq.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/7b_nuq.cc b/gemma/instantiations/7b_nuq.cc new file mode 100644 index 0000000..11e4676 --- /dev/null +++ b/gemma/instantiations/7b_nuq.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/gemma/instantiations/7b_nuq.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma7B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/9b_nuq.cc b/gemma/instantiations/9b_nuq.cc new file mode 100644 index 0000000..d9e7254 --- /dev/null +++ b/gemma/instantiations/9b_nuq.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/gemma/instantiations/9b_nuq.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma2_9B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/gemma2_2b_nuq.cc b/gemma/instantiations/gemma2_2b_nuq.cc new file mode 100644 index 0000000..39d3fe6 --- /dev/null +++ b/gemma/instantiations/gemma2_2b_nuq.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/gemma/instantiations/gemma2_2b_nuq.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma2_2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/gr2b_nuq.cc b/gemma/instantiations/gr2b_nuq.cc new file mode 100644 index 0000000..51e57f3 --- /dev/null +++ b/gemma/instantiations/gr2b_nuq.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/gemma/instantiations/gr2b_nuq.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGriffin2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/paligemma_224_nuq.cc b/gemma/instantiations/paligemma_224_nuq.cc new file mode 100644 index 0000000..733dc16 --- /dev/null +++ b/gemma/instantiations/paligemma_224_nuq.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/gemma/instantiations/paligemma_224_nuq.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigPaliGemma_224 +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/tiny_nuq.cc b/gemma/instantiations/tiny_nuq.cc new file mode 100644 index 0000000..0121cf5 --- /dev/null +++ b/gemma/instantiations/tiny_nuq.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/gemma/instantiations/tiny_nuq.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemmaTiny +#include "gemma/gemma-inl.h" diff --git a/gemma/weights.h b/gemma/weights.h index 60e9d13..f9cfbf1 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -26,6 +26,7 @@ #include #include +#include "compression/compress-inl.h" #include "compression/compress.h" #include "compression/shared.h" #include "gemma/common.h" @@ -33,6 +34,7 @@ #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" namespace gcpp { @@ -218,6 +220,41 @@ struct LayerWeightsPtrs { storage->Allocate(); att_weights.SetPtr(*storage); } + + if (hwy::IsSame()) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag df; + + hwy::AlignedFreeUniquePtr attn_vec_einsum_w_tmp = + hwy::AllocateAligned(model_dim * heads * qkv_dim); + hwy::AlignedFreeUniquePtr att_weights_tmp = + hwy::AllocateAligned(model_dim * heads * qkv_dim); + + HWY_NAMESPACE::DecompressAndZeroPad( + df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim), + 0, attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim); + + for (size_t m = 0; m < model_dim; ++m) { + float* HWY_RESTRICT out_row = + att_weights_tmp.get() + m * heads * qkv_dim; + for (size_t h = 0; h < heads; ++h) { + hwy::CopyBytes(attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim + + m * qkv_dim, + out_row + h * qkv_dim, qkv_dim * sizeof(float)); + } + } + + CompressWorkingSet work; + hwy::ThreadPool pool(0); + + HWY_NAMESPACE::Compress( + att_weights_tmp.get(), model_dim * heads * qkv_dim, work, + MakeSpan(att_weights.data(), model_dim * heads * qkv_dim), + /*packed_ofs=*/0, pool); + + return; + } + for (size_t m = 0; m < model_dim; ++m) { Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim; for (size_t h = 0; h < heads; ++h) {