Skip to content

Commit

Permalink
brain/sqlbrain: use efficient sequence sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
zephyrtronium committed Aug 8, 2024
1 parent e3932d2 commit 37ce5e9
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions brain/sqlbrain/speak.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ func next(conn *sqlite.Conn, tag string, b []byte, prompt []string) ([]byte, int
st.SetText(":tag", tag)
w := make([]byte, 0, 32)
var d []byte
var m uint64
var skip brain.Skip
picked := 0
for {
b = prefix(b[:0], prompt)
b, d = searchbounds(b)
st.SetBytes(":lower", b)
st.SetBytes(":upper", d)
sel:
for {
ok, err := st.Step()
if err != nil {
Expand All @@ -72,20 +73,21 @@ func next(conn *sqlite.Conn, tag string, b []byte, prompt []string) ([]byte, int
if !ok {
break
}
// We generate a uniform variate per row, then choose the row that
// gets the maximum variate.
// TODO(zeph): gumbel distribution
u := rand.Uint64()
if m > u {
continue
}
picked++
m = u
n := st.ColumnLen(0)
if cap(w) < n {
w = make([]byte, n)
}
w = w[:st.ColumnBytes(0, w[:n])]
picked++
for range skip.N(rand.Uint64(), rand.Uint64()) {
ok, err := st.Step()
if err != nil {
return b[:0], len(prompt), fmt.Errorf("couldn't step term selection: %w", err)
}
if !ok {
break sel
}
}
}
if picked < 3 && len(prompt) > 1 {
// We haven't seen enough options, and we have context we could
Expand Down Expand Up @@ -123,7 +125,8 @@ func first(conn *sqlite.Conn, tag string, b []byte) ([]byte, error) {
return b, fmt.Errorf("couldn't prepare first term selection: %w", err)
}
s.SetText(":tag", tag)
var m uint64
var skip brain.Skip
sel:
for {
ok, err := s.Step()
if err != nil {
Expand All @@ -132,17 +135,20 @@ func first(conn *sqlite.Conn, tag string, b []byte) ([]byte, error) {
if !ok {
break
}
// TODO(zeph): gumbel distribution
u := rand.Uint64()
if m > u {
continue
}
m = u
n := s.ColumnLen(0)
if cap(b) < n {
b = make([]byte, n)
}
b = b[:s.ColumnBytes(0, b[:n])]
for range skip.N(rand.Uint64(), rand.Uint64()) {
ok, err := s.Step()
if err != nil {
return b[:0], fmt.Errorf("couldn't step first term selection: %w", err)
}
if !ok {
break sel
}
}
}
return b, nil
}

0 comments on commit 37ce5e9

Please sign in to comment.