-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generalize the iterator implementation and add a batch iterator (#502)
* Move the iteration code to a generalized library. This iteration code will be used by criticality_score and eventually collect_signals to process input repositories. Signed-off-by: Caleb Brown <[email protected]> * Bump go.work.sum Signed-off-by: Caleb Brown <[email protected]> * Fix broken test Signed-off-by: Caleb Brown <[email protected]> * Fix broken tests. Signed-off-by: Caleb Brown <[email protected]> --------- Signed-off-by: Caleb Brown <[email protected]>
- Loading branch information
1 parent
b3aca0a
commit c870dcd
Showing
10 changed files
with
586 additions
and
139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package iterator | ||
|
||
import "fmt" | ||
|
||
type batchIter[T any] struct { | ||
input IterCloser[T] | ||
lastErr error | ||
item []T | ||
batchSize int | ||
} | ||
|
||
func (i *batchIter[T]) nextBatch() ([]T, error) { | ||
var batch []T | ||
|
||
for i.input.Next() { | ||
item := i.input.Item() | ||
batch = append(batch, item) | ||
if len(batch) >= i.batchSize { | ||
break | ||
} | ||
} | ||
if err := i.input.Err(); err != nil { | ||
// The input iterator failed, so return an error. | ||
return nil, fmt.Errorf("input iter: %w", err) | ||
} | ||
if len(batch) == 0 { | ||
// We've passed the end. | ||
return nil, nil | ||
} | ||
return batch, nil | ||
} | ||
|
||
func (i *batchIter[T]) Item() []T { | ||
return i.item | ||
} | ||
|
||
func (i *batchIter[T]) Next() bool { | ||
if i.lastErr != nil { | ||
// Stop if we've encountered an error. | ||
return false | ||
} | ||
batch, err := i.nextBatch() | ||
if err != nil { | ||
i.lastErr = err | ||
return false | ||
} | ||
if len(batch) == 0 { | ||
// We are also done at this point. | ||
return false | ||
} | ||
i.item = batch | ||
return true | ||
} | ||
|
||
func (i *batchIter[T]) Err() error { | ||
return i.lastErr | ||
} | ||
|
||
func (i *batchIter[T]) Close() error { | ||
if err := i.input.Close(); err != nil { | ||
return fmt.Errorf("input close: %w", i.input.Close()) | ||
} | ||
return nil | ||
} | ||
|
||
func Batch[T any](input IterCloser[T], batchSize int) IterCloser[[]T] { | ||
return &batchIter[T]{ | ||
input: input, | ||
batchSize: batchSize, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
package iterator_test | ||
|
||
import ( | ||
"bytes" | ||
"errors" | ||
"io" | ||
"strings" | ||
"testing" | ||
"testing/iotest" | ||
|
||
"golang.org/x/exp/slices" | ||
|
||
"github.com/ossf/criticality_score/internal/iterator" | ||
) | ||
|
||
func TestBatchIter_Empty(t *testing.T) { | ||
var b bytes.Buffer | ||
i := iterator.Batch(iterator.Lines(io.NopCloser(&b)), 10) | ||
|
||
if got := i.Next(); got { | ||
t.Errorf("Next() = %v; want false", got) | ||
} | ||
if err := i.Err(); err != nil { | ||
t.Errorf("Err() = %v; want no error", err) | ||
} | ||
} | ||
|
||
func TestBatchIter_SingleLine(t *testing.T) { | ||
want := []string{"test line"} | ||
b := bytes.NewBuffer([]byte(strings.Join(want, "\n"))) | ||
i := iterator.Batch(iterator.Lines(io.NopCloser(b)), 10) | ||
|
||
if got := i.Next(); !got { | ||
t.Errorf("Next() = %v; want true", got) | ||
} | ||
if err := i.Err(); err != nil { | ||
t.Errorf("Err() = %v; want no error", err) | ||
} | ||
if got := i.Item(); !slices.Equal(got, want) { | ||
t.Errorf("Item() = %v; want %v", got, want) | ||
} | ||
if got := i.Next(); got { | ||
t.Errorf("Next()#2 = %v; want false", got) | ||
} | ||
if err := i.Err(); err != nil { | ||
t.Errorf("Err()#2 = %v; want no error", err) | ||
} | ||
} | ||
|
||
func TestBatchIter_MultiLineSingleBatch(t *testing.T) { | ||
want := []string{"line one", "line two", "line three"} | ||
b := bytes.NewBuffer([]byte(strings.Join(want, "\n"))) | ||
i := iterator.Batch(iterator.Lines(io.NopCloser(b)), 10) | ||
|
||
if got := i.Next(); !got { | ||
t.Errorf("Next() = %v; want true", got) | ||
} | ||
if err := i.Err(); err != nil { | ||
t.Errorf("Err() = %v; want no error", err) | ||
} | ||
if got := i.Item(); !slices.Equal(got, want) { | ||
t.Errorf("Item() = %v; want %v", got, want) | ||
} | ||
if got := i.Next(); got { | ||
t.Errorf("Next()#2 = %v; want false", got) | ||
} | ||
if err := i.Err(); err != nil { | ||
t.Errorf("Err()#2 = %v; want no error", err) | ||
} | ||
} | ||
|
||
func TestBatchIter_MultiLineMultiBatch(t *testing.T) { | ||
want1 := []string{"line one", "line two"} | ||
want2 := []string{"line three"} | ||
b := bytes.NewBuffer([]byte(strings.Join(append(want1, want2...), "\n"))) | ||
i := iterator.Batch(iterator.Lines(io.NopCloser(b)), 2) | ||
|
||
if got := i.Next(); !got { | ||
t.Errorf("Next() = %v; want true", got) | ||
} | ||
if err := i.Err(); err != nil { | ||
t.Errorf("Err() = %v; want no error", err) | ||
} | ||
if got := i.Item(); !slices.Equal(got, want1) { | ||
t.Errorf("Item()#1 = %v; want %v", got, want1) | ||
} | ||
if got := i.Next(); !got { | ||
t.Errorf("Next()#2 = %v; want true", got) | ||
} | ||
if err := i.Err(); err != nil { | ||
t.Errorf("Err()#2 = %v; want no error", err) | ||
} | ||
if got := i.Item(); !slices.Equal(got, want2) { | ||
t.Errorf("Item()#2 = %v; want %v", got, want2) | ||
} | ||
if got := i.Next(); got { | ||
t.Errorf("Next()#3 = %v; want false", got) | ||
} | ||
if err := i.Err(); err != nil { | ||
t.Errorf("Err()#3 = %v; want no error", err) | ||
} | ||
} | ||
|
||
func TestBatchIter_Error(t *testing.T) { | ||
want := errors.New("error") | ||
r := iotest.ErrReader(want) | ||
i := iterator.Batch(iterator.Lines(io.NopCloser(r)), 10) | ||
|
||
if got := i.Next(); got { | ||
t.Errorf("Next() = %v; want false", got) | ||
} | ||
if err := i.Err(); err == nil || !errors.Is(err, want) { | ||
t.Errorf("Err() = %v; want %v", err, want) | ||
} | ||
} | ||
|
||
func TestBatchIter_Close(t *testing.T) { | ||
got := 0 | ||
i := iterator.Batch(iterator.Lines(&struct { | ||
closerFn | ||
io.Reader | ||
}{ | ||
closerFn: closerFn(func() error { | ||
got++ | ||
return nil | ||
}), | ||
Reader: &bytes.Buffer{}, | ||
}), 10) | ||
err := i.Close() | ||
|
||
if got != 1 { | ||
t.Errorf("Close() called %d times; want 1", got) | ||
} | ||
if err != nil { | ||
t.Errorf("Err() = %v; want no error", err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package iterator | ||
|
||
import ( | ||
"bufio" | ||
"fmt" | ||
"io" | ||
) | ||
|
||
// scannerIter implements Iter[string] using a bufio.Scanner to iterate through | ||
// lines in a file. | ||
type scannerIter struct { | ||
c io.Closer | ||
scanner *bufio.Scanner | ||
} | ||
|
||
func (i *scannerIter) Item() string { | ||
return i.scanner.Text() | ||
} | ||
|
||
func (i *scannerIter) Next() bool { | ||
return i.scanner.Scan() | ||
} | ||
|
||
func (i *scannerIter) Err() error { | ||
if err := i.scanner.Err(); err != nil { | ||
return fmt.Errorf("scanner: %w", i.scanner.Err()) | ||
} | ||
return nil | ||
} | ||
|
||
func (i *scannerIter) Close() error { | ||
if err := i.c.Close(); err != nil { | ||
return fmt.Errorf("closer: %w", i.c.Close()) | ||
} | ||
return nil | ||
} | ||
|
||
func Lines(r io.ReadCloser) IterCloser[string] { | ||
return &scannerIter{ | ||
c: r, | ||
scanner: bufio.NewScanner(r), | ||
} | ||
} |
Oops, something went wrong.