diff --git a/brain/braintest/bench.go b/brain/braintest/bench.go index 9fb5f50..be425f9 100644 --- a/brain/braintest/bench.go +++ b/brain/braintest/bench.go @@ -24,24 +24,25 @@ func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context, if cleanup != nil { b.Cleanup(func() { cleanup(l) }) } - var msg brain.MessageMeta b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { var t int64 - toks := make([]string, 2+l.Order()) - for i := range toks { - toks[i] = hex.EncodeToString(randbytes(make([]byte, 16))) + toks := []string{ + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), } for pb.Next() { t++ toks[len(toks)-1] = strconv.FormatInt(t, 10) - msg = brain.MessageMeta{ - ID: uuid.UUID(randbytes(make([]byte, len(uuid.UUID{})))), - User: userhash.Hash(randbytes(make([]byte, len(userhash.Hash{})))), - Tag: "bocchi", - Time: time.Unix(t, 0), - } - err := brain.Learn(ctx, l, &msg, toks) + id := uuid.UUID(randbytes(make([]byte, len(uuid.UUID{})))) + u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{})))) + err := brain.Learn(ctx, l, "bocchi", u, id, time.Unix(t, 0), toks) if err != nil { b.Errorf("error while learning: %v", err) } @@ -53,25 +54,33 @@ func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context, if cleanup != nil { b.Cleanup(func() { cleanup(l) }) } - var msg brain.MessageMeta b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { var t int64 - order := l.Order() - toks := make([]string, 16+order) - for i := range toks { - toks[i] = hex.EncodeToString(randbytes(make([]byte, 16))) + toks := []string{ + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), } for pb.Next() { t++ rand.Shuffle(len(toks), func(i, j int) { toks[i], toks[j] = toks[j], toks[i] }) - msg = brain.MessageMeta{ - ID: uuid.UUID(randbytes(make([]byte, len(uuid.UUID{})))), - User: userhash.Hash(randbytes(make([]byte, len(userhash.Hash{})))), - Tag: "bocchi", - Time: time.Unix(t, 0), - } - err := brain.Learn(ctx, l, &msg, toks[:2+order]) + id := uuid.UUID(randbytes(make([]byte, len(uuid.UUID{})))) + u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{})))) + err := brain.Learn(ctx, l, "bocchi", u, id, time.Unix(t, 0), toks[:8]) if err != nil { b.Errorf("error while learning: %v", err) } @@ -91,21 +100,21 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, b.Cleanup(func() { cleanup(br) }) } // First fill the brain. - var msg brain.MessageMeta - order := br.Order() - toks := make([]string, 2+order) - for i := range toks { - toks[i] = hex.EncodeToString(randbytes(make([]byte, 16))) + toks := []string{ + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), } for t := range size { toks[len(toks)-1] = strconv.FormatInt(t, 10) - msg = brain.MessageMeta{ - ID: uuid.UUID(randbytes(make([]byte, len(uuid.UUID{})))), - User: userhash.Hash(randbytes(make([]byte, len(userhash.Hash{})))), - Tag: "bocchi", - Time: time.Unix(t, 0), - } - err := brain.Learn(ctx, br, &msg, toks) + id := uuid.UUID(randbytes(make([]byte, len(uuid.UUID{})))) + u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{})))) + err := brain.Learn(ctx, br, "bocchi", u, id, time.Unix(t, 0), toks) if err != nil { b.Errorf("error while learning: %v", err) } @@ -128,21 +137,29 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, b.Cleanup(func() { cleanup(br) }) } // First fill the brain. - var msg brain.MessageMeta - order := br.Order() - toks := make([]string, 16+order) - for i := range toks { - toks[i] = hex.EncodeToString(randbytes(make([]byte, 16))) + toks := []string{ + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), } for t := range size { rand.Shuffle(len(toks), func(i, j int) { toks[i], toks[j] = toks[j], toks[i] }) - msg = brain.MessageMeta{ - ID: uuid.UUID(randbytes(make([]byte, len(uuid.UUID{})))), - User: userhash.Hash(randbytes(make([]byte, len(userhash.Hash{})))), - Tag: "bocchi", - Time: time.Unix(t, 0), - } - err := brain.Learn(ctx, br, &msg, toks) + id := uuid.UUID(randbytes(make([]byte, len(uuid.UUID{})))) + u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{})))) + err := brain.Learn(ctx, br, "bocchi", u, id, time.Unix(t, 0), toks) if err != nil { b.Errorf("error while learning: %v", err) } @@ -165,21 +182,29 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, b.Cleanup(func() { cleanup(br) }) } // First fill the brain. - var msg brain.MessageMeta - order := br.Order() - toks := make([]string, 16+order) - for i := range toks { - toks[i] = hex.EncodeToString(randbytes(make([]byte, 16))) + toks := []string{ + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), + hex.EncodeToString(randbytes(make([]byte, 4))), } for t := range size { rand.Shuffle(len(toks), func(i, j int) { toks[i], toks[j] = toks[j], toks[i] }) - msg = brain.MessageMeta{ - ID: uuid.UUID(randbytes(make([]byte, len(uuid.UUID{})))), - User: userhash.Hash(randbytes(make([]byte, len(userhash.Hash{})))), - Tag: "bocchi", - Time: time.Unix(t, 0), - } - err := brain.Learn(ctx, br, &msg, toks) + id := uuid.UUID(randbytes(make([]byte, len(uuid.UUID{})))) + u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{})))) + err := brain.Learn(ctx, br, "bocchi", u, id, time.Unix(t, 0), toks) if err != nil { b.Errorf("error while learning: %v", err) } @@ -197,9 +222,8 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, } } -// randbytes fills a slice of at least length 16 with random data. +// randbytes fills a slice of at least length 4 with random data. func randbytes(b []byte) []byte { - binary.NativeEndian.PutUint64(b[8:], rand.Uint64()) - binary.NativeEndian.PutUint64(b, rand.Uint64()) + binary.NativeEndian.PutUint32(b, rand.Uint32()) return b } diff --git a/brain/braintest/braintest.go b/brain/braintest/braintest.go index 68e2e63..1ba1d3c 100644 --- a/brain/braintest/braintest.go +++ b/brain/braintest/braintest.go @@ -4,6 +4,7 @@ package braintest import ( "context" "maps" + "slices" "strings" "testing" "time" @@ -31,97 +32,88 @@ func Test(ctx context.Context, t *testing.T, new func(context.Context) Interface t.Run("combinatoric", testCombinatoric(ctx, new(ctx))) } +func these(s ...string) func() []string { + return func() []string { + return slices.Clone(s) + } +} + var messages = [...]struct { - brain.MessageMeta - Tokens []string + ID uuid.UUID + User userhash.Hash + Tag string + Time time.Time + Tokens func() []string }{ { - MessageMeta: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, - Tokens: []string{"member", "bocchi"}, + ID: uuid.UUID{1}, + User: userhash.Hash{2}, + Tag: "kessoku", + Time: time.Unix(0, 0), + Tokens: these("member", "bocchi"), }, { - MessageMeta: brain.MessageMeta{ - ID: uuid.UUID{2}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(1, 0), - }, - Tokens: []string{"member", "ryou"}, + ID: uuid.UUID{2}, + User: userhash.Hash{2}, + Tag: "kessoku", + Time: time.Unix(1, 0), + Tokens: these("member", "ryou"), }, { - MessageMeta: brain.MessageMeta{ - ID: uuid.UUID{3}, - User: userhash.Hash{3}, - Tag: "kessoku", - Time: time.Unix(2, 0), - }, - Tokens: []string{"member", "nijika"}, + ID: uuid.UUID{3}, + User: userhash.Hash{3}, + Tag: "kessoku", + Time: time.Unix(2, 0), + Tokens: these("member", "nijika"), }, { - MessageMeta: brain.MessageMeta{ - ID: uuid.UUID{4}, - User: userhash.Hash{3}, - Tag: "kessoku", - Time: time.Unix(3, 0), - }, - Tokens: []string{"member", "kita"}, + ID: uuid.UUID{4}, + User: userhash.Hash{3}, + Tag: "kessoku", + Time: time.Unix(3, 0), + Tokens: these("member", "kita"), }, { - MessageMeta: brain.MessageMeta{ - ID: uuid.UUID{5}, - User: userhash.Hash{2}, - Tag: "sickhack", - Time: time.Unix(0, 0), - }, - Tokens: []string{"member", "bocchi"}, + ID: uuid.UUID{5}, + User: userhash.Hash{2}, + Tag: "sickhack", + Time: time.Unix(0, 0), + Tokens: these("member", "bocchi"), }, { - MessageMeta: brain.MessageMeta{ - ID: uuid.UUID{6}, - User: userhash.Hash{2}, - Tag: "sickhack", - Time: time.Unix(1, 0), - }, - Tokens: []string{"member", "ryou"}, + ID: uuid.UUID{6}, + User: userhash.Hash{2}, + Tag: "sickhack", + Time: time.Unix(1, 0), + Tokens: these("member", "ryou"), }, { - MessageMeta: brain.MessageMeta{ - ID: uuid.UUID{7}, - User: userhash.Hash{3}, - Tag: "sickhack", - Time: time.Unix(2, 0), - }, - Tokens: []string{"member", "nijika"}, + ID: uuid.UUID{7}, + User: userhash.Hash{3}, + Tag: "sickhack", + Time: time.Unix(2, 0), + Tokens: these("member", "nijika"), }, { - MessageMeta: brain.MessageMeta{ - ID: uuid.UUID{8}, - User: userhash.Hash{3}, - Tag: "sickhack", - Time: time.Unix(3, 0), - }, - Tokens: []string{"member", "kita"}, + ID: uuid.UUID{8}, + User: userhash.Hash{3}, + Tag: "sickhack", + Time: time.Unix(3, 0), + Tokens: these("member", "kita"), }, { - MessageMeta: brain.MessageMeta{ - ID: uuid.UUID{9}, - User: userhash.Hash{4}, - Tag: "sickhack", - Time: time.Unix(43, 0), - }, - Tokens: []string{"manager", "seika"}, + ID: uuid.UUID{9}, + User: userhash.Hash{4}, + Tag: "sickhack", + Time: time.Unix(43, 0), + Tokens: these("manager", "seika"), }, } func learn(ctx context.Context, t *testing.T, br brain.Learner) { t.Helper() for _, m := range messages { - if err := brain.Learn(ctx, br, &m.MessageMeta, m.Tokens); err != nil { + if err := brain.Learn(ctx, br, m.Tag, m.User, m.ID, m.Time, m.Tokens()); err != nil { t.Fatalf("couldn't learn message %v: %v", m.ID, err) } } @@ -179,7 +171,7 @@ func testSpeak(ctx context.Context, br Interface) func(t *testing.T) { func testForget(ctx context.Context, br Interface) func(t *testing.T) { return func(t *testing.T) { learn(ctx, t, br) - if err := brain.Forget(ctx, br, "kessoku", messages[0].Tokens); err != nil { + if err := brain.Forget(ctx, br, "kessoku", messages[0].Tokens()); err != nil { t.Errorf("couldn't forget: %v", err) } for range 100 { @@ -191,16 +183,16 @@ func testForget(ctx context.Context, br Interface) func(t *testing.T) { t.Errorf("remembered that which must be forgotten: %q", s) } } - for { + for range 10000 { s, err := brain.Speak(ctx, br, "sickhack", "") if err != nil { t.Errorf("couldn't speak: %v", err) } if strings.Contains(s, "bocchi") { - break + return } - // The failure condition is that this loop is infinite. } + t.Error("didn't see bocchi in many attempts; deleted from wrong tag?") } } @@ -269,12 +261,7 @@ func testForgetDuring(ctx context.Context, br Interface) func(t *testing.T) { // overlap in learned material. func testCombinatoric(ctx context.Context, br Interface) func(t *testing.T) { return func(t *testing.T) { - msg := brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "bocchi", - Time: time.Unix(0, 0), - } + u := userhash.Hash{2} band := []string{"bocchi", "ryou", "nijika", "kita"} toks := make([]string, 6) for _, toks[0] = range band { @@ -285,7 +272,8 @@ func testCombinatoric(ctx context.Context, br Interface) func(t *testing.T) { for _, toks[5] = range band { toks := toks for len(toks) > 1 { - err := brain.Learn(ctx, br, &msg, toks) + id := uuid.New() + err := brain.Learn(ctx, br, "bocchi", u, id, time.Unix(0, 0), toks) if err != nil { t.Fatalf("couldn't learn init: %v", err) } @@ -297,7 +285,7 @@ func testCombinatoric(ctx context.Context, br Interface) func(t *testing.T) { } } } - allocs := testing.AllocsPerRun(100, func() { + allocs := testing.AllocsPerRun(10, func() { _, err := brain.Speak(ctx, br, "bocchi", "") if err != nil { t.Errorf("couldn't speak: %v", err) diff --git a/brain/braintest/braintest_test.go b/brain/braintest/braintest_test.go new file mode 100644 index 0000000..ebeebe4 --- /dev/null +++ b/brain/braintest/braintest_test.go @@ -0,0 +1,153 @@ +package braintest_test + +import ( + "context" + "math/rand/v2" + "slices" + "strings" + "sync" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/zephyrtronium/robot/brain" + "github.com/zephyrtronium/robot/brain/braintest" + "github.com/zephyrtronium/robot/userhash" +) + +// membrain is an implementation of braintest.Interface using in-memory maps +// to verify that the integration tests test the correct things. +type membrain struct { + mu sync.Mutex + tups map[string]map[string][]string // map of tags to map of prefixes to suffixes + users map[userhash.Hash][][3]string // map of hashes to tag and prefix+suffix + ids map[string]map[uuid.UUID][][2]string // map of tags to map of ids to prefix+suffix + tms map[string]map[int64][][2]string // map of tags to map of timestamps to prefix+suffix +} + +var _ braintest.Interface = (*membrain)(nil) + +func (m *membrain) Learn(ctx context.Context, tag string, user userhash.Hash, id uuid.UUID, t time.Time, tuples []brain.Tuple) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.tups == nil { + m.tups = make(map[string]map[string][]string) + m.users = make(map[userhash.Hash][][3]string) + m.ids = make(map[string]map[uuid.UUID][][2]string) + m.tms = make(map[string]map[int64][][2]string) + } + if m.tups[tag] == nil { + m.tups[tag] = make(map[string][]string) + } + r := m.tups[tag] + if m.ids[tag] == nil { + m.ids[tag] = make(map[uuid.UUID][][2]string) + } + ids := m.ids[tag] + if m.tms[tag] == nil { + m.tms[tag] = make(map[int64][][2]string) + } + tms := m.tms[tag] + for _, tup := range tuples { + p := strings.Join(tup.Prefix, "\xff") + r[p] = append(r[p], tup.Suffix) + m.users[user] = append(m.users[user], [3]string{tag, p, tup.Suffix}) + ids[id] = append(ids[id], [2]string{p, tup.Suffix}) + tms[t.UnixNano()] = append(tms[t.UnixNano()], [2]string{p, tup.Suffix}) + } + return nil +} + +func (m *membrain) forgetLocked(tag, p, s string) { + u := m.tups[tag][p] + k := slices.Index(u, s) + if k < 0 { + return + } + u[k], u[len(u)-1] = u[len(u)-1], u[k] + m.tups[tag][p] = u[:len(u)-1] +} + +func (m *membrain) Forget(ctx context.Context, tag string, tuples []brain.Tuple) error { + m.mu.Lock() + defer m.mu.Unlock() + for _, tup := range tuples { + p := strings.Join(tup.Prefix, "\xff") + m.forgetLocked(tag, p, tup.Suffix) + } + return nil +} + +func (m *membrain) ForgetMessage(ctx context.Context, tag string, msg uuid.UUID) error { + m.mu.Lock() + defer m.mu.Unlock() + u := m.ids[tag][msg] + for _, v := range u { + m.forgetLocked(tag, v[0], v[1]) + } + delete(m.ids[tag], msg) + return nil +} + +func (m *membrain) ForgetDuring(ctx context.Context, tag string, since, before time.Time) error { + m.mu.Lock() + defer m.mu.Unlock() + s, b := since.UnixNano(), before.UnixNano() + for tm, u := range m.tms[tag] { + if tm < s || tm > b { + continue + } + for _, v := range u { + m.forgetLocked(tag, v[0], v[1]) + } + delete(m.tms[tag], tm) // yea i modify the map during iteration, yea i'm cool + } + return nil +} + +func (m *membrain) ForgetUser(ctx context.Context, user *userhash.Hash) error { + m.mu.Lock() + defer m.mu.Unlock() + for _, v := range m.users[*user] { + m.forgetLocked(v[0], v[1], v[2]) + } + delete(m.users, *user) + return nil +} + +func (m *membrain) Speak(ctx context.Context, tag string, prompt []string, w []byte) ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + var s string + if len(prompt) == 0 { + u := m.tups[tag][""] + if len(u) == 0 { + return nil, nil + } + t := u[rand.IntN(len(u))] + w = append(w, t...) + w = append(w, ' ') + s = brain.ReduceEntropy(t) + } else { + s = brain.ReduceEntropy(prompt[len(prompt)-1]) + } + for range 256 { + u := m.tups[tag][s] + if len(u) == 0 { + break + } + t := u[rand.IntN(len(u))] + if t == "" { + break + } + w = append(w, t...) + w = append(w, ' ') + s = brain.ReduceEntropy(t) + } + return w, nil +} + +func TestTests(t *testing.T) { + braintest.Test(context.Background(), t, func(ctx context.Context) braintest.Interface { return new(membrain) }) +} diff --git a/brain/kvbrain/forget.go b/brain/kvbrain/forget.go index 9133673..58cc587 100644 --- a/brain/kvbrain/forget.go +++ b/brain/kvbrain/forget.go @@ -3,6 +3,7 @@ package kvbrain import ( "bytes" "context" + "errors" "fmt" "slices" "strings" @@ -108,14 +109,13 @@ func (br *Brain) Forget(ctx context.Context, tag string, tuples []brain.Tuple) e return p }) err := br.knowledge.Update(func(txn *badger.Txn) error { + var errs error opts := badger.DefaultIteratorOptions - opts.Prefix = hashTag(nil, tag) - it := txn.NewIterator(badger.DefaultIteratorOptions) + it := txn.NewIterator(opts) defer it.Close() - var b []byte + b := hashTag(nil, tag) for _, t := range tuples { - b = hashTag(b[:0], tag) - b = append(appendPrefix(b, t.Prefix), '\xff') // terminate the prefix + b = append(appendPrefix(b[:tagHashLen], t.Prefix), '\xff') // terminate the prefix it.Seek(b) for it.ValidForPrefix(b) { v := it.Item() @@ -123,22 +123,23 @@ func (br *Brain) Forget(ctx context.Context, tag string, tuples []brain.Tuple) e if v.IsDeletedOrExpired() { continue } - var err error - b, err = v.ValueCopy(b[:0]) + u, err := v.ValueCopy(nil) if err != nil { - // TODO(zeph): collect and continue - return err + errs = errors.Join(errs, err) + continue } - if string(b) != t.Suffix { + if string(u) != t.Suffix { continue } if err := txn.Delete(v.KeyCopy(nil)); err != nil { - // TODO(zeph): collect and continue - return err + errs = errors.Join(errs, err) + continue } + // Only delete a single instance of each tuple. + break } } - return nil + return errs }) if err != nil { return fmt.Errorf("couldn't forget: %w", err) diff --git a/brain/kvbrain/forget_test.go b/brain/kvbrain/forget_test.go index cf2f960..74bcfbe 100644 --- a/brain/kvbrain/forget_test.go +++ b/brain/kvbrain/forget_test.go @@ -212,7 +212,10 @@ func use(x [][]byte) {} func TestForget(t *testing.T) { type message struct { - msg brain.MessageMeta + id uuid.UUID + user userhash.Hash + tag string + time time.Time tups []brain.Tuple } cases := []struct { @@ -225,15 +228,13 @@ func TestForget(t *testing.T) { name: "none", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -241,7 +242,7 @@ func TestForget(t *testing.T) { }, forget: []brain.Tuple{ { - Prefix: []string{"kikuri", "eliza"}, + Prefix: []string{"eliza", "kikuri"}, Suffix: "shima", }, }, @@ -253,15 +254,13 @@ func TestForget(t *testing.T) { name: "suffix", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -269,7 +268,7 @@ func TestForget(t *testing.T) { }, forget: []brain.Tuple{ { - Prefix: []string{"kikuri", "eliza"}, + Prefix: []string{"eliza", "kikuri"}, Suffix: "kita", }, }, @@ -281,15 +280,13 @@ func TestForget(t *testing.T) { name: "prefix", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -297,7 +294,7 @@ func TestForget(t *testing.T) { }, forget: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "shima", }, }, @@ -309,15 +306,13 @@ func TestForget(t *testing.T) { name: "tag", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "sickhack", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "sickhack", + time: time.Unix(0, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -325,7 +320,7 @@ func TestForget(t *testing.T) { }, forget: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -337,15 +332,13 @@ func TestForget(t *testing.T) { name: "match", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -353,7 +346,7 @@ func TestForget(t *testing.T) { }, forget: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -363,29 +356,25 @@ func TestForget(t *testing.T) { name: "single", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, }, { - msg: brain.MessageMeta{ - ID: uuid.UUID{2}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{2}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -393,7 +382,7 @@ func TestForget(t *testing.T) { }, forget: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -412,7 +401,7 @@ func TestForget(t *testing.T) { } br := New(db) for _, msg := range c.msgs { - err := br.Learn(ctx, &msg.msg, msg.tups) + err := br.Learn(ctx, msg.tag, msg.user, msg.id, msg.time, msg.tups) if err != nil { t.Errorf("failed to learn: %v", err) } @@ -427,7 +416,10 @@ func TestForget(t *testing.T) { func TestForgetMessage(t *testing.T) { type message struct { - msg brain.MessageMeta + id uuid.UUID + user userhash.Hash + tag string + time time.Time tups []brain.Tuple } cases := []struct { @@ -440,12 +432,10 @@ func TestForgetMessage(t *testing.T) { name: "single", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ {Prefix: []string{"bocchi"}, Suffix: "ryou"}, }, @@ -458,12 +448,10 @@ func TestForgetMessage(t *testing.T) { name: "several", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ {Prefix: []string{"bocchi"}, Suffix: "ryou"}, {Prefix: []string{"nijika"}, Suffix: "kita"}, @@ -477,12 +465,10 @@ func TestForgetMessage(t *testing.T) { name: "tagged", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "sickhack", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "sickhack", + time: time.Unix(0, 0), tups: []brain.Tuple{ {Prefix: []string{"bocchi"}, Suffix: "ryou"}, }, @@ -497,12 +483,10 @@ func TestForgetMessage(t *testing.T) { name: "unseen", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ {Prefix: []string{"bocchi"}, Suffix: "ryou"}, }, @@ -524,7 +508,7 @@ func TestForgetMessage(t *testing.T) { } br := New(db) for _, msg := range c.msgs { - err := br.Learn(ctx, &msg.msg, msg.tups) + err := br.Learn(ctx, msg.tag, msg.user, msg.id, msg.time, msg.tups) if err != nil { t.Errorf("failed to learn: %v", err) } @@ -539,7 +523,10 @@ func TestForgetMessage(t *testing.T) { func TestForgetDuring(t *testing.T) { type message struct { - msg brain.MessageMeta + id uuid.UUID + user userhash.Hash + tag string + time time.Time tups []brain.Tuple } cases := []struct { @@ -552,15 +539,13 @@ func TestForgetDuring(t *testing.T) { name: "single", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(1, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(1, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -574,29 +559,25 @@ func TestForgetDuring(t *testing.T) { name: "several", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(1, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(1, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, }, { - msg: brain.MessageMeta{ - ID: uuid.UUID{2}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(1, 0), - }, + id: uuid.UUID{2}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(1, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -610,15 +591,13 @@ func TestForgetDuring(t *testing.T) { name: "none", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(5, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(5, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -634,15 +613,13 @@ func TestForgetDuring(t *testing.T) { name: "tagged", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "sickhack", - Time: time.Unix(1, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "sickhack", + time: time.Unix(1, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -665,7 +642,7 @@ func TestForgetDuring(t *testing.T) { } br := New(db) for _, msg := range c.msgs { - err := br.Learn(ctx, &msg.msg, msg.tups) + err := br.Learn(ctx, msg.tag, msg.user, msg.id, msg.time, msg.tups) if err != nil { t.Errorf("failed to learn: %v", err) } @@ -682,7 +659,10 @@ func TestForgetDuring(t *testing.T) { func TestForgetUserSince(t *testing.T) { type message struct { - msg brain.MessageMeta + id uuid.UUID + user userhash.Hash + tag string + time time.Time tups []brain.Tuple } cases := []struct { @@ -695,15 +675,13 @@ func TestForgetUserSince(t *testing.T) { name: "match", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(1, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(1, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -716,15 +694,13 @@ func TestForgetUserSince(t *testing.T) { name: "different", msgs: []message{ { - msg: brain.MessageMeta{ - ID: uuid.UUID{1}, - User: userhash.Hash{2}, - Tag: "kessoku", - Time: time.Unix(1, 0), - }, + id: uuid.UUID{1}, + user: userhash.Hash{2}, + tag: "kessoku", + time: time.Unix(1, 0), tups: []brain.Tuple{ { - Prefix: []string{"bocchi", "ryou"}, + Prefix: []string{"ryou", "bocchi"}, Suffix: "kita", }, }, @@ -746,7 +722,7 @@ func TestForgetUserSince(t *testing.T) { } br := New(db) for _, msg := range c.msgs { - err := br.Learn(ctx, &msg.msg, msg.tups) + err := br.Learn(ctx, msg.tag, msg.user, msg.id, msg.time, msg.tups) if err != nil { t.Errorf("failed to learn: %v", err) } diff --git a/brain/kvbrain/kvbrain.go b/brain/kvbrain/kvbrain.go index b2f35df..1ce2d43 100644 --- a/brain/kvbrain/kvbrain.go +++ b/brain/kvbrain/kvbrain.go @@ -48,17 +48,11 @@ func New(knowledge *badger.DB) *Brain { } } -// Order returns the number of elements in the prefix of a chain. It is -// called once at the beginning of learning. The returned value must always -// be at least 1. -func (br *Brain) Order() int { - // TOOD(zeph): this can go away one day - return 250 -} - // hashTag appends the hash of a tag to b to serve as the start of a knowledge key. func hashTag(b []byte, tag string) []byte { h := fnv.New64a() io.WriteString(h, tag) return h.Sum(b) } + +const tagHashLen = 8 diff --git a/brain/kvbrain/learn.go b/brain/kvbrain/learn.go index d7d402d..bc0addf 100644 --- a/brain/kvbrain/learn.go +++ b/brain/kvbrain/learn.go @@ -5,8 +5,12 @@ import ( "context" "errors" "fmt" + "time" + + "github.com/google/uuid" "github.com/zephyrtronium/robot/brain" + "github.com/zephyrtronium/robot/userhash" ) // Learn records a set of tuples. Each tuple prefix has length equal to the @@ -14,7 +18,7 @@ import ( // denote the start of the message and end with one empty suffix to denote // the end; all other tokens are non-empty. Each tuple's prefix has entropy // reduction transformations applied. -func (br *Brain) Learn(ctx context.Context, meta *brain.MessageMeta, tuples []brain.Tuple) error { +func (br *Brain) Learn(ctx context.Context, tag string, user userhash.Hash, id uuid.UUID, t time.Time, tuples []brain.Tuple) error { if len(tuples) == 0 { return errors.New("no tuples to learn") } @@ -25,12 +29,11 @@ func (br *Brain) Learn(ctx context.Context, meta *brain.MessageMeta, tuples []br keys := make([][]byte, len(tuples)) vals := make([][]byte, len(tuples)) // TODO(zeph): could do one call to make var b []byte - tag := meta.Tag for i, t := range tuples { b = hashTag(b[:0], tag) b = append(appendPrefix(b, t.Prefix), '\xff') // Write message ID. - b = append(b, meta.ID[:]...) + b = append(b, id[:]...) keys[i] = bytes.Clone(b) vals[i] = []byte(t.Suffix) } @@ -41,7 +44,7 @@ func (br *Brain) Learn(ctx context.Context, meta *brain.MessageMeta, tuples []br // overwrite if that happens. p, _ = br.past.LoadOrStore(tag, new(past)) } - p.record(meta.ID, meta.User, meta.Time.UnixNano(), keys) + p.record(id, user, t.UnixNano(), keys) batch := br.knowledge.NewWriteBatch() defer batch.Cancel() @@ -64,11 +67,8 @@ func (br *Brain) Learn(ctx context.Context, meta *brain.MessageMeta, tuples []br // append a final \xff to terminate the prefix before appending the message ID // to form a complete key. func appendPrefix(b []byte, prefix []string) []byte { - for i := len(prefix) - 1; i >= 0; i-- { - if prefix[i] == "" { - break - } - b = append(b, prefix[i]...) + for _, w := range prefix { + b = append(b, w...) b = append(b, '\xff') } return b diff --git a/brain/kvbrain/learn_test.go b/brain/kvbrain/learn_test.go index 1aa7e47..284c78d 100644 --- a/brain/kvbrain/learn_test.go +++ b/brain/kvbrain/learn_test.go @@ -55,21 +55,22 @@ func TestLearn(t *testing.T) { h := userhash.Hash{2} cases := []struct { name string - msg brain.MessageMeta + id uuid.UUID + user userhash.Hash + tag string + time time.Time tups []brain.Tuple want map[string]string }{ { name: "single", - msg: brain.MessageMeta{ - ID: uu, - User: h, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uu, + user: h, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ { - Prefix: []string{""}, + Prefix: nil, Suffix: "bocchi", }, }, @@ -79,45 +80,43 @@ func TestLearn(t *testing.T) { }, { name: "full", - msg: brain.MessageMeta{ - ID: uu, - User: h, - Tag: "kessoku", - Time: time.Unix(0, 0), - }, + id: uu, + user: h, + tag: "kessoku", + time: time.Unix(0, 0), tups: []brain.Tuple{ { - Prefix: []string{"", "", "", ""}, - Suffix: "bocchi", + Prefix: []string{"seika", "kita", "nijika", "ryou", "bocchi"}, + Suffix: "", }, { - Prefix: []string{"", "", "", "bocchi"}, - Suffix: "ryou", + Prefix: []string{"kita", "nijika", "ryou", "bocchi"}, + Suffix: "seika", }, { - Prefix: []string{"", "", "bocchi", "ryou"}, - Suffix: "nijika", + Prefix: []string{"nijika", "ryou", "bocchi"}, + Suffix: "kita", }, { - Prefix: []string{"", "bocchi", "ryou", "nijika"}, - Suffix: "kita", + Prefix: []string{"ryou", "bocchi"}, + Suffix: "nijika", }, { - Prefix: []string{"bocchi", "ryou", "nijika", "kita"}, - Suffix: "seika", + Prefix: []string{"bocchi"}, + Suffix: "ryou", }, { - Prefix: []string{"ryou", "nijika", "kita", "seika"}, - Suffix: "", + Prefix: nil, + Suffix: "bocchi", }, }, want: map[string]string{ - mkey("kessoku", "\xff", uu): "bocchi", - mkey("kessoku", "bocchi\xff\xff", uu): "ryou", - mkey("kessoku", "ryou\xffbocchi\xff\xff", uu): "nijika", - mkey("kessoku", "nijika\xffryou\xffbocchi\xff\xff", uu): "kita", - mkey("kessoku", "kita\xffnijika\xffryou\xffbocchi\xff\xff", uu): "seika", - mkey("kessoku", "seika\xffkita\xffnijika\xffryou\xff\xff", uu): "", + mkey("kessoku", "\xff", uu): "bocchi", + mkey("kessoku", "bocchi\xff\xff", uu): "ryou", + mkey("kessoku", "ryou\xffbocchi\xff\xff", uu): "nijika", + mkey("kessoku", "nijika\xffryou\xffbocchi\xff\xff", uu): "kita", + mkey("kessoku", "kita\xffnijika\xffryou\xffbocchi\xff\xff", uu): "seika", + mkey("kessoku", "seika\xffkita\xffnijika\xffryou\xffbocchi\xff\xff", uu): "", }, }, } @@ -130,7 +129,7 @@ func TestLearn(t *testing.T) { t.Fatal(err) } br := New(db) - if err := br.Learn(ctx, &c.msg, c.tups); err != nil { + if err := br.Learn(ctx, c.tag, c.user, c.id, c.time, c.tups); err != nil { t.Errorf("failed to learn: %v", err) } dbcheck(t, db, c.want) diff --git a/brain/kvbrain/speak.go b/brain/kvbrain/speak.go index 2d6f99e..f9eb552 100644 --- a/brain/kvbrain/speak.go +++ b/brain/kvbrain/speak.go @@ -8,63 +8,50 @@ import ( "github.com/dgraph-io/badger/v4" "github.com/zephyrtronium/robot/brain" + "github.com/zephyrtronium/robot/prepend" + "github.com/zephyrtronium/robot/tpool" ) -// New finds a prompt to begin a random message. When a message is -// generated with no prompt, the result from New is passed directly to -// Speak; it is the speaker's responsibility to ensure it meets -// requirements with regard to length and matchable content. Only data -// originally learned with the given tag should be used to generate a -// prompt. -func (br *Brain) New(ctx context.Context, tag string) ([]string, error) { - return br.Speak(ctx, tag, nil) -} +var prependerPool tpool.Pool[*prepend.List[string]] -// Speak generates a full message from the given prompt. The prompt is -// guaranteed to have length equal to the value returned from Order, unless -// it is a prompt returned from New. If the number of tokens in the prompt -// is smaller than Order, the difference is made up by prepending empty -// strings to the prompt. The speaker should use ReduceEntropy on all -// tokens, including those in the prompt, when generating a message. -// Empty strings at the start and end of the result will be trimmed. Only -// data originally learned with the given tag should be used to generate a -// message. -func (br *Brain) Speak(ctx context.Context, tag string, prompt []string) ([]string, error) { - terms := make([]string, 0, len(prompt)) - for i, s := range prompt { - if s == "" { - continue - } - terms = append(terms, s) - prompt[i] = brain.ReduceEntropy(s) - } - var b []byte +// Speak generates a full message and appends it to w. +// The prompt is in reverse order and has entropy reduction applied. +func (br *Brain) Speak(ctx context.Context, tag string, prompt []string, w []byte) ([]byte, error) { + search := prependerPool.Get().Set(prompt...) + defer func() { prependerPool.Put(search) }() + + tb := hashTag(make([]byte, 0, tagHashLen), tag) + b := make([]byte, 0, 128) opts := badger.DefaultIteratorOptions // We don't actually need to iterate over values, only the single value // that we decide to use per suffix. So, we can disable value prefetch. opts.PrefetchValues = false opts.Prefix = hashTag(nil, tag) - for { + for range 1024 { var err error - var s string - b = hashTag(b[:0], tag) - s, b, prompt, err = br.next(b, prompt, opts) + var l int + b = append(b[:0], tb...) + b, l, err = br.next(b, search.Slice(), opts) if err != nil { return nil, err } - if s == "" { - return terms, nil + if len(b) == 0 { + break } - terms = append(terms, s) - prompt = append(prompt, brain.ReduceEntropy(s)) + w = append(w, b...) + w = append(w, ' ') + search = search.Drop(search.Len() - l - 1).Prepend(brain.ReduceEntropy(string(b))) } + return w, nil } // next finds a single token to continue a prompt. -// The returned values are, in order, the new term, b with possibly appended -// memory, the suffix of prompt which matched to produce the new term, and -// any error. If the returned term is the empty string, generation should end. -func (br *Brain) next(b []byte, prompt []string, opts badger.IteratorOptions) (string, []byte, []string, error) { +// The returned values are, in order, +// b with its contents replaced with the new term, +// the number of terms of the prompt which matched to produce the new term, +// and any error. +// If the returned term is the empty string, generation should end. +func (br *Brain) next(b []byte, prompt []string, opts badger.IteratorOptions) ([]byte, int, error) { // These definitions are outside the loop to ensure we don't bias toward // smaller contexts. var ( @@ -86,6 +73,7 @@ func (br *Brain) next(b []byte, prompt []string, opts badger.IteratorOptions) (s for it.ValidForPrefix(b) { // We generate a uniform variate per key, then choose the key // that gets the maximum variate. + // TODO(zeph): gumbel distribution u := rand.Uint64() if m <= u { item := it.Item() @@ -100,22 +88,19 @@ func (br *Brain) next(b []byte, prompt []string, opts badger.IteratorOptions) (s return nil }) if err != nil { - return "", b, prompt, fmt.Errorf("couldn't read knowledge: %w", err) + return nil, len(prompt), fmt.Errorf("couldn't read knowledge: %w", err) } if picked < 3 && len(prompt) > 1 { // We haven't seen enough options, and we have context we could // lose. Do so and try again from the beginning. - // TODO(zeph): we could save the start of the prompt so we don't - // reallocate, and we could construct the next key to use by - // trimming off the end of the current one - prompt = prompt[1:] - b = appendPrefix(b[:8], prompt) + prompt = prompt[:len(prompt)-1] + b = appendPrefix(b[:tagHashLen], prompt) continue } if key == nil { // We never saw any options. Since we always select the first, this // means there were no options. Don't look for nothing in the DB. - return "", b, prompt, nil + return b[:0], len(prompt), nil } err = br.knowledge.View(func(txn *badger.Txn) error { item, err := txn.Get(key) @@ -128,9 +113,6 @@ func (br *Brain) next(b []byte, prompt []string, opts badger.IteratorOptions) (s } return nil }) - if err != nil { - return "", b, prompt, err - } - return string(b), b, prompt, nil + return b, len(prompt), err } } diff --git a/brain/kvbrain/speak_test.go b/brain/kvbrain/speak_test.go index b811cf0..3cce239 100644 --- a/brain/kvbrain/speak_test.go +++ b/brain/kvbrain/speak_test.go @@ -1,6 +1,7 @@ package kvbrain import ( + "bytes" "context" "errors" "maps" @@ -82,7 +83,7 @@ func TestSpeak(t *testing.T) { }, prompt: []string{"bocchi"}, want: [][]string{ - {"bocchi", "ryou", "nijika", "kita"}, + {"ryou", "nijika", "kita"}, }, }, { @@ -94,9 +95,9 @@ func TestSpeak(t *testing.T) { {mkey("kessoku", "nijika\xffryou\xffbocchi\xff\xff", uu), "KITA"}, {mkey("kessoku", "kita\xffnijika\xffryou\xffbocchi\xff\xff", uu), ""}, }, - prompt: []string{"BOCCHI"}, + prompt: []string{"bocchi"}, want: [][]string{ - {"BOCCHI", "RYOU", "NIJIKA", "KITA"}, + {"RYOU", "NIJIKA", "KITA"}, }, }, { @@ -138,15 +139,15 @@ func TestSpeak(t *testing.T) { br := New(db) want := make(map[string]bool, len(c.want)) for _, v := range c.want { - want[strings.Join(v, ":")] = true + want[strings.Join(v, " ")] = true } got := make(map[string]bool, len(c.want)) for range 256 { - m, err := br.Speak(ctx, "kessoku", slices.Clone(c.prompt)) + m, err := br.Speak(ctx, "kessoku", slices.Clone(c.prompt), nil) if err != nil { t.Errorf("failed to speak: %v", err) } - got[strings.Join(m, ":")] = true + got[string(bytes.TrimSpace(m))] = true } if !maps.Equal(want, got) { t.Errorf("wrong results: want %v, got %v", want, got) diff --git a/brain/learn.go b/brain/learn.go index 23877d4..eba8ff1 100644 --- a/brain/learn.go +++ b/brain/learn.go @@ -2,49 +2,38 @@ package brain import ( "context" - "fmt" + "slices" "time" "github.com/google/uuid" + "github.com/zephyrtronium/robot/tpool" "github.com/zephyrtronium/robot/userhash" ) // Tuple is a single Markov chain tuple. type Tuple struct { + // Prefix is the entropy-reduced prefix in reverse order relative to the + // source message. Prefix []string + // Suffix is the full-entropy term following the prefix. Suffix string } -// MessageMeta holds metadata about a message. -type MessageMeta struct { - // ID is a UUID for the message. - ID uuid.UUID - // User is an identifier for the user. It is obfuscated such that the user - // cannot be identified and is not correlated between rooms. - User userhash.Hash - // Tag is a tag that should be associated with the message data. - Tag string - // Time is the time at which the message was sent. - Time time.Time -} - // Learner records Markov chain tuples. type Learner interface { - // Order returns the number of elements in the prefix of a chain. It is - // called once at the beginning of learning. The returned value must always - // be at least 1. - Order() int - // Learn records a set of tuples. Each tuple prefix has length equal to the - // result of Order. The tuples begin with empty strings in the prefix to - // denote the start of the message and end with one empty suffix to denote - // the end; all other tokens are non-empty. Each tuple's prefix has entropy - // reduction transformations applied. - Learn(ctx context.Context, meta *MessageMeta, tuples []Tuple) error - // Forget removes a set of recorded tuples. The tuples provided are as for - // Learn. If a tuple has been recorded multiple times, only the first - // should be deleted. If a tuple has not been recorded, it should be - // ignored. + // Learn records a set of tuples. + // One tuple has an empty prefix to denote the start of the message, and + // a different tuple has the empty string as its suffix to denote the end + // of the message. The positions of each in the argument are not guaranteed. + // Each tuple's prefix has entropy reduction transformations applied. + // Tuples in the argument may share storage for prefixes. + Learn(ctx context.Context, tag string, user userhash.Hash, id uuid.UUID, t time.Time, tuples []Tuple) error + // Forget removes a set of recorded tuples. + // The tuples provided are as for Learn. + // If a tuple has been recorded multiple times, only the first + // should be deleted. + // If a tuple has not been recorded, it should be ignored. Forget(ctx context.Context, tag string, tuples []Tuple) error // ForgetMessage forgets everything learned from a single given message. // If nothing has been learned from the message, it should be ignored. @@ -55,38 +44,43 @@ type Learner interface { ForgetUser(ctx context.Context, user *userhash.Hash) error } +var tuplesPool tpool.Pool[[]Tuple] + // Learn records tokens into a Learner. -func Learn(ctx context.Context, l Learner, meta *MessageMeta, toks []string) error { - n := l.Order() - if n < 1 { - panic(fmt.Errorf("order must be at least 1, got %d from %#v", n, l)) +func Learn(ctx context.Context, l Learner, tag string, user userhash.Hash, id uuid.UUID, t time.Time, toks []string) error { + if len(toks) == 0 { + return nil } - tt := tupleToks(make([]Tuple, 0, len(toks)+1), toks, n) - return l.Learn(ctx, meta, tt) + tt := tuplesPool.Get() + defer func() { tuplesPool.Put(tt[:0]) }() + tt = slices.Grow(tt, len(toks)+1) + tt = tupleToks(tt, toks) + return l.Learn(ctx, tag, user, id, t, tt) } // Forget removes tokens from a Learner. func Forget(ctx context.Context, l Learner, tag string, toks []string) error { - n := l.Order() - if n < 1 { - panic(fmt.Errorf("order must be at least 1, got %d from %#v", n, l)) + if len(toks) == 0 { + return nil } - tt := tupleToks(make([]Tuple, 0, len(toks)+1), toks, n) + tt := tuplesPool.Get() + defer func() { tuplesPool.Put(tt[:0]) }() + tt = slices.Grow(tt, len(toks)+1) + tt = tupleToks(tt, toks) return l.Forget(ctx, tag, tt) } -func tupleToks(tt []Tuple, toks []string, n int) []Tuple { - p := Tuple{Prefix: make([]string, n)} - for _, w := range toks { - q := Tuple{Prefix: make([]string, n), Suffix: w} - copy(q.Prefix, p.Prefix[1:]) - q.Prefix[n-1] = ReduceEntropy(p.Suffix) - tt = append(tt, q) - p = q +func tupleToks(tt []Tuple, toks []string) []Tuple { + slices.Reverse(toks) + pres := slices.Clone(toks) + for i, w := range pres { + pres[i] = ReduceEntropy(w) + } + suf := "" + for i, w := range toks { + tt = append(tt, Tuple{Prefix: pres[i:], Suffix: suf}) + suf = w } - q := Tuple{Prefix: make([]string, n), Suffix: ""} - copy(q.Prefix, p.Prefix[1:]) - q.Prefix[n-1] = ReduceEntropy(p.Suffix) - tt = append(tt, q) + tt = append(tt, Tuple{Prefix: nil, Suffix: suf}) return tt } diff --git a/brain/learn_test.go b/brain/learn_test.go index 57d91e3..87ff191 100644 --- a/brain/learn_test.go +++ b/brain/learn_test.go @@ -53,17 +53,12 @@ func TestTokens(t *testing.T) { } type testLearner struct { - order int learned []brain.Tuple forgot []brain.Tuple err error } -func (t *testLearner) Order() int { - return t.order -} - -func (t *testLearner) Learn(ctx context.Context, meta *brain.MessageMeta, tuples []brain.Tuple) error { +func (t *testLearner) Learn(ctx context.Context, tag string, user userhash.Hash, id uuid.UUID, tm time.Time, tuples []brain.Tuple) error { t.learned = append(t.learned, tuples...) return t.err } @@ -88,80 +83,46 @@ func (t *testLearner) ForgetUser(ctx context.Context, user *userhash.Hash) error func TestLearn(t *testing.T) { s := func(x ...string) []string { return x } cases := []struct { - name string - msg []string - order int - want []brain.Tuple + name string + msg []string + want []brain.Tuple }{ { - name: "single-1", - msg: s("word"), - order: 1, + name: "single", + msg: s("word"), want: []brain.Tuple{ - {Prefix: s(""), Suffix: "word"}, {Prefix: s("word"), Suffix: ""}, + {Prefix: nil, Suffix: "word"}, }, }, { - name: "single-3", - msg: s("word"), - order: 3, - want: []brain.Tuple{ - {Prefix: s("", "", ""), Suffix: "word"}, - {Prefix: s("", "", "word"), Suffix: ""}, - }, - }, - { - name: "many-1", - msg: s("many", "words", "in", "this", "message"), - order: 1, + name: "many", + msg: s("many", "words", "in", "this", "message"), want: []brain.Tuple{ - {Prefix: s(""), Suffix: "many"}, + {Prefix: s("message", "this", "in", "words", "many"), Suffix: ""}, + {Prefix: s("this", "in", "words", "many"), Suffix: "message"}, + {Prefix: s("in", "words", "many"), Suffix: "this"}, + {Prefix: s("words", "many"), Suffix: "in"}, {Prefix: s("many"), Suffix: "words"}, - {Prefix: s("words"), Suffix: "in"}, - {Prefix: s("in"), Suffix: "this"}, - {Prefix: s("this"), Suffix: "message"}, - {Prefix: s("message"), Suffix: ""}, - }, - }, - { - name: "many-3", - msg: s("many", "words", "in", "this", "message"), - order: 3, - want: []brain.Tuple{ - {Prefix: s("", "", ""), Suffix: "many"}, - {Prefix: s("", "", "many"), Suffix: "words"}, - {Prefix: s("", "many", "words"), Suffix: "in"}, - {Prefix: s("many", "words", "in"), Suffix: "this"}, - {Prefix: s("words", "in", "this"), Suffix: "message"}, - {Prefix: s("in", "this", "message"), Suffix: ""}, + {Prefix: nil, Suffix: "many"}, }, }, { - name: "entropy", - msg: s("A"), - order: 1, + name: "entropy", + msg: s("A"), want: []brain.Tuple{ - {Prefix: s(""), Suffix: "A"}, {Prefix: s("a"), Suffix: ""}, + {Prefix: nil, Suffix: "A"}, }, }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - l := testLearner{order: c.order} - err := brain.Learn(context.Background(), &l, nil, c.msg) + var l testLearner + err := brain.Learn(context.Background(), &l, "", userhash.Hash{}, uuid.UUID{}, time.Unix(0, 0), c.msg) if err != nil { t.Error(err) } - // Check lengths of prefixes against the order we put down rather - // than leaving it to cmp, because I have been known to typo a test - // case or two when writing them at 5 AM. - for _, p := range l.learned { - if len(p.Prefix) != c.order { - t.Errorf("wrong prefix size: want %d, got %d", c.order, len(p.Prefix)) - } - } if diff := cmp.Diff(c.want, l.learned); diff != "" { t.Errorf("learned the wrong things from %q:\n%s", c.msg, diff) } @@ -172,90 +133,49 @@ func TestLearn(t *testing.T) { func TestForget(t *testing.T) { s := func(x ...string) []string { return x } cases := []struct { - name string - msg []string - order int - want []brain.Tuple + name string + msg []string + want []brain.Tuple }{ { - name: "single-1", - msg: s("word"), - order: 1, + name: "single", + msg: s("word"), want: []brain.Tuple{ - {Prefix: s(""), Suffix: "word"}, {Prefix: s("word"), Suffix: ""}, + {Prefix: nil, Suffix: "word"}, }, }, { - name: "single-3", - msg: s("word"), - order: 3, - want: []brain.Tuple{ - {Prefix: s("", "", ""), Suffix: "word"}, - {Prefix: s("", "", "word"), Suffix: ""}, - }, - }, - { - name: "many-1", - msg: s("many", "words", "in", "this", "message"), - order: 1, + name: "many-1", + msg: s("many", "words", "in", "this", "message"), want: []brain.Tuple{ - {Prefix: s(""), Suffix: "many"}, + {Prefix: s("message", "this", "in", "words", "many"), Suffix: ""}, + {Prefix: s("this", "in", "words", "many"), Suffix: "message"}, + {Prefix: s("in", "words", "many"), Suffix: "this"}, + {Prefix: s("words", "many"), Suffix: "in"}, {Prefix: s("many"), Suffix: "words"}, - {Prefix: s("words"), Suffix: "in"}, - {Prefix: s("in"), Suffix: "this"}, - {Prefix: s("this"), Suffix: "message"}, - {Prefix: s("message"), Suffix: ""}, - }, - }, - { - name: "many-3", - msg: s("many", "words", "in", "this", "message"), - order: 3, - want: []brain.Tuple{ - {Prefix: s("", "", ""), Suffix: "many"}, - {Prefix: s("", "", "many"), Suffix: "words"}, - {Prefix: s("", "many", "words"), Suffix: "in"}, - {Prefix: s("many", "words", "in"), Suffix: "this"}, - {Prefix: s("words", "in", "this"), Suffix: "message"}, - {Prefix: s("in", "this", "message"), Suffix: ""}, + {Prefix: nil, Suffix: "many"}, }, }, { - name: "entropy", - msg: s("A"), - order: 1, + name: "entropy", + msg: s("A"), want: []brain.Tuple{ - {Prefix: s(""), Suffix: "A"}, {Prefix: s("a"), Suffix: ""}, + {Prefix: nil, Suffix: "A"}, }, }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - l := testLearner{order: c.order} + var l testLearner err := brain.Forget(context.Background(), &l, "", c.msg) if err != nil { t.Error(err) } - for _, p := range l.forgot { - if len(p.Prefix) != c.order { - t.Errorf("wrong prefix size: want %d, got %d", c.order, len(p.Prefix)) - } - } if diff := cmp.Diff(c.want, l.forgot); diff != "" { t.Errorf("forgot the wrong things from %q:\n%s", c.msg, diff) } }) } } - -func TestMinimumOrder(t *testing.T) { - defer func() { - err := recover() - if err == nil { - t.Error("no panic") - } - }() - brain.Learn(context.Background(), new(testLearner), nil, []string{"word"}) -} diff --git a/brain/speak.go b/brain/speak.go index 48a6bf8..bdd98a7 100644 --- a/brain/speak.go +++ b/brain/speak.go @@ -1,78 +1,48 @@ package brain import ( + "bytes" "context" "fmt" - "strings" + "slices" + + "github.com/zephyrtronium/robot/tpool" ) // Speaker produces random messages. type Speaker interface { - // Order is the length of prompts given to Speak. - Order() int - // New finds a prompt to begin a random message. When a message is - // generated with no prompt, the result from New is passed directly to - // Speak; it is the speaker's responsibility to ensure it meets - // requirements with regard to length and matchable content. Only data - // originally learned with the given tag should be used to generate a - // prompt. - New(ctx context.Context, tag string) ([]string, error) - // Speak generates a full message from the given prompt. The prompt is - // guaranteed to have length equal to the value returned from Order, unless - // it is a prompt returned from New. If the number of tokens in the prompt - // is smaller than Order, the difference is made up by prepending empty - // strings to the prompt. The speaker should use ReduceEntropy on all - // tokens, including those in the prompt, when generating a message. - // Empty strings at the start and end of the result will be trimmed. Only - // data originally learned with the given tag should be used to generate a - // message. - Speak(ctx context.Context, tag string, prompt []string) ([]string, error) + // Speak generates a full message and appends it to w. + // The prompt is in reverse order and has entropy reduction applied. + Speak(ctx context.Context, tag string, prompt []string, w []byte) ([]byte, error) +} + +type Prompt struct { + Terms []string } +var ( + tokensPool tpool.Pool[[]string] + builderPool tpool.Pool[[]byte] +) + // Speak produces a new message from the given prompt. func Speak(ctx context.Context, s Speaker, tag, prompt string) (string, error) { - toks := Tokens(nil, prompt) - var p []string - if len(toks) == 0 { - // No prompt; get one from the speaker instead. - var err error - p, err = s.New(ctx, tag) - if err != nil { - return "", fmt.Errorf("couldn't get a new prompt: %w", err) - } - } else { - // Make sure the prompt is the right size and has empty strings to - // make up the difference. - n := s.Order() - switch { - case len(toks) < n: - p = make([]string, n-len(toks), n) - p = append(p, toks...) - toks = toks[:0] - case len(toks) >= n: - k := len(toks) - n - toks, p = toks[:k], toks[k:] - } + w := builderPool.Get() + toks := Tokens(tokensPool.Get(), prompt) + defer func() { + builderPool.Put(w[:0]) + tokensPool.Put(toks[:0]) + }() + w = slices.Grow(w, len(prompt)) + for i, t := range toks { + w = append(w, t...) + w = append(w, ' ') + toks[i] = ReduceEntropy(t) } - r, err := s.Speak(ctx, tag, p) + slices.Reverse(toks) + w, err := s.Speak(ctx, tag, toks, w) if err != nil { return "", fmt.Errorf("couldn't speak: %w", err) } - return strings.Join(append(toks, trim(r)...), " "), nil -} - -// trim removes empty strings from the start and end of r. -func trim(r []string) []string { - for k := len(r) - 1; k >= 0; k-- { - if r[k] != "" { - r = r[:k+1] - break - } - } - for k, v := range r { - if v != "" { - return r[k:] - } - } - return nil + return string(bytes.TrimSpace(w)), nil } diff --git a/brain/speak_test.go b/brain/speak_test.go index f4f171c..db047e8 100644 --- a/brain/speak_test.go +++ b/brain/speak_test.go @@ -2,7 +2,6 @@ package brain_test import ( "context" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -10,70 +9,62 @@ import ( ) type testSpeaker struct { - order int - new []string prompt []string + append []byte } -func (t *testSpeaker) Order() int { - return t.order -} - -func (t *testSpeaker) New(ctx context.Context, tag string) ([]string, error) { - return t.new, nil -} - -func (t *testSpeaker) Speak(ctx context.Context, tag string, prompt []string) ([]string, error) { - // TODO(zeph): use ReduceEntropy +func (t *testSpeaker) Speak(ctx context.Context, tag string, prompt []string, w []byte) ([]byte, error) { t.prompt = prompt - return prompt, nil + return append(w, t.append...), nil } func TestSpeak(t *testing.T) { cases := []struct { name string prompt string - order int - new []string + append []byte want []string + say string }{ { name: "empty", prompt: "", - order: 1, - new: nil, + append: nil, want: nil, + say: "", }, { - name: "empty-new", + name: "empty-add", prompt: "", - order: 1, - new: []string{"anime"}, - want: []string{"anime"}, + append: []byte("bocchi"), + want: nil, + say: "bocchi", }, { - name: "prompted-short", - prompt: "madoka", - order: 2, - want: []string{"", "madoka"}, + name: "prompted", + prompt: "bocchi ryo nijika", + append: nil, + want: []string{"nijika", "ryo", "bocchi"}, + say: "bocchi ryo nijika", }, { - name: "prompted-even", - prompt: "madoka homura", - order: 2, - want: []string{"madoka", "homura"}, + name: "prompted-add", + prompt: "bocchi ryo nijika", + append: []byte("kita"), + want: []string{"nijika", "ryo", "bocchi"}, + say: "bocchi ryo nijika kita", }, { - name: "prompted-long", - prompt: "madoka homura anime", - order: 2, - want: []string{"homura", "anime"}, + name: "entropy", + prompt: "BOCCHI RYO NIJIKA", + append: []byte("KITA"), + want: []string{"nijika", "ryo", "bocchi"}, + say: "BOCCHI RYO NIJIKA KITA", }, - // TODO(zeph): cases testing entropy reduction } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - s := testSpeaker{order: c.order, new: c.new} + s := testSpeaker{append: c.append} r, err := brain.Speak(context.Background(), &s, "", c.prompt) if err != nil { t.Error(err) @@ -81,48 +72,8 @@ func TestSpeak(t *testing.T) { if diff := cmp.Diff(c.want, s.prompt); diff != "" { t.Errorf("wrong prompt from %q:\n%s", c.prompt, diff) } - // Check that each word the speaker gave to Speak appears in - // sequence in the result. - t.Logf("%q gave result %q", c.prompt, r) - for _, w := range c.want { - if w == "" { - continue - } - k := strings.Index(r, w) - if k < 0 { - t.Errorf("%q doesn't appear in %q", w, r) - continue - } - r = r[k:] - } - }) - } -} - -func TestSpeakResult(t *testing.T) { - cases := []struct { - name string - prompt string - order int - new []string - want string - }{ - { - name: "long-prompt", - prompt: "madoka homura sayaka mami", - order: 1, - want: "madoka homura sayaka mami", - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - s := testSpeaker{order: c.order, new: c.new} - r, err := brain.Speak(context.Background(), &s, "", c.prompt) - if err != nil { - t.Error(err) - } - if r != c.want { - t.Errorf("wrong result: wanted %q, got %q", c.want, r) + if diff := cmp.Diff(c.say, r); diff != "" { + t.Errorf("wrong result from %q:\n%s", c.say, diff) } }) } diff --git a/brain/sqlbrain/brain.go b/brain/sqlbrain/brain.go index 3bfce88..6c94da8 100644 --- a/brain/sqlbrain/brain.go +++ b/brain/sqlbrain/brain.go @@ -1,170 +1,53 @@ package sqlbrain import ( - "bytes" "context" - "embed" + _ "embed" "fmt" - "io/fs" - "strings" - "text/template" - _ "github.com/mattn/go-sqlite3" // driver - "gitlab.com/zephyrtronium/sq" + "zombiezen.com/go/sqlite" + "zombiezen.com/go/sqlite/sqlitex" ) +// Brain is an implementation of knowledge using an SQLite database. type Brain struct { - db DB - tpl *template.Template - stmts statements - order int + db *sqlitex.Pool } -type statements struct { - // selectTuple selects a tuple with a given tag and current state. - selectTuple *sq.Stmt - // newTuple selects a single starting term with a given tag. - newTuple *sq.Stmt - // deleteTuple is a sequence of statements to remove a single tuple with a - // given tag. It is strings instead of prepared statements because the - // sqlite3 driver actively resists my attempts to do horrible things. - deleteTuple []string -} - -// DB encapsulates database methods a Brain requires to allow use of a DB or a -// single Conn. -type DB interface { - Exec(ctx context.Context, query string, args ...any) (sq.Result, error) - Query(ctx context.Context, query string, args ...any) (*sq.Rows, error) - QueryRow(ctx context.Context, query string, args ...any) *sq.Row - Begin(ctx context.Context, opts *sq.TxOptions) (*sq.Tx, error) - Prepare(ctx context.Context, query string) (*sq.Stmt, error) - Close() error -} - -var _, _ DB = (*sq.DB)(nil), (*sq.Conn)(nil) - -// Open returns a brain within the given database. The db must remain open for -// the lifetime of the brain. -func Open(ctx context.Context, db DB) (*Brain, error) { - br := Brain{ - db: db, - tpl: template.New("base"), - } - err := db.QueryRow(ctx, `SELECT value FROM Config WHERE option='order'`).Scan(&br.order) - if err != nil { - return nil, fmt.Errorf("couldn't get order from database (not a brain?): %w", err) - } - // Parse templates. - // tuple.insert.sql is special because it is executed independently for - // every call instead of being executed once and prepared. - template.Must(br.tpl.New("tuple.insert.sql").Parse(insertTuple)) - br.stmts.newTuple = br.initTpStmt(ctx, "tuple.new.sql", newTuple) - br.stmts.selectTuple = br.initTpStmt(ctx, "tuple.select.sql", selectTuple) - br.initDelete() - +// Open returns a brain within the given database. +// The db must remain open for the lifetime of the brain. +func Open(ctx context.Context, db *sqlitex.Pool) (*Brain, error) { + // TODO(zeph): validate schema + br := Brain{db} return &br, nil } -// Close closes the brain's database. -func (br *Brain) Close() error { - return br.db.Close() -} - -// initTpStmt initializes a SQL statement that requires ahead-of-time template -// initialization. Panics on any error. -func (br *Brain) initTpStmt(ctx context.Context, name, text string) *sq.Stmt { - fib := make([]int, br.order-1) - a, b := 1, 1 - for i := range fib { - a, b = b, a+b - fib[i] = b - } - if br.order == 1 { - // Special case for the minimum order. In this case, the minimum score - // must be 0, because the score of every match is 0, since there is - // nothing additional in the prefix to contribute score. - a = 0 - } - - data := struct { - Iter []struct{} - Fibonacci []int - NM1 int - MinScore int - }{ - Iter: make([]struct{}, br.order), - Fibonacci: fib, - NM1: br.order - 1, - MinScore: a, - } - buf := make([]byte, 0, 2048) - w := bytes.NewBuffer(buf) - - tp, err := br.tpl.New(name).Parse(text) - if err != nil { - panic(fmt.Errorf("couldn't parse template %s: %w", name, err)) - } - if err := tp.Execute(w, &data); err != nil { - panic(fmt.Errorf("couldn't execute template %s: %w", name, err)) - } - s, err := br.db.Prepare(ctx, w.String()) - if err != nil { - panic(fmt.Errorf("couldn't prepare statement from %s: %w", name, err)) - } - return s -} - -// Order returns the brain's configured Markov chain order. -func (br *Brain) Order() int { - return br.order -} - -// Tx opens a transaction directly with the brain's database. Passing nil for -// opts uses reasonable defaults. The returned transaction must be committed -// or rolled back once finished. -func (br *Brain) Tx(ctx context.Context, opts *sq.TxOptions) (*sq.Tx, error) { - return br.db.Begin(ctx, opts) -} - -//go:embed templates/*.create.sql templates/*.pragma.sql -var createFiles embed.FS -var createTemplates = template.Must(template.ParseFS(createFiles, "templates/*.sql")) - -// Create initializes a new brain with the given order within a database. -func Create(ctx context.Context, db DB, order int) error { - // Create the query to generate the right schema with the given order. - data := struct { - N, NM1 int - Version int - Iter []struct{} - }{ - N: order, NM1: order - 1, - Version: SchemaVersion, - Iter: make([]struct{}, order), - } - var query strings.Builder - files, err := fs.ReadDir(createFiles, "templates") - if err != nil { - panic(err) - } - for _, file := range files { - err := createTemplates.ExecuteTemplate(&query, file.Name(), &data) +//go:embed schema.sql +var schemaSQL string + +// Create initializes a new brain in a database. +// For convenience, it accepts either a single connection or a pool. +func Create[DB *sqlite.Conn | *sqlitex.Pool](ctx context.Context, db DB) error { + var conn *sqlite.Conn + switch db := any(db).(type) { + case *sqlite.Conn: + conn = db + case *sqlitex.Pool: + var err error + conn, err = db.Take(ctx) + defer db.Put(conn) if err != nil { - // A problem here is a problem with the templates. - panic(fmt.Errorf("couldn't interpret %s: %w", file.Name(), err)) + return fmt.Errorf("couldn't get connection from pool: %w", err) } } - // Execute the query. - tx, err := db.Begin(ctx, nil) + err := sqlitex.ExecuteScript(conn, schemaSQL, nil) if err != nil { - return fmt.Errorf("couldn't open transaction: %w", err) - } - if _, err := tx.Exec(ctx, query.String()); err != nil { - return fmt.Errorf("couldn't exec %s\n%w", query.String(), err) - } - if err := tx.Commit(); err != nil { - return fmt.Errorf("couldn't commit: %w", err) + return fmt.Errorf("couldn't initialize schema: %w", err) } return nil } + +// Close closes the underlying database. +func (br *Brain) Close() error { + return br.db.Close() +} diff --git a/brain/sqlbrain/brain_test.go b/brain/sqlbrain/brain_test.go index ef79b08..9d801f5 100644 --- a/brain/sqlbrain/brain_test.go +++ b/brain/sqlbrain/brain_test.go @@ -6,41 +6,31 @@ import ( "sync/atomic" "testing" + "zombiezen.com/go/sqlite" + "zombiezen.com/go/sqlite/sqlitex" + "github.com/zephyrtronium/robot/brain" "github.com/zephyrtronium/robot/brain/braintest" "github.com/zephyrtronium/robot/brain/sqlbrain" - "gitlab.com/zephyrtronium/sq" - - _ "github.com/mattn/go-sqlite3" // driver for tests ) -var testdbCounter atomic.Uint64 +var dbCount atomic.Int64 -func testDB(order int) sqlbrain.DB { - ctx := context.Background() - k := testdbCounter.Add(1) - db, err := sq.Open("sqlite3", fmt.Sprintf("file:%d.db?cache=shared&mode=memory", k)) +func testDB(ctx context.Context) *sqlitex.Pool { + k := dbCount.Add(1) + pool, err := sqlitex.NewPool(fmt.Sprintf("file:%d.db?mode=memory&cache=shared", k), sqlitex.PoolOptions{Flags: sqlite.OpenReadWrite | sqlite.OpenCreate | sqlite.OpenMemory | sqlite.OpenSharedCache | sqlite.OpenURI}) if err != nil { panic(err) } - if err := db.Ping(ctx); err != nil { + conn, err := pool.Take(ctx) + defer pool.Put(conn) + if err != nil { panic(err) } - if err := sqlbrain.Create(ctx, db, order); err != nil { + if err := sqlbrain.Create(ctx, conn); err != nil { panic(err) } - return db -} - -func TestOpen(t *testing.T) { - ctx := context.Background() - br, err := sqlbrain.Open(ctx, testDB(2)) - if err != nil { - t.Error(err) - } - if got := br.Order(); got != 2 { - t.Errorf("wrong order after opening brain: want 2, got %d", got) - } + return pool } var _ brain.Learner = (*sqlbrain.Brain)(nil) @@ -50,10 +40,10 @@ func TestIntegrated(t *testing.T) { t.Parallel() ctx := context.Background() new := func(ctx context.Context) braintest.Interface { - db := testDB(2) + db := testDB(ctx) br, err := sqlbrain.Open(ctx, db) if err != nil { - t.Fatalf("couldn't open brain: %v", err) + panic(err) } return br } diff --git a/brain/sqlbrain/config.go b/brain/sqlbrain/config.go deleted file mode 100644 index 126c20d..0000000 --- a/brain/sqlbrain/config.go +++ /dev/null @@ -1,4 +0,0 @@ -package sqlbrain - -// SchemaVersion is the version of the brain database. -const SchemaVersion = 1 diff --git a/brain/sqlbrain/forget.go b/brain/sqlbrain/forget.go index 15edf27..a9197e3 100644 --- a/brain/sqlbrain/forget.go +++ b/brain/sqlbrain/forget.go @@ -2,132 +2,201 @@ package sqlbrain import ( "context" - "database/sql" _ "embed" "fmt" - "strconv" - "strings" + "slices" "time" "github.com/google/uuid" - "gitlab.com/zephyrtronium/sq" + "zombiezen.com/go/sqlite" + "zombiezen.com/go/sqlite/sqlitex" "github.com/zephyrtronium/robot/brain" "github.com/zephyrtronium/robot/userhash" ) -// Forget deletes tuples from the database. To ensure consistency and accuracy, -// the ForgetMessage, ForgetDuring, and ForgetUserSince methods should be -// preferred where possible. -func (br *Brain) Forget(ctx context.Context, tag string, tuples []brain.Tuple) error { - names := make([]sq.NamedArg, 2+br.order) - names[0] = sql.Named("tag", tag) - terms := make([]string, 1+br.order) - for i := 0; i < br.order; i++ { - names[i+1] = sql.Named("p"+strconv.Itoa(i), &terms[i]) - } - names[br.order+1] = sql.Named("suffix", &terms[br.order]) - p := make([]any, len(names)) - for i := range names { - p[i] = names[i] - } - tx, err := br.db.Begin(ctx, nil) +//go:embed forget.sql +var forgetSQL string + +// Forget removes a set of recorded tuples. +func (br *Brain) Forget(ctx context.Context, tag string, tuples []brain.Tuple) (err error) { + conn, err := br.db.Take(ctx) + defer br.db.Put(conn) if err != nil { - return fmt.Errorf("couldn't open transaction: %w", err) + return fmt.Errorf("couldn't get connection to forget: %w", err) } - defer tx.Rollback() - for _, tup := range tuples { - // Note that each p[i] is a named arg, and those for the prefix and - // suffix each point to an element of terms. So, updating terms is - // sufficient to update the query parameters. - copy(terms, tup.Prefix) - terms[br.order] = tup.Suffix - // Execute the statements in order. We do this manually because the - // arguments differ for some statements, and the SQLite3 driver - // complains if we give the wrong ones. - snd := func(_ sq.Result, err error) error { return err } - steps := []func() error{ - func() error { return snd(tx.Exec(ctx, br.stmts.deleteTuple[0])) }, - func() error { return snd(tx.Exec(ctx, br.stmts.deleteTuple[1], p...)) }, - func() error { return snd(tx.Exec(ctx, br.stmts.deleteTuple[2], p[1:]...)) }, - func() error { return snd(tx.Exec(ctx, br.stmts.deleteTuple[3])) }, - func() error { return snd(tx.Exec(ctx, br.stmts.deleteTuple[4])) }, - } - for i, step := range steps { - err := step() - if err != nil { - return fmt.Errorf("couldn't remove tuples (step %d failed): %w", i, err) - } + defer sqlitex.Transaction(conn)(&err) + p := make([]byte, 0, 256) + s := make([]byte, 0, 32) + for _, tt := range tuples { + p = prefix(p[:0], tt.Prefix) + s = append(s[:0], tt.Suffix...) + // Unlike learning and speaking, forgetting is generally outside the hot path. + // So, it's fine to have extra allocations and reflection here. + opts := sqlitex.ExecOptions{ + Named: map[string]any{ + ":tag": tag, + ":prefix": p, + ":suffix": s, + }, + } + if err := sqlitex.Execute(conn, forgetSQL, &opts); err != nil { + return fmt.Errorf("couldn't forget: %w", err) } - } - if err := tx.Commit(); err != nil { - return fmt.Errorf("couldn't commit delete ops: %w", err) } return nil } -// ForgetMessage removes tuples associated with a message from the database. -// The delete reason is set to "CLEARMSG". -func (br *Brain) ForgetMessage(ctx context.Context, tag string, msg uuid.UUID) error { - res, err := br.db.Exec(ctx, `UPDATE Message SET deleted='CLEARMSG' WHERE id = ?`, msg) +// ForgetMessage forgets everything learned from a single given message. +// If nothing has been learned from the message, nothing happens. +func (br *Brain) ForgetMessage(ctx context.Context, tag string, msg uuid.UUID) (err error) { + conn, err := br.db.Take(ctx) + defer br.db.Put(conn) if err != nil { - return fmt.Errorf("couldn't delete message %v: %w", msg, err) + return fmt.Errorf("couldn't get connection to forget message %v: %w", msg, err) } - n, err := res.RowsAffected() - if err != nil { - // Since the query succeeded, an error here is probably from the driver - // not supporting RowsAffected (although those we use do). Don't care. - return nil + defer sqlitex.Transaction(conn)(&err) + { + // First forget the message, so that an attempt to learn it later will fail. + const forget = ` + INSERT INTO messages (tag, id, deleted) VALUES (:tag, :id, 'CLEARMSG') + ON CONFLICT DO UPDATE SET deleted = 'CLEARMSG' + ` + st, err := conn.Prepare(forget) + if err != nil { + return fmt.Errorf("couldn't prepare delete for message %v: %w", msg, err) + } + st.SetText(":tag", tag) + st.SetBytes(":id", msg[:]) + if err := allsteps(st); err != nil { + return fmt.Errorf("couldn't delete message %v: %w", msg, err) + } } - if n == 0 { - return fmt.Errorf("no message with id %v", msg) + { + // Now forget tuples. + const forget = `UPDATE knowledge SET deleted = 'CLEARMSG' WHERE tag=:tag AND id=:id` + st, err := conn.Prepare(forget) + if err != nil { + return fmt.Errorf("couldn't prepare delete for tuples of message %v: %w", msg, err) + } + st.SetText(":tag", tag) + st.SetBytes(":id", msg[:]) + if err := allsteps(st); err != nil { + return fmt.Errorf("couldn't delete tuples of message %v: %w", msg, err) + } } return nil } -// ForgetDuring removes tuples associated with messages learned in the given -// time span. The delete reason is set to "TIMED". -func (br *Brain) ForgetDuring(ctx context.Context, tag string, since, before time.Time) error { - a, b := since.UnixMilli(), before.UnixMilli() - _, err := br.db.Exec(ctx, `UPDATE Message SET deleted='TIMED' WHERE tag = ? AND time BETWEEN ? AND ?`, tag, a, b) +// ForgetDuring forgets all messages learned in the given time span. +func (br *Brain) ForgetDuring(ctx context.Context, tag string, since time.Time, before time.Time) error { + conn, err := br.db.Take(ctx) + defer br.db.Put(conn) if err != nil { - return fmt.Errorf("couldn't delete messages between %v and %v: %w", since, before, err) + return fmt.Errorf("couldn't get connection to forget time span: %w", err) + } + defer sqlitex.Transaction(conn)(&err) + // Forget messages by time and get their IDs. + const forgetTime = `UPDATE messages SET deleted = 'TIME' WHERE tag=:tag AND time BETWEEN :since AND :before RETURNING id` + sm, err := conn.Prepare(forgetTime) + if err != nil { + return fmt.Errorf("couldn't prepare delete for messages in time span: %w", err) + } + sm.SetText(":tag", tag) + sm.SetInt64(":since", since.UnixNano()) + sm.SetInt64(":before", before.UnixNano()) + const forgetTuple = `UPDATE knowledge SET deleted = 'TIME' WHERE tag=:tag AND id=:id` + st, err := conn.Prepare(forgetTuple) + if err != nil { + return fmt.Errorf("couldn't prepare delete for tuples in time span: %w", err) + } + st.SetText(":tag", tag) + // Now forget tuples by the IDs. + id := make([]byte, 0, 16) + for { + ok, err := sm.Step() + if err != nil { + return fmt.Errorf("couldn't step delete for messages in time span: %w", err) + } + if !ok { + break + } + idk := sm.ColumnIndex("id") + if idk < 0 { + panic("sqlbrain: no index for id column") + } + l := sm.ColumnLen(idk) + id = slices.Grow(id[:0], l)[:l] + sm.ColumnBytes(idk, id) + st.SetBytes(":id", id) + if err := allsteps(st); err != nil { + return fmt.Errorf("couldn't step delete for tuples in time span: %w", err) + } + if err := st.Reset(); err != nil { + return fmt.Errorf("couldn't reset delete for tuples in time span: %w", err) + } } return nil } -// ForgetUser removes tuples learned from the given user hash. -// The delete reason is set to "CLEARCHAT". +// ForgetUser forgets all messages associated with a userhash. func (br *Brain) ForgetUser(ctx context.Context, user *userhash.Hash) error { - _, err := br.db.Exec(ctx, `UPDATE Message SET deleted='CLEARCHAT' WHERE user = ?`, user[:]) + conn, err := br.db.Take(ctx) + defer br.db.Put(conn) if err != nil { - return fmt.Errorf("couldn't forget messages from %x: %w", user, err) + return fmt.Errorf("couldn't get connection to forget from user: %w", err) } - return nil -} - -func (br *Brain) initDelete() { - tp, err := br.tpl.Parse(deleteTuple) + defer sqlitex.Transaction(conn)(&err) + // Forget messages by user and get their IDs. + const forgetUser = `UPDATE messages SET deleted = 'CLEARCHAT' WHERE user = :user RETURNING tag, id` + sm, err := conn.Prepare(forgetUser) if err != nil { - panic(fmt.Errorf("couldn't parse tuple.delete.sql: %w", err)) + return fmt.Errorf("couldn't prepare delete for messages from user: %w", err) } - const numTemplates = 5 - br.stmts.deleteTuple = make([]string, numTemplates) - data := struct { - Iter []struct{} - }{ - Iter: make([]struct{}, br.order), + sm.SetBytes(":user", user[:]) + const forgetTuple = `UPDATE knowledge SET deleted = 'CLEARCHAT' WHERE tag=:tag AND id=:id` + st, err := conn.Prepare(forgetTuple) + if err != nil { + return fmt.Errorf("couldn't prepare delete for tuples from user: %w", err) } - var b strings.Builder - for i := range br.stmts.deleteTuple { - b.Reset() - err := tp.ExecuteTemplate(&b, fmt.Sprintf("tuple.delete.%d", i), &data) + // Now forget by the IDs. + id := make([]byte, 0, 16) + for { + ok, err := sm.Step() if err != nil { - panic(fmt.Errorf("couldn't exec tuple.delete.%d: %w", i, err)) + return fmt.Errorf("couldn't step delete for messages from user: %w", err) + } + if !ok { + break + } + tag := sm.GetText("tag") + idk := sm.ColumnIndex("id") + if idk < 0 { + panic("sqlbrain: no index for id column") + } + l := sm.ColumnLen(idk) + id = slices.Grow(id[:0], l)[:l] + sm.ColumnBytes(idk, id) + st.SetText(":tag", tag) + st.SetBytes(":id", id) + if err := allsteps(st); err != nil { + return fmt.Errorf("couldn't step delete for tuples from user: %w", err) + } + if err := st.Reset(); err != nil { + return fmt.Errorf("couldn't reset delete for tuples from user: %w", err) } - br.stmts.deleteTuple[i] = b.String() } + return nil } -//go:embed templates/tuple.delete.sql -var deleteTuple string +func allsteps(st *sqlite.Stmt) error { + for { + ok, err := st.Step() + if err != nil { + return err + } + if !ok { + return nil + } + } +} diff --git a/brain/sqlbrain/forget.sql b/brain/sqlbrain/forget.sql new file mode 100644 index 0000000..2d54d2d --- /dev/null +++ b/brain/sqlbrain/forget.sql @@ -0,0 +1,14 @@ +UPDATE knowledge +SET deleted = 'FORGET' +WHERE tag = :tag + AND id = ( + SELECT id + FROM knowledge + WHERE tag = :tag + AND prefix = :prefix + AND suffix = :suffix + AND LIKELY(deleted IS NULL) + LIMIT 1 + ) + AND prefix = :prefix + -- We don't need to match suffix because every prefix of a given message is unique. diff --git a/brain/sqlbrain/forget_test.go b/brain/sqlbrain/forget_test.go index 8cd025e..3d63f51 100644 --- a/brain/sqlbrain/forget_test.go +++ b/brain/sqlbrain/forget_test.go @@ -2,15 +2,11 @@ package sqlbrain_test import ( "context" - "slices" - "sort" + "strings" "testing" "time" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" - "gitlab.com/zephyrtronium/sq" "github.com/zephyrtronium/robot/brain" "github.com/zephyrtronium/robot/brain/sqlbrain" @@ -18,1220 +14,1800 @@ import ( ) func TestForget(t *testing.T) { - type insert struct { - tag string - tuples []brain.Tuple + learn := []learn{ + { + tag: "結束", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: strings.Fields("喜多 虹夏 リョウ ぼっち"), Suffix: ""}, + {Prefix: strings.Fields("虹夏 リョウ ぼっち"), Suffix: "喜多"}, + {Prefix: strings.Fields("リョウ ぼっち"), Suffix: "虹夏"}, + {Prefix: strings.Fields("ぼっち"), Suffix: "リョウ"}, + {Prefix: nil, Suffix: "ぼっち"}, + }, + }, + { + tag: "結束", + user: userhash.Hash{4}, + id: uuid.UUID{5}, + t: 6, + tups: []brain.Tuple{ + {Prefix: []string{"bocchi"}, Suffix: ""}, + {Prefix: nil, Suffix: "bocchi"}, + }, + }, + { + tag: "結束", + user: userhash.Hash{7}, + id: uuid.UUID{8}, + t: 9, + tups: []brain.Tuple{ + {Prefix: []string{"bocchi"}, Suffix: ""}, + {Prefix: nil, Suffix: "bocchi"}, + }, + }, + { + tag: "sickhack", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: []string{"bocchi"}, Suffix: ""}, + {Prefix: nil, Suffix: "bocchi"}, + }, + }, + } + initKnow := []know{ + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "喜多\x00虹夏\x00リョウ\x00ぼっち\x00", + suffix: "", + }, + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "虹夏\x00リョウ\x00ぼっち\x00", + suffix: "喜多", + }, + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "リョウ\x00ぼっち\x00", + suffix: "虹夏", + }, + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "ぼっち\x00", + suffix: "リョウ", + }, + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "", + suffix: "ぼっち", + }, + { + tag: "結束", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "結束", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "結束", + id: uuid.UUID{8}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "結束", + id: uuid.UUID{8}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, + } + initMsgs := []msg{ + { + tag: "結束", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + { + tag: "結束", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, + }, + { + tag: "結束", + id: uuid.UUID{8}, + time: 9, + user: userhash.Hash{7}, + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + } + type forget struct { + tag string + tups []brain.Tuple } cases := []struct { name string - order int - insert []insert - forget []insert - left []insert + forget []forget + know []know + msgs []msg }{ { - name: "empty-1", - order: 1, - insert: nil, - forget: []insert{ + name: "empty", + forget: nil, + know: initKnow, + msgs: initMsgs, + }, + { + name: "none", + forget: []forget{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, + tag: "結束", + tups: []brain.Tuple{ + {Prefix: strings.Fields("tuples such no"), Suffix: ""}, + {Prefix: strings.Fields("such no"), Suffix: "tuples"}, + {Prefix: strings.Fields("no"), Suffix: "such"}, + {Prefix: nil, Suffix: "no"}, }, }, }, - left: nil, + know: initKnow, + msgs: initMsgs, }, { - name: "success-1", - order: 1, - insert: []insert{ + name: "single", + forget: []forget{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, + tag: "結束", + tups: []brain.Tuple{ + {Prefix: nil, Suffix: "ぼっち"}, }, }, }, - forget: []insert{ + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "喜多\x00虹夏\x00リョウ\x00ぼっち\x00", + suffix: "", }, - }, - left: nil, - }, - { - name: "prefix-1", - order: 1, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "虹夏\x00リョウ\x00ぼっち\x00", + suffix: "喜多", }, - }, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "リョウ\x00ぼっち\x00", + suffix: "虹夏", }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "ぼっち\x00", + suffix: "リョウ", }, - }, - }, - { - name: "suffix-1", - order: 1, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "", + suffix: "ぼっち", + deleted: ref("FORGET"), }, - }, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"c"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", }, - }, - }, - { - name: "single-1", - order: 1, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{8}, + prefix: "bocchi\x00", + suffix: "", }, - }, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{8}, + prefix: "", + suffix: "bocchi", }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", }, }, + msgs: initMsgs, }, { - name: "idempotent-1", - order: 1, - insert: []insert{ + name: "all", + forget: []forget{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"a"}, Suffix: "c"}, + tag: "結束", + tups: []brain.Tuple{ + {Prefix: strings.Fields("喜多 虹夏 リョウ ぼっち"), Suffix: ""}, + {Prefix: strings.Fields("虹夏 リョウ ぼっち"), Suffix: "喜多"}, + {Prefix: strings.Fields("リョウ ぼっち"), Suffix: "虹夏"}, + {Prefix: strings.Fields("ぼっち"), Suffix: "リョウ"}, + {Prefix: nil, Suffix: "ぼっち"}, }, }, }, - forget: []insert{ + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "喜多\x00虹夏\x00リョウ\x00ぼっち\x00", + suffix: "", + deleted: ref("FORGET"), }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "虹夏\x00リョウ\x00ぼっち\x00", + suffix: "喜多", + deleted: ref("FORGET"), }, - }, - }, - { - name: "repeat-1", - order: 1, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "リョウ\x00ぼっち\x00", + suffix: "虹夏", + deleted: ref("FORGET"), }, - }, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "ぼっち\x00", + suffix: "リョウ", + deleted: ref("FORGET"), }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "", + suffix: "ぼっち", + deleted: ref("FORGET"), }, - }, - }, - { - name: "tagged-1", - order: 1, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", }, - }, - forget: []insert{ { - tag: "homura", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "結束", + id: uuid.UUID{8}, + prefix: "bocchi\x00", + suffix: "", }, - }, - }, - { - name: "empty-2", - order: 2, - insert: nil, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{8}, + prefix: "", + suffix: "bocchi", }, - }, - left: nil, - }, - { - name: "success-2", - order: 2, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "", }, - }, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", }, }, - left: nil, + msgs: initMsgs, }, { - name: "prefix-2", - order: 2, - insert: []insert{ + name: "once", + forget: []forget{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "d"}, + tag: "結束", + tups: []brain.Tuple{ + {Prefix: nil, Suffix: "bocchi"}, }, }, }, - forget: []insert{ + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "喜多\x00虹夏\x00リョウ\x00ぼっち\x00", + suffix: "", }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "d"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "虹夏\x00リョウ\x00ぼっち\x00", + suffix: "喜多", }, - }, - }, - { - name: "suffix-first-2", - order: 2, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "リョウ\x00ぼっち\x00", + suffix: "虹夏", }, - }, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"d", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "ぼっち\x00", + suffix: "リョウ", }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "", + suffix: "ぼっち", }, - }, - }, - { - name: "suffix-second-2", - order: 2, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", }, - }, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "d"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + deleted: ref("FORGET"), }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{8}, + prefix: "bocchi\x00", + suffix: "", }, - }, - }, - { - name: "single-2", - order: 2, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{8}, + prefix: "", + suffix: "bocchi", }, - }, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "", }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", }, }, + msgs: initMsgs, }, { - name: "idempotent-2", - order: 2, - insert: []insert{ + name: "multi", + forget: []forget{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - {Prefix: []string{"a", "b"}, Suffix: "d"}, + tag: "結束", + tups: []brain.Tuple{ + {Prefix: nil, Suffix: "bocchi"}, }, }, - }, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - {Prefix: []string{"a", "b"}, Suffix: "c"}, + tag: "結束", + tups: []brain.Tuple{ + {Prefix: nil, Suffix: "bocchi"}, }, }, }, - left: []insert{ + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "d"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "喜多\x00虹夏\x00リョウ\x00ぼっち\x00", + suffix: "", }, - }, - }, - { - name: "repeat-2", - order: 2, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - {Prefix: []string{"a", "b"}, Suffix: "c"}, - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "虹夏\x00リョウ\x00ぼっち\x00", + suffix: "喜多", }, - }, - forget: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "リョウ\x00ぼっち\x00", + suffix: "虹夏", }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "ぼっち\x00", + suffix: "リョウ", + }, + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "", + suffix: "ぼっち", + }, + { + tag: "結束", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "結束", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + deleted: ref("FORGET"), + }, + { + tag: "結束", + id: uuid.UUID{8}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "結束", + id: uuid.UUID{8}, + prefix: "", + suffix: "bocchi", + deleted: ref("FORGET"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", }, }, + msgs: initMsgs, }, { - name: "tagged-2", - order: 2, - insert: []insert{ + name: "tag", + forget: []forget{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, + tag: "sickhack", + tups: []brain.Tuple{ + {Prefix: nil, Suffix: "bocchi"}, }, }, }, - forget: []insert{ + know: []know{ { - tag: "homura", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "喜多\x00虹夏\x00リョウ\x00ぼっち\x00", + suffix: "", }, - }, - left: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "結束", + id: uuid.UUID{2}, + prefix: "虹夏\x00リョウ\x00ぼっち\x00", + suffix: "喜多", + }, + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "リョウ\x00ぼっち\x00", + suffix: "虹夏", + }, + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "ぼっち\x00", + suffix: "リョウ", + }, + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "", + suffix: "ぼっち", + }, + { + tag: "結束", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "結束", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "結束", + id: uuid.UUID{8}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "結束", + id: uuid.UUID{8}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + deleted: ref("FORGET"), }, }, + msgs: initMsgs, }, } for _, c := range cases { - c := c t.Run(c.name, func(t *testing.T) { t.Parallel() ctx := context.Background() - db := testDB(c.order) + db := testDB(ctx) br, err := sqlbrain.Open(ctx, db) if err != nil { t.Fatalf("couldn't open brain: %v", err) } - for _, v := range c.insert { - err := addTuples(ctx, db, tagged(v.tag), v.tuples) - if err != nil { - t.Fatal(err) - } - // Double-check that the tuples are in. - // This is largely testing that tuples() works as advertised. - tups, err := tuples(ctx, db, v.tag, c.order) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(v.tuples, tups, cmpopts.EquateEmpty()); diff != "" { - t.Fatalf("wrong tuples added before test (+got/-want):\n%s", diff) - } - } - for i, v := range c.forget { - err := br.Forget(ctx, v.tag, v.tuples) - if err != nil { - t.Errorf("couldn't forget %q (group %d): %v", v, i, err) - } - } - var wantTags []string - for _, left := range c.left { - wantTags = append(wantTags, left.tag) - tups, err := tuples(ctx, db, left.tag, c.order) + for _, m := range learn { + err := br.Learn(ctx, m.tag, m.user, m.id, time.Unix(0, m.t), m.tups) if err != nil { - t.Errorf("couldn't get remaining tuples for tag %s: %v", left.tag, err) - continue - } - if diff := cmp.Diff(left.tuples, tups, cmpopts.EquateEmpty()); diff != "" { - t.Errorf("wrong tuples left with tag %s (+got/-want):\n%s", left.tag, diff) + t.Errorf("failed to learn %v/%v: %v", m.tag, m.id, err) } } - sort.Strings(wantTags) - gotTags, err := tags(ctx, t, db) + conn, err := db.Take(ctx) + defer db.Put(conn) if err != nil { - t.Errorf("couldn't get tags list: %v", err) - } - if diff := cmp.Diff(wantTags, gotTags); diff != "" { - t.Errorf("wrong tags have tuples (+got/-want):\n%s", diff) + t.Fatalf("couldn't get conn to check db state: %v", err) } + contents(t, conn, initKnow, initMsgs) if t.Failed() { - dumpdb(ctx, t, db) + t.Fatal("setup failed") + } + for _, v := range c.forget { + err := br.Forget(ctx, v.tag, v.tups) + if err != nil { + t.Errorf("couldn't forget %q in %v: %v", v.tups, v.tag, err) + } } + contents(t, conn, c.know, c.msgs) }) } } func TestForgetMessage(t *testing.T) { - type insert struct { - id uuid.UUID - tag string - tuples []brain.Tuple - } - type forget struct { - id uuid.UUID - tag string - } - type remain struct { - tag string - tuples []brain.Tuple - } - uuids := []uuid.UUID{ - uuid.New(), - uuid.New(), - } - cases := []struct { - name string - order int - insert []insert - forget []forget - left []remain - errs bool - }{ + learn := []learn{ { - name: "single-1", - order: 1, - insert: []insert{ - { - id: uuids[0], - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, - }, + tag: "kessoku", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: strings.Fields("kita nijika ryo bocchi"), Suffix: ""}, + {Prefix: strings.Fields("nijika ryo bocchi"), Suffix: "kita"}, + {Prefix: strings.Fields("ryo bocchi"), Suffix: "nijika"}, + {Prefix: strings.Fields("bocchi"), Suffix: "ryo"}, + {Prefix: nil, Suffix: "bocchi"}, }, - forget: []forget{ - {id: uuids[0], tag: "madoka"}, - }, - left: nil, }, { - name: "multi-1", - order: 1, - insert: []insert{ - { - id: uuids[0], - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"b"}, Suffix: "c"}, - {Prefix: []string{"c"}, Suffix: "d"}, - }, - }, - }, - forget: []forget{ - {id: uuids[0], tag: "madoka"}, + tag: "kessoku", + user: userhash.Hash{4}, + id: uuid.UUID{5}, + t: 6, + tups: []brain.Tuple{ + {Prefix: []string{"bocchi"}, Suffix: ""}, + {Prefix: nil, Suffix: "bocchi"}, }, - left: nil, }, { - name: "unmatched-1", - order: 1, - insert: []insert{ - { - id: uuids[0], - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, - }, - }, - forget: []forget{ - {id: uuids[1], tag: "madoka"}, + tag: "sickhack", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: []string{"kikuri"}, Suffix: ""}, + {Prefix: nil, Suffix: "kikuri"}, }, - left: []remain{ - { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, - }, - }, - errs: true, }, + } + initKnow := []know{ { - name: "single-2", - order: 2, - insert: []insert{ - { - id: uuids[0], - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, - }, - }, - forget: []forget{ - {id: uuids[0], tag: "madoka"}, - }, - left: nil, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", }, { - name: "multi-2", - order: 2, - insert: []insert{ - { - id: uuids[0], - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - {Prefix: []string{"b", "c"}, Suffix: "d"}, - {Prefix: []string{"c", "d"}, Suffix: "e"}, - }, - }, - }, - forget: []forget{ - {id: uuids[0], tag: "madoka"}, - }, - left: nil, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", }, { - name: "unmatched-2", - order: 2, - insert: []insert{ - { - id: uuids[0], - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, - }, - }, - forget: []forget{ - {id: uuids[1], tag: "madoka"}, - }, - left: []remain{ - { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, - }, - }, - errs: true, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", }, - } - for _, c := range cases { - c := c - t.Run(c.name, func(t *testing.T) { - t.Parallel() - ctx := context.Background() - db := testDB(c.order) + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", + }, + } + initMsgs := []msg{ + { + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + } + cases := []struct { + name string + tag string + id uuid.UUID + know []know + msgs []msg + }{ + { + name: "none", + tag: "kessoku", + id: uuid.UUID{}, + know: initKnow, + msgs: initMsgs, + }, + { + name: "first", + tag: "kessoku", + id: uuid.UUID{2}, + know: []know{ + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + deleted: ref("CLEARMSG"), + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + deleted: ref("CLEARMSG"), + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + deleted: ref("CLEARMSG"), + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + deleted: ref("CLEARMSG"), + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + deleted: ref("CLEARMSG"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", + }, + }, + msgs: []msg{ + { + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + deleted: ref("CLEARMSG"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + }, + }, + { + name: "second", + tag: "kessoku", + id: uuid.UUID{5}, + know: []know{ + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + deleted: ref("CLEARMSG"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + deleted: ref("CLEARMSG"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", + }, + }, + msgs: []msg{ + { + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, + deleted: ref("CLEARMSG"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + }, + }, + { + name: "tagged", + tag: "sickhack", + id: uuid.UUID{2}, + know: []know{ + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + deleted: ref("CLEARMSG"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", + deleted: ref("CLEARMSG"), + }, + }, + msgs: []msg{ + { + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + deleted: ref("CLEARMSG"), + }, + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + db := testDB(ctx) br, err := sqlbrain.Open(ctx, db) if err != nil { t.Fatalf("couldn't open brain: %v", err) } - for _, v := range c.insert { - md := brain.MessageMeta{ - ID: v.id, - Tag: v.tag, - } - err := addTuples(ctx, db, md, v.tuples) - if err != nil { - t.Fatal(err) - } - // Double-check that the tuples are in. - tups, err := tuples(ctx, db, v.tag, c.order) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(v.tuples, tups, cmpopts.EquateEmpty()); diff != "" { - t.Fatalf("wrong tuples added before test (+got/-want):\n%s", diff) - } - } - for _, f := range c.forget { - err := br.ForgetMessage(ctx, f.tag, f.id) - if err != nil && !c.errs { - t.Errorf("couldn't forget message %v: %v", f.id, err) - } else if err == nil && c.errs { - t.Error("expected forget to fail") - } - } - var wantTags []string - for _, left := range c.left { - wantTags = append(wantTags, left.tag) - tups, err := tuples(ctx, db, left.tag, c.order) + for _, m := range learn { + err := br.Learn(ctx, m.tag, m.user, m.id, time.Unix(0, m.t), m.tups) if err != nil { - t.Errorf("couldn't get remaining tuples for tag %s: %v", left.tag, err) - continue - } - if diff := cmp.Diff(left.tuples, tups, cmpopts.EquateEmpty()); diff != "" { - t.Errorf("wrong tuples left with tag %s (+got/-want):\n%s", left.tag, diff) + t.Errorf("failed to learn %v/%v: %v", m.tag, m.id, err) } } - sort.Strings(wantTags) - gotTags, err := tags(ctx, t, db) + conn, err := db.Take(ctx) + defer db.Put(conn) if err != nil { - t.Errorf("couldn't get tags list: %v", err) - } - if diff := cmp.Diff(wantTags, gotTags); diff != "" { - t.Errorf("wrong tags have tuples (+got/-want):\n%s", diff) + t.Fatalf("couldn't get conn to check db state: %v", err) } + contents(t, conn, initKnow, initMsgs) if t.Failed() { - dumpdb(ctx, t, db) + t.Fatal("setup failed") + } + if err := br.ForgetMessage(ctx, c.tag, c.id); err != nil { + t.Errorf("failed to delete %v/%v: %v", c.tag, c.id, err) } + contents(t, conn, c.know, c.msgs) }) } } func TestForgetDuring(t *testing.T) { - type insert struct { - tag string - time int64 - tuples []brain.Tuple + learn := []learn{ + { + tag: "kessoku", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: strings.Fields("kita nijika ryo bocchi"), Suffix: ""}, + {Prefix: strings.Fields("nijika ryo bocchi"), Suffix: "kita"}, + {Prefix: strings.Fields("ryo bocchi"), Suffix: "nijika"}, + {Prefix: strings.Fields("bocchi"), Suffix: "ryo"}, + {Prefix: nil, Suffix: "bocchi"}, + }, + }, + { + tag: "kessoku", + user: userhash.Hash{4}, + id: uuid.UUID{5}, + t: 6, + tups: []brain.Tuple{ + {Prefix: []string{"bocchi"}, Suffix: ""}, + {Prefix: nil, Suffix: "bocchi"}, + }, + }, + { + tag: "sickhack", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: []string{"kikuri"}, Suffix: ""}, + {Prefix: nil, Suffix: "kikuri"}, + }, + }, } - type remain struct { - tag string - tuples []brain.Tuple + initKnow := []know{ + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", + }, + } + initMsgs := []msg{ + { + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, } cases := []struct { name string - order int - insert []insert tag string - forget [2]int64 - left []remain + since int64 + before int64 + know []know + msgs []msg }{ { - name: "single-1", - order: 1, - insert: []insert{ - { - tag: "madoka", - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, - }, - }, - tag: "madoka", - forget: [2]int64{1, 3}, - left: nil, + name: "none", + tag: "kessoku", + since: 100, + before: 200, + know: initKnow, + msgs: initMsgs, }, { - name: "multiple-1", - order: 1, - insert: []insert{ + name: "early", + tag: "kessoku", + since: 1, + before: 4, + know: []know{ { - tag: "madoka", - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"b"}, Suffix: "c"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + deleted: ref("TIME"), }, - }, - tag: "madoka", - forget: [2]int64{1, 3}, - left: nil, - }, - { - name: "outside-1", - order: 1, - insert: []insert{ { - tag: "madoka", - time: 4, - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + deleted: ref("TIME"), }, - }, - tag: "madoka", - forget: [2]int64{1, 3}, - left: []remain{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + deleted: ref("TIME"), }, - }, - }, - { - name: "tagged-1", - order: 1, - insert: []insert{ { - tag: "madoka", - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", }, }, - tag: "homura", - forget: [2]int64{1, 3}, - left: []remain{ + msgs: []msg{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, }, }, }, { - name: "single-2", - order: 2, - insert: []insert{ + name: "late", + tag: "kessoku", + since: 5, + before: 8, + know: []know{ { - tag: "madoka", - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + deleted: ref("TIME"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", }, }, - tag: "madoka", - forget: [2]int64{1, 3}, - left: nil, - }, - { - name: "multiple-2", - order: 2, - insert: []insert{ + msgs: []msg{ { - tag: "madoka", - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - {Prefix: []string{"b", "c"}, Suffix: "d"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, + deleted: ref("TIME"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, }, }, - tag: "madoka", - forget: [2]int64{1, 3}, - left: nil, }, { - name: "outside-2", - order: 2, - insert: []insert{ + name: "all", + tag: "kessoku", + since: 1, + before: 8, + know: []know{ { - tag: "madoka", - time: 4, - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + deleted: ref("TIME"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", }, }, - tag: "madoka", - forget: [2]int64{1, 3}, - left: []remain{ + msgs: []msg{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + deleted: ref("TIME"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, + deleted: ref("TIME"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, }, }, }, { - name: "tagged-2", - order: 2, - insert: []insert{ + name: "tagged", + tag: "sickhack", + since: 1, + before: 7, + know: []know{ { - tag: "madoka", - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + deleted: ref("TIME"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", + deleted: ref("TIME"), }, }, - tag: "homura", - forget: [2]int64{1, 3}, - left: []remain{ + msgs: []msg{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + deleted: ref("TIME"), }, }, }, } for _, c := range cases { - c := c t.Run(c.name, func(t *testing.T) { t.Parallel() ctx := context.Background() - db := testDB(c.order) + db := testDB(ctx) br, err := sqlbrain.Open(ctx, db) if err != nil { t.Fatalf("couldn't open brain: %v", err) } - for _, v := range c.insert { - md := brain.MessageMeta{ - ID: uuid.New(), - Time: time.UnixMilli(v.time), - Tag: v.tag, - } - err := addTuples(ctx, db, md, v.tuples) - if err != nil { - t.Fatal(err) - } - // Double-check that the tuples are in. - tups, err := tuples(ctx, db, v.tag, c.order) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(v.tuples, tups, cmpopts.EquateEmpty()); diff != "" { - t.Fatalf("wrong tuples added before test (+got/-want):\n%s", diff) - } - } - a, b := time.UnixMilli(c.forget[0]), time.UnixMilli(c.forget[1]) - if err := br.ForgetDuring(ctx, c.tag, a, b); err != nil { - t.Errorf("could't forget: %v", err) - } - var wantTags []string - for _, left := range c.left { - wantTags = append(wantTags, left.tag) - tups, err := tuples(ctx, db, left.tag, c.order) + for _, m := range learn { + err := br.Learn(ctx, m.tag, m.user, m.id, time.Unix(0, m.t), m.tups) if err != nil { - t.Errorf("couldn't get remaining tuples for tag %s: %v", left.tag, err) - continue - } - if diff := cmp.Diff(left.tuples, tups, cmpopts.EquateEmpty()); diff != "" { - t.Errorf("wrong tuples left with tag %s (+got/-want):\n%s", left.tag, diff) + t.Errorf("failed to learn %v/%v: %v", m.tag, m.id, err) } } - sort.Strings(wantTags) - gotTags, err := tags(ctx, t, db) + conn, err := db.Take(ctx) + defer db.Put(conn) if err != nil { - t.Errorf("couldn't get tags list: %v", err) - } - if diff := cmp.Diff(wantTags, gotTags); diff != "" { - t.Errorf("wrong tags have tuples (+got/-want):\n%s", diff) + t.Fatalf("couldn't get conn to check db state: %v", err) } + contents(t, conn, initKnow, initMsgs) if t.Failed() { - dumpdb(ctx, t, db) + t.Fatal("setup failed") + } + since, before := time.Unix(0, c.since), time.Unix(0, c.before) + if err := br.ForgetDuring(ctx, c.tag, since, before); err != nil { + t.Errorf("couldn't delete in %v between %d and %d: %v", c.tag, c.since, c.before, err) } + contents(t, conn, c.know, c.msgs) }) } } -func TestForgetUserSince(t *testing.T) { - type insert struct { - tag string - user userhash.Hash - time int64 - tuples []brain.Tuple +func TestForgetUser(t *testing.T) { + learn := []learn{ + { + tag: "kessoku", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: strings.Fields("kita nijika ryo bocchi"), Suffix: ""}, + {Prefix: strings.Fields("nijika ryo bocchi"), Suffix: "kita"}, + {Prefix: strings.Fields("ryo bocchi"), Suffix: "nijika"}, + {Prefix: strings.Fields("bocchi"), Suffix: "ryo"}, + {Prefix: nil, Suffix: "bocchi"}, + }, + }, + { + tag: "kessoku", + user: userhash.Hash{1}, + id: uuid.UUID{5}, + t: 6, + tups: []brain.Tuple{ + {Prefix: []string{"bocchi"}, Suffix: ""}, + {Prefix: nil, Suffix: "bocchi"}, + }, + }, + { + tag: "sickhack", + user: userhash.Hash{4}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: []string{"kikuri"}, Suffix: ""}, + {Prefix: nil, Suffix: "kikuri"}, + }, + }, } - type remain struct { - tag string - tuples []brain.Tuple + initKnow := []know{ + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", + }, + } + initMsgs := []msg{ + { + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{1}, + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{4}, + }, } cases := []struct { - name string - order int - insert []insert - user userhash.Hash - left []remain + name string + user userhash.Hash + know []know + msgs []msg }{ { - name: "single-1", - order: 1, - insert: []insert{ - { - tag: "madoka", - user: userhash.Hash{0: 1}, - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, - }, - }, - user: userhash.Hash{0: 1}, - left: nil, + name: "none", + user: userhash.Hash{100}, + know: initKnow, + msgs: initMsgs, }, { - name: "multiple-1", - order: 1, - insert: []insert{ + name: "all", + user: userhash.Hash{1}, + know: []know{ { - tag: "madoka", - user: userhash.Hash{0: 1}, - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"b"}, Suffix: "c"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + deleted: ref("CLEARCHAT"), }, - }, - user: userhash.Hash{0: 1}, - left: nil, - }, - { - name: "user-1", - order: 1, - insert: []insert{ { - tag: "madoka", - user: userhash.Hash{0: 1}, - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + deleted: ref("CLEARCHAT"), }, - }, - user: userhash.Hash{0: 2}, - left: []remain{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a"}, Suffix: "b"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + deleted: ref("CLEARCHAT"), }, - }, - }, - { - name: "single-2", - order: 2, - insert: []insert{ { - tag: "madoka", - user: userhash.Hash{0: 1}, - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + deleted: ref("CLEARCHAT"), }, - }, - user: userhash.Hash{0: 1}, - left: nil, - }, - { - name: "multiple-2", - order: 2, - insert: []insert{ { - tag: "madoka", - user: userhash.Hash{0: 1}, - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - {Prefix: []string{"b", "c"}, Suffix: "d"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + deleted: ref("CLEARCHAT"), }, - }, - user: userhash.Hash{0: 1}, - left: nil, - }, - { - name: "user-2", - order: 2, - insert: []insert{ { - tag: "madoka", - user: userhash.Hash{0: 1}, - time: 2, - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "bocchi\x00", + suffix: "", + deleted: ref("CLEARCHAT"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "bocchi", + deleted: ref("CLEARCHAT"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", }, }, - user: userhash.Hash{0: 2}, - left: []remain{ + msgs: []msg{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, - }, + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + deleted: ref("CLEARCHAT"), + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{1}, + deleted: ref("CLEARCHAT"), + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{4}, }, }, }, } for _, c := range cases { - c := c t.Run(c.name, func(t *testing.T) { t.Parallel() ctx := context.Background() - db := testDB(c.order) + db := testDB(ctx) br, err := sqlbrain.Open(ctx, db) if err != nil { t.Fatalf("couldn't open brain: %v", err) } - for _, v := range c.insert { - md := brain.MessageMeta{ - ID: uuid.New(), - User: v.user, - Tag: v.tag, - Time: time.UnixMilli(v.time), - } - err := addTuples(ctx, db, md, v.tuples) + for _, m := range learn { + err := br.Learn(ctx, m.tag, m.user, m.id, time.Unix(0, m.t), m.tups) if err != nil { - t.Fatal(err) - } - // Double-check that the tuples are in. - tups, err := tuples(ctx, db, v.tag, c.order) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(v.tuples, tups, cmpopts.EquateEmpty()); diff != "" { - t.Fatalf("wrong tuples added before test (+got/-want):\n%s", diff) + t.Errorf("failed to learn %v/%v: %v", m.tag, m.id, err) } } - if err := br.ForgetUser(ctx, &c.user); err != nil { - t.Errorf("couldn't forget: %v", err) - } - var wantTags []string - for _, left := range c.left { - wantTags = append(wantTags, left.tag) - tups, err := tuples(ctx, db, left.tag, c.order) - if err != nil { - t.Errorf("couldn't get remaining tuples for tag %s: %v", left.tag, err) - continue - } - if diff := cmp.Diff(left.tuples, tups, cmpopts.EquateEmpty()); diff != "" { - t.Errorf("wrong tuples left with tag %s (+got/-want):\n%s", left.tag, diff) - } - } - sort.Strings(wantTags) - gotTags, err := tags(ctx, t, db) + conn, err := db.Take(ctx) + defer db.Put(conn) if err != nil { - t.Errorf("couldn't get tags list: %v", err) - } - if diff := cmp.Diff(wantTags, gotTags); diff != "" { - t.Errorf("wrong tags have tuples (+got/-want):\n%s", diff) + t.Fatalf("couldn't get conn to check db state: %v", err) } + contents(t, conn, initKnow, initMsgs) if t.Failed() { - dumpdb(ctx, t, db) + t.Fatal("setup failed") } - }) - } -} - -// tuples gets all tuples in db with the given tag. The returned tuples are in -// lexicographically ascending order. -func tuples(ctx context.Context, db sqlbrain.DB, tag string, order int) ([]brain.Tuple, error) { - rows, err := db.Query(ctx, "SELECT Tuple.* FROM Tuple JOIN Message AS m ON m.id = Tuple.msg WHERE m.tag = ? AND m.deleted IS NULL", tag) - if err != nil { - panic(err) - } - defer rows.Close() - var all []brain.Tuple - r := make([]any, order+2) - for i := range r { - r[i] = &sq.NullString{} - } - for rows.Next() { - if err := rows.Scan(r...); err != nil { - panic(err) - } - t := brain.Tuple{Prefix: make([]string, order)} - for i := range t.Prefix { - s := r[i+1].(*sq.NullString) - if s.Valid { - t.Prefix[i] = s.String + if err := br.ForgetUser(ctx, &c.user); err != nil { + t.Errorf("couldn't delete from %x: %v", c.user, err) } - } - s := r[order+1].(*sq.NullString) - if s.Valid { - t.Suffix = s.String - } - all = append(all, t) - } - sort.Slice(all, func(i, j int) bool { - switch d := slices.Compare(all[i].Prefix, all[j].Prefix); d { - case -1: - return true - case 1: - return false - default: - return all[i].Suffix < all[j].Suffix - } - }) - return all, rows.Err() -} - -// tags gets all tags with any associated tuples in db in ascending order. -func tags(ctx context.Context, t *testing.T, db sqlbrain.DB) ([]string, error) { - t.Helper() - rows, err := db.Query(ctx, `SELECT DISTINCT m.tag FROM Message AS m INNER JOIN Tuple ON m.id = Tuple.msg WHERE m.deleted IS NULL ORDER BY tag`) - if err != nil { - panic(err) - } - defer rows.Close() - var tags []string - for rows.Next() { - var s string - if err := rows.Scan(&s); err != nil { - panic(err) - } - tags = append(tags, s) + contents(t, conn, c.know, c.msgs) + }) } - return tags, rows.Err() } diff --git a/brain/sqlbrain/learn.go b/brain/sqlbrain/learn.go index f979c4b..2dcb572 100644 --- a/brain/sqlbrain/learn.go +++ b/brain/sqlbrain/learn.go @@ -2,82 +2,65 @@ package sqlbrain import ( "context" - "database/sql" - _ "embed" "fmt" - "strings" + "time" + + "github.com/google/uuid" + "zombiezen.com/go/sqlite/sqlitex" "github.com/zephyrtronium/robot/brain" + "github.com/zephyrtronium/robot/userhash" ) -// Learn learns a message. -func (br *Brain) Learn(ctx context.Context, meta *brain.MessageMeta, tuples []brain.Tuple) error { - s := br.tupleInsert(tuples) - // Convert the tuples to SQL parameters. The first parameter to the SQL - // statement is the message ID, and all the rest are the tuple terms in - // sequence. Since the parameters are passed as ...any, we need to build a - // slice of all of them. Using pointers to the strings instead of the - // strings directly avoids extra allocations. - p := make([]any, 1, 1+len(tuples)*br.order) - p[0] = &meta.ID - for i := range tuples { - tuple := &tuples[i] - for i := range tuple.Prefix { - p = append(p, &tuple.Prefix[i]) - } - p = append(p, &tuple.Suffix) - } - // Now execute SQL statements. - tx, err := br.db.Begin(ctx, nil) +// Learn records a set of tuples. +func (br *Brain) Learn(ctx context.Context, tag string, user userhash.Hash, id uuid.UUID, t time.Time, tuples []brain.Tuple) (err error) { + conn, err := br.db.Take(ctx) + defer br.db.Put(conn) if err != nil { - return fmt.Errorf("couldn't open transaction: %w", err) + return fmt.Errorf("couldn't get connection to learn: %w", err) } - defer tx.Rollback() - // We must insert the message first because tuples use it as an FK. - // The INSERT returns the delete reason, which is probably but not - // certainly NULL. - var deleted sql.NullString - id := &meta.ID - user := meta.User[:] - err = tx.QueryRow(ctx, insertMessage, id, user, meta.Tag, meta.Time.UnixMilli()).Scan(&deleted) + defer sqlitex.Transaction(conn)(&err) + + st, err := conn.Prepare(`INSERT INTO knowledge(tag, id, prefix, suffix) VALUES (:tag, :id, :prefix, :suffix)`) if err != nil { - return fmt.Errorf("couldn't insert message: %w", err) + return fmt.Errorf("couldn't prepare tuple insert: %w", err) + } + p := make([]byte, 0, 256) + s := make([]byte, 0, 32) + for _, tt := range tuples { + p = prefix(p[:0], tt.Prefix) + s = append(s[:0], tt.Suffix...) + st.SetText(":tag", tag) + st.SetBytes(":id", id[:]) + st.SetBytes(":prefix", p) + st.SetBytes(":suffix", s) + _, err := st.Step() + if err != nil { + return fmt.Errorf("couldn't insert tuple: %w", err) + } + st.Reset() } - if deleted.Valid { - return fmt.Errorf("message %v was already deleted: %s", meta.ID, deleted.String) + + sm, err := conn.Prepare(`INSERT INTO messages(tag, id, time, user) VALUES (:tag, :id, :time, :user)`) + if err != nil { + return fmt.Errorf("couldn't prepare message insert: %w", err) } - // Now insert tuples. - _, err = tx.Exec(ctx, s, p...) + sm.SetText(":tag", tag) + sm.SetBytes(":id", id[:]) + sm.SetInt64(":time", t.UnixNano()) + sm.SetBytes(":user", user[:]) + _, err = sm.Step() if err != nil { - return fmt.Errorf("couldn't insert tuples: %w", err) + return fmt.Errorf("couldn't insert message: %w", err) } - if err := tx.Commit(); err != nil { - return fmt.Errorf("couldn't commit tuples: %w", err) - } return nil } -// tupleInsert formats the tuple insert message for the given tuples. -func (br *Brain) tupleInsert(tuples []brain.Tuple) string { - data := struct { - Iter []struct{} - // We don't actually have to pass the tuples to the template, we just - // need the right number of elements. - Tuples []struct{} - }{ - Iter: make([]struct{}, br.order), - Tuples: make([]struct{}, len(tuples)), +func prefix(b []byte, tup []string) []byte { + for _, w := range tup { + b = append(b, w...) + b = append(b, 0) } - var b strings.Builder - if err := br.tpl.ExecuteTemplate(&b, "tuple.insert.sql", &data); err != nil { - panic(err) - } - return b.String() + return b } - -//go:embed templates/message.insert.sql -var insertMessage string - -//go:embed templates/tuple.insert.sql -var insertTuple string diff --git a/brain/sqlbrain/learn_test.go b/brain/sqlbrain/learn_test.go index 8fe51dd..62f50c5 100644 --- a/brain/sqlbrain/learn_test.go +++ b/brain/sqlbrain/learn_test.go @@ -4,15 +4,13 @@ import ( "context" "fmt" "path/filepath" - "strconv" "strings" "testing" - "text/template" "time" - "github.com/google/go-cmp/cmp" "github.com/google/uuid" - "gitlab.com/zephyrtronium/sq" + "zombiezen.com/go/sqlite" + "zombiezen.com/go/sqlite/sqlitex" "github.com/zephyrtronium/robot/brain" "github.com/zephyrtronium/robot/brain/braintest" @@ -20,210 +18,407 @@ import ( "github.com/zephyrtronium/robot/userhash" ) -func TestLearn(t *testing.T) { - type row struct { - ID uuid.UUID - User userhash.Hash - Tag string - Ts int64 - Pn []string - Suf string +type learn struct { + tag string + user userhash.Hash + id uuid.UUID + t int64 + tups []brain.Tuple +} + +type know struct { + tag string + id uuid.UUID + prefix string + suffix string + deleted *string +} + +type msg struct { + tag string + id uuid.UUID + time int64 + user userhash.Hash + deleted *string +} + +func ref[T any](x T) *T { return &x } + +func contents(t *testing.T, conn *sqlite.Conn, know []know, msgs []msg) { + t.Helper() + for _, want := range know { + opts := sqlitex.ExecOptions{ + Named: map[string]any{ + ":tag": want.tag, + ":id": want.id[:], + ":prefix": []byte(want.prefix), + ":suffix": []byte(want.suffix), + }, + ResultFunc: func(stmt *sqlite.Stmt) error { + n := stmt.ColumnInt64(0) + if n != 1 { + t.Errorf("wrong number of rows for tag=%q id=%v prefix=%q suffix=%q: want 1, got %d", want.tag, want.id, want.prefix, want.suffix, n) + } + return nil + }, + } + if want.deleted == nil { + err := sqlitex.Execute(conn, `SELECT COUNT(*) FROM knowledge WHERE tag=:tag AND id=:id AND prefix=:prefix AND suffix=:suffix AND deleted IS NULL`, &opts) + if err != nil { + t.Errorf("couldn't check for tag=%q id=%v prefix=%q suffix=%q: %v", want.tag, want.id, want.prefix, want.suffix, err) + } + } else { + opts.Named[":deleted"] = *want.deleted + err := sqlitex.Execute(conn, `SELECT COUNT(*) FROM knowledge WHERE tag=:tag AND id=:id AND prefix=:prefix AND suffix=:suffix AND deleted=:deleted`, &opts) + if err != nil { + t.Errorf("couldn't check for tag=%q id=%v prefix=%q suffix=%q: %v", want.tag, want.id, want.prefix, want.suffix, err) + } + } + } + for _, want := range msgs { + opts := sqlitex.ExecOptions{ + Named: map[string]any{ + ":tag": want.tag, + ":id": want.id[:], + ":time": want.time, + ":user": want.user[:], + }, + ResultFunc: func(stmt *sqlite.Stmt) error { + n := stmt.ColumnInt64(0) + if n != 1 { + t.Errorf("wrong number of rows for tag=%q id=%v time=%d user=%02x: want 1, got %d", want.tag, want.id, want.time, want.user, n) + } + return nil + }, + } + if want.deleted == nil { + err := sqlitex.Execute(conn, `SELECT COUNT(*) FROM messages WHERE tag=:tag AND id=:id AND time=:time AND user=:user AND deleted IS NULL`, &opts) + if err != nil { + t.Errorf("couldn't check for tag=%q id=%v time=%d user=%02x deleted=null: %v", want.tag, want.id, want.time, want.user, err) + } + } else { + opts.Named[":deleted"] = *want.deleted + err := sqlitex.Execute(conn, `SELECT COUNT(*) FROM messages WHERE tag=:tag AND id=:id AND time=:time AND user=:user AND deleted=:deleted`, &opts) + if err != nil { + t.Errorf("couldn't check for tag=%q id=%v time=%d user=%02x deleted=%s: %v", want.tag, want.id, want.time, want.user, *want.deleted, err) + } + } } +} + +func TestLearn(t *testing.T) { cases := []struct { name string - order int - msg brain.MessageMeta - tups []brain.Tuple - want []row + learn []learn + know []know + msgs []msg }{ { - name: "2x1", - order: 2, - msg: brain.MessageMeta{ - ID: uuid.UUID([16]byte{0: 1}), - User: userhash.Hash{1: 2}, - Tag: "tag", - Time: time.UnixMilli(3), + name: "empty", + learn: nil, + know: nil, + msgs: nil, + }, + { + name: "terms", + learn: []learn{ + { + tag: "kessoku", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: strings.Fields("kita nijika ryo bocchi"), Suffix: ""}, + {Prefix: strings.Fields("nijika ryo bocchi"), Suffix: "kita"}, + {Prefix: strings.Fields("ryo bocchi"), Suffix: "nijika"}, + {Prefix: strings.Fields("bocchi"), Suffix: "ryo"}, + {Prefix: nil, Suffix: "bocchi"}, + }, + }, }, - tups: []brain.Tuple{ - {Prefix: []string{"a", "b"}, Suffix: "c"}, + know: []know{ + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, }, - want: []row{ + msgs: []msg{ { - ID: uuid.UUID([16]byte{0: 1}), - User: userhash.Hash{1: 2}, - Tag: "tag", - Ts: 3, - Pn: []string{"a", "b"}, - Suf: "c", + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, }, }, }, { - name: "2x2", - order: 2, - msg: brain.MessageMeta{ - ID: uuid.UUID([16]byte{1: 1}), - User: userhash.Hash{2: 2}, - Tag: "tag", - Time: time.UnixMilli(4), - }, - tups: []brain.Tuple{ - {Prefix: []string{"u", "v"}, Suffix: "w"}, - {Prefix: []string{"v", "w"}, Suffix: "x"}, + name: "unicode", + learn: []learn{ + { + tag: "結束", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: strings.Fields("喜多 虹夏 リョウ ぼっち"), Suffix: ""}, + {Prefix: strings.Fields("虹夏 リョウ ぼっち"), Suffix: "喜多"}, + {Prefix: strings.Fields("リョウ ぼっち"), Suffix: "虹夏"}, + {Prefix: strings.Fields("ぼっち"), Suffix: "リョウ"}, + {Prefix: nil, Suffix: "ぼっち"}, + }, + }, }, - want: []row{ + know: []know{ + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "喜多\x00虹夏\x00リョウ\x00ぼっち\x00", + suffix: "", + }, + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "虹夏\x00リョウ\x00ぼっち\x00", + suffix: "喜多", + }, + { + tag: "結束", + id: uuid.UUID{2}, + prefix: "リョウ\x00ぼっち\x00", + suffix: "虹夏", + }, { - ID: uuid.UUID([16]byte{1: 1}), - User: userhash.Hash{2: 2}, - Tag: "tag", - Ts: 4, - Pn: []string{"u", "v"}, - Suf: "w", + tag: "結束", + id: uuid.UUID{2}, + prefix: "ぼっち\x00", + suffix: "リョウ", }, { - ID: uuid.UUID([16]byte{1: 1}), - User: userhash.Hash{2: 2}, - Tag: "tag", - Ts: 4, - Pn: []string{"v", "w"}, - Suf: "x", + tag: "結束", + id: uuid.UUID{2}, + prefix: "", + suffix: "ぼっち", + }, + }, + msgs: []msg{ + { + tag: "結束", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, }, }, }, { - name: "4x1", - order: 4, - msg: brain.MessageMeta{ - ID: uuid.UUID([16]byte{2: 1}), - User: userhash.Hash{3: 2}, - Tag: "tag", - Time: time.UnixMilli(5), + name: "msgs", + learn: []learn{ + { + tag: "kessoku", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: []string{"bocchi"}, Suffix: ""}, + {Prefix: nil, Suffix: "bocchi"}, + }, + }, + { + tag: "kessoku", + user: userhash.Hash{4}, + id: uuid.UUID{5}, + t: 6, + tups: []brain.Tuple{ + {Prefix: []string{"ryo"}, Suffix: ""}, + {Prefix: nil, Suffix: "ryo"}, + }, + }, }, - tups: []brain.Tuple{ - {Prefix: []string{"a", "b", "c", "d"}, Suffix: "e"}, + know: []know{ + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "ryo\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + prefix: "", + suffix: "ryo", + }, }, - want: []row{ + msgs: []msg{ { - ID: uuid.UUID([16]byte{2: 1}), - User: userhash.Hash{3: 2}, - Tag: "tag", - Ts: 5, - Pn: []string{"a", "b", "c", "d"}, - Suf: "e", + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, + }, + { + tag: "kessoku", + id: uuid.UUID{5}, + time: 6, + user: userhash.Hash{4}, }, }, }, { - name: "4x2", - order: 4, - msg: brain.MessageMeta{ - ID: uuid.UUID([16]byte{3: 1}), - User: userhash.Hash{4: 2}, - Tag: "tag", - Time: time.UnixMilli(6), + name: "tagged", + learn: []learn{ + { + tag: "kessoku", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: []string{"bocchi"}, Suffix: ""}, + {Prefix: nil, Suffix: "bocchi"}, + }, + }, + { + tag: "sickhack", + user: userhash.Hash{1}, + id: uuid.UUID{2}, + t: 3, + tups: []brain.Tuple{ + {Prefix: []string{"kikuri"}, Suffix: ""}, + {Prefix: nil, Suffix: "kikuri"}, + }, + }, }, - tups: []brain.Tuple{ - {Prefix: []string{"u", "v", "w", "x"}, Suffix: "y"}, - {Prefix: []string{"v", "w", "x", "y"}, Suffix: "z"}, + know: []know{ + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "bocchi\x00", + suffix: "", + }, + { + tag: "kessoku", + id: uuid.UUID{2}, + prefix: "", + suffix: "bocchi", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "kikuri\x00", + suffix: "", + }, + { + tag: "sickhack", + id: uuid.UUID{2}, + prefix: "", + suffix: "kikuri", + }, }, - want: []row{ + msgs: []msg{ { - ID: uuid.UUID([16]byte{3: 1}), - User: userhash.Hash{4: 2}, - Tag: "tag", - Ts: 6, - Pn: []string{"u", "v", "w", "x"}, - Suf: "y", + tag: "kessoku", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, }, { - ID: uuid.UUID([16]byte{3: 1}), - User: userhash.Hash{4: 2}, - Tag: "tag", - Ts: 6, - Pn: []string{"v", "w", "x", "y"}, - Suf: "z", + tag: "sickhack", + id: uuid.UUID{2}, + time: 3, + user: userhash.Hash{1}, }, }, }, } for _, c := range cases { - c := c t.Run(c.name, func(t *testing.T) { t.Parallel() ctx := context.Background() - db := testDB(c.order) + db := testDB(ctx) br, err := sqlbrain.Open(ctx, db) if err != nil { t.Fatalf("couldn't open brain: %v", err) } - if err := br.Learn(ctx, &c.msg, c.tups); err != nil { - t.Errorf("couldn't learn: %v", err) + for _, m := range c.learn { + err := br.Learn(ctx, m.tag, m.user, m.id, time.Unix(0, m.t), m.tups) + if err != nil { + t.Errorf("failed to learn %v/%v: %v", m.tag, m.id, err) + } } - q := `SELECT id, user, tag, time, {{range $i, $_ := $}}p{{$i}}, {{end}}suffix FROM Message, Tuple` - var b strings.Builder - template.Must(template.New("query").Parse(q)).Execute(&b, make([]struct{}, c.order)) - rows, err := db.Query(ctx, b.String()) + conn, err := db.Take(ctx) + defer db.Put(conn) if err != nil { - t.Fatalf("couldn't select: %v", err) - } - for i := 0; rows.Next(); i++ { - got := row{Pn: make([]string, c.order)} - dst := []any{&got.ID, &got.User, &got.Tag, &got.Ts} - for i := range got.Pn { - dst = append(dst, &got.Pn[i]) - } - dst = append(dst, &got.Suf) - if err := rows.Scan(dst...); err != nil { - t.Errorf("couldn't scan: %v", err) - } - if i >= len(c.want) { - t.Errorf("too many rows: got %+v", got) - continue - } - if diff := cmp.Diff(c.want[i], got); diff != "" { - t.Errorf("got wrong row #%d: %s", i, diff) - } + t.Fatalf("couldn't get conn to check db state: %v", err) } + contents(t, conn, c.know, c.msgs) }) } } func BenchmarkLearn(b *testing.B) { dir := filepath.ToSlash(b.TempDir()) - for _, order := range []int{2, 4, 6, 16} { - b.Run(strconv.Itoa(order), func(b *testing.B) { - new := func(ctx context.Context, b *testing.B) brain.Learner { - dsn := fmt.Sprintf("file:%s/benchmark_learn_%d.db?_journal=WAL&_mutex=full", dir, order) - db, err := sq.Open("sqlite3", dsn) - if err != nil { - b.Fatal(err) - } - // The benchmark function will run this multiple times to - // estimate iteration count, so we need to drop tables and - // views if they exist. - stmts := []string{ - `DROP VIEW IF EXISTS MessageTuple`, - `DROP TABLE IF EXISTS Tuple`, - `DROP TABLE IF EXISTS Message`, - `DROP TABLE IF EXISTS Config`, - } - for _, s := range stmts { - _, err := db.Exec(context.Background(), s) - if err != nil { - b.Fatal(err) - } - } - if err := sqlbrain.Create(context.Background(), db, order); err != nil { - b.Fatal(err) - } - br, err := sqlbrain.Open(ctx, db) - if err != nil { - b.Fatalf("couldn't open brain: %v", err) - } - return br + new := func(ctx context.Context, b *testing.B) brain.Learner { + dsn := fmt.Sprintf("file:%s/benchmark_learn.db?_journal=WAL", dir) + db, err := sqlitex.NewPool(dsn, sqlitex.PoolOptions{Flags: sqlite.OpenCreate | sqlite.OpenReadWrite | sqlite.OpenURI | sqlite.OpenWAL}) + if err != nil { + b.Fatal(err) + } + { + conn, err := db.Take(ctx) + if err != nil { + b.Fatal(err) } - braintest.BenchLearn(context.Background(), b, new, func(l brain.Learner) { l.(*sqlbrain.Brain).Close() }) - }) + // The benchmark function will run this multiple times to estimate + // iteration count, so we need to drop tables if they exist. + if err := sqlitex.ExecuteScript(conn, `DROP TABLE IF EXISTS knowledge; DROP TABLE IF EXISTS messages;`, nil); err != nil { + b.Fatal(err) + } + if err := sqlbrain.Create(ctx, conn); err != nil { + b.Fatal(err) + } + db.Put(conn) + } + br, err := sqlbrain.Open(ctx, db) + if err != nil { + b.Fatalf("couldn't open brain: %v", err) + } + return br } + braintest.BenchLearn(context.Background(), b, new, func(l brain.Learner) { l.(*sqlbrain.Brain).Close() }) } diff --git a/brain/sqlbrain/schema.sql b/brain/sqlbrain/schema.sql new file mode 100644 index 0000000..b53396b --- /dev/null +++ b/brain/sqlbrain/schema.sql @@ -0,0 +1,51 @@ +PRAGMA journal_mode = WAL; + +CREATE TABLE knowledge ( + -- Tag or tenant for the entry. + tag TEXT NOT NULL, + -- Message ID, particularly UUID. + id BLOB NOT NULL, + -- Prefix stored with entropy-reduced tokens in reverse order, + -- with each token terminated by \x00 in the string. + prefix BLOB NOT NULL, + -- Full-entropy suffix. + suffix BLOB NOT NULL, + -- Reason for delete, if any. + -- Values may include: + -- 'FORGET', for tuples deleted by content; + -- 'CLEARMSG', for messages deleted by ID; + -- 'CLEARCHAT', for messages deleted by userhash; + -- 'TIME', for messages deleted in a time range; + -- or NULL, for tuples which have not been deleted. + -- These values are only for analytics; any non-null value indicates the + -- tuple should be treated as deleted. + deleted TEXT +) STRICT; + +CREATE TABLE messages ( + -- Tag or tenant for the message. + tag TEXT NOT NULL, + -- Message ID, particularly UUID. + id BLOB NOT NULL, + -- Message timestamp as nanoseconds from the UNIX epoch. + -- May be null for messages imported from other sources or for messages + -- deleted before being fully learned. + time INTEGER, + -- Sender userhash. + -- May be null for messages imported from other sources or for messages + -- deleted before being fully learned. + user BLOB, + -- Reason for delete, if any. + -- Same meaning as in knowledge, except that the value 'FORGET' will never + -- appear (since that is specifically for operating on tuples). + -- Denormalized here to allow soft deletes of messages before they are + -- actually learned. + deleted TEXT, + + PRIMARY KEY(tag, id) +) STRICT; + +CREATE INDEX ids ON knowledge (tag, id); +CREATE INDEX prefixes ON knowledge (tag, prefix); +CREATE INDEX times ON messages (tag, time); +CREATE INDEX users ON messages (user); diff --git a/brain/sqlbrain/speak.go b/brain/sqlbrain/speak.go index 3e3ec0a..394e485 100644 --- a/brain/sqlbrain/speak.go +++ b/brain/sqlbrain/speak.go @@ -2,91 +2,147 @@ package sqlbrain import ( "context" - "database/sql" - _ "embed" "fmt" "math/rand/v2" - "strconv" - "gitlab.com/zephyrtronium/sq" + "zombiezen.com/go/sqlite" "github.com/zephyrtronium/robot/brain" + "github.com/zephyrtronium/robot/prepend" + "github.com/zephyrtronium/robot/tpool" ) -func gumbelscan(rows *sq.Rows) (string, error) { - var s string +var prependerPool tpool.Pool[*prepend.List[string]] + +// Speak generates a full message and appends it to w. +// The prompt is in reverse order and has entropy reduction applied. +func (br *Brain) Speak(ctx context.Context, tag string, prompt []string, w []byte) ([]byte, error) { + search := prependerPool.Get().Set(prompt...) + defer func() { prependerPool.Put(search) }() + + conn, err := br.db.Take(ctx) + defer br.db.Put(conn) + if err != nil { + return w, fmt.Errorf("couldn't get connection to speak: %w", err) + } + + b := make([]byte, 0, 128) + for range 1024 { + var err error + var l int + b, l, err = next(conn, tag, b, search.Slice()) + if err != nil { + return nil, err + } + if len(b) == 0 { + break + } + w = append(w, b...) + w = append(w, ' ') + search = search.Drop(search.Len() - l - 1).Prepend(brain.ReduceEntropy(string(b))) + } + return w, nil +} + +func next(conn *sqlite.Conn, tag string, b []byte, prompt []string) ([]byte, int, error) { + if len(prompt) == 0 { + var err error + b, err = first(conn, tag, b) + return b, 0, err + } + st, err := conn.Prepare(`SELECT suffix FROM knowledge WHERE tag = :tag AND prefix >= :lower AND prefix < :upper AND LIKELY(deleted IS NULL)`) + if err != nil { + return b[:0], len(prompt), fmt.Errorf("couldn't prepare term selection: %w", err) + } + st.SetText(":tag", tag) + w := make([]byte, 0, 32) + var d []byte var m uint64 - for rows.Next() { - u := rand.Uint64() - if m <= u { - err := rows.Scan(&s) + picked := 0 + for { + b = prefix(b[:0], prompt) + b, d = searchbounds(b) + st.SetBytes(":lower", b) + st.SetBytes(":upper", d) + for { + ok, err := st.Step() if err != nil { - return "", fmt.Errorf("couldn't scan string for sample: %w", err) + return b[:0], len(prompt), fmt.Errorf("couldn't step term selection: %w", err) + } + if !ok { + break + } + // We generate a uniform variate per row, then choose the row that + // gets the maximum variate. + // TODO(zeph): gumbel distribution + u := rand.Uint64() + if m > u { + continue } + picked++ m = u + n := st.ColumnLen(0) + if cap(w) < n { + w = make([]byte, n) + } + w = w[:st.ColumnBytes(0, w[:n])] } + if picked < 3 && len(prompt) > 1 { + // We haven't seen enough options, and we have context we could + // lose. Do so and try again from the beginning. + prompt = prompt[:len(prompt)-1] + if err := st.Reset(); err != nil { + return b[:0], len(prompt), fmt.Errorf("couldn't reset term selection: %w", err) + } + continue + } + // Note that this also handles the case where there were no results. + b = append(b[:0], w...) + return b, len(prompt), nil } - if rows.Err() != nil { - return "", fmt.Errorf("couldn't get sample: %w", rows.Err()) - } - return s, nil } -// New creates a new prompt. -func (br *Brain) New(ctx context.Context, tag string) ([]string, error) { - rows, err := br.stmts.newTuple.Query(ctx, tag) - if err != nil { - return nil, fmt.Errorf("couldn't run query for new chain: %w", err) +// searchbounds produces the lower and upper bounds for a search by prefix. +// The upper bound is always a slice of the lower bound's underlying array. +func searchbounds(prefix []byte) (lower, upper []byte) { + lower = append(prefix, prefix...) + lower, upper = lower[:len(prefix)], lower[len(prefix):] + if len(upper) != 0 { + // The prefix is a list of terms each followed by a 0 byte. + // So, the supremum of all strings with that prefix is the same with + // the last byte replaced by 1. + upper[len(upper)-1] = 1 } - s, err := gumbelscan(rows) - if err != nil { - return nil, fmt.Errorf("couldn't get new chain: %w", err) - } - r := make([]string, br.order) - r[br.order-1] = s - return r, nil + return lower, upper } -// Speak creates a message from a prompt. -func (br *Brain) Speak(ctx context.Context, tag string, prompt []string) ([]string, error) { - names := make([]sq.NamedArg, 1+len(prompt)) - names[0] = sql.Named("tag", tag) - terms := make([]string, len(prompt)) - nn := 0 - for i, w := range prompt { - nn += len(w) + 1 - terms[i] = brain.ReduceEntropy(w) - names[i+1] = sql.Named("p"+strconv.Itoa(i), &terms[i]) - } - p := make([]any, len(names)) - for i := range names { - p[i] = names[i] +func first(conn *sqlite.Conn, tag string, b []byte) ([]byte, error) { + b = b[:0] // in case we get no rows + s, err := conn.Prepare(`SELECT suffix FROM knowledge WHERE tag = :tag AND prefix = x'' AND LIKELY(deleted IS NULL)`) + if err != nil { + return b, fmt.Errorf("couldn't prepare first term selection: %w", err) } - for nn < 500 { - rows, err := br.stmts.selectTuple.Query(ctx, p...) - if err != nil { - return nil, fmt.Errorf("couldn't run query to continue chain with terms %v: %w", terms, err) - } - w, err := gumbelscan(rows) + s.SetText(":tag", tag) + var m uint64 + for { + ok, err := s.Step() if err != nil { - return nil, fmt.Errorf("couldn't continue chain with terms %v: %w", terms, err) + return b[:0], fmt.Errorf("couldn't step first term selection: %w", err) } - if w == "" { + if !ok { break } - nn += len(w) + 1 - prompt = append(prompt, w) - // Note that each p[i] is a named arg, and each name for prefix - // elements aliases an element of terms. So, just updating terms is - // sufficient to update the query parameters. - copy(terms, terms[1:]) - terms[len(terms)-1] = w + // TODO(zeph): gumbel distribution + u := rand.Uint64() + if m > u { + continue + } + m = u + n := s.ColumnLen(0) + if cap(b) < n { + b = make([]byte, n) + } + b = b[:s.ColumnBytes(0, b[:n])] } - return prompt, nil + return b, nil } - -//go:embed templates/tuple.new.sql -var newTuple string - -//go:embed templates/tuple.select.sql -var selectTuple string diff --git a/brain/sqlbrain/speak_test.go b/brain/sqlbrain/speak_test.go index 9aa6027..7838344 100644 --- a/brain/sqlbrain/speak_test.go +++ b/brain/sqlbrain/speak_test.go @@ -2,531 +2,482 @@ package sqlbrain_test import ( "context" - "fmt" - "path/filepath" "slices" - "strconv" - "strings" "testing" - "github.com/google/go-cmp/cmp" - "github.com/google/uuid" - "github.com/zephyrtronium/robot/brain" - "github.com/zephyrtronium/robot/brain/braintest" + "zombiezen.com/go/sqlite" + "zombiezen.com/go/sqlite/sqlitex" + "github.com/zephyrtronium/robot/brain/sqlbrain" - "gitlab.com/zephyrtronium/sq" ) -func TestNew(t *testing.T) { - type insert struct { - tag string - tuples []brain.Tuple - } +func TestSpeak(t *testing.T) { cases := []struct { name string - order int - insert []insert + know []know tag string - want [][]string + prompt []string + w []byte + want []string }{ { - name: "include-1", - order: 1, - insert: []insert{ - { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{""}, Suffix: "a"}, - {Prefix: []string{""}, Suffix: "b"}, - }, - }, - }, - tag: "madoka", - want: [][]string{ - {"a"}, - {"b"}, - }, + name: "empty", + know: nil, + tag: "kessoku", + prompt: nil, + w: nil, + // We should only ever get nil from the brain, + // but that converts to the empty string. + want: []string{""}, }, { - name: "start-1", - order: 1, - insert: []insert{ + name: "empty-tagged", + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{""}, Suffix: "a"}, - {Prefix: []string{""}, Suffix: "b"}, - {Prefix: []string{"b"}, Suffix: "c"}, - }, + tag: "kessoku", + prefix: "", + suffix: "bocchi", + }, + { + tag: "kessoku", + prefix: "bocchi\x00", + suffix: "", }, }, - tag: "madoka", - want: [][]string{ - {"a"}, - {"b"}, - }, + tag: "sickhack", + prompt: nil, + w: nil, + // We should only ever get nil from the brain, + // but that converts to the empty string. + want: []string{""}, }, { - name: "tagged-1", - order: 1, - insert: []insert{ + name: "empty-prompted", + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{""}, Suffix: "a"}, - {Prefix: []string{""}, Suffix: "b"}, - }, + tag: "kessoku", + prefix: "", + suffix: "bocchi", }, { - tag: "homura", - tuples: []brain.Tuple{ - {Prefix: []string{""}, Suffix: "x"}, - {Prefix: []string{""}, Suffix: "y"}, - }, + tag: "kessoku", + prefix: "bocchi\x00", + suffix: "", }, }, - tag: "madoka", - want: [][]string{ - {"a"}, - {"b"}, - }, + tag: "kessoku", + prompt: []string{"kikuri"}, + w: nil, + // We should only ever get nil from the brain, + // but that converts to the empty string. + want: []string{""}, }, { - name: "include-2", - order: 2, - insert: []insert{ + name: "single", + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"", ""}, Suffix: "a"}, - {Prefix: []string{"", ""}, Suffix: "b"}, - }, + tag: "kessoku", + prefix: "", + suffix: "bocchi", }, - }, - tag: "madoka", - want: [][]string{ - {"", "a"}, - {"", "b"}, - }, - }, - { - name: "start-2", - order: 2, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"", ""}, Suffix: "a"}, - {Prefix: []string{"", ""}, Suffix: "b"}, - {Prefix: []string{"", "b"}, Suffix: "c"}, - }, + tag: "kessoku", + prefix: "bocchi\x00", + suffix: "", }, }, - tag: "madoka", - want: [][]string{ - {"", "a"}, - {"", "b"}, - }, + tag: "kessoku", + prompt: nil, + w: nil, + want: []string{"bocchi "}, }, { - name: "tagged-2", - order: 2, - insert: []insert{ + name: "several", + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"", ""}, Suffix: "a"}, - {Prefix: []string{"", ""}, Suffix: "b"}, - }, + tag: "kessoku", + prefix: "", + suffix: "bocchi", }, { - tag: "homura", - tuples: []brain.Tuple{ - {Prefix: []string{"", ""}, Suffix: "x"}, - {Prefix: []string{"", ""}, Suffix: "y"}, - }, + tag: "kessoku", + prefix: "bocchi\x00", + suffix: "ryo", }, - }, - tag: "madoka", - want: [][]string{ - {"", "a"}, - {"", "b"}, - }, - }, - } - for _, c := range cases { - c := c - t.Run(c.name, func(t *testing.T) { - t.Parallel() - ctx := context.Background() - db := testDB(c.order) - br, err := sqlbrain.Open(ctx, db) - if err != nil { - t.Fatalf("couldn't open brain: %v", err) - } - for _, v := range c.insert { - err := addTuples(ctx, db, tagged(v.tag), v.tuples) - if err != nil { - t.Fatal(err) - } - } - var got [][]string - for i := 0; i < 100; i++ { - p, err := br.New(ctx, c.tag) - if err != nil { - t.Errorf("err from new: %v", err) - } - got = lexset(got, p) - } - if diff := cmp.Diff(c.want, got); diff != "" { - t.Errorf("wrong prompts: (-want/+got)\n%s", diff) - } - if t.Failed() { - dumpdb(ctx, t, db) - } - }) - } -} - -func TestSpeak(t *testing.T) { - type insert struct { - tag string - tuples []brain.Tuple - } - cases := []struct { - name string - order int - insert []insert - tag string - prompt []string - want [][]string - }{ - { - name: "include-1", - order: 1, - insert: []insert{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{""}, Suffix: "a"}, - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"b"}, Suffix: ""}, - }, + tag: "kessoku", + prefix: "ryo\x00bocchi\x00", + suffix: "nijika", }, - }, - tag: "madoka", - prompt: []string{""}, - want: [][]string{ - {"a", "b"}, - }, - }, - { - name: "branch-1", - order: 1, - insert: []insert{ - { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{""}, Suffix: "a"}, - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"a"}, Suffix: "c"}, - {Prefix: []string{"b"}, Suffix: ""}, - {Prefix: []string{"c"}, Suffix: ""}, - }, + { + tag: "kessoku", + prefix: "nijika\x00ryo\x00bocchi\x00", + suffix: "kita", + }, + { + tag: "kessoku", + prefix: "kita\x00nijika\x00ryo\x00bocchi\x00", + suffix: "", }, }, - tag: "madoka", - prompt: []string{""}, - want: [][]string{ - {"a", "b"}, - {"a", "c"}, - }, + tag: "kessoku", + prompt: nil, + w: nil, + want: []string{"bocchi ryo nijika kita "}, }, { - name: "tagged-1", - order: 1, - insert: []insert{ + name: "append", + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{""}, Suffix: "a"}, - {Prefix: []string{"a"}, Suffix: "b"}, - {Prefix: []string{"b"}, Suffix: ""}, - }, + tag: "kessoku", + prefix: "", + suffix: "bocchi", }, { - tag: "homura", - tuples: []brain.Tuple{ - {Prefix: []string{""}, Suffix: "a"}, - {Prefix: []string{"a"}, Suffix: "c"}, - {Prefix: []string{"c"}, Suffix: ""}, - }, + tag: "kessoku", + prefix: "bocchi\x00", + suffix: "", }, }, - tag: "madoka", - prompt: []string{""}, - want: [][]string{ - {"a", "b"}, - }, + tag: "kessoku", + prompt: nil, + w: []byte("member "), + want: []string{"member bocchi "}, }, { - name: "include-2", - order: 2, - insert: []insert{ + name: "multi", + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"", ""}, Suffix: "a"}, - {Prefix: []string{"", "a"}, Suffix: "b"}, - {Prefix: []string{"a", "b"}, Suffix: ""}, - }, + tag: "kessoku", + prefix: "", + suffix: "member", }, - }, - tag: "madoka", - prompt: []string{"", ""}, - want: [][]string{ - {"a", "b"}, - }, - }, - { - name: "branch-2", - order: 2, - insert: []insert{ - { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"", ""}, Suffix: "a"}, - {Prefix: []string{"", "a"}, Suffix: "b"}, - {Prefix: []string{"", "a"}, Suffix: "c"}, - {Prefix: []string{"a", "b"}, Suffix: ""}, - {Prefix: []string{"a", "c"}, Suffix: ""}, - }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "bocchi", + }, + { + tag: "kessoku", + prefix: "bocchi\x00member\x00", + suffix: "", + }, + { + tag: "kessoku", + prefix: "", + suffix: "member", + }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + prefix: "ryo\x00member\x00", + suffix: "", + }, + { + tag: "kessoku", + prefix: "", + suffix: "member", + }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "nijika", + }, + { + tag: "kessoku", + prefix: "nijika\x00member\x00", + suffix: "", + }, + { + tag: "kessoku", + prefix: "", + suffix: "member", + }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "kita", + }, + { + tag: "kessoku", + prefix: "kita\x00member\x00", + suffix: "", }, }, - tag: "madoka", - prompt: []string{"", ""}, - want: [][]string{ - {"a", "b"}, - {"a", "c"}, - }, + tag: "kessoku", + prompt: nil, + w: nil, + want: []string{"member bocchi ", "member ryo ", "member nijika ", "member kita "}, }, { - name: "tagged-2", - order: 2, - insert: []insert{ + name: "multi-tagged", + know: []know{ { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"", ""}, Suffix: "a"}, - {Prefix: []string{"", "a"}, Suffix: "b"}, - {Prefix: []string{"a", "b"}, Suffix: ""}, - }, + tag: "kessoku", + prefix: "", + suffix: "member", }, { - tag: "homura", - tuples: []brain.Tuple{ - {Prefix: []string{"", ""}, Suffix: "a"}, - {Prefix: []string{"", "a"}, Suffix: "c"}, - {Prefix: []string{"a", "c"}, Suffix: ""}, - }, + tag: "kessoku", + prefix: "member\x00", + suffix: "bocchi", + }, + { + tag: "kessoku", + prefix: "bocchi\x00member\x00", + suffix: "", + }, + { + tag: "kessoku", + prefix: "", + suffix: "member", + }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + prefix: "ryo\x00member\x00", + suffix: "", + }, + { + tag: "kessoku", + prefix: "", + suffix: "member", + }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "nijika", + }, + { + tag: "kessoku", + prefix: "nijika\x00member\x00", + suffix: "", + }, + { + tag: "kessoku", + prefix: "", + suffix: "member", + }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "kita", + }, + { + tag: "kessoku", + prefix: "kita\x00member\x00", + suffix: "", + }, + { + tag: "sickhack", + prefix: "", + suffix: "member", + }, + { + tag: "sickhack", + prefix: "member\x00", + suffix: "kikuri", + }, + { + tag: "sickhack", + prefix: "kikuri\x00member\x00", + suffix: "", + }, + { + tag: "sickhack", + prefix: "", + suffix: "member", + }, + { + tag: "sickhack", + prefix: "member\x00", + suffix: "eliza", + }, + { + tag: "sickhack", + prefix: "eliza\x00member\x00", + suffix: "", + }, + { + tag: "sickhack", + prefix: "", + suffix: "member", + }, + { + tag: "sickhack", + prefix: "member\x00", + suffix: "shima", + }, + { + tag: "sickhack", + prefix: "shima\x00member\x00", + suffix: "", }, }, - tag: "madoka", - prompt: []string{"", ""}, - want: [][]string{ - {"a", "b"}, - }, + tag: "sickhack", + prompt: nil, + w: nil, + want: []string{"member kikuri ", "member eliza ", "member shima "}, }, { - name: "long", - order: 4, - insert: []insert{ - { - tag: "madoka", - tuples: []brain.Tuple{ - {Prefix: []string{"", "", "", ""}, Suffix: "a"}, - {Prefix: []string{"", "", "", "a"}, Suffix: "b"}, - {Prefix: []string{"", "", "a", "b"}, Suffix: "c"}, - {Prefix: []string{"", "a", "b", "c"}, Suffix: "d"}, - {Prefix: []string{"a", "b", "c", "d"}, Suffix: "e"}, - {Prefix: []string{"b", "c", "d", "e"}, Suffix: "f"}, - {Prefix: []string{"c", "d", "e", "f"}, Suffix: ""}, - }, + name: "forgort", + know: []know{ + { + tag: "kessoku", + prefix: "", + suffix: "member", + deleted: ref("FORGET"), + }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "bocchi", + }, + { + tag: "kessoku", + prefix: "bocchi\x00member\x00", + suffix: "", + }, + { + tag: "kessoku", + prefix: "", + suffix: "member", + }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "ryo", + }, + { + tag: "kessoku", + prefix: "ryo\x00member\x00", + suffix: "", + }, + { + tag: "kessoku", + prefix: "", + suffix: "member", + }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "nijika", + }, + { + tag: "kessoku", + prefix: "nijika\x00member\x00", + suffix: "", + }, + { + tag: "kessoku", + prefix: "", + suffix: "member", + }, + { + tag: "kessoku", + prefix: "member\x00", + suffix: "kita", + deleted: ref("FORGET"), + }, + { + tag: "kessoku", + prefix: "kita\x00member\x00", + suffix: "", + deleted: ref("FORGET"), }, }, - tag: "madoka", - prompt: []string{"", "", "", ""}, - want: [][]string{ - {"a", "b", "c", "d", "e", "f"}, - }, + tag: "kessoku", + prompt: nil, + w: nil, + want: []string{"member bocchi ", "member ryo ", "member nijika "}, }, } for _, c := range cases { - c := c t.Run(c.name, func(t *testing.T) { t.Parallel() ctx := context.Background() - db := testDB(c.order) + db := testDB(ctx) br, err := sqlbrain.Open(ctx, db) if err != nil { t.Fatalf("couldn't open brain: %v", err) } - for _, v := range c.insert { - err := addTuples(ctx, db, tagged(v.tag), v.tuples) - if err != nil { - t.Fatal(err) - } + conn, err := db.Take(ctx) + defer db.Put(conn) + if err != nil { + t.Fatalf("couldn't get conn: %v", err) } - var got [][]string - for i := 0; i < 100; i++ { - msg, err := br.Speak(ctx, c.tag, c.prompt) + insert(t, conn, c.know, nil) + slices.Sort(c.want) + got := make([]string, 0, len(c.want)) + for range 10000 { + w, err := br.Speak(ctx, c.tag, c.prompt, slices.Clone(c.w)) if err != nil { - t.Errorf("err from speak: %v", err) + t.Errorf("couldn't speak: %v", err) + } + s := string(w) + k, ok := slices.BinarySearch(got, s) + if !ok { + got = slices.Insert(got, k, s) + if len(got) == len(c.want) { + break + } } - got = lexset(got, trim(msg)) - } - if diff := cmp.Diff(c.want, got); diff != "" { - t.Errorf("wrong prompts: (-want/+got)\n%s", diff) } - if t.Failed() { - dumpdb(ctx, t, db) + if !slices.Equal(c.want, got) { + t.Errorf("wrong results:\nwant %q\ngot %q", c.want, got) } }) } } -// addTuples inserts tuples into a test db. -func addTuples(ctx context.Context, db sqlbrain.DB, msg brain.MessageMeta, tuples []brain.Tuple) error { - order := len(tuples[0].Prefix) - tx, err := db.Begin(ctx, nil) - if err != nil { - panic(err) - } - defer tx.Rollback() - _, err = tx.Exec(ctx, "INSERT INTO Message(id, user, tag, time) VALUES (?, ?, ?, ?)", msg.ID, msg.User[:], msg.Tag, msg.Time.UnixMilli()) - if err != nil { - return fmt.Errorf("couldn't add message: %v", err) - } - var b strings.Builder - for _, tup := range tuples { - b.Reset() - b.WriteString("INSERT INTO Tuple(msg") - for i := 0; i < order; i++ { - fmt.Fprintf(&b, ", p%d", i) +func insert(t *testing.T, conn *sqlite.Conn, know []know, msgs []msg) { + t.Helper() + for _, v := range know { + opts := sqlitex.ExecOptions{ + Named: map[string]any{ + ":tag": v.tag, + ":id": v.id[:], + ":prefix": []byte(v.prefix), + ":suffix": []byte(v.suffix), + }, } - b.WriteString(", suffix) VALUES (?, ?") - b.WriteString(strings.Repeat(", ?", order)) - b.WriteByte(')') - args := []any{msg.ID} - for _, w := range tup.Prefix { - args = append(args, w) + var err error + if v.deleted != nil { + opts.Named[":deleted"] = *v.deleted + err = sqlitex.Execute(conn, `INSERT INTO knowledge(tag, id, prefix, suffix, deleted) VALUES (:tag, :id, :prefix, :suffix, :deleted)`, &opts) + } else { + err = sqlitex.Execute(conn, `INSERT INTO knowledge(tag, id, prefix, suffix) VALUES (:tag, :id, :prefix, :suffix)`, &opts) } - args = append(args, tup.Suffix) - _, err := tx.Exec(ctx, b.String(), args...) if err != nil { - return fmt.Errorf("couldn't add tuples %q with query %q: %w", tuples, b.String(), err) + t.Errorf("couldn't learn knowledge %v %q %q: %v", v.id, v.prefix, v.suffix, err) } } - return tx.Commit() -} - -// tagged is a shortcut to create message metadata holding only a given tag -// and a random UUID. -func tagged(tag string) brain.MessageMeta { - return brain.MessageMeta{ - ID: uuid.New(), - Tag: tag, - } -} - -// lexset adds a []string to a [][]string such that the latter remains in -// sorted order without duplicates. -func lexset(dst [][]string, n []string) [][]string { - k, ok := slices.BinarySearchFunc(dst, n, slices.Compare[[]string]) - if ok { - return dst - } - return slices.Insert(dst, k, n) -} - -func trim(r []string) []string { - for k := len(r) - 1; k >= 0; k-- { - if r[k] != "" { - r = r[:k+1] - break - } - } - for k, v := range r { - if v != "" { - return r[k:] + for _, v := range msgs { + opts := sqlitex.ExecOptions{ + Named: map[string]any{ + ":tag": v.tag, + ":id": v.id[:], + ":time": v.time, + ":user": v.user[:], + }, } - } - return nil -} - -func dumpdb(ctx context.Context, t *testing.T, db sqlbrain.DB) { - t.Helper() - t.Log("db content:") - rows, err := db.Query(ctx, "SELECT m.user, m.tag, m.time, m.deleted, Tuple.* FROM Message AS m JOIN Tuple ON m.id = Tuple.msg") - if err != nil { - panic(err) - } - defer rows.Close() - cols, err := rows.Columns() - if err != nil { - panic(err) - } - t.Log(cols) - for rows.Next() { - r := make([]any, len(cols)) - for i := range r { - r[i] = &r[i] + var err error + if v.deleted != nil { + opts.Named[":deleted"] = *v.deleted + err = sqlitex.Execute(conn, `INSERT INTO message(tag, id, time, user, deleted) VALUES (:tag, :id, time, :user, :deleted)`, &opts) + } else { + err = sqlitex.Execute(conn, `INSERT INTO message(tag, id, time, user) VALUES (:tag, :id, time, :user)`, &opts) } - if err := rows.Scan(r...); err != nil { - panic(err) + if err != nil { + t.Errorf("couldn't learn message %v: %v", v.id, err) } - t.Logf("%q", r) - } - if rows.Err() != nil { - t.Log(rows.Err()) - } -} - -func BenchmarkSpeak(b *testing.B) { - dir := filepath.ToSlash(b.TempDir()) - for _, order := range []int{2, 4, 6} { - b.Run(strconv.Itoa(order), func(b *testing.B) { - new := func(ctx context.Context, b *testing.B) braintest.Interface { - dsn := fmt.Sprintf("file:%s/benchmark_learn_%d.db?_journal=WAL&_mutex=full", dir, order) - db, err := sq.Open("sqlite3", dsn) - if err != nil { - b.Fatal(err) - } - // The benchmark function will run this multiple times to - // estimate iteration count, so we need to drop tables and - // views if they exist. - stmts := []string{ - `DROP VIEW IF EXISTS MessageTuple`, - `DROP TABLE IF EXISTS Tuple`, - `DROP TABLE IF EXISTS Message`, - `DROP TABLE IF EXISTS Config`, - } - for _, s := range stmts { - _, err := db.Exec(context.Background(), s) - if err != nil { - b.Fatal(err) - } - } - if err := sqlbrain.Create(context.Background(), db, order); err != nil { - b.Fatal(err) - } - br, err := sqlbrain.Open(ctx, db) - if err != nil { - b.Fatalf("couldn't open brain: %v", err) - } - return br - } - braintest.BenchSpeak(context.Background(), b, new, func(br braintest.Interface) { br.(*sqlbrain.Brain).Close() }) - }) } } diff --git a/brain/sqlbrain/templates/config.create.sql b/brain/sqlbrain/templates/config.create.sql deleted file mode 100644 index ec1b153..0000000 --- a/brain/sqlbrain/templates/config.create.sql +++ /dev/null @@ -1,9 +0,0 @@ --- Config holds global config for the Markov chain data. -CREATE TABLE Config ( - option TEXT NOT NULL, - value ANY -) STRICT; - -INSERT INTO Config(option, value) VALUES - ('schema-version', {{$.Version}}), - ('order', {{$.N}}); diff --git a/brain/sqlbrain/templates/message.create.sql b/brain/sqlbrain/templates/message.create.sql deleted file mode 100644 index 950a352..0000000 --- a/brain/sqlbrain/templates/message.create.sql +++ /dev/null @@ -1,40 +0,0 @@ --- Define both Message and Tuple tables. Tuple depends on Message for an FK, --- and putting the CREATE TABLEs for both in one file ensures they serialize. - --- Message holds message metadata. -CREATE TABLE Message ( - id TEXT PRIMARY KEY, -- Message UUID. - user BLOB NOT NULL, -- Obfuscated user hash. - tag TEXT, -- Tag used to learn the message. - time INTEGER, -- Message send timestamp. Can be null for migrated data. - deleted TEXT -- Message delete reason. Null indicates not deleted. -) STRICT; - -CREATE UNIQUE INDEX IdxMessageIDs ON Message(id); -CREATE INDEX IdxMessageTags ON Message(tag); - --- Tuple holds actual Markov chain tuples. -CREATE TABLE Tuple ( - msg TEXT REFERENCES Message(id), - {{- range $i, $_ := $.Iter }} - p{{$i}} TEXT, - {{- end }} - suffix TEXT -) STRICT; - -CREATE INDEX IdxTupleMsg ON Tuple(msg); -CREATE INDEX IdxTuplePN ON Tuple(p{{ $.NM1 }}); - --- MessageTuple contains exactly those tuples which should be considered for --- generating messages. -CREATE VIEW MessageTuple AS - SELECT - Message.tag AS tag, - {{- range $i, $_ := $.Iter }} - Tuple.p{{$i}} AS p{{$i}}, - {{- end }} - Tuple.suffix AS suffix - FROM Message - INNER JOIN Tuple ON Message.id = Tuple.msg - WHERE - LIKELY(Message.deleted IS NULL); diff --git a/brain/sqlbrain/templates/message.delete.sql b/brain/sqlbrain/templates/message.delete.sql deleted file mode 100644 index a9e4f18..0000000 --- a/brain/sqlbrain/templates/message.delete.sql +++ /dev/null @@ -1,7 +0,0 @@ --- "Delete" a message. --- We delete a message by recording a reason why the message is deleted. --- Nominally this would be an UPDATE, but we don't necessarily guarantee that --- the INSERT to record the message happens-before that update. -INSERT INTO Message(id, user, deleted) VALUES (?1, ?2, ?3) -ON CONFLICT DO UPDATE -SET deleted=excluded.deleted; diff --git a/brain/sqlbrain/templates/message.insert.sql b/brain/sqlbrain/templates/message.insert.sql deleted file mode 100644 index a55ce7d..0000000 --- a/brain/sqlbrain/templates/message.insert.sql +++ /dev/null @@ -1,12 +0,0 @@ --- Insert a new message. --- Inserting messages happens-before inserting tuple data. Two goroutines can --- concurrently insert a message if it is deleted immediately upon receive. So, --- we use upserts (both here and in deletion) to ensure that a delete is always --- recorded. Here, if another goroutine inserts a record to delete this --- message, we still update the fields the "delete" won't set. -INSERT INTO Message(id, user, tag, time) VALUES (?, ?, ?, ?) -ON CONFLICT DO UPDATE -SET - tag=excluded.tag, - time=excluded.time -RETURNING deleted; diff --git a/brain/sqlbrain/templates/sqlite.pragma.sql b/brain/sqlbrain/templates/sqlite.pragma.sql deleted file mode 100644 index a524395..0000000 --- a/brain/sqlbrain/templates/sqlite.pragma.sql +++ /dev/null @@ -1,2 +0,0 @@ -PRAGMA journal_mode = WAL; -PRAGMA foreign_keys = ON; diff --git a/brain/sqlbrain/templates/tuple.delete.sql b/brain/sqlbrain/templates/tuple.delete.sql deleted file mode 100644 index b4b1f86..0000000 --- a/brain/sqlbrain/templates/tuple.delete.sql +++ /dev/null @@ -1,74 +0,0 @@ -{{- /* We want to delete exactly one tuple that matches the input. */ -}} -{{- /* Unfortunately, SQLite3's DELETE ... LIMIT 1 is behind a compile-time */ -}} -{{- /* option which is disabled by default. The best solution I've come up */ -}} -{{- /* with is to insert all rows into a temporary table, delete them from */ -}} -{{- /* tuples, and insert all but one back. Any other solution would seem */ -}} -{{- /* to require having a PK, which we want to avoid. */ -}} - -{{- /* We also define this as a sequence of templates each containing one */ -}} -{{- /* statement, because SQLite3 doesn't support preparing multiple */ -}} -{{- /* statements at a time, and the SQL parameters in each differ. */ -}} - -{{- define "tuple.delete.0" -}} -CREATE TEMPORARY TABLE delete_hold ( - id INTEGER PRIMARY KEY, - msg BLOB, - {{- range $i, $_ := $.Iter}} - p{{$i}} TEXT, - {{- end}} - suffix TEXT -); -{{- end -}} - -{{- define "tuple.delete.1" -}} -INSERT INTO delete_hold ( - msg, - {{- range $i, $_ := $.Iter}} - p{{$i}}, - {{- end}} - suffix -) SELECT - msg, - {{- range $i, $_ := $.Iter}} - p{{$i}}, - {{- end}} - suffix -FROM main.Tuple - INNER JOIN Message ON Tuple.msg = Message.id -WHERE tag = :tag - {{- range $i, $_ := $.Iter}} - AND p{{$i}} IS :p{{$i}} - {{- end}} - AND suffix IS :suffix - AND LIKELY(deleted IS NULL); -{{- end -}} - -{{- define "tuple.delete.2" -}} -DELETE FROM main.Tuple -WHERE msg IN (SELECT msg FROM delete_hold) - {{- range $i, $_ := $.Iter}} - AND p{{$i}} IS :p{{$i}} - {{- end}} - AND suffix IS :suffix; -{{- end -}} - -{{- define "tuple.delete.3" -}} -INSERT INTO main.Tuple ( - msg, - {{- range $i, $_ := $.Iter}} - p{{$i}}, - {{- end}} - suffix -) SELECT - msg, - {{- range $i, $_ := $.Iter}} - p{{$i}}, - {{- end}} - suffix -FROM delete_hold -WHERE id != (SELECT id FROM delete_hold LIMIT 1); -{{- end -}} - -{{- define "tuple.delete.4" -}} -DROP TABLE delete_hold; -{{- end -}} diff --git a/brain/sqlbrain/templates/tuple.insert.sql b/brain/sqlbrain/templates/tuple.insert.sql deleted file mode 100644 index 58faf54..0000000 --- a/brain/sqlbrain/templates/tuple.insert.sql +++ /dev/null @@ -1,10 +0,0 @@ --- Insert a sequence of tuples. -WITH MsgID AS ( - VALUES (?) -) -INSERT INTO Tuple(msg, {{range $i, $_ := $.Iter}}p{{$i}}, {{end}}suffix) -VALUES - ((SELECT * FROM MsgID LIMIT 1), {{range $.Iter}}?, {{end}}?) - {{- range slice $.Tuples 1}}, - ((SELECT * FROM MsgID LIMIT 1), {{range $.Iter}}?, {{end}}?) - {{- end}}; diff --git a/brain/sqlbrain/templates/tuple.new.sql b/brain/sqlbrain/templates/tuple.new.sql deleted file mode 100644 index 4c95283..0000000 --- a/brain/sqlbrain/templates/tuple.new.sql +++ /dev/null @@ -1,6 +0,0 @@ --- Select all start-of-message terms with a given tag. --- The template input requires $.NM1 to be (order-1). -SELECT suffix -FROM MessageTuple -WHERE tag = ? - AND p{{$.NM1}} = '' diff --git a/brain/sqlbrain/templates/tuple.select.sql b/brain/sqlbrain/templates/tuple.select.sql deleted file mode 100644 index 7fab7ae..0000000 --- a/brain/sqlbrain/templates/tuple.select.sql +++ /dev/null @@ -1,34 +0,0 @@ --- Select a single suffix. --- For performance reasons, we always require the final prefix element to match --- exactly. For the remaining elements, we want a higher probability to select --- a given tuple as more of its terms match, weighted toward later elements. --- The template inputs include $.Fibonacci, a slice of (order-1) consecutive --- Fibonacci numbers; $.NM1, which is (order-1); and $.MinScore, which is the --- minimum sum of the Fibonacci numbers corresponding to matching terms for a --- tuple to be considered. --- We use one Fibonacci number less than you might expect, because the final --- prefix element must be an exact match anyway. Using consecutive Fibonacci --- numbers expresses the goal that we can drop one term if the two previous --- ones match. --- The SQL statement inputs are :tag and :p0, :p1, ... as named parameters. -WITH InitialSet AS ( - SELECT {{- range $i, $_ := $.Fibonacci}} - p{{$i}}, - {{- end}} - suffix - FROM MessageTuple - WHERE tag = :tag - AND p{{$.NM1}} IS :p{{$.NM1}} -), Scored AS ( - SELECT 0 {{- range $i, $w := $.Fibonacci -}} - + {{$w}}*(p{{$i}} IS :p{{$i}}) - {{- end}} AS score, - suffix - FROM InitialSet -), Thresholded AS ( - SELECT suffix - FROM Scored - WHERE score >= {{$.MinScore}} -) -SELECT suffix -FROM Thresholded diff --git a/go.mod b/go.mod index a3c597f..8e9571e 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( golang.org/x/sync v0.6.0 golang.org/x/time v0.5.0 gopkg.in/typ.v4 v4.3.0 + zombiezen.com/go/sqlite v1.3.0 ) require ( @@ -30,11 +31,18 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v24.3.7+incompatible // indirect github.com/klauspost/compress v1.17.7 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect go.opencensus.io v0.24.0 // indirect golang.org/x/net v0.22.0 // indirect golang.org/x/sys v0.18.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/protobuf v1.33.0 // indirect + modernc.org/libc v1.41.0 // indirect + modernc.org/mathutil v1.6.0 // indirect + modernc.org/memory v1.7.2 // indirect + modernc.org/sqlite v1.29.1 // indirect ) diff --git a/go.sum b/go.sum index e753038..4cffe84 100644 --- a/go.sum +++ b/go.sum @@ -66,13 +66,19 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -142,6 +148,7 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= @@ -151,6 +158,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -200,3 +209,13 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +modernc.org/libc v1.41.0 h1:g9YAc6BkKlgORsUWj+JwqoB1wU3o4DE3bM3yvA3k+Gk= +modernc.org/libc v1.41.0/go.mod h1:w0eszPsiXoOnoMJgrXjglgLuDy/bt5RR4y3QzUUeodY= +modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= +modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= +modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= +modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= +modernc.org/sqlite v1.29.1 h1:19GY2qvWB4VPw0HppFlZCPAbmxFU41r+qjKZQdQ1ryA= +modernc.org/sqlite v1.29.1/go.mod h1:hG41jCYxOAOoO6BRK66AdRlmOcDzXf7qnwlwjUIOqa0= +zombiezen.com/go/sqlite v1.3.0 h1:98g1gnCm+CNz6AuQHu0gqyw7gR2WU3O3PJufDOStpUs= +zombiezen.com/go/sqlite v1.3.0/go.mod h1:yRl27//s/9aXU3RWs8uFQwjkTG9gYNGEls6+6SvrclY= diff --git a/prepend/prepend.go b/prepend/prepend.go new file mode 100644 index 0000000..73cde0e --- /dev/null +++ b/prepend/prepend.go @@ -0,0 +1,83 @@ +package prepend + +// List is a list that minimizes copying while prepending. +// A nil *List is useful; methods which modify the list return a possibly new +// value, similar to the append builtin function. +type List[E any] struct { + space []E + k int +} + +// Len returns the number of elements in the list. +func (p *List[E]) Len() int { + if p == nil { + return 0 + } + return len(p.space) - p.k +} + +// Slice returns the elements in the list as a Slice directly into the list's +// owned memory. +func (p *List[E]) Slice() []E { + if p == nil { + return nil + } + return p.space[p.k:] +} + +// Set sets the contents of the list. +func (p *List[E]) Set(ee ...E) *List[E] { + if len(ee) == 0 { + return p.Reset() + } + p = p.Reset() + if len(ee) > len(p.space) { + p.space = make([]E, len(ee)) + } + p.k = len(p.space) - len(ee) + copy(p.space[p.k:], ee) + return p +} + +// Prepend inserts elements in provided order at the start of the list. +func (p *List[E]) Prepend(ee ...E) *List[E] { + if p == nil { + p = new(List[E]) + } + if p.k < len(ee) { + // We don't expect enormous prompts, so a simple growth algorithm is fine. + b := make([]E, cap(p.space)*2+len(ee)) + p.k = len(b) - len(p.space) + copy(b[p.k:], p.space) + p.space = b + } + p.k -= len(ee) + copy(p.space[p.k:], ee) + return p +} + +// Drop removes the last n terms from the list. +// If n <= 0, there is no change. +// If n >= p.len(), the list becomes empty. +func (p *List[E]) Drop(n int) *List[E] { + if n <= 0 { + return p + } + if n >= p.Len() { + // As a special case, we can reset the entire list when we drop all. + // Note this branch also includes p == nil. + return p.Reset() + } + p.space = p.space[:len(p.space)-n] + return p +} + +// Reset removes all elements from the list. +func (p *List[E]) Reset() *List[E] { + if p == nil { + return new(List[E]) + } + p.space = p.space[:cap(p.space):cap(p.space)] + p.k = cap(p.space) + return p +} diff --git a/prepend/prepend_test.go b/prepend/prepend_test.go new file mode 100644 index 0000000..c151976 --- /dev/null +++ b/prepend/prepend_test.go @@ -0,0 +1,110 @@ +package prepend + +import ( + "fmt" + "slices" + "testing" +) + +func TestPrepender(t *testing.T) { + cases := []struct { + name string + set []int + pre [][]int + drop int + want []int + }{ + { + name: "empty", + set: nil, + pre: nil, + drop: 0, + want: nil, + }, + { + name: "set", + set: []int{1, 2}, + pre: nil, + drop: 0, + want: []int{1, 2}, + }, + { + name: "empty-pre", + set: nil, + pre: [][]int{{1}}, + drop: 0, + want: []int{1}, + }, + { + name: "pre", + set: []int{2}, + pre: [][]int{{1}}, + drop: 0, + want: []int{1, 2}, + }, + { + name: "many-pre", + set: nil, + pre: [][]int{{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {10}, {11}, {12}, {13}, {14}, {15}, {16}}, + drop: 0, + // prepending gives reverse order + want: []int{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + }, + { + name: "multi-pre", + set: nil, + pre: [][]int{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}}, + drop: 0, + want: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + { + name: "empty-drop", + set: nil, + pre: nil, + drop: 1, + want: nil, + }, + { + name: "drop", + set: []int{1, 2}, + pre: nil, + drop: 1, + want: []int{1}, + }, + { + name: "drop-minus", + set: []int{1, 2}, + pre: nil, + drop: -1, + want: []int{1, 2}, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + t.Parallel() + var p *List[int] + invariants := func(step string) { + if want, got := len(p.Slice()), p.Len(); want != got { + t.Errorf("lengths disagree after %s: slice gives %d, len gives %d", step, want, got) + } + } + invariants("nil decl") + p = p.Set(c.set...) + invariants("set") + for _, x := range c.pre { + p = p.Prepend(x...) + invariants(fmt.Sprintf("prepend %d", x)) + } + p = p.Drop(c.drop) + invariants("drop") + if !slices.Equal(p.Slice(), c.want) { + t.Errorf("wrong final list:\nwant %v\ngot %v", c.want, p.Slice()) + } + p = p.Reset() + invariants("reset") + if len(p.Slice()) != 0 { + t.Errorf("not empty after reset: %v", p.Slice()) + } + }) + } +} diff --git a/privacy/privacy_test.go b/privacy/privacy_test.go index 3ea3caa..d2b5fa2 100644 --- a/privacy/privacy_test.go +++ b/privacy/privacy_test.go @@ -6,7 +6,6 @@ import ( "gitlab.com/zephyrtronium/sq" - "github.com/zephyrtronium/robot/brain/sqlbrain" "github.com/zephyrtronium/robot/privacy" _ "github.com/mattn/go-sqlite3" // driver @@ -39,14 +38,15 @@ func TestInit(t *testing.T) { // TestCohabitant tests that a privacy list and an sqlbrain can exist in the // same database. func TestCohabitant(t *testing.T) { - ctx := context.Background() - db := testConn() - if err := sqlbrain.Create(ctx, db, 1); err != nil { - t.Errorf("couldn't create sqlbrain in the first place: %v", err) - } - if err := privacy.Init(ctx, db); err != nil { - t.Errorf("couldn't create privacy list together with sqlbrain: %v", err) - } + t.Skip("package needs update to support new sqlite provider") + // ctx := context.Background() + // db := testConn() + // if err := sqlbrain.Create(ctx, db); err != nil { + // t.Errorf("couldn't create sqlbrain in the first place: %v", err) + // } + // if err := privacy.Init(ctx, db); err != nil { + // t.Errorf("couldn't create privacy list together with sqlbrain: %v", err) + // } } func TestList(t *testing.T) { diff --git a/privmsg.go b/privmsg.go index af33939..361846f 100644 --- a/privmsg.go +++ b/privmsg.go @@ -178,13 +178,7 @@ func (robo *Robot) learn(ctx context.Context, ch *channel.Channel, hasher userha // Continue on with a zero UUID. } user := hasher.Hash(new(userhash.Hash), msg.Sender, msg.To, msg.Time()) - meta := &brain.MessageMeta{ - ID: id, - User: *user, - Tag: ch.Learn, - Time: msg.Time(), - } - if err := brain.Learn(ctx, robo.brain, meta, brain.Tokens(nil, msg.Text)); err != nil { + if err := brain.Learn(ctx, robo.brain, ch.Learn, *user, id, msg.Time(), brain.Tokens(nil, msg.Text)); err != nil { slog.ErrorContext(ctx, "failed to learn", slog.String("err", err.Error())) } } diff --git a/tpool/tpool.go b/tpool/tpool.go new file mode 100644 index 0000000..45c422e --- /dev/null +++ b/tpool/tpool.go @@ -0,0 +1,25 @@ +// Package tpool provides a generic, type-safe sync.Pool wrapper. +package tpool + +import "sync" + +// Pool is a type-safe wrapper around a [sync.Pool]. +// To obtain one, declare a variable or convert an existing sync.Pool to it. +// In the latter case, if the pool's New field is non-nil, +// it must return values which assert to T. +type Pool[T any] sync.Pool + +// Get pulls a value from the pool. +// It is a thin wrapper around [*sync.Pool.Get] and so mirrors its semantics. +// If the pool's New field is non-nil and returns a value which does not assert to T, +// then the result is the zero value of T. +func (p *Pool[T]) Get() T { + r, _ := (*sync.Pool)(p).Get().(T) + return r +} + +// Put returns a value to the pool. +// It is a thin wrapper around [*sync.Pool.Put] and so mirrors its semantics. +func (p *Pool[T]) Put(e T) { + (*sync.Pool)(p).Put(e) +} diff --git a/tpool/tpool_test.go b/tpool/tpool_test.go new file mode 100644 index 0000000..c2aa86a --- /dev/null +++ b/tpool/tpool_test.go @@ -0,0 +1,60 @@ +package tpool_test + +import ( + "sync" + "testing" + + "github.com/zephyrtronium/robot/tpool" +) + +func TestAllocs(t *testing.T) { + const iters, runs int = 1e3, 1e3 + u := testing.AllocsPerRun(runs, func() { + var pool sync.Pool + for range iters { + x, _ := pool.Get().(*int) + if x == nil { + x = new(int) + } + pool.Put(x) + pool.Put(new(int)) + } + }) + v := testing.AllocsPerRun(runs, func() { + var pool tpool.Pool[*int] + for range iters { + x := pool.Get() + if x == nil { + x = new(int) + } + pool.Put(x) + pool.Put(new(int)) + } + }) + if u != v { + t.Errorf("different allocs per run: sync.Pool has %v, tpool.Pool[*int] has %v", u, v) + } +} + +func TestAllocsNew(t *testing.T) { + const iters, runs int = 1e3, 1e3 + u := testing.AllocsPerRun(runs, func() { + pool := sync.Pool{New: func() any { return new(int) }} + for range iters { + x, _ := pool.Get().(*int) + pool.Put(x) + pool.Put(new(int)) + } + }) + v := testing.AllocsPerRun(runs, func() { + pool := tpool.Pool[*int]{New: func() any { return new(int) }} + for range iters { + x := pool.Get() + pool.Put(x) + pool.Put(new(int)) + } + }) + if u != v { + t.Errorf("different allocs per run: sync.Pool has %v, tpool.Pool[*int] has %v", u, v) + } +}