-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
channel: don't use sqlite for meme detection
- Loading branch information
1 parent
a44050f
commit acf4ea0
Showing
3 changed files
with
93 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,115 +1,140 @@ | ||
package channel | ||
|
||
import ( | ||
"context" | ||
_ "embed" | ||
"errors" | ||
"fmt" | ||
"sync" | ||
"time" | ||
|
||
"github.com/mattn/go-sqlite3" | ||
"gitlab.com/zephyrtronium/sq" | ||
) | ||
|
||
// MemeDetector is literally a meme detector. | ||
type MemeDetector struct { | ||
// db is a database of messages received and memes sent in a channel. | ||
db *sq.Conn | ||
// mu guards the messages list and the counts. | ||
mu sync.Mutex | ||
// front is the node of the most recent message in history. | ||
front *node | ||
// back is the node of the oldest message in history. | ||
back *node | ||
// counts tracks the expiry time of each message said by each user. | ||
// Detected memes are recorded with the empty string as the user. | ||
counts map[string]map[string]int64 // map[message]map[user]UnixMillis | ||
|
||
// need is the number of messages needed to trigger memery. | ||
need int | ||
// within is the duration to hold messages. | ||
within time.Duration | ||
} | ||
|
||
var ( | ||
memdb *sq.DB | ||
memdbOnce sync.Once | ||
) | ||
// node is a node in a doubly linked list of messages sorted by time. | ||
type node struct { | ||
older, newer *node | ||
|
||
func loadMemdb() { | ||
var err error | ||
memdb, err = sq.Open("sqlite3", ":memory:") | ||
if err != nil { | ||
panic(fmt.Errorf("couldn't open memory db: %w", err)) | ||
} | ||
if err := memdb.Ping(context.Background()); err != nil { | ||
panic(fmt.Errorf("couldn't ping memory db: %w", err)) | ||
} | ||
msg string | ||
user string | ||
exp int64 | ||
} | ||
|
||
// NewMemeDetector creates. | ||
func NewMemeDetector(need int, within time.Duration) *MemeDetector { | ||
ctx := context.Background() | ||
memdbOnce.Do(loadMemdb) | ||
conn, err := memdb.Conn(ctx) | ||
if err != nil { | ||
panic(fmt.Errorf("couldn't open single conn: %w", err)) | ||
} | ||
if _, err := conn.Exec(ctx, createCopypasta); err != nil { | ||
panic(fmt.Errorf("couldn't set up meme tables: %w", err)) | ||
} | ||
return &MemeDetector{ | ||
db: conn, | ||
counts: make(map[string]map[string]int64), | ||
need: need, | ||
within: within, | ||
} | ||
} | ||
|
||
func (m *MemeDetector) DB() *sq.Conn { | ||
return m.db | ||
func (m *MemeDetector) chopLocked(now int64) { | ||
b := m.back | ||
// For each expired node at the back: | ||
for b != nil && b.exp <= now { | ||
// We will drop this node. | ||
// If the most recent expiry time from this user isn't newer than b, | ||
// we also need to stop tracking it in the map. | ||
if m.counts[b.msg][b.user] <= b.exp { | ||
if m.counts[b.msg] != nil { | ||
delete(m.counts[b.msg], b.user) | ||
// If there are no users left who sent this message, | ||
// stop tracking the message as well to control memory usage. | ||
if len(m.counts[b.msg]) == 0 { | ||
delete(m.counts, b.msg) | ||
} | ||
} | ||
} | ||
if b.newer == nil { | ||
m.front, m.back = nil, nil | ||
return | ||
} | ||
m.back = b.newer | ||
b.newer = nil // clear the pointer to help the gc | ||
b = m.back | ||
b.older = nil | ||
} | ||
} | ||
|
||
func (m *MemeDetector) insertLocked(msg, user string, exp int64) { | ||
new := &node{ | ||
msg: msg, | ||
user: user, | ||
exp: exp, | ||
} | ||
if new.exp <= m.counts[msg][user] { | ||
// This message is (somehow) older than another from the same user. | ||
// We don't care about it. | ||
return | ||
} | ||
if m.counts[msg] == nil { | ||
m.counts[msg] = make(map[string]int64) | ||
} | ||
m.counts[msg][user] = new.exp | ||
|
||
if m.front == nil { | ||
m.front, m.back = new, new | ||
return | ||
} | ||
if new.exp >= m.front.exp { | ||
new.older = m.front | ||
m.front = new | ||
return | ||
} | ||
l := m.front | ||
for l.older != nil { | ||
if new.exp >= l.older.exp { | ||
new.newer, new.older = l, l.older | ||
l.older, l.older.newer = new, new | ||
return | ||
} | ||
l = l.older | ||
} | ||
new.newer = l | ||
l.older, m.back = new, new | ||
} | ||
|
||
// Check determines whether a message is a meme. If it is not, the returned | ||
// error is NotCopypasta. Times passed to Check should be monotonic, as | ||
// messages outside the detector's threshold are removed. | ||
func (m *MemeDetector) Check(t time.Time, from, msg string) error { | ||
ctx := context.Background() | ||
tm := t.UnixMilli() | ||
// Remove old messages. | ||
_, err := m.db.Exec(ctx, `DELETE FROM Message WHERE time < ?`, tm-m.within.Milliseconds()) | ||
if err != nil { | ||
return fmt.Errorf("couldn't remove old messages from meme detector: %w", err) | ||
} | ||
// Discard old memes. | ||
_, err = m.db.Exec(ctx, `DELETE FROM Meme WHERE time < ?`, tm-15*time.Minute.Milliseconds()) | ||
if err != nil { | ||
return fmt.Errorf("couldn't remove old memes: %w", err) | ||
} | ||
// Insert the new one. | ||
_, err = m.db.Exec(ctx, `INSERT INTO Message (time, user, msg) VALUES (?, ?, ?)`, tm, from, msg) | ||
if err != nil { | ||
return fmt.Errorf("couldn't insert new message into meme detector: %w", err) | ||
} | ||
now := t.UnixMilli() | ||
m.mu.Lock() | ||
defer m.mu.Unlock() | ||
// Remove old messages and discard old memes. | ||
m.chopLocked(now) | ||
// Insert the new message. | ||
m.insertLocked(msg, from, now+m.within.Milliseconds()) | ||
// Get the meme metric: number of distinct users who sent this message in | ||
// the time window. | ||
var n int | ||
err = m.db.QueryRow(ctx, `SELECT COUNT(DISTINCT user) FROM Message WHERE msg = ?`, msg).Scan(&n) | ||
if err != nil { | ||
return fmt.Errorf("couldn't get memery count: %w", err) | ||
} | ||
n := len(m.counts[msg]) | ||
if n < m.need { | ||
return ErrNotCopypasta | ||
} | ||
// Genuine meme. But is it fresh? | ||
_, err = m.db.Exec(ctx, `INSERT INTO Meme (time, msg) VALUES (?, ?)`, tm, msg) | ||
if err != nil { | ||
// Since we expect to react to (i.e. log) non-copypasta errors that | ||
// aren't ErrNotCopypasta, it's more helpful to return it when it's a | ||
// real reason not to be copypasta. | ||
if err, ok := err.(sqlite3.Error); ok { | ||
if err.Code == sqlite3.ErrConstraint && err.ExtendedCode == sqlite3.ErrConstraintUnique { | ||
return ErrNotCopypasta | ||
} | ||
} | ||
return fmt.Errorf("couldn't register fresh meme: %w %#v", err, err) | ||
if _, ok := m.counts[msg][""]; ok { | ||
return ErrNotCopypasta | ||
} | ||
// It is, but not for the following fifteen minutes. | ||
m.insertLocked(msg, "", now+15*60*1000) | ||
return nil | ||
} | ||
|
||
// ErrNotCopypasta is a sentinel error returned by MemeDetector.Check when a | ||
// message is not copypasta. | ||
var ErrNotCopypasta = errors.New("not copypasta") | ||
|
||
//go:embed copypasta.sql | ||
var createCopypasta string |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters