From 5456e65e8b213befad8d582801ce51a64538585e Mon Sep 17 00:00:00 2001 From: Branden J Brown Date: Thu, 15 Aug 2024 12:48:33 -0500 Subject: [PATCH] brain/*: record trace of message ids used to speak For #44. --- brain/braintest/bench.go | 6 +- brain/braintest/braintest.go | 182 ++++++++++++++++++++---------- brain/braintest/braintest_test.go | 93 +++++++-------- brain/builder.go | 50 ++++++++ brain/builder_test.go | 116 +++++++++++++++++++ brain/kvbrain/speak.go | 27 +++-- brain/kvbrain/speak_test.go | 6 +- brain/speak.go | 28 +++-- brain/speak_test.go | 25 +++- brain/sqlbrain/speak.go | 55 +++++---- brain/sqlbrain/speak_test.go | 34 +----- command/talk.go | 3 +- privmsg.go | 3 +- 13 files changed, 432 insertions(+), 196 deletions(-) create mode 100644 brain/builder.go create mode 100644 brain/builder_test.go diff --git a/brain/braintest/bench.go b/brain/braintest/bench.go index b97a654..9a7ad17 100644 --- a/brain/braintest/bench.go +++ b/brain/braintest/bench.go @@ -125,7 +125,7 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - if _, err := brain.Speak(ctx, br, "bocchi", ""); err != nil { + if _, _, err := brain.Speak(ctx, br, "bocchi", ""); err != nil { b.Errorf("error while speaking: %v", err) } } @@ -170,7 +170,7 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - if _, err := brain.Speak(ctx, br, "bocchi", ""); err != nil { + if _, _, err := brain.Speak(ctx, br, "bocchi", ""); err != nil { b.Errorf("error while speaking: %v", err) } } @@ -215,7 +215,7 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - if _, err := brain.Speak(ctx, br, "bocchi", toks[rand.IntN(len(toks)-1)]); err != nil { + if _, _, err := brain.Speak(ctx, br, "bocchi", toks[rand.IntN(len(toks)-1)]); err != nil { b.Errorf("error while speaking: %v", err) } } diff --git a/brain/braintest/braintest.go b/brain/braintest/braintest.go index a6d80ac..9663809 100644 --- a/brain/braintest/braintest.go +++ b/brain/braintest/braintest.go @@ -3,12 +3,12 @@ package braintest import ( "context" - "maps" "slices" "strings" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/zephyrtronium/robot/brain" "github.com/zephyrtronium/robot/userhash" ) @@ -111,15 +111,15 @@ func learn(ctx context.Context, t *testing.T, br brain.Learner) { } } -func speak(ctx context.Context, t *testing.T, br brain.Speaker, tag, prompt string, iters int) map[string]bool { +func speak(ctx context.Context, t *testing.T, br brain.Speaker, tag, prompt string, iters int) map[string]struct{} { t.Helper() - got := make(map[string]bool, 5) + got := make(map[string]struct{}, 20) for range iters { - s, err := brain.Speak(ctx, br, tag, prompt) + s, trace, err := brain.Speak(ctx, br, tag, prompt) if err != nil { t.Errorf("couldn't speak: %v", err) } - got[s] = true + got[strings.Join(trace, " ")+"#"+s] = struct{}{} } return got } @@ -128,33 +128,57 @@ func speak(ctx context.Context, t *testing.T, br brain.Speaker, tag, prompt stri func testSpeak(ctx context.Context, br brain.Brain) func(t *testing.T) { return func(t *testing.T) { learn(ctx, t, br) - got := speak(ctx, t, br, "kessoku", "", 256) - want := map[string]bool{ - "member bocchi": true, - "member ryou": true, - "member nijika": true, - "member kita": true, + got := speak(ctx, t, br, "kessoku", "", 2048) + want := map[string]struct{}{ + "1#member bocchi": {}, + "1 2#member bocchi": {}, + "1 3#member bocchi": {}, + "1 4#member bocchi": {}, + "1 2#member ryou": {}, + "2#member ryou": {}, + "2 3#member ryou": {}, + "2 4#member ryou": {}, + "1 3#member nijika": {}, + "2 3#member nijika": {}, + "3#member nijika": {}, + "3 4#member nijika": {}, + "1 4#member kita": {}, + "2 4#member kita": {}, + "3 4#member kita": {}, + "4#member kita": {}, } - if !maps.Equal(got, want) { - t.Errorf("wrong spoken messages for kessoku: want %v, got %v", want, got) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong spoken messages for kessoku (+got/-want):\n%s", diff) } - got = speak(ctx, t, br, "sickhack", "", 256) - want = map[string]bool{ - "member bocchi": true, - "member ryou": true, - "member nijika": true, - "member kita": true, - "manager seika": true, + got = speak(ctx, t, br, "sickhack", "", 2048) + want = map[string]struct{}{ + "5#member bocchi": {}, + "5 6#member bocchi": {}, + "5 7#member bocchi": {}, + "5 8#member bocchi": {}, + "5 6#member ryou": {}, + "6#member ryou": {}, + "6 7#member ryou": {}, + "6 8#member ryou": {}, + "5 7#member nijika": {}, + "6 7#member nijika": {}, + "7#member nijika": {}, + "7 8#member nijika": {}, + "5 8#member kita": {}, + "6 8#member kita": {}, + "7 8#member kita": {}, + "8#member kita": {}, + "9#manager seika": {}, } - if !maps.Equal(got, want) { - t.Errorf("wrong spoken messages for sickhack: want %v, got %v", want, got) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong spoken messages for sickhack (+got/-want):\n%s", diff) } got = speak(ctx, t, br, "sickhack", "manager", 32) - want = map[string]bool{ - "manager seika": true, + want = map[string]struct{}{ + "9#manager seika": {}, } - if !maps.Equal(got, want) { - t.Errorf("wrong prompted messages for sickhack: want %v, got %v", want, got) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong prompted messages for sickhack (+got/-want):\n%s", diff) } } } @@ -167,16 +191,19 @@ func testForget(ctx context.Context, br brain.Brain) func(t *testing.T) { t.Errorf("couldn't forget: %v", err) } for range 100 { - s, err := brain.Speak(ctx, br, "kessoku", "") + s, trace, err := brain.Speak(ctx, br, "kessoku", "") if err != nil { t.Errorf("couldn't speak: %v", err) } if strings.Contains(s, "bocchi") { t.Errorf("remembered that which must be forgotten: %q", s) } + if trace[len(trace)-1] == messages[0].ID { + t.Errorf("id %q should have been forgotten but was used in %q", messages[0].ID, trace) + } } for range 10000 { - s, err := brain.Speak(ctx, br, "sickhack", "") + s, _, err := brain.Speak(ctx, br, "sickhack", "") if err != nil { t.Errorf("couldn't speak: %v", err) } @@ -195,25 +222,46 @@ func testForgetMessage(ctx context.Context, br brain.Brain) func(t *testing.T) { if err := br.ForgetMessage(ctx, "kessoku", messages[0].ID); err != nil { t.Errorf("failed to forget first message: %v", err) } - got := speak(ctx, t, br, "kessoku", "", 256) - want := map[string]bool{ - "member ryou": true, - "member nijika": true, - "member kita": true, + got := speak(ctx, t, br, "kessoku", "", 2048) + want := map[string]struct{}{ + // The current brains should delete the "member" with ID 1, but we + // don't strictly require it. + // This should change anyway once we stop deleting by tuples. + "2#member ryou": {}, + "2 3#member ryou": {}, + "2 4#member ryou": {}, + "2 3#member nijika": {}, + "3#member nijika": {}, + "3 4#member nijika": {}, + "2 4#member kita": {}, + "3 4#member kita": {}, + "4#member kita": {}, } - if !maps.Equal(got, want) { - t.Errorf("wrong messages after forgetting: want %v, got %v", want, got) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong messages after forgetting (+got/-want):\n%s", diff) } - got = speak(ctx, t, br, "sickhack", "", 256) - want = map[string]bool{ - "member bocchi": true, - "member ryou": true, - "member nijika": true, - "member kita": true, - "manager seika": true, + got = speak(ctx, t, br, "sickhack", "", 2048) + want = map[string]struct{}{ + "5#member bocchi": {}, + "5 6#member bocchi": {}, + "5 7#member bocchi": {}, + "5 8#member bocchi": {}, + "5 6#member ryou": {}, + "6#member ryou": {}, + "6 7#member ryou": {}, + "6 8#member ryou": {}, + "5 7#member nijika": {}, + "6 7#member nijika": {}, + "7#member nijika": {}, + "7 8#member nijika": {}, + "5 8#member kita": {}, + "6 8#member kita": {}, + "7 8#member kita": {}, + "8#member kita": {}, + "9#manager seika": {}, } - if !maps.Equal(got, want) { - t.Errorf("wrong messages in other tag after forgetting: want %v, got %v", want, got) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong messages in other tag after forgetting (+got/-want):\n%s", diff) } } } @@ -225,24 +273,38 @@ func testForgetDuring(ctx context.Context, br brain.Brain) func(t *testing.T) { if err := br.ForgetDuring(ctx, "kessoku", time.Unix(1, 0).Add(-time.Millisecond), time.Unix(2, 0).Add(time.Millisecond)); err != nil { t.Errorf("failed to forget: %v", err) } - got := speak(ctx, t, br, "kessoku", "", 256) - want := map[string]bool{ - "member bocchi": true, - "member kita": true, + got := speak(ctx, t, br, "kessoku", "", 2048) + want := map[string]struct{}{ + "1#member bocchi": {}, + "1 4#member bocchi": {}, + "1 4#member kita": {}, + "4#member kita": {}, } - if !maps.Equal(got, want) { - t.Errorf("wrong messages after forgetting: want %v, got %v", want, got) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong messages after forgetting (+got/-want):\n%s", diff) } - got = speak(ctx, t, br, "sickhack", "", 256) - want = map[string]bool{ - "member bocchi": true, - "member ryou": true, - "member nijika": true, - "member kita": true, - "manager seika": true, + got = speak(ctx, t, br, "sickhack", "", 2048) + want = map[string]struct{}{ + "5#member bocchi": {}, + "5 6#member bocchi": {}, + "5 7#member bocchi": {}, + "5 8#member bocchi": {}, + "5 6#member ryou": {}, + "6#member ryou": {}, + "6 7#member ryou": {}, + "6 8#member ryou": {}, + "5 7#member nijika": {}, + "6 7#member nijika": {}, + "7#member nijika": {}, + "7 8#member nijika": {}, + "5 8#member kita": {}, + "6 8#member kita": {}, + "7 8#member kita": {}, + "8#member kita": {}, + "9#manager seika": {}, } - if !maps.Equal(got, want) { - t.Errorf("wrong spoken messages for sickhack: want %v, got %v", want, got) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong spoken messages for sickhack (+got/-want):\n%s", diff) } } } @@ -278,7 +340,7 @@ func testCombinatoric(ctx context.Context, br brain.Brain) func(t *testing.T) { } } allocs := testing.AllocsPerRun(10, func() { - _, err := brain.Speak(ctx, br, "bocchi", "") + _, _, err := brain.Speak(ctx, br, "bocchi", "") if err != nil { t.Errorf("couldn't speak: %v", err) } diff --git a/brain/braintest/braintest_test.go b/brain/braintest/braintest_test.go index 803c032..d13fe09 100644 --- a/brain/braintest/braintest_test.go +++ b/brain/braintest/braintest_test.go @@ -18,10 +18,9 @@ import ( // to verify that the integration tests test the correct things. type membrain struct { mu sync.Mutex - tups map[string]map[string][]string // map of tags to map of prefixes to suffixes - users map[userhash.Hash][][3]string // map of hashes to tag and prefix+suffix - ids map[string]map[string][][2]string // map of tags to map of ids to prefix+suffix - tms map[string]map[int64][][2]string // map of tags to map of timestamps to prefix+suffix + tups map[string]map[string][][2]string // map of tags to map of prefixes to id and suffix + users map[userhash.Hash][][2]string // map of hashes to tag and id + tms map[string]map[int64][]string // map of tags to map of timestamps to ids } var _ brain.Brain = (*membrain)(nil) @@ -29,42 +28,42 @@ var _ brain.Brain = (*membrain)(nil) func (m *membrain) Learn(ctx context.Context, tag, id string, user userhash.Hash, t time.Time, tuples []brain.Tuple) error { m.mu.Lock() defer m.mu.Unlock() - if m.tups == nil { - m.tups = make(map[string]map[string][]string) - m.users = make(map[userhash.Hash][][3]string) - m.ids = make(map[string]map[string][][2]string) - m.tms = make(map[string]map[int64][][2]string) - } if m.tups[tag] == nil { - m.tups[tag] = make(map[string][]string) - } - r := m.tups[tag] - if m.ids[tag] == nil { - m.ids[tag] = make(map[string][][2]string) - } - ids := m.ids[tag] - if m.tms[tag] == nil { - m.tms[tag] = make(map[int64][][2]string) + if m.tups == nil { + m.tups = make(map[string]map[string][][2]string) + m.users = make(map[userhash.Hash][][2]string) + m.tms = make(map[string]map[int64][]string) + } + m.tups[tag] = make(map[string][][2]string) + m.tms[tag] = make(map[int64][]string) } + m.users[user] = append(m.users[user], [2]string{tag, id}) tms := m.tms[tag] + tms[t.UnixNano()] = append(tms[t.UnixNano()], id) + r := m.tups[tag] for _, tup := range tuples { p := strings.Join(tup.Prefix, "\xff") - r[p] = append(r[p], tup.Suffix) - m.users[user] = append(m.users[user], [3]string{tag, p, tup.Suffix}) - ids[id] = append(ids[id], [2]string{p, tup.Suffix}) - tms[t.UnixNano()] = append(tms[t.UnixNano()], [2]string{p, tup.Suffix}) + r[p] = append(r[p], [2]string{id, tup.Suffix}) } return nil } -func (m *membrain) forgetLocked(tag, p, s string) { - u := m.tups[tag][p] - k := slices.Index(u, s) - if k < 0 { - return +func (m *membrain) forgetIDLocked(tag, id string) { + for p, u := range m.tups[tag] { + for len(u) > 0 { + k := slices.IndexFunc(u, func(v [2]string) bool { return v[0] == id }) + if k < 0 { + break + } + u[k], u[len(u)-1] = u[len(u)-1], u[k] + u = u[:len(u)-1] + } + if len(u) != 0 { + m.tups[tag][p] = u + } else { + delete(m.tups[tag], p) + } } - u[k], u[len(u)-1] = u[len(u)-1], u[k] - m.tups[tag][p] = u[:len(u)-1] } func (m *membrain) Forget(ctx context.Context, tag string, tuples []brain.Tuple) error { @@ -72,7 +71,13 @@ func (m *membrain) Forget(ctx context.Context, tag string, tuples []brain.Tuple) defer m.mu.Unlock() for _, tup := range tuples { p := strings.Join(tup.Prefix, "\xff") - m.forgetLocked(tag, p, tup.Suffix) + u := m.tups[tag][p] + k := slices.IndexFunc(u, func(v [2]string) bool { return v[1] == tup.Suffix }) + if k < 0 { + continue + } + u[k], u[len(u)-1] = u[len(u)-1], u[k] + m.tups[tag][p] = u[:len(u)-1] } return nil } @@ -80,11 +85,7 @@ func (m *membrain) Forget(ctx context.Context, tag string, tuples []brain.Tuple) func (m *membrain) ForgetMessage(ctx context.Context, tag, id string) error { m.mu.Lock() defer m.mu.Unlock() - u := m.ids[tag][id] - for _, v := range u { - m.forgetLocked(tag, v[0], v[1]) - } - delete(m.ids[tag], id) + m.forgetIDLocked(tag, id) return nil } @@ -97,7 +98,7 @@ func (m *membrain) ForgetDuring(ctx context.Context, tag string, since, before t continue } for _, v := range u { - m.forgetLocked(tag, v[0], v[1]) + m.forgetIDLocked(tag, v) } delete(m.tms[tag], tm) // yea i modify the map during iteration, yea i'm cool } @@ -108,24 +109,24 @@ func (m *membrain) ForgetUser(ctx context.Context, user *userhash.Hash) error { m.mu.Lock() defer m.mu.Unlock() for _, v := range m.users[*user] { - m.forgetLocked(v[0], v[1], v[2]) + m.forgetIDLocked(v[0], v[1]) } delete(m.users, *user) return nil } -func (m *membrain) Speak(ctx context.Context, tag string, prompt []string, w []byte) ([]byte, error) { +func (m *membrain) Speak(ctx context.Context, tag string, prompt []string, w *brain.Builder) error { m.mu.Lock() defer m.mu.Unlock() var s string if len(prompt) == 0 { u := m.tups[tag][""] if len(u) == 0 { - return nil, nil + return nil } t := u[rand.IntN(len(u))] - w = append(w, t...) - s = brain.ReduceEntropy(t) + w.Append(t[0], []byte(t[1])) + s = brain.ReduceEntropy(t[1]) } else { s = brain.ReduceEntropy(prompt[len(prompt)-1]) } @@ -135,13 +136,13 @@ func (m *membrain) Speak(ctx context.Context, tag string, prompt []string, w []b break } t := u[rand.IntN(len(u))] - if t == "" { + if t[1] == "" { break } - w = append(w, t...) - s = brain.ReduceEntropy(t) + w.Append(t[0], []byte(t[1])) + s = brain.ReduceEntropy(t[1]) } - return w, nil + return nil } func TestTests(t *testing.T) { diff --git a/brain/builder.go b/brain/builder.go new file mode 100644 index 0000000..547e011 --- /dev/null +++ b/brain/builder.go @@ -0,0 +1,50 @@ +package brain + +import "slices" + +// Builder builds a spoken message along with its message trace. +type Builder struct { + w []byte + id []string +} + +// Append adds a term to the builder. +func (b *Builder) Append(id string, term []byte) { + b.w = append(b.w, term...) + k, ok := slices.BinarySearch(b.id, id) + if !ok { + b.id = slices.Insert(b.id, k, id) + } +} + +// prompt adds a term without an ID. +func (b *Builder) prompt(term string) { + b.w = append(b.w, term...) +} + +// grow reserves sufficient space to append at least n bytes without reallocating. +func (b *Builder) grow(n int) { + if cap(b.w)-len(b.w) >= n { + return + } + t := make([]byte, len(b.w), len(b.w)+n) + copy(t, b.w) + b.w = t +} + +// String returns the built message. +func (b *Builder) String() string { + return string(b.w) +} + +// Trace returns a direct reference to the message trace. +func (b *Builder) Trace() []string { + return b.id +} + +// Reset restores the builder to an empty state. +func (b *Builder) Reset() { + b.w = b.w[:0] + clear(b.id) // allow held strings to release + b.id = b.id[:0] +} diff --git a/brain/builder_test.go b/brain/builder_test.go new file mode 100644 index 0000000..2eda322 --- /dev/null +++ b/brain/builder_test.go @@ -0,0 +1,116 @@ +package brain_test + +import ( + "math/rand/v2" + "slices" + "strconv" + "testing" + + "github.com/zephyrtronium/robot/brain" +) + +func TestBuilder(t *testing.T) { + cases := []struct { + name string + terms [][2]string + want string + trace []string + }{ + { + name: "empty", + terms: nil, + want: "", + trace: nil, + }, + { + name: "single", + terms: [][2]string{ + {"bocchi", "ryo"}, + }, + want: "ryo", + trace: []string{ + "bocchi", + }, + }, + { + name: "multi", + terms: [][2]string{ + {"bocchi", "ryo"}, + {"nijika", "kita"}, + }, + want: "ryokita", + trace: []string{ + "bocchi", + "nijika", + }, + }, + { + name: "order", + terms: [][2]string{ + {"nijika", "ryo"}, + {"bocchi", "kita"}, + }, + want: "ryokita", + trace: []string{ + "bocchi", + "nijika", + }, + }, + { + name: "dedup", + terms: [][2]string{ + {"bocchi", "ryo"}, + {"bocchi", "kita"}, + }, + want: "ryokita", + trace: []string{ + "bocchi", + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var b brain.Builder + for _, t := range c.terms { + b.Append(t[0], []byte(t[1])) + } + got := b.String() + trace := b.Trace() + if got != c.want { + t.Errorf("wrong string: want %q, got %q", c.want, got) + } + if !slices.Equal(trace, c.trace) { + t.Errorf("wrong trace: want %q, got %q", c.trace, trace) + } + b.Reset() + got = b.String() + trace = b.Trace() + if got != "" { + t.Errorf("string %q not empty after reset", got) + } + if len(trace) != 0 { + t.Errorf("trace %q not empty after reset", trace) + } + }) + } +} + +func BenchmarkBuilder(b *testing.B) { + var ids [256]string + var words [256][]byte + for i := range ids { + ids[i] = strconv.FormatUint(rand.Uint64()>>(i/16), 2) + words[i] = []byte(ids[i]) + } + var m brain.Builder + b.ReportAllocs() + b.ResetTimer() + for range b.N { + m.Reset() + u := rand.Uint64() + m.Append(ids[byte(u>>0)], words[byte(u>>8)]) + m.Append(ids[byte(u>>16)], words[byte(u>>24)]) + m.Append(ids[byte(u>>32)], words[byte(u>>40)]) + m.Append(ids[byte(u>>48)], words[byte(u>>56)]) + } +} diff --git a/brain/kvbrain/speak.go b/brain/kvbrain/speak.go index 87d2626..eae6bb6 100644 --- a/brain/kvbrain/speak.go +++ b/brain/kvbrain/speak.go @@ -1,6 +1,7 @@ package kvbrain import ( + "bytes" "context" "fmt" "math/rand/v2" @@ -16,12 +17,13 @@ var prependerPool tpool.Pool[*prepend.List[string]] // Speak generates a full message and appends it to w. // The prompt is in reverse order and has entropy reduction applied. -func (br *Brain) Speak(ctx context.Context, tag string, prompt []string, w []byte) ([]byte, error) { +func (br *Brain) Speak(ctx context.Context, tag string, prompt []string, w *brain.Builder) error { search := prependerPool.Get().Set(prompt...) defer func() { prependerPool.Put(search) }() tb := hashTag(make([]byte, 0, tagHashLen), tag) b := make([]byte, 0, 128) + var id string opts := badger.DefaultIteratorOptions // We don't actually need to iterate over values, only the single value // that we decide to use per suffix. So, we can disable value prefetch. @@ -31,26 +33,27 @@ func (br *Brain) Speak(ctx context.Context, tag string, prompt []string, w []byt var err error var l int b = append(b[:0], tb...) - b, l, err = br.next(b, search.Slice(), opts) + b, id, l, err = br.next(b, search.Slice(), opts) if err != nil { - return nil, err + return err } if len(b) == 0 { break } - w = append(w, b...) + w.Append(id, b) search = search.Drop(search.Len() - l - 1).Prepend(brain.ReduceEntropy(string(b))) } - return w, nil + return nil } // next finds a single token to continue a prompt. // The returned values are, in order, // b with its contents replaced with the new term, +// the ID of the message used for the term, // the number of terms of the prompt which matched to produce the new term, // and any error. // If the returned term is the empty string, generation should end. -func (br *Brain) next(b []byte, prompt []string, opts badger.IteratorOptions) ([]byte, int, error) { +func (br *Brain) next(b []byte, prompt []string, opts badger.IteratorOptions) ([]byte, string, int, error) { // These definitions are outside the loop to ensure we don't bias toward // smaller contexts. var ( @@ -84,7 +87,7 @@ func (br *Brain) next(b []byte, prompt []string, opts badger.IteratorOptions) ([ return nil }) if err != nil { - return nil, len(prompt), fmt.Errorf("couldn't read knowledge: %w", err) + return nil, "", len(prompt), fmt.Errorf("couldn't read knowledge: %w", err) } if picked < 3 && len(prompt) > 1 { // We haven't seen enough options, and we have context we could @@ -96,7 +99,7 @@ func (br *Brain) next(b []byte, prompt []string, opts badger.IteratorOptions) ([ if key == nil { // We never saw any options. Since we always select the first, this // means there were no options. Don't look for nothing in the DB. - return b[:0], len(prompt), nil + return b[:0], "", len(prompt), nil } err = br.knowledge.View(func(txn *badger.Txn) error { item, err := txn.Get(key) @@ -109,6 +112,12 @@ func (br *Brain) next(b []byte, prompt []string, opts badger.IteratorOptions) ([ } return nil }) - return b, len(prompt), err + // The id is everything after the first byte following the hash for + // empty prefixes, and everything after the first \xff\xff otherwise. + id := key[tagHashLen+1:] + if len(prompt) > 0 { + _, id, _ = bytes.Cut(key, []byte{0xff, 0xff}) + } + return b, string(id), len(prompt), err } } diff --git a/brain/kvbrain/speak_test.go b/brain/kvbrain/speak_test.go index 52b2d79..052800b 100644 --- a/brain/kvbrain/speak_test.go +++ b/brain/kvbrain/speak_test.go @@ -140,12 +140,14 @@ func TestSpeak(t *testing.T) { want[v] = true } got := make(map[string]bool, len(c.want)) + var w brain.Builder for range 256 { - m, err := br.Speak(ctx, "kessoku", slices.Clone(c.prompt), nil) + w.Reset() + err := br.Speak(ctx, "kessoku", slices.Clone(c.prompt), &w) if err != nil { t.Errorf("failed to speak: %v", err) } - got[string(m)] = true + got[w.String()] = true } if !maps.Equal(want, got) { t.Errorf("wrong results: want %v, got %v", want, got) diff --git a/brain/speak.go b/brain/speak.go index 61a4051..e416be7 100644 --- a/brain/speak.go +++ b/brain/speak.go @@ -1,10 +1,10 @@ package brain import ( - "bytes" "context" "fmt" "slices" + "strings" "github.com/zephyrtronium/robot/tpool" ) @@ -13,35 +13,33 @@ import ( type Speaker interface { // Speak generates a full message and appends it to w. // The prompt is in reverse order and has entropy reduction applied. - Speak(ctx context.Context, tag string, prompt []string, w []byte) ([]byte, error) -} - -type Prompt struct { - Terms []string + Speak(ctx context.Context, tag string, prompt []string, w *Builder) error } var ( tokensPool tpool.Pool[[]string] - builderPool tpool.Pool[[]byte] + builderPool = tpool.Pool[*Builder]{New: func() any { return new(Builder) }} ) -// Speak produces a new message from the given prompt. -func Speak(ctx context.Context, s Speaker, tag, prompt string) (string, error) { +// Speak produces a new message and the trace of messages used to form it +// from the given prompt. +func Speak(ctx context.Context, s Speaker, tag, prompt string) (string, []string, error) { w := builderPool.Get() toks := Tokens(tokensPool.Get(), prompt) defer func() { - builderPool.Put(w[:0]) + w.Reset() + builderPool.Put(w) tokensPool.Put(toks[:0]) }() - w = slices.Grow(w, len(prompt)+1) + w.grow(len(prompt) + 1) for i, t := range toks { - w = append(w, t...) + w.prompt(t) toks[i] = ReduceEntropy(t) } slices.Reverse(toks) - w, err := s.Speak(ctx, tag, toks, w) + err := s.Speak(ctx, tag, toks, w) if err != nil { - return "", fmt.Errorf("couldn't speak: %w", err) + return "", nil, fmt.Errorf("couldn't speak: %w", err) } - return string(bytes.TrimSpace(w)), nil + return strings.TrimSpace(w.String()), slices.Clone(w.Trace()), nil } diff --git a/brain/speak_test.go b/brain/speak_test.go index d6bb670..dba28f5 100644 --- a/brain/speak_test.go +++ b/brain/speak_test.go @@ -12,62 +12,76 @@ import ( type testSpeaker struct { prompt []string + id string append []byte } -func (t *testSpeaker) Speak(ctx context.Context, tag string, prompt []string, w []byte) ([]byte, error) { +func (t *testSpeaker) Speak(ctx context.Context, tag string, prompt []string, w *brain.Builder) error { t.prompt = prompt - return append(w, t.append...), nil + w.Append(t.id, t.append) + return nil } func TestSpeak(t *testing.T) { cases := []struct { name string prompt string + id string append []byte want []string + trace []string say string }{ { name: "empty", prompt: "", + id: "kessoku", append: nil, want: nil, + trace: []string{"kessoku"}, say: "", }, { name: "empty-add", prompt: "", + id: "kessoku", append: []byte("bocchi"), want: nil, + trace: []string{"kessoku"}, say: "bocchi", }, { name: "prompted", prompt: "bocchi ryo nijika", + id: "kessoku", append: nil, want: []string{"nijika ", "ryo ", "bocchi "}, + trace: []string{"kessoku"}, say: "bocchi ryo nijika", }, { name: "prompted-add", prompt: "bocchi ryo nijika", + id: "kessoku", append: []byte("kita"), want: []string{"nijika ", "ryo ", "bocchi "}, + trace: []string{"kessoku"}, say: "bocchi ryo nijika kita", }, { name: "entropy", prompt: "BOCCHI RYO NIJIKA", + id: "kessoku", append: []byte("KITA"), want: []string{"nijika ", "ryo ", "bocchi "}, + trace: []string{"kessoku"}, say: "BOCCHI RYO NIJIKA KITA", }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - s := testSpeaker{append: c.append} - r, err := brain.Speak(context.Background(), &s, "", c.prompt) + s := testSpeaker{id: c.id, append: c.append} + r, trace, err := brain.Speak(context.Background(), &s, "", c.prompt) if err != nil { t.Error(err) } @@ -77,6 +91,9 @@ func TestSpeak(t *testing.T) { if diff := cmp.Diff(c.say, r); diff != "" { t.Errorf("wrong result from %q:\n%s", c.say, diff) } + if diff := cmp.Diff(c.trace, trace); diff != "" { + t.Errorf("wrong trace:\n%s", diff) + } }) } } diff --git a/brain/sqlbrain/speak.go b/brain/sqlbrain/speak.go index 0438289..1f8342f 100644 --- a/brain/sqlbrain/speak.go +++ b/brain/sqlbrain/speak.go @@ -16,42 +16,44 @@ var prependerPool tpool.Pool[*prepend.List[string]] // Speak generates a full message and appends it to w. // The prompt is in reverse order and has entropy reduction applied. -func (br *Brain) Speak(ctx context.Context, tag string, prompt []string, w []byte) ([]byte, error) { +func (br *Brain) Speak(ctx context.Context, tag string, prompt []string, w *brain.Builder) error { search := prependerPool.Get().Set(prompt...) defer func() { prependerPool.Put(search) }() conn, err := br.db.Take(ctx) defer br.db.Put(conn) if err != nil { - return w, fmt.Errorf("couldn't get connection to speak: %w", err) + return fmt.Errorf("couldn't get connection to speak: %w", err) } b := make([]byte, 0, 128) for range 1024 { var err error var l int - b, l, err = next(conn, tag, b, search.Slice()) + var id string + b, id, l, err = next(conn, tag, b, search.Slice()) if err != nil { - return nil, err + return err } if len(b) == 0 { break } - w = append(w, b...) + w.Append(id, b) search = search.Drop(search.Len() - l - 1).Prepend(brain.ReduceEntropy(string(b))) } - return w, nil + return nil } -func next(conn *sqlite.Conn, tag string, b []byte, prompt []string) ([]byte, int, error) { +func next(conn *sqlite.Conn, tag string, b []byte, prompt []string) ([]byte, string, int, error) { + var id string if len(prompt) == 0 { var err error - b, err = first(conn, tag, b) - return b, 0, err + b, id, err = first(conn, tag, b) + return b, id, 0, err } - st, err := conn.Prepare(`SELECT suffix FROM knowledge WHERE tag = :tag AND prefix >= :lower AND prefix < :upper AND LIKELY(deleted IS NULL)`) + st, err := conn.Prepare(`SELECT id, suffix FROM knowledge WHERE tag = :tag AND prefix >= :lower AND prefix < :upper AND LIKELY(deleted IS NULL)`) if err != nil { - return b[:0], len(prompt), fmt.Errorf("couldn't prepare term selection: %w", err) + return b[:0], "", len(prompt), fmt.Errorf("couldn't prepare term selection: %w", err) } st.SetText(":tag", tag) w := make([]byte, 0, 32) @@ -67,21 +69,22 @@ func next(conn *sqlite.Conn, tag string, b []byte, prompt []string) ([]byte, int for { ok, err := st.Step() if err != nil { - return b[:0], len(prompt), fmt.Errorf("couldn't step term selection: %w", err) + return b[:0], "", len(prompt), fmt.Errorf("couldn't step term selection: %w", err) } if !ok { break } - n := st.ColumnLen(0) + id = st.ColumnText(0) + n := st.ColumnLen(1) if cap(w) < n { w = make([]byte, n) } - w = w[:st.ColumnBytes(0, w[:n])] + w = w[:st.ColumnBytes(1, w[:n])] picked++ for range skip.N(rand.Uint64(), rand.Uint64()) { ok, err := st.Step() if err != nil { - return b[:0], len(prompt), fmt.Errorf("couldn't step term selection: %w", err) + return b[:0], "", len(prompt), fmt.Errorf("couldn't step term selection: %w", err) } if !ok { break sel @@ -93,13 +96,13 @@ func next(conn *sqlite.Conn, tag string, b []byte, prompt []string) ([]byte, int // lose. Do so and try again from the beginning. prompt = prompt[:len(prompt)-1] if err := st.Reset(); err != nil { - return b[:0], len(prompt), fmt.Errorf("couldn't reset term selection: %w", err) + return b[:0], "", len(prompt), fmt.Errorf("couldn't reset term selection: %w", err) } continue } // Note that this also handles the case where there were no results. b = append(b[:0], w...) - return b, len(prompt), nil + return b, id, len(prompt), nil } } @@ -117,11 +120,12 @@ func searchbounds(prefix []byte) (lower, upper []byte) { return lower, upper } -func first(conn *sqlite.Conn, tag string, b []byte) ([]byte, error) { +func first(conn *sqlite.Conn, tag string, b []byte) ([]byte, string, error) { + var id string b = b[:0] // in case we get no rows - s, err := conn.Prepare(`SELECT suffix FROM knowledge WHERE tag = :tag AND prefix = x'' AND LIKELY(deleted IS NULL)`) + s, err := conn.Prepare(`SELECT id, suffix FROM knowledge WHERE tag = :tag AND prefix = x'' AND LIKELY(deleted IS NULL)`) if err != nil { - return b, fmt.Errorf("couldn't prepare first term selection: %w", err) + return b, "", fmt.Errorf("couldn't prepare first term selection: %w", err) } s.SetText(":tag", tag) var skip brain.Skip @@ -129,25 +133,26 @@ sel: for { ok, err := s.Step() if err != nil { - return b[:0], fmt.Errorf("couldn't step first term selection: %w", err) + return b[:0], "", fmt.Errorf("couldn't step first term selection: %w", err) } if !ok { break } - n := s.ColumnLen(0) + id = s.ColumnText(0) + n := s.ColumnLen(1) if cap(b) < n { b = make([]byte, n) } - b = b[:s.ColumnBytes(0, b[:n])] + b = b[:s.ColumnBytes(1, b[:n])] for range skip.N(rand.Uint64(), rand.Uint64()) { ok, err := s.Step() if err != nil { - return b[:0], fmt.Errorf("couldn't step first term selection: %w", err) + return b[:0], "", fmt.Errorf("couldn't step first term selection: %w", err) } if !ok { break sel } } } - return b, nil + return b, id, nil } diff --git a/brain/sqlbrain/speak_test.go b/brain/sqlbrain/speak_test.go index b5bd7bf..8415104 100644 --- a/brain/sqlbrain/speak_test.go +++ b/brain/sqlbrain/speak_test.go @@ -21,7 +21,6 @@ func TestSpeak(t *testing.T) { know []know tag string prompt []string - w []byte want []string }{ { @@ -29,7 +28,6 @@ func TestSpeak(t *testing.T) { know: nil, tag: "kessoku", prompt: nil, - w: nil, // We should only ever get nil from the brain, // but that converts to the empty string. want: []string{""}, @@ -50,7 +48,6 @@ func TestSpeak(t *testing.T) { }, tag: "sickhack", prompt: nil, - w: nil, // We should only ever get nil from the brain, // but that converts to the empty string. want: []string{""}, @@ -71,7 +68,6 @@ func TestSpeak(t *testing.T) { }, tag: "kessoku", prompt: []string{"kikuri "}, - w: nil, // We should only ever get nil from the brain, // but that converts to the empty string. want: []string{""}, @@ -92,7 +88,6 @@ func TestSpeak(t *testing.T) { }, tag: "kessoku", prompt: nil, - w: nil, want: []string{"bocchi "}, }, { @@ -126,28 +121,8 @@ func TestSpeak(t *testing.T) { }, tag: "kessoku", prompt: nil, - w: nil, want: []string{"bocchi ryo nijika kita "}, }, - { - name: "append", - know: []know{ - { - tag: "kessoku", - prefix: "", - suffix: "bocchi ", - }, - { - tag: "kessoku", - prefix: "bocchi \x00", - suffix: "", - }, - }, - tag: "kessoku", - prompt: nil, - w: []byte("member "), - want: []string{"member bocchi "}, - }, { name: "multi", know: []know{ @@ -214,7 +189,6 @@ func TestSpeak(t *testing.T) { }, tag: "kessoku", prompt: nil, - w: nil, want: []string{"member bocchi ", "member ryo ", "member nijika ", "member kita "}, }, { @@ -328,7 +302,6 @@ func TestSpeak(t *testing.T) { }, tag: "sickhack", prompt: nil, - w: nil, want: []string{"member kikuri ", "member eliza ", "member shima "}, }, { @@ -400,7 +373,6 @@ func TestSpeak(t *testing.T) { }, tag: "kessoku", prompt: nil, - w: nil, want: []string{"member bocchi ", "member ryo ", "member nijika "}, }, } @@ -421,12 +393,14 @@ func TestSpeak(t *testing.T) { insert(t, conn, c.know, nil) slices.Sort(c.want) got := make([]string, 0, len(c.want)) + var w brain.Builder for range 10000 { - w, err := br.Speak(ctx, c.tag, c.prompt, slices.Clone(c.w)) + w.Reset() + err := br.Speak(ctx, c.tag, c.prompt, &w) if err != nil { t.Errorf("couldn't speak: %v", err) } - s := string(w) + s := w.String() k, ok := slices.BinarySearch(got, s) if !ok { got = slices.Insert(got, k, s) diff --git a/command/talk.go b/command/talk.go index 47d6cf6..c6409a3 100644 --- a/command/talk.go +++ b/command/talk.go @@ -19,7 +19,8 @@ func speakCmd(ctx context.Context, robo *Robot, call *Invocation) string { cancel() return "" } - m, err := brain.Speak(ctx, robo.Brain, call.Channel.Send, call.Args["prompt"]) + // TODO(zeph): record trace + m, _, err := brain.Speak(ctx, robo.Brain, call.Channel.Send, call.Args["prompt"]) if err != nil { robo.Log.ErrorContext(ctx, "couldn't speak", "err", err.Error()) cancel() diff --git a/privmsg.go b/privmsg.go index d5dbf61..df417cc 100644 --- a/privmsg.go +++ b/privmsg.go @@ -67,7 +67,8 @@ func (robo *Robot) tmiMessage(ctx context.Context, group *errgroup.Group, send c if rand.Float64() > ch.Responses { return } - s, err := brain.Speak(ctx, robo.brain, ch.Send, "") + // TODO(zeph): record trace + s, _, err := brain.Speak(ctx, robo.brain, ch.Send, "") if err != nil { slog.ErrorContext(ctx, "wanted to speak but failed", slog.String("err", err.Error())) return