Skip to content

Commit

Permalink
Generalize the iterator implementation and add a batch iterator (#502)
Browse files Browse the repository at this point in the history
* Move the iteration code to a generalized library.

This iteration code will be used by criticality_score and eventually
collect_signals to process input repositories.

Signed-off-by: Caleb Brown <[email protected]>

* Bump go.work.sum

Signed-off-by: Caleb Brown <[email protected]>

* Fix broken test

Signed-off-by: Caleb Brown <[email protected]>

* Fix broken tests.

Signed-off-by: Caleb Brown <[email protected]>

---------

Signed-off-by: Caleb Brown <[email protected]>
  • Loading branch information
calebbrown authored Dec 14, 2023
1 parent b3aca0a commit c870dcd
Show file tree
Hide file tree
Showing 10 changed files with 586 additions and 139 deletions.
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ linters:
enable:
- asciicheck
#- bodyclose # Temporarily disabled.
- depguard
#- depguard # Temporarily disabled.
- dogsled
#- errcheck # Temporarily disabled.
- errorlint
Expand Down
13 changes: 4 additions & 9 deletions cmd/criticality_score/inputiter/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
package inputiter

import (
"bufio"
"context"
"errors"
"net/url"
"os"

"github.com/ossf/criticality_score/internal/infile"
"github.com/ossf/criticality_score/internal/iterator"
)

// osErrorWithFilename is an os-specific helper for determining if a particular
Expand Down Expand Up @@ -53,7 +53,7 @@ func errWithFilename(err error) bool {
//
// TODO: support the ability to force args to be interpreted as either a file,
// or a list of repos.
func New(args []string) (IterCloser[string], error) {
func New(args []string) (iterator.IterCloser[string], error) {
if len(args) == 1 {
// If there is 1 arg, attempt to open it as a file.
fileOrRepo := args[0]
Expand All @@ -65,10 +65,7 @@ func New(args []string) (IterCloser[string], error) {
// Open the in-file for reading
r, err := infile.Open(context.Background(), fileOrRepo)
if err == nil {
return &scannerIter{
c: r,
scanner: bufio.NewScanner(r),
}, nil
return iterator.Lines(r), nil
}
if urlParseFailed || !errWithFilename(err) {
// Only report errors if the file doesn't appear to be a URL, if the
Expand All @@ -78,7 +75,5 @@ func New(args []string) (IterCloser[string], error) {
}
// If file loading failed, or there are 2 or more args, treat args as a list
// of repos.
return &sliceIter[string]{
values: args,
}, nil
return iterator.Slice(args), nil
}
237 changes: 237 additions & 0 deletions go.work.sum

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions internal/iterator/batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package iterator

import "fmt"

type batchIter[T any] struct {
input IterCloser[T]
lastErr error
item []T
batchSize int
}

func (i *batchIter[T]) nextBatch() ([]T, error) {
var batch []T

for i.input.Next() {
item := i.input.Item()
batch = append(batch, item)
if len(batch) >= i.batchSize {
break
}
}
if err := i.input.Err(); err != nil {
// The input iterator failed, so return an error.
return nil, fmt.Errorf("input iter: %w", err)
}
if len(batch) == 0 {
// We've passed the end.
return nil, nil
}
return batch, nil
}

func (i *batchIter[T]) Item() []T {
return i.item
}

func (i *batchIter[T]) Next() bool {
if i.lastErr != nil {
// Stop if we've encountered an error.
return false
}
batch, err := i.nextBatch()
if err != nil {
i.lastErr = err
return false
}
if len(batch) == 0 {
// We are also done at this point.
return false
}
i.item = batch
return true
}

func (i *batchIter[T]) Err() error {
return i.lastErr
}

func (i *batchIter[T]) Close() error {
if err := i.input.Close(); err != nil {
return fmt.Errorf("input close: %w", i.input.Close())
}
return nil
}

func Batch[T any](input IterCloser[T], batchSize int) IterCloser[[]T] {
return &batchIter[T]{
input: input,
batchSize: batchSize,
}
}
137 changes: 137 additions & 0 deletions internal/iterator/batch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package iterator_test

import (
"bytes"
"errors"
"io"
"strings"
"testing"
"testing/iotest"

"golang.org/x/exp/slices"

"github.com/ossf/criticality_score/internal/iterator"
)

func TestBatchIter_Empty(t *testing.T) {
var b bytes.Buffer
i := iterator.Batch(iterator.Lines(io.NopCloser(&b)), 10)

if got := i.Next(); got {
t.Errorf("Next() = %v; want false", got)
}
if err := i.Err(); err != nil {
t.Errorf("Err() = %v; want no error", err)
}
}

func TestBatchIter_SingleLine(t *testing.T) {
want := []string{"test line"}
b := bytes.NewBuffer([]byte(strings.Join(want, "\n")))
i := iterator.Batch(iterator.Lines(io.NopCloser(b)), 10)

if got := i.Next(); !got {
t.Errorf("Next() = %v; want true", got)
}
if err := i.Err(); err != nil {
t.Errorf("Err() = %v; want no error", err)
}
if got := i.Item(); !slices.Equal(got, want) {
t.Errorf("Item() = %v; want %v", got, want)
}
if got := i.Next(); got {
t.Errorf("Next()#2 = %v; want false", got)
}
if err := i.Err(); err != nil {
t.Errorf("Err()#2 = %v; want no error", err)
}
}

func TestBatchIter_MultiLineSingleBatch(t *testing.T) {
want := []string{"line one", "line two", "line three"}
b := bytes.NewBuffer([]byte(strings.Join(want, "\n")))
i := iterator.Batch(iterator.Lines(io.NopCloser(b)), 10)

if got := i.Next(); !got {
t.Errorf("Next() = %v; want true", got)
}
if err := i.Err(); err != nil {
t.Errorf("Err() = %v; want no error", err)
}
if got := i.Item(); !slices.Equal(got, want) {
t.Errorf("Item() = %v; want %v", got, want)
}
if got := i.Next(); got {
t.Errorf("Next()#2 = %v; want false", got)
}
if err := i.Err(); err != nil {
t.Errorf("Err()#2 = %v; want no error", err)
}
}

func TestBatchIter_MultiLineMultiBatch(t *testing.T) {
want1 := []string{"line one", "line two"}
want2 := []string{"line three"}
b := bytes.NewBuffer([]byte(strings.Join(append(want1, want2...), "\n")))
i := iterator.Batch(iterator.Lines(io.NopCloser(b)), 2)

if got := i.Next(); !got {
t.Errorf("Next() = %v; want true", got)
}
if err := i.Err(); err != nil {
t.Errorf("Err() = %v; want no error", err)
}
if got := i.Item(); !slices.Equal(got, want1) {
t.Errorf("Item()#1 = %v; want %v", got, want1)
}
if got := i.Next(); !got {
t.Errorf("Next()#2 = %v; want true", got)
}
if err := i.Err(); err != nil {
t.Errorf("Err()#2 = %v; want no error", err)
}
if got := i.Item(); !slices.Equal(got, want2) {
t.Errorf("Item()#2 = %v; want %v", got, want2)
}
if got := i.Next(); got {
t.Errorf("Next()#3 = %v; want false", got)
}
if err := i.Err(); err != nil {
t.Errorf("Err()#3 = %v; want no error", err)
}
}

func TestBatchIter_Error(t *testing.T) {
want := errors.New("error")
r := iotest.ErrReader(want)
i := iterator.Batch(iterator.Lines(io.NopCloser(r)), 10)

if got := i.Next(); got {
t.Errorf("Next() = %v; want false", got)
}
if err := i.Err(); err == nil || !errors.Is(err, want) {
t.Errorf("Err() = %v; want %v", err, want)
}
}

func TestBatchIter_Close(t *testing.T) {
got := 0
i := iterator.Batch(iterator.Lines(&struct {
closerFn
io.Reader
}{
closerFn: closerFn(func() error {
got++
return nil
}),
Reader: &bytes.Buffer{},
}), 10)
err := i.Close()

if got != 1 {
t.Errorf("Close() called %d times; want 1", got)
}
if err != nil {
t.Errorf("Err() = %v; want no error", err)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package inputiter
package iterator

import (
"bufio"
"io"
)

Expand Down Expand Up @@ -45,51 +44,3 @@ type IterCloser[T any] interface {
Iter[T]
io.Closer
}

// scannerIter implements Iter[string] using a bufio.Scanner to iterate through
// lines in a file.
type scannerIter struct {
c io.Closer
scanner *bufio.Scanner
}

func (i *scannerIter) Item() string {
return i.scanner.Text()
}

func (i *scannerIter) Next() bool {
return i.scanner.Scan()
}

func (i *scannerIter) Err() error {
return i.scanner.Err()
}

func (i *scannerIter) Close() error {
return i.c.Close()
}

// sliceIter implements iter using a slice for iterating.
type sliceIter[T any] struct {
values []T
next int
}

func (i *sliceIter[T]) Item() T {
return i.values[i.next-1]
}

func (i *sliceIter[T]) Next() bool {
if i.next <= len(i.values) {
i.next++
}
return i.next <= len(i.values)
}

func (i *sliceIter[T]) Err() error {
return nil
}

func (i *sliceIter[T]) Close() error {
return nil
}
43 changes: 43 additions & 0 deletions internal/iterator/scanner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package iterator

import (
"bufio"
"fmt"
"io"
)

// scannerIter implements Iter[string] using a bufio.Scanner to iterate through
// lines in a file.
type scannerIter struct {
c io.Closer
scanner *bufio.Scanner
}

func (i *scannerIter) Item() string {
return i.scanner.Text()
}

func (i *scannerIter) Next() bool {
return i.scanner.Scan()
}

func (i *scannerIter) Err() error {
if err := i.scanner.Err(); err != nil {
return fmt.Errorf("scanner: %w", i.scanner.Err())
}
return nil
}

func (i *scannerIter) Close() error {
if err := i.c.Close(); err != nil {
return fmt.Errorf("closer: %w", i.c.Close())
}
return nil
}

func Lines(r io.ReadCloser) IterCloser[string] {
return &scannerIter{
c: r,
scanner: bufio.NewScanner(r),
}
}
Loading

0 comments on commit c870dcd

Please sign in to comment.