From 0d6173da3a1cdf5f6e39711bcd60ac7bdd6221c7 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 6 Feb 2025 17:01:54 +0800 Subject: [PATCH] Remove leading zero bytes from bit array --- core/trie2/trieutils/bitarray.go | 33 +++++++++++++++++++------ core/trie2/trieutils/bitarray_test.go | 35 +++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/core/trie2/trieutils/bitarray.go b/core/trie2/trieutils/bitarray.go index e9f781fab3..6d53057225 100644 --- a/core/trie2/trieutils/bitarray.go +++ b/core/trie2/trieutils/bitarray.go @@ -15,6 +15,7 @@ import ( const ( maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF maxUint8 = uint8(math.MaxUint8) + bytes32 = 32 MaxBitArraySize = 33 // (1 + 4 * 8) bytes ) @@ -456,7 +457,7 @@ func (b *BitArray) IsEmpty() bool { // Serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) -// - Remaining bytes: the necessary bytes included in big endian order +// - Remaining bytes: the necessary bytes included in big endian order, without leading zeros // Example: // // BitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] @@ -481,16 +482,17 @@ func (b *BitArray) UnmarshalBinary(data []byte) error { } length := data[0] - byteCount := (uint(length) + 7) / 8 // Round up to nearest byte + byteCount := (int(length) + 7) / 8 // Get the total number of bytes needed to represent the bit array - if len(data) < int(byteCount)+1 { - return fmt.Errorf("invalid data length: got %d bytes, expected %d", len(data), byteCount+1) + if len(data) > byteCount+1 { + return fmt.Errorf("invalid data length: got %d bytes, expected <= %d", len(data), byteCount+1) } b.len = length var bs [32]byte - copy(bs[32-byteCount:], data[1:]) + bitArrBytes := data[1:] + copy(bs[32-len(bitArrBytes):], bitArrBytes) // Fill up the non-zero bytes at the end of the byte array b.setBytes32(bs[:]) return nil @@ -715,14 +717,31 @@ func (b *BitArray) byteCount() uint { } // Returns a slice containing only the bytes that are actually used by the bit array, -// as specified by the length. The returned slice is in big-endian order. +// as specified by the length. Leading zero bytes will be removed. +// The returned slice is in big-endian order. // // Example: // // len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] func (b *BitArray) activeBytes() []byte { + if b.len == 0 { + return nil + } + wordsBytes := b.Bytes() - return wordsBytes[32-b.byteCount():] + end := uint(bytes32) + start := end - b.byteCount() + + // Find first non-zero byte + for start < end-1 && wordsBytes[start] == 0 { + start++ + } + + if start == end { + return nil + } + + return wordsBytes[start:end] } func (b *BitArray) rsh64(x *BitArray) { diff --git a/core/trie2/trieutils/bitarray_test.go b/core/trie2/trieutils/bitarray_test.go index 6589f94939..1835b538bf 100644 --- a/core/trie2/trieutils/bitarray_test.go +++ b/core/trie2/trieutils/bitarray_test.go @@ -1095,6 +1095,41 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, want: []byte{16, 0xAA, 0xAA}, // length byte + 2 data bytes }, + { + name: "leading zeros in first byte", + ba: BitArray{ + len: 8, + words: [4]uint64{0x0F, 0, 0, 0}, // 00001111 + }, + want: []byte{8, 0x0F}, // length byte + 1 data byte + }, + { + name: "all leading zeros in first byte", + ba: BitArray{ + len: 8, + words: [4]uint64{0x00, 0, 0, 0}, // 00000000 + }, + want: []byte{8, 0x00}, // length byte + 1 data byte + }, + { + name: "leading zeros across multiple bytes", + ba: BitArray{ + len: 24, + words: [4]uint64{0x0000FF, 0, 0, 0}, // 000000000000000011111111 + }, + want: []byte{24, 0xFF}, // length byte + 1 data byte + }, + { + name: "leading zeros in large number", + ba: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, // All 1s in lower bits, zeros in upper bits + }, + want: append( + []byte{255}, // length byte + bytes.Repeat([]byte{0xFF}, 16)..., // 16 bytes of all 1s + ), + }, } for _, tt := range tests {