Skip to content

Commit

Permalink
brain/kvbrain: reimplement forgetting
Browse files Browse the repository at this point in the history
Fixes #43.
  • Loading branch information
zephyrtronium committed Jan 1, 2025
1 parent 6bf6296 commit 0bc2085
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 187 deletions.
3 changes: 0 additions & 3 deletions brain/braintest/braintest.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@ func testForget(ctx context.Context, br brain.Interface) func(t *testing.T) {
}
got := speak(ctx, t, br, "kessoku", "", 2048)
want := map[string]struct{}{
// The current brains should delete the "member" with ID 1, but we
// don't strictly require it.
// This should change anyway once we stop deleting by tuples.
"2#member ryou": {},
"2 3#member ryou": {},
"2 4#member ryou": {},
Expand Down
70 changes: 9 additions & 61 deletions brain/kvbrain/forget.go
Original file line number Diff line number Diff line change
@@ -1,76 +1,24 @@
package kvbrain

import (
"bytes"
"context"
"fmt"
"slices"
"sync"

"github.com/zephyrtronium/robot/userhash"
"github.com/dgraph-io/badger/v4"
)

type past struct {
mu sync.Mutex

k uint8
key [256][][]byte
id [256]string
user [256]userhash.Hash
time [256]int64 // unix nano
}

// record associates a message with a knowledge key.
func (p *past) record(id string, user userhash.Hash, nanotime int64, keys [][]byte) {
p.mu.Lock()
p.key[p.k] = slices.Grow(p.key[p.k][:0], len(keys))[:len(keys)]
for i, key := range keys {
p.key[p.k][i] = append(p.key[p.k][i][:0], key...)
}
p.id[p.k] = id
p.user[p.k] = user
p.time[p.k] = nanotime
p.k++
p.mu.Unlock()
}

// findID finds all keys corresponding to the given UUID.
func (p *past) findID(id string) [][]byte {
r := make([][]byte, 0, 64)
p.mu.Lock()
defer p.mu.Unlock()
for k, v := range p.id {
if v == id {
keys := p.key[k]
r = slices.Grow(r, len(keys))
for _, v := range keys {
r = append(r, bytes.Clone(v))
}
return r
}
}
return nil
}

// Forget forgets everything learned from a single given message.
// If nothing has been learned from the message, it should be ignored.
func (br *Brain) Forget(ctx context.Context, tag, id string) error {
past, _ := br.past.Load(tag)
if past == nil {
return nil
}
keys := past.findID(id)
batch := br.knowledge.NewWriteBatch()
defer batch.Cancel()
for _, key := range keys {
err := batch.Delete(key)
if err != nil {
return err
}
}
err := batch.Flush()
err := br.knowledge.Update(func(txn *badger.Txn) error {
k := make([]byte, 0, tagHashLen+2+len(id))
k = hashTag(k, tag)
k = append(k, 0xfe, 0xfe)
k = append(k, id...)
return txn.Set(k, []byte{})
})
if err != nil {
return fmt.Errorf("couldn't commit deleting message %v: %w", id, err)
return fmt.Errorf("couldn't forget: %w", err)
}
return nil
}
115 changes: 13 additions & 102 deletions brain/kvbrain/forget_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package kvbrain

import (
"bytes"
"context"
"slices"
"testing"
"time"

Expand All @@ -13,102 +11,6 @@ import (
"github.com/zephyrtronium/robot/userhash"
)

func TestPastRecord(t *testing.T) {
var p past
ch := make(chan struct{})
for i := range len(p.id) {
go func() {
p.record("1"+string(byte(i)), userhash.Hash{2, byte(i)}, int64(i), [][]byte{{4, byte(i)}, {5}})
ch <- struct{}{}
}()
}
for range len(p.id) {
<-ch
}
if p.k != 0 {
t.Errorf("wrong final index: %d should be zero", p.k)
}
for k := range len(p.id) {
key, id, user, time := p.key[k], p.id[k], p.user[k], p.time[k]
if want := [][]byte{{4, byte(time)}, {5}}; !slices.EqualFunc(key, want, bytes.Equal) {
t.Errorf("wrong association between key and time: want %v, got %v", want, key)
}
if want := ("1" + string(byte(time))); id != want {
t.Errorf("wrong association between id and time: want %v, got %v", want, id)
}
if want := (userhash.Hash{2, byte(time)}); user != want {
t.Errorf("wrong association between user and time: want %v, got %v", want, user)
}
}
// Do it again to verify we overwrite.
for i := range len(p.id) {
go func() {
p.record("5"+string(byte(i)), userhash.Hash{6, byte(i)}, int64(i), [][]byte{{8, byte(i)}, {9}})
ch <- struct{}{}
}()
}
for range len(p.id) {
<-ch
}
if p.k != 0 {
t.Errorf("wrong final index: %d should be zero", p.k)
}
for k := range len(p.id) {
key, id, user, time := p.key[k], p.id[k], p.user[k], p.time[k]
if want := [][]byte{{8, byte(time)}, {9}}; !slices.EqualFunc(key, want, bytes.Equal) {
t.Errorf("wrong association between key and time: want %v, got %v", want, key)
}
if want := ("5" + string(byte(time))); id != want {
t.Errorf("wrong association between id and time: want %v, got %v", want, id)
}
if want := (userhash.Hash{6, byte(time)}); user != want {
t.Errorf("wrong association between user and time: want %v, got %v", want, user)
}
}
}

func TestPastFind(t *testing.T) {
uu := "1"
p := past{
k: 127,
key: [256][][]byte{255: {[]byte("bocchi")}},
id: [256]string{255: uu},
user: [256]userhash.Hash{255: {2}},
time: [256]int64{255: 3},
}
if got, want := p.findID(uu), [][]byte{[]byte("bocchi")}; !slices.EqualFunc(got, want, bytes.Equal) {
t.Errorf("wrong key: want %q, got %q", want, got)
}
if got := p.findID("fake"); got != nil {
t.Errorf("non-nil key %q finding fake uuid", got)
}
}

func BenchmarkPastRecord(b *testing.B) {
var p past
uu := "1"
user := userhash.Hash{2}
b.ReportAllocs()
for i := range b.N {
p.record(uu, user, int64(i), [][]byte{{byte(i)}})
}
}

func BenchmarkPastFind(b *testing.B) {
var p past
for i := range len(p.id) {
p.record(string(byte(i)), userhash.Hash{byte(i)}, int64(i), [][]byte{{byte(i)}})
}
b.ReportAllocs()
b.ResetTimer()
for i := range b.N {
use(p.findID(string(byte(i))))
}
}

//go:noinline
func use(x [][]byte) {}

func TestForget(t *testing.T) {
type message struct {
id string
Expand Down Expand Up @@ -136,8 +38,11 @@ func TestForget(t *testing.T) {
},
},
},
uu: "1",
want: map[string]string{},
uu: "1",
want: map[string]string{
mkey("kessoku", "bocchi\xff\xff", "1"): "ryou",
mkey("kessoku", "\xfe\xfe", "1"): "",
},
},
{
name: "several",
Expand All @@ -153,8 +58,12 @@ func TestForget(t *testing.T) {
},
},
},
uu: "1",
want: map[string]string{},
uu: "1",
want: map[string]string{
mkey("kessoku", "bocchi\xff\xff", "1"): "ryou",
mkey("kessoku", "nijika\xff\xff", "1"): "kita",
mkey("kessoku", "\xfe\xfe", "1"): "",
},
},
{
name: "tagged",
Expand All @@ -172,6 +81,7 @@ func TestForget(t *testing.T) {
uu: "1",
want: map[string]string{
mkey("sickhack", "bocchi\xff\xff", "1"): "ryou",
mkey("kessoku", "\xfe\xfe", "1"): "",
},
},
{
Expand All @@ -190,6 +100,7 @@ func TestForget(t *testing.T) {
uu: "2",
want: map[string]string{
mkey("kessoku", "bocchi\xff\xff", "1"): "ryou",
mkey("kessoku", "\xfe\xfe", "2"): "",
},
},
}
Expand Down
2 changes: 0 additions & 2 deletions brain/kvbrain/kvbrain.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"io"

"github.com/dgraph-io/badger/v4"
"gopkg.in/typ.v4/sync2"

"github.com/zephyrtronium/robot/brain"
)
Expand Down Expand Up @@ -37,7 +36,6 @@ Operations:

type Brain struct {
knowledge *badger.DB
past sync2.Map[string, *past]
}

var _ brain.Interface = (*Brain)(nil)
Expand Down
9 changes: 0 additions & 9 deletions brain/kvbrain/learn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,6 @@ func (br *Brain) Learn(ctx context.Context, tag string, msg *brain.Message, tupl
vals[i] = []byte(t.Suffix)
}

p, _ := br.past.Load(tag)
if p == nil {
// We might race with others also creating this past. Ensure we don't
// overwrite if that happens.
p, _ = br.past.LoadOrStore(tag, new(past))
}
// Scale the timestamp from milliseconds to nanoseconds for historical reasons.
p.record(msg.ID, msg.Sender, msg.Timestamp*1e6, keys)

batch := br.knowledge.NewWriteBatch()
defer batch.Cancel()
for i, key := range keys {
Expand Down
54 changes: 44 additions & 10 deletions brain/kvbrain/speak.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func (br *Brain) Think(ctx context.Context, tag string, prompt []string) iter.Se
// We don't actually need to iterate over values, only the single value
// that we decide to use per suffix. So, we can disable value prefetch.
opts.PrefetchValues = false
d := make([]byte, 0, 64)
b := make([]byte, 0, 128)
b = append(b, opts.Prefix...)
b = appendPrefix(b, prompt)
Expand All @@ -40,14 +41,8 @@ func (br *Brain) Think(ctx context.Context, tag string, prompt []string) iter.Se
// after its definition, because our half of the loop needs to use
// them too.
f := func(id, suf *[]byte) error {
// The id is everything after the first byte following the hash
// for empty prefixes, and everything after the first \xff\xff
// otherwise.
k := key[tagHashLen+1:]
if len(prompt) > 0 {
_, k, _ = bytes.Cut(k, []byte{0xff, 0xff})
}
*id = k
_, _, t := keyparts(key)
*id = append(*id, t...)
var err error
*suf, err = item.ValueCopy(*suf)
if err != nil {
Expand All @@ -57,9 +52,22 @@ func (br *Brain) Think(ctx context.Context, tag string, prompt []string) iter.Se
}
for it.ValidForPrefix(b) {
item = it.Item()
// TODO(zeph): for #43, check deleted uuids so we never pick
// a message that has been deleted
key = item.KeyCopy(key[:0])
// Check whether the message ID is deleted.
tag, _, id := keyparts(key)
d = append(d[:0], tag...)
d = append(d, 0xfe, 0xfe)
d = append(d, id...)
switch _, err := txn.Get(d); err {
case badger.ErrKeyNotFound: // do nothing
case nil:
// The fact that there is a value means the message was
// forgotten. Don't yield it.
it.Next()
continue
default:
erf(fmt.Errorf("couldn't check for deleted message: %w", err))
}
if !yield(f) {
break
}
Expand All @@ -72,3 +80,29 @@ func (br *Brain) Think(ctx context.Context, tag string, prompt []string) iter.Se
}
}
}

func keyparts(key []byte) (tag, content, id []byte) {
if len(key) < tagHashLen+2 {
return nil, nil, nil
}
tag = key[:tagHashLen]
switch key[tagHashLen] {
case 0xff:
// Empty prefix sentinel. The rest is the ID.
content = key[tagHashLen : tagHashLen+1]
id = key[tagHashLen+1:]
case 0xfe:
// Deleted ID sentinel. Two bytes long, and the rest is the ID.
content = key[tagHashLen : tagHashLen+2]
id = key[tagHashLen+2:]
default:
// Non-empty prefix. Ends after \xff\xff.
k := bytes.Index(key[tagHashLen:], []byte{0xff, 0xff})
if k < 0 {
panic("kvbrain: invalid key")
}
content = key[tagHashLen : tagHashLen+k+2]
id = key[tagHashLen+k+2:]
}
return tag, content, id
}

0 comments on commit 0bc2085

Please sign in to comment.