From bd85c99c7780d8a52ff24f36d1bbec2256c8238a Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Sat, 4 Jan 2025 20:06:09 -0800 Subject: [PATCH] stree: add a Find method This is similar to Cursor, but selects the first key greater than or equal to the given key rather than failing when the exact key is not found. --- stree/stree.go | 10 ++++++++++ stree/stree_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/stree/stree.go b/stree/stree.go index c3a4dba..ca4e181 100644 --- a/stree/stree.go +++ b/stree/stree.go @@ -277,6 +277,16 @@ func (t *Tree[T]) Get(key T) (_ T, ok bool) { return } +// Find returns a cursor to the smallest key in the tree greater than or equal +// to key. If no such key exists, Find returns nil. +func (t *Tree[T]) Find(key T) *Cursor[T] { + path := t.root.pathTo(key, t.compare) + if len(path) == 0 || t.compare(path[len(path)-1].X, key) < 0 { + return nil + } + return &Cursor[T]{path: path} +} + // Inorder is a range function that visits each key of t in order. func (t *Tree[T]) Inorder(yield func(key T) bool) { t.root.inorder(yield) } diff --git a/stree/stree_test.go b/stree/stree_test.go index c17d24f..3a6b247 100644 --- a/stree/stree_test.go +++ b/stree/stree_test.go @@ -322,6 +322,45 @@ func TestCursor(t *testing.T) { t.Errorf("Right tree (-got, +want):\n%s", diff) } }) + + t.Run("Find", func(t *testing.T) { + tree := stree.New(250, strings.Compare, "a", "e", "i", "o", "u") + + t.Run("None", func(t *testing.T) { + if got := tree.Find("z"); got != nil { + t.Errorf("Find z: got %v, want nil", got) + } + }) + t.Run("Exact", func(t *testing.T) { + if got := tree.Find("e"); got.Key() != "e" { + t.Errorf("Find e: got %q, want e", got.Key()) + } + }) + t.Run("Before", func(t *testing.T) { + got := tree.Find("0") + if got.Key() != "a" { + t.Errorf("Find 0: got %q, want a", got.Key()) + } + if next := got.Next(); next.Key() != "e" { + t.Errorf("Next a: got %q, want e", next.Key()) + } + if prev := got.Prev().Prev(); prev.Valid() { + t.Errorf("Prev a: got %v, want invalid", prev) + } + }) + t.Run("Between", func(t *testing.T) { + got := tree.Find("k") + if got.Key() != "o" { + t.Errorf("Find k: got %q, want o", got.Key()) + } + if next := got.Next(); next.Key() != "u" { + t.Errorf("Next o: got %q, want u", next.Key()) + } + if prev := got.Prev().Prev(); prev.Key() != "i" { + t.Errorf("Prev o: got %q, want i", prev.Key()) + } + }) + }) } func TestKV(t *testing.T) {