Skip to content

Commit 68e77ee

Browse files
Merge pull request #221 from Workiva/add-bitarray-getsetbits
Add `GetSetBits` and `Count` to `BitArray`
2 parents c466da2 + 8f1c722 commit 68e77ee

File tree

5 files changed

+254
-2
lines changed

5 files changed

+254
-2
lines changed

bitarray/bitarray.go

+79-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ efficient way. This is *NOT* a threadsafe package.
2020
*/
2121
package bitarray
2222

23+
import "math/bits"
24+
2325
// bitArray is a struct that maintains state of a bit array.
2426
type bitArray struct {
2527
blocks []block
@@ -116,7 +118,74 @@ func (ba *bitArray) GetBit(k uint64) (bool, error) {
116118
return result, nil
117119
}
118120

119-
//ClearBit will unset a bit at the given index if it is set.
121+
// GetSetBits gets the position of bits set in the array.
122+
func (ba *bitArray) GetSetBits(from uint64, buffer []uint64) []uint64 {
123+
fromBlockIndex, fromOffset := getIndexAndRemainder(from)
124+
return getSetBitsInBlocks(
125+
fromBlockIndex,
126+
fromOffset,
127+
ba.blocks[fromBlockIndex:],
128+
nil,
129+
buffer,
130+
)
131+
}
132+
133+
// getSetBitsInBlocks fills a buffer with positions of set bits in the provided blocks. Optionally, indices may be
134+
// provided for sparse/non-consecutive blocks.
135+
func getSetBitsInBlocks(
136+
fromBlockIndex, fromOffset uint64,
137+
blocks []block,
138+
indices []uint64,
139+
buffer []uint64,
140+
) []uint64 {
141+
bufferCapacity := cap(buffer)
142+
if bufferCapacity == 0 {
143+
return buffer[:0]
144+
}
145+
146+
results := buffer[:bufferCapacity]
147+
resultSize := 0
148+
149+
for i, block := range blocks {
150+
blockIndex := fromBlockIndex + uint64(i)
151+
if indices != nil {
152+
blockIndex = indices[i]
153+
}
154+
155+
isFirstBlock := blockIndex == fromBlockIndex
156+
if isFirstBlock {
157+
block >>= fromOffset
158+
}
159+
160+
for block != 0 {
161+
trailing := bits.TrailingZeros64(uint64(block))
162+
163+
if isFirstBlock {
164+
results[resultSize] = uint64(trailing) + (blockIndex << 6) + fromOffset
165+
} else {
166+
results[resultSize] = uint64(trailing) + (blockIndex << 6)
167+
}
168+
resultSize++
169+
170+
if resultSize == cap(results) {
171+
return results[:resultSize]
172+
}
173+
174+
// Clear the bit we just added to the result, which is the last bit set in the block. Ex.:
175+
// block 01001100
176+
// ^block 10110011
177+
// (^block) + 1 10110100
178+
// block & (^block) + 1 00000100
179+
// block ^ mask 01001000
180+
mask := block & ((^block) + 1)
181+
block = block ^ mask
182+
}
183+
}
184+
185+
return results[:resultSize]
186+
}
187+
188+
// ClearBit will unset a bit at the given index if it is set.
120189
func (ba *bitArray) ClearBit(k uint64) error {
121190
if k >= ba.Capacity() {
122191
return OutOfRangeError(k)
@@ -137,6 +206,15 @@ func (ba *bitArray) ClearBit(k uint64) error {
137206
return nil
138207
}
139208

209+
// Count returns the number of set bits in this array.
210+
func (ba *bitArray) Count() int {
211+
count := 0
212+
for _, block := range ba.blocks {
213+
count += bits.OnesCount64(uint64(block))
214+
}
215+
return count
216+
}
217+
140218
// Or will bitwise or two bit arrays and return a new bit array
141219
// representing the result.
142220
func (ba *bitArray) Or(other BitArray) BitArray {

bitarray/bitarray_test.go

+70
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"testing"
2121

2222
"github.com/stretchr/testify/assert"
23+
"github.com/stretchr/testify/require"
2324
)
2425

2526
func TestBitOperations(t *testing.T) {
@@ -142,6 +143,28 @@ func TestIsEmpty(t *testing.T) {
142143
assert.False(t, ba.IsEmpty())
143144
}
144145

146+
func TestCount(t *testing.T) {
147+
ba := newBitArray(500)
148+
assert.Equal(t, 0, ba.Count())
149+
150+
require.NoError(t, ba.SetBit(0))
151+
assert.Equal(t, 1, ba.Count())
152+
153+
require.NoError(t, ba.SetBit(40))
154+
require.NoError(t, ba.SetBit(64))
155+
require.NoError(t, ba.SetBit(100))
156+
require.NoError(t, ba.SetBit(200))
157+
require.NoError(t, ba.SetBit(469))
158+
require.NoError(t, ba.SetBit(500))
159+
assert.Equal(t, 7, ba.Count())
160+
161+
require.NoError(t, ba.ClearBit(200))
162+
assert.Equal(t, 6, ba.Count())
163+
164+
ba.Reset()
165+
assert.Equal(t, 0, ba.Count())
166+
}
167+
145168
func TestClear(t *testing.T) {
146169
ba := newBitArray(10)
147170

@@ -195,6 +218,53 @@ func BenchmarkGetBit(b *testing.B) {
195218
}
196219
}
197220

221+
func TestGetSetBits(t *testing.T) {
222+
ba := newBitArray(1000)
223+
buf := make([]uint64, 0, 5)
224+
225+
require.NoError(t, ba.SetBit(1))
226+
require.NoError(t, ba.SetBit(4))
227+
require.NoError(t, ba.SetBit(8))
228+
require.NoError(t, ba.SetBit(63))
229+
require.NoError(t, ba.SetBit(64))
230+
require.NoError(t, ba.SetBit(200))
231+
require.NoError(t, ba.SetBit(1000))
232+
233+
assert.Equal(t, []uint64(nil), ba.GetSetBits(0, nil))
234+
assert.Equal(t, []uint64{}, ba.GetSetBits(0, []uint64{}))
235+
236+
assert.Equal(t, []uint64{1, 4, 8, 63, 64}, ba.GetSetBits(0, buf))
237+
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(10, buf))
238+
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(63, buf))
239+
assert.Equal(t, []uint64{200, 1000}, ba.GetSetBits(128, buf))
240+
241+
require.NoError(t, ba.ClearBit(4))
242+
require.NoError(t, ba.ClearBit(64))
243+
assert.Equal(t, []uint64{1, 8, 63, 200, 1000}, ba.GetSetBits(0, buf))
244+
assert.Empty(t, ba.GetSetBits(1001, buf))
245+
246+
ba.Reset()
247+
assert.Empty(t, ba.GetSetBits(0, buf))
248+
}
249+
250+
func BenchmarkGetSetBits(b *testing.B) {
251+
numItems := uint64(168000)
252+
253+
ba := newBitArray(numItems)
254+
for i := uint64(0); i < numItems; i++ {
255+
if i%13 == 0 || i%5 == 0 {
256+
require.NoError(b, ba.SetBit(i))
257+
}
258+
}
259+
260+
buf := make([]uint64, 0, ba.Capacity())
261+
262+
b.ResetTimer()
263+
for i := 0; i < b.N; i++ {
264+
ba.GetSetBits(0, buf)
265+
}
266+
}
267+
198268
func TestEquality(t *testing.T) {
199269
ba := newBitArray(s + 1)
200270
other := newBitArray(s + 1)

bitarray/interface.go

+6
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ type BitArray interface {
3636
// function returns an error if the position is out
3737
// of range. A sparse bit array never returns an error.
3838
GetBit(k uint64) (bool, error)
39+
// GetSetBits gets the position of bits set in the array. Will
40+
// return as many set bits as can fit in the provided buffer
41+
// starting from the specified position in the array.
42+
GetSetBits(from uint64, buffer []uint64) []uint64
3943
// ClearBit clears the bit at the given position. This
4044
// function returns an error if the position is out
4145
// of range. A sparse bit array never returns an error.
@@ -55,6 +59,8 @@ type BitArray interface {
5559
// in the case of a dense bit array or the highest possible
5660
// seen capacity of the sparse array.
5761
Capacity() uint64
62+
// Count returns the number of set bits in this array.
63+
Count() int
5864
// Or will bitwise or the two bitarrays and return a new bitarray
5965
// representing the result.
6066
Or(other BitArray) BitArray

bitarray/sparse_bitarray.go

+31-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ limitations under the License.
1616

1717
package bitarray
1818

19-
import "sort"
19+
import (
20+
"math/bits"
21+
"sort"
22+
)
2023

2124
// uintSlice is an alias for a slice of ints. Len, Swap, and Less
2225
// are exported to fulfill an interface needed for the search
@@ -127,6 +130,24 @@ func (sba *sparseBitArray) GetBit(k uint64) (bool, error) {
127130
return sba.blocks[i].get(position), nil
128131
}
129132

133+
// GetSetBits gets the position of bits set in the array.
134+
func (sba *sparseBitArray) GetSetBits(from uint64, buffer []uint64) []uint64 {
135+
fromBlockIndex, fromOffset := getIndexAndRemainder(from)
136+
137+
fromBlockLocation := sba.indices.search(fromBlockIndex)
138+
if int(fromBlockLocation) == len(sba.indices) {
139+
return buffer[:0]
140+
}
141+
142+
return getSetBitsInBlocks(
143+
fromBlockIndex,
144+
fromOffset,
145+
sba.blocks[fromBlockLocation:],
146+
sba.indices[fromBlockLocation:],
147+
buffer,
148+
)
149+
}
150+
130151
// ToNums converts this sparse bitarray to a list of numbers contained
131152
// within it.
132153
func (sba *sparseBitArray) ToNums() []uint64 {
@@ -225,6 +246,15 @@ func (sba *sparseBitArray) Equals(other BitArray) bool {
225246
return true
226247
}
227248

249+
// Count returns the number of set bits in this array.
250+
func (sba *sparseBitArray) Count() int {
251+
count := 0
252+
for _, block := range sba.blocks {
253+
count += bits.OnesCount64(uint64(block))
254+
}
255+
return count
256+
}
257+
228258
// Or will perform a bitwise or operation with the provided bitarray and
229259
// return a new result bitarray.
230260
func (sba *sparseBitArray) Or(other BitArray) BitArray {

bitarray/sparse_bitarray_test.go

+68
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"testing"
2121

2222
"github.com/stretchr/testify/assert"
23+
"github.com/stretchr/testify/require"
2324
)
2425

2526
func TestGetCompressedBit(t *testing.T) {
@@ -76,6 +77,73 @@ func BenchmarkSetCompressedBit(b *testing.B) {
7677
}
7778
}
7879

80+
func TestGetSetCompressedBits(t *testing.T) {
81+
ba := newSparseBitArray()
82+
buf := make([]uint64, 0, 5)
83+
84+
require.NoError(t, ba.SetBit(1))
85+
require.NoError(t, ba.SetBit(4))
86+
require.NoError(t, ba.SetBit(8))
87+
require.NoError(t, ba.SetBit(63))
88+
require.NoError(t, ba.SetBit(64))
89+
require.NoError(t, ba.SetBit(200))
90+
require.NoError(t, ba.SetBit(1000))
91+
92+
assert.Equal(t, []uint64(nil), ba.GetSetBits(0, nil))
93+
assert.Equal(t, []uint64{}, ba.GetSetBits(0, []uint64{}))
94+
95+
assert.Equal(t, []uint64{1, 4, 8, 63, 64}, ba.GetSetBits(0, buf))
96+
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(10, buf))
97+
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(63, buf))
98+
assert.Equal(t, []uint64{200, 1000}, ba.GetSetBits(128, buf))
99+
100+
require.NoError(t, ba.ClearBit(4))
101+
require.NoError(t, ba.ClearBit(64))
102+
assert.Equal(t, []uint64{1, 8, 63, 200, 1000}, ba.GetSetBits(0, buf))
103+
assert.Empty(t, ba.GetSetBits(1001, buf))
104+
105+
ba.Reset()
106+
assert.Empty(t, ba.GetSetBits(0, buf))
107+
}
108+
109+
func BenchmarkGetSetCompressedBits(b *testing.B) {
110+
ba := newSparseBitArray()
111+
for i := uint64(0); i < 168000; i++ {
112+
if i%13 == 0 || i%5 == 0 {
113+
require.NoError(b, ba.SetBit(i))
114+
}
115+
}
116+
117+
buf := make([]uint64, 0, ba.Capacity())
118+
119+
b.ResetTimer()
120+
for i := 0; i < b.N; i++ {
121+
ba.GetSetBits(0, buf)
122+
}
123+
}
124+
125+
func TestCompressedCount(t *testing.T) {
126+
ba := newSparseBitArray()
127+
assert.Equal(t, 0, ba.Count())
128+
129+
require.NoError(t, ba.SetBit(0))
130+
assert.Equal(t, 1, ba.Count())
131+
132+
require.NoError(t, ba.SetBit(40))
133+
require.NoError(t, ba.SetBit(64))
134+
require.NoError(t, ba.SetBit(100))
135+
require.NoError(t, ba.SetBit(200))
136+
require.NoError(t, ba.SetBit(469))
137+
require.NoError(t, ba.SetBit(500))
138+
assert.Equal(t, 7, ba.Count())
139+
140+
require.NoError(t, ba.ClearBit(200))
141+
assert.Equal(t, 6, ba.Count())
142+
143+
ba.Reset()
144+
assert.Equal(t, 0, ba.Count())
145+
}
146+
79147
func TestClearCompressedBit(t *testing.T) {
80148
ba := newSparseBitArray()
81149
ba.SetBit(5)

0 commit comments

Comments
 (0)