Skip to content

Commit

Permalink
channel: don't use sqlite for meme detection
Browse files Browse the repository at this point in the history
  • Loading branch information
zephyrtronium committed Aug 7, 2024
1 parent a44050f commit acf4ea0
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 111 deletions.
161 changes: 93 additions & 68 deletions channel/copypasta.go
Original file line number Diff line number Diff line change
@@ -1,115 +1,140 @@
package channel

import (
"context"
_ "embed"
"errors"
"fmt"
"sync"
"time"

"github.com/mattn/go-sqlite3"
"gitlab.com/zephyrtronium/sq"
)

// MemeDetector is literally a meme detector.
type MemeDetector struct {
// db is a database of messages received and memes sent in a channel.
db *sq.Conn
// mu guards the messages list and the counts.
mu sync.Mutex
// front is the node of the most recent message in history.
front *node
// back is the node of the oldest message in history.
back *node
// counts tracks the expiry time of each message said by each user.
// Detected memes are recorded with the empty string as the user.
counts map[string]map[string]int64 // map[message]map[user]UnixMillis

// need is the number of messages needed to trigger memery.
need int
// within is the duration to hold messages.
within time.Duration
}

var (
memdb *sq.DB
memdbOnce sync.Once
)
// node is a node in a doubly linked list of messages sorted by time.
type node struct {
older, newer *node

func loadMemdb() {
var err error
memdb, err = sq.Open("sqlite3", ":memory:")
if err != nil {
panic(fmt.Errorf("couldn't open memory db: %w", err))
}
if err := memdb.Ping(context.Background()); err != nil {
panic(fmt.Errorf("couldn't ping memory db: %w", err))
}
msg string
user string
exp int64
}

// NewMemeDetector creates.
func NewMemeDetector(need int, within time.Duration) *MemeDetector {
ctx := context.Background()
memdbOnce.Do(loadMemdb)
conn, err := memdb.Conn(ctx)
if err != nil {
panic(fmt.Errorf("couldn't open single conn: %w", err))
}
if _, err := conn.Exec(ctx, createCopypasta); err != nil {
panic(fmt.Errorf("couldn't set up meme tables: %w", err))
}
return &MemeDetector{
db: conn,
counts: make(map[string]map[string]int64),
need: need,
within: within,
}
}

func (m *MemeDetector) DB() *sq.Conn {
return m.db
func (m *MemeDetector) chopLocked(now int64) {
b := m.back
// For each expired node at the back:
for b != nil && b.exp <= now {
// We will drop this node.
// If the most recent expiry time from this user isn't newer than b,
// we also need to stop tracking it in the map.
if m.counts[b.msg][b.user] <= b.exp {
if m.counts[b.msg] != nil {
delete(m.counts[b.msg], b.user)
// If there are no users left who sent this message,
// stop tracking the message as well to control memory usage.
if len(m.counts[b.msg]) == 0 {
delete(m.counts, b.msg)
}
}
}
if b.newer == nil {
m.front, m.back = nil, nil
return
}
m.back = b.newer
b.newer = nil // clear the pointer to help the gc
b = m.back
b.older = nil
}
}

func (m *MemeDetector) insertLocked(msg, user string, exp int64) {
new := &node{
msg: msg,
user: user,
exp: exp,
}
if new.exp <= m.counts[msg][user] {
// This message is (somehow) older than another from the same user.
// We don't care about it.
return
}
if m.counts[msg] == nil {
m.counts[msg] = make(map[string]int64)
}
m.counts[msg][user] = new.exp

if m.front == nil {
m.front, m.back = new, new
return
}
if new.exp >= m.front.exp {
new.older = m.front
m.front = new
return
}
l := m.front
for l.older != nil {
if new.exp >= l.older.exp {
new.newer, new.older = l, l.older
l.older, l.older.newer = new, new
return
}
l = l.older
}
new.newer = l
l.older, m.back = new, new
}

// Check determines whether a message is a meme. If it is not, the returned
// error is NotCopypasta. Times passed to Check should be monotonic, as
// messages outside the detector's threshold are removed.
func (m *MemeDetector) Check(t time.Time, from, msg string) error {
ctx := context.Background()
tm := t.UnixMilli()
// Remove old messages.
_, err := m.db.Exec(ctx, `DELETE FROM Message WHERE time < ?`, tm-m.within.Milliseconds())
if err != nil {
return fmt.Errorf("couldn't remove old messages from meme detector: %w", err)
}
// Discard old memes.
_, err = m.db.Exec(ctx, `DELETE FROM Meme WHERE time < ?`, tm-15*time.Minute.Milliseconds())
if err != nil {
return fmt.Errorf("couldn't remove old memes: %w", err)
}
// Insert the new one.
_, err = m.db.Exec(ctx, `INSERT INTO Message (time, user, msg) VALUES (?, ?, ?)`, tm, from, msg)
if err != nil {
return fmt.Errorf("couldn't insert new message into meme detector: %w", err)
}
now := t.UnixMilli()
m.mu.Lock()
defer m.mu.Unlock()
// Remove old messages and discard old memes.
m.chopLocked(now)
// Insert the new message.
m.insertLocked(msg, from, now+m.within.Milliseconds())
// Get the meme metric: number of distinct users who sent this message in
// the time window.
var n int
err = m.db.QueryRow(ctx, `SELECT COUNT(DISTINCT user) FROM Message WHERE msg = ?`, msg).Scan(&n)
if err != nil {
return fmt.Errorf("couldn't get memery count: %w", err)
}
n := len(m.counts[msg])
if n < m.need {
return ErrNotCopypasta
}
// Genuine meme. But is it fresh?
_, err = m.db.Exec(ctx, `INSERT INTO Meme (time, msg) VALUES (?, ?)`, tm, msg)
if err != nil {
// Since we expect to react to (i.e. log) non-copypasta errors that
// aren't ErrNotCopypasta, it's more helpful to return it when it's a
// real reason not to be copypasta.
if err, ok := err.(sqlite3.Error); ok {
if err.Code == sqlite3.ErrConstraint && err.ExtendedCode == sqlite3.ErrConstraintUnique {
return ErrNotCopypasta
}
}
return fmt.Errorf("couldn't register fresh meme: %w %#v", err, err)
if _, ok := m.counts[msg][""]; ok {
return ErrNotCopypasta
}
// It is, but not for the following fifteen minutes.
m.insertLocked(msg, "", now+15*60*1000)
return nil
}

// ErrNotCopypasta is a sentinel error returned by MemeDetector.Check when a
// message is not copypasta.
var ErrNotCopypasta = errors.New("not copypasta")

//go:embed copypasta.sql
var createCopypasta string
12 changes: 0 additions & 12 deletions channel/copypasta.sql

This file was deleted.

31 changes: 0 additions & 31 deletions channel/copypasta_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package channel_test

import (
"context"
"testing"
"time"

"github.com/zephyrtronium/robot/channel"
"gitlab.com/zephyrtronium/sq"
)

func TestMemeDetector(t *testing.T) {
Expand Down Expand Up @@ -103,37 +101,8 @@ func TestMemeDetector(t *testing.T) {
err := d.Check(time.UnixMilli(m.when), m.who, m.text)
if err != m.err {
t.Errorf("wrong error for %+v: want %v, got %v", m, m.err, err)
dumpdb(context.Background(), t, d.DB())
}
}
})
}
}

func dumpdb(ctx context.Context, t *testing.T, db *sq.Conn) {
t.Helper()
t.Log("db content:")
rows, err := db.Query(ctx, "SELECT * FROM Message")
if err != nil {
panic(err)
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
panic(err)
}
t.Log(cols)
for rows.Next() {
r := make([]any, len(cols))
for i := range r {
r[i] = &r[i]
}
if err := rows.Scan(r...); err != nil {
panic(err)
}
t.Logf("%q", r)
}
if rows.Err() != nil {
t.Log(rows.Err())
}
}

0 comments on commit acf4ea0

Please sign in to comment.