Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return values #103

Merged
merged 5 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ func TestCoroutineYield(t *testing.T) {
tests := []struct {
name string
coro func()
coroR func() int
yields []int
result int
skip bool
}{
{
Expand Down Expand Up @@ -198,15 +200,28 @@ func TestCoroutineYield(t *testing.T) {
coro: func() { VarArgs(3) },
yields: []int{0, 1, 2},
},

{
name: "return values",
coroR: func() int { return NestedLoops(3) },
yields: []int{1, 2, 3, 2, 4, 6, 3, 6, 9, 2, 4, 6, 4, 8, 12, 6, 12, 18, 3, 6, 9, 6, 12, 18, 9, 18, 27},
result: 27,
},
}

// This emulates the installation of function type information by the
// compiler because we are not doing codegen for the test files in this
// package.
for _, test := range tests {
a := types.FuncAddr(test.coro)
f := types.FuncByAddr(a)
types.RegisterFunc[func()](f.Name)
if test.coro != nil {
addr := types.FuncAddr(test.coro)
fn := types.FuncByAddr(addr)
types.RegisterFunc[func()](fn.Name)
} else {
addr := types.FuncAddr(test.coroR)
fn := types.FuncByAddr(addr)
types.RegisterFunc[func() int](fn.Name)
}
}

for _, test := range tests {
Expand All @@ -215,7 +230,12 @@ func TestCoroutineYield(t *testing.T) {
t.Skip("test is disabled")
}

g := coroutine.New[int, any](test.coro)
var g coroutine.Coroutine[int, any]
if test.coro != nil {
g = coroutine.New[int, any](test.coro)
} else {
g = coroutine.NewWithReturn[int, any](test.coroR)
}

var yield int
for g.Next() {
Expand Down Expand Up @@ -251,6 +271,11 @@ func TestCoroutineYield(t *testing.T) {
if yield < len(test.yields) {
t.Errorf("coroutine did not yield the correct number of times: got %d, expect %d", yield, len(test.yields))
}
if test.coroR != nil {
if got := g.Result(); got != test.result {
t.Errorf("unexpected coroutine return value: got %v, want %v", got, test.result)
}
}
})
}
}
Expand Down
5 changes: 4 additions & 1 deletion compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ func EvenSquareGenerator(n int) {
}
}

func NestedLoops(n int) {
func NestedLoops(n int) int {
var count int
for i := 1; i <= n; i++ {
for j := 1; j <= n; j++ {
for k := 1; k <= n; k++ {
coroutine.Yield[int, any](i * j * k)
count++
}
}
}
return count
}

func FizzBuzzIfGenerator(n int) {
Expand Down
62 changes: 42 additions & 20 deletions compiler/testdata/coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,20 +167,22 @@ func EvenSquareGenerator(_fn0 int) {
}

//go:noinline
func NestedLoops(_fn0 int) {
func NestedLoops(_fn0 int) (_ int) {
_c := coroutine.LoadContext[int, any]()
var _f0 *struct {
IP int
X0 int
X1 int
X2 int
X3 int
X4 int
} = coroutine.Push[struct {
IP int
X0 int
X1 int
X2 int
X3 int
X4 int
}](&_c.Stack)
if _f0.IP == 0 {
*_f0 = struct {
Expand All @@ -189,6 +191,7 @@ func NestedLoops(_fn0 int) {
X1 int
X2 int
X3 int
X4 int
}{X0: _fn0}
}
defer func() {
Expand All @@ -198,32 +201,51 @@ func NestedLoops(_fn0 int) {
}()
switch {
case _f0.IP < 2:
_f0.X1 = 1
_f0.IP = 2
fallthrough
case _f0.IP < 5:
for ; _f0.X1 <= _f0.X0; _f0.X1, _f0.IP = _f0.X1+1, 2 {
switch {
case _f0.IP < 3:
_f0.X2 = 1
_f0.IP = 3
fallthrough
case _f0.IP < 5:
for ; _f0.X2 <= _f0.X0; _f0.X2, _f0.IP = _f0.X2+1, 3 {
switch {
case _f0.IP < 4:
_f0.X3 = 1
_f0.IP = 4
fallthrough
case _f0.IP < 5:
for ; _f0.X3 <= _f0.X0; _f0.X3, _f0.IP = _f0.X3+1, 4 {
coroutine.Yield[int, any](_f0.X1 * _f0.X2 * _f0.X3)
case _f0.IP < 7:
switch {
case _f0.IP < 3:
_f0.X2 = 1
_f0.IP = 3
fallthrough
case _f0.IP < 7:
for ; _f0.X2 <= _f0.X0; _f0.X2, _f0.IP = _f0.X2+1, 3 {
switch {
case _f0.IP < 4:
_f0.X3 = 1
_f0.IP = 4
fallthrough
case _f0.IP < 7:
for ; _f0.X3 <= _f0.X0; _f0.X3, _f0.IP = _f0.X3+1, 4 {
switch {
case _f0.IP < 5:
_f0.X4 = 1
_f0.IP = 5
fallthrough
case _f0.IP < 7:
for ; _f0.X4 <= _f0.X0; _f0.X4, _f0.IP = _f0.X4+1, 5 {
switch {
case _f0.IP < 6:
coroutine.Yield[int, any](_f0.X2 * _f0.X3 * _f0.X4)
_f0.IP = 6
fallthrough
case _f0.IP < 7:
_f0.X1++
}
}
}
}
}
}
}
_f0.IP = 7
fallthrough
case _f0.IP < 8:

return _f0.X1
}
return
}

//go:noinline
Expand Down Expand Up @@ -3182,7 +3204,7 @@ func init() {
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Identity")
_types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue")
_types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.MethodGenerator")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops")
_types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops")
_types.RegisterFunc[func(_fn0 int, _fn1 func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.Range")
_types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureCapturingPointers")
_types.RegisterClosure[func() (_ bool), struct {
Expand Down
10 changes: 9 additions & 1 deletion coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ func (c Coroutine[R, S]) Recv() R { return c.ctx.recv }
// by the coroutine.
func (c Coroutine[R, S]) Send(v S) { c.ctx.send = v }

// Result is the return value of the coroutine, if it was constructed with
// NewWithReturn. Result should only be called once Next returns false,
// indicating that the coroutine finished executing.
func (c Coroutine[R, S]) Result() R { return c.ctx.result }

// Stop interrupts the coroutine. On the next call to Next, the coroutine will
// not return from its yield point; instead, it unwinds its call stack, calling
// each defer statement in the inverse order that they were declared.
Expand Down Expand Up @@ -53,12 +58,15 @@ type Context[R, S any] struct {
recv R
send S

// Value returned from the coroutine.
result R
chriso marked this conversation as resolved.
Show resolved Hide resolved

// Booleans managing the state of the coroutine.
done bool
stop bool
resume bool //nolint

context
context[R]
}

// Run executes a coroutine to completion, calling f for each value that the
Expand Down
44 changes: 37 additions & 7 deletions coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,22 @@ func New[R, S any](f func()) Coroutine[R, S] {
// the compiler cannot track.
return Coroutine[R, S]{
ctx: &Context[R, S]{
context: context{entry: f},
context: context[R]{entry: f},
},
}
}

// New creates a new coroutine which executes f as entry point.
//
//go:noinline
func NewWithReturn[R, S any](f func() R) Coroutine[R, S] {
// The function has the go:noinline tag because we want to ensure that the
// context will be allocated on the heap. If the context remains allocated
// on the stack it might escape when returned by a call to LoadContext that
// the compiler cannot track.
return Coroutine[R, S]{
ctx: &Context[R, S]{
context: context[R]{entryR: f},
},
}
}
Expand Down Expand Up @@ -74,16 +89,18 @@ func (s *Stack) isTop() bool {
return s.FP == len(s.Frames)-1
}

type serializedCoroutine struct {
type serializedCoroutine[R any] struct {
entry func()
entryR func() R
stack Stack
resume bool
}

// Marshal returns a serialized Context.
func (c *Context[R, S]) Marshal() ([]byte, error) {
return types.Serialize(&serializedCoroutine{
return types.Serialize(&serializedCoroutine[R]{
entry: c.entry,
entryR: c.entryR,
stack: c.Stack,
resume: c.resume,
}), nil
Expand All @@ -101,8 +118,9 @@ func (c *Context[R, S]) Unmarshal(b []byte) (int, error) {
}
return 0, err
}
s := v.(*serializedCoroutine)
s := v.(*serializedCoroutine[R])
c.entry = s.entry
c.entryR = s.entryR
c.Stack = s.stack
c.resume = s.resume
sn := start - len(b)
Expand Down Expand Up @@ -157,17 +175,29 @@ func (c Coroutine[R, S]) Next() (hasNext bool) {
}()

c.ctx.Stack.FP = -1
c.ctx.entry()
if c.ctx.entry != nil {
c.ctx.entry()
} else {
c.ctx.result = c.ctx.entryR()
}
})

return hasNext
}

type context struct {
type context[R any] struct {
// Entry point of the coroutine, this is captured so the associated
// generator can call into the coroutine to start or resume it at the
// last yield point.
entry func()
//
// The raw func (via New) and func returning R (via NewWithReturn)
// are stored separately to work around a limitation with the compiler.
// In volatile mode we only store the latter, and support the former
// by creating a closure that calls the func() and returns the zero
// value R. The compiler does not yet support compiling generic
// functions so the strategy doesn't work in durable mode.
entry func()
entryR func() R
Stack
}

Expand Down
14 changes: 11 additions & 3 deletions coroutine_volatile.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,16 @@ const Durable = false

// New creates a new coroutine which executes f as entry point.
func New[R, S any](f func()) Coroutine[R, S] {
return NewWithReturn[R, S](func() (_ R) {
f()
return
})
}

// New creates a new coroutine which executes f as entry point.
func NewWithReturn[R, S any](f func() R) Coroutine[R, S] {
c := &Context[R, S]{
context: context{
context: context[R]{
next: make(chan struct{}),
},
}
Expand All @@ -30,7 +38,7 @@ func New[R, S any](f func()) Coroutine[R, S] {
<-c.next

if !c.stop {
f()
c.result = f()
}
})
}()
Expand All @@ -51,7 +59,7 @@ func (c Coroutine[R, S]) Next() bool {
return ok
}

type context struct {
type context[R any] struct {
next chan struct{}
}

Expand Down
Loading