diff --git a/api/internalutils/stream/stream.go b/api/internalutils/stream/stream.go index e891015862f2..8032b6daf1da 100644 --- a/api/internalutils/stream/stream.go +++ b/api/internalutils/stream/stream.go @@ -408,6 +408,116 @@ func Take[T any](stream Stream[T], n int) ([]T, bool) { return items, true } +type skip[T any] struct { + inner Stream[T] + skip int +} + +func (s *skip[T]) Next() bool { + for i := 0; i < s.skip; i++ { + if !s.inner.Next() { + return false + } + } + s.skip = 0 + return s.inner.Next() +} + +func (s *skip[T]) Item() T { + return s.inner.Item() +} + +func (s *skip[T]) Done() error { + return s.inner.Done() +} + +// Skip skips the first n items from a stream. Zero/negative values of n +// have no effect. +func Skip[T any](stream Stream[T], n int) Stream[T] { + return &skip[T]{ + inner: stream, + skip: n, + } +} + +type flatten[T any] struct { + inner Stream[Stream[T]] + current Stream[T] + err error +} + +func (stream *flatten[T]) Next() bool { + for { + if stream.current != nil { + if stream.current.Next() { + return true + } + stream.err = stream.current.Done() + stream.current = nil + if stream.err != nil { + return false + } + } + + if !stream.inner.Next() { + return false + } + + stream.current = stream.inner.Item() + } +} + +func (stream *flatten[T]) Item() T { + return stream.current.Item() +} + +func (stream *flatten[T]) Done() error { + if stream.current != nil { + stream.err = stream.current.Done() + stream.current = nil + } + + ierr := stream.inner.Done() + if stream.err != nil { + return stream.err + } + return ierr +} + +// Flatten flattens a stream of streams into a single stream of items. +func Flatten[T any](stream Stream[Stream[T]]) Stream[T] { + return &flatten[T]{ + inner: stream, + } +} + +type mapErr[T any] struct { + inner Stream[T] + fn func(error) error +} + +func (stream *mapErr[T]) Next() bool { + return stream.inner.Next() +} + +func (stream *mapErr[T]) Item() T { + return stream.inner.Item() +} + +func (stream *mapErr[T]) Done() error { + return stream.fn(stream.inner.Done()) +} + +// MapErr maps over the error returned by Done(). The supplied function is called +// for all invocations of Done(), meaning that it can change, suppress, or create +// errors as needed. +func MapErr[T any](stream Stream[T], fn func(error) error) Stream[T] { + return &mapErr[T]{ + inner: stream, + fn: fn, + } +} + type rateLimit[T any] struct { inner Stream[T] wait func() error diff --git a/api/internalutils/stream/stream_test.go b/api/internalutils/stream/stream_test.go index d55cb19c20c3..b70972a461ec 100644 --- a/api/internalutils/stream/stream_test.go +++ b/api/internalutils/stream/stream_test.go @@ -506,6 +506,141 @@ func TestTake(t *testing.T) { } } +// TestSkip tests the Skip combinator. +func TestSkip(t *testing.T) { + t.Parallel() + + // normal usage + s, err := Collect(Skip(Slice([]int{1, 2, 3, 4}), 2)) + require.NoError(t, err) + require.Equal(t, []int{3, 4}, s) + + // skip all + s, err = Collect(Skip(Slice([]int{1, 2, 3, 4}), 4)) + require.NoError(t, err) + require.Empty(t, s) + + // skip none + s, err = Collect(Skip(Slice([]int{1, 2, 3, 4}), 0)) + require.NoError(t, err) + require.Equal(t, []int{1, 2, 3, 4}, s) + + // negative skip + s, err = Collect(Skip(Slice([]int{1, 2, 3, 4}), -1)) + require.NoError(t, err) + require.Equal(t, []int{1, 2, 3, 4}, s) + + // skip more than available + s, err = Collect(Skip(Slice([]int{1, 2, 3, 4}), 5)) + require.NoError(t, err) + require.Empty(t, s) + + // positive skip on empty stream + s, err = Collect(Skip(Empty[int](), 2)) + require.NoError(t, err) + require.Empty(t, s) + + // zero skip on empty stream + s, err = Collect(Skip(Empty[int](), 0)) + require.NoError(t, err) + require.Empty(t, s) + + // negative skip on empty stream + s, err = Collect(Skip(Empty[int](), -1)) + require.NoError(t, err) + require.Empty(t, s) + + // immediate failure + err = Drain(Skip(Fail[int](fmt.Errorf("unexpected error")), 1)) + require.Error(t, err) + + // failure during skip + err = Drain(Skip(Chain( + Slice([]int{1, 2}), + Fail[int](fmt.Errorf("unexpected error")), + Slice([]int{3, 4}), + ), 3)) + require.Error(t, err) +} + +// TestFlatten tests the Flatten combinator. +func TestFlatten(t *testing.T) { + t.Parallel() + + // normal usage + s, err := Collect(Flatten(Slice([]Stream[int]{ + Slice([]int{1, 2}), + Slice([]int{3, 4}), + Slice([]int{5, 6}), + }))) + require.NoError(t, err) + require.Equal(t, []int{1, 2, 3, 4, 5, 6}, s) + + // empty stream + s, err = Collect(Flatten(Empty[Stream[int]]())) + require.NoError(t, err) + require.Empty(t, s) + + // empty substreams + s, err = Collect(Flatten(Slice([]Stream[int]{ + Empty[int](), + Slice([]int{1, 2, 3}), + Empty[int](), + Slice([]int{4, 5, 6}), + Empty[int](), + }))) + require.NoError(t, err) + require.Equal(t, []int{1, 2, 3, 4, 5, 6}, s) + + // immediate failure + err = Drain(Flatten(Fail[Stream[int]](fmt.Errorf("unexpected error")))) + require.Error(t, err) + + // failure during streaming + s, err = Collect(Flatten(Slice([]Stream[int]{ + Slice([]int{1, 2}), + Fail[int](fmt.Errorf("unexpected error")), + Slice([]int{3, 4}), + }))) + require.Error(t, err) + require.Equal(t, []int{1, 2}, s) +} + +// TestMapErr tests the MapErr combinator. +func TestMapErr(t *testing.T) { + t.Parallel() + + // normal inject error + err := Drain(MapErr(Slice([]int{1, 2, 3}), func(err error) error { + require.NoError(t, err) + return fmt.Errorf("unexpected error") + })) + require.Error(t, err) + + // empty inject error + err = Drain(MapErr(Empty[int](), func(err error) error { + require.NoError(t, err) + return fmt.Errorf("unexpected error") + })) + require.Error(t, err) + + // normal suppress error + s, err := Collect(MapErr(Chain(Slice([]int{1, 2, 3}), Fail[int](fmt.Errorf("unexpected error"))), func(err error) error { + require.Error(t, err) + return nil + })) + require.NoError(t, err) + require.Equal(t, []int{1, 2, 3}, s) + + // empty suppress error + s, err = Collect(MapErr(Fail[int](fmt.Errorf("unexpected error")), func(err error) error { + require.Error(t, err) + return nil + })) + require.NoError(t, err) + require.Empty(t, s) +} + // TestRateLimitFailure verifies the expected failure conditions of the RateLimit helper. func TestRateLimitFailure(t *testing.T) { t.Parallel()