Skip to content

Commit

Permalink
Fixes (#148)
Browse files Browse the repository at this point in the history
This PR contains various fixes:
* underscore imports are now preserved
* interfaces with embeddeds are now supported
* `ast.Ellipsis` is now handled correctly when used in a closure struct
* `ast.IndexListExpr` is now supported
* generic type params are now translated correctly when they're present
in an anonymous function return value
* `reflect.Type` can now be serialized
  • Loading branch information
chriso authored Jun 17, 2024
2 parents 7013f86 + 4eca87e commit 892d764
Show file tree
Hide file tree
Showing 7 changed files with 619 additions and 9 deletions.
15 changes: 13 additions & 2 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er
c.generateFunctypes(p, gen, colorsByFunc)

// Find all the required imports for this file.
gen = addImports(p, gen)
gen = addImports(p, f, gen)

outputPath := strings.TrimSuffix(p.GoFiles[i], ".go")
outputPath += "_durable.go"
Expand Down Expand Up @@ -400,7 +400,7 @@ func containsColoredFuncLit(decl *ast.FuncDecl, colorsByFunc map[ast.Node]*types
return
}

func addImports(p *packages.Package, gen *ast.File) *ast.File {
func addImports(p *packages.Package, f *ast.File, gen *ast.File) *ast.File {
imports := map[string]string{}

ast.Inspect(gen, func(n ast.Node) bool {
Expand Down Expand Up @@ -438,6 +438,15 @@ func addImports(p *packages.Package, gen *ast.File) *ast.File {
}

importspecs := make([]ast.Spec, 0, len(imports))

// Preserve underscore (side effect) imports.
for _, imp := range f.Imports {
if imp.Name != nil && imp.Name.Name == "_" {
importspecs = append(importspecs, imp)
}
}

// Add imports for all packages used in the file.
for name, path := range imports {
importspecs = append(importspecs, &ast.ImportSpec{
Name: ast.NewIdent(name),
Expand Down Expand Up @@ -526,6 +535,8 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *
Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color),
}

p.TypesInfo.Types[gen] = types.TypeAndValue{Type: p.TypesInfo.TypeOf(fn)}

if !isExpr(gen.Body) {
scope.colors[gen] = color
}
Expand Down
27 changes: 27 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package compiler

import (
"math"
"reflect"
"slices"
"testing"

Expand Down Expand Up @@ -220,6 +222,11 @@ func TestCoroutineYield(t *testing.T) {
coro: func() { StructClosure(3) },
yields: []int{10, 100, 1000, 11, 101, 1000, 12, 102, 1000},
},
{
name: "generic closure capturing receiver and param",
coro: func() { StructGenericClosure(3) },
yields: []int{10, 100, 1000, 11, 101, 1000, 12, 102, 1000},
},
{
name: "generic function",
coro: func() { IdentityGenericInt(11) },
Expand Down Expand Up @@ -255,6 +262,26 @@ func TestCoroutineYield(t *testing.T) {
coro: func() { RangeOverInt(3) },
yields: []int{0, 1, 2},
},

{
name: "reflect type",
coro: func() {
ReflectType(reflect.TypeFor[uint8](), reflect.TypeFor[uint16]())
},
yields: []int{math.MaxUint8, math.MaxUint16},
},

{
name: "ellipsis closure",
coro: func() { EllipsisClosure(3) },
yields: []int{-1, 0, 1, 2},
},

{
name: "interface embedded",
coro: func() { InterfaceEmbedded() },
yields: []int{1, 1, 1},
},
}

// This emulates the installation of function type information by the
Expand Down
7 changes: 6 additions & 1 deletion compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
typeArg = g.typeArgOf
}

signature := copyFunctionType(functionTypeOf(fn))
signature := copyFunctionType(funcTypeWithNamedResults(p, fn))
signature.TypeParams = nil

recv := copyFieldList(functionRecvOf(fn))
Expand Down Expand Up @@ -182,6 +182,11 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
fieldName := ast.NewIdent(fmt.Sprintf("X%d", i))
fieldType := freeVar.typ

// Convert ellipsis into slice (...X => []X).
if e, ok := fieldType.(*ast.Ellipsis); ok {
fieldType = &ast.ArrayType{Elt: e.Elt}
}

// The Go compiler uses a more advanced mechanism to determine if a
// free variable should be captured by pointer or by value: it looks
// at whether the variable is reassigned, its address taken, and if
Expand Down
79 changes: 79 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
package testdata

import (
"math"
"reflect"
"time"
"unsafe"

Expand Down Expand Up @@ -586,6 +588,34 @@ func StructClosure(n int) {
}
}

type GenericBox[T integer] struct {
x T
}

func (b *GenericBox[T]) YieldAndInc() {
coroutine.Yield[T, any](b.x)
b.x++
}

func (b *GenericBox[T]) Closure(y T) func(T) {
return func(z T) {
coroutine.Yield[T, any](b.x)
coroutine.Yield[T, any](y)
coroutine.Yield[T, any](z)
b.x++
y++
z++ // mutation is lost
}
}

func StructGenericClosure(n int) {
box := GenericBox[int]{10}
fn := box.Closure(100)
for i := 0; i < n; i++ {
fn(1000)
}
}

func IdentityGeneric[T any](n T) {
coroutine.Yield[T, any](n)
}
Expand Down Expand Up @@ -662,3 +692,52 @@ func RangeOverInt(n int) {
coroutine.Yield[int, any](i)
}
}

func ReflectType(types ...reflect.Type) {
for _, t := range types {
v := reflect.New(t).Elem()
if !v.CanUint() {
panic("expected uint type")
}
v.SetUint(math.MaxUint64)
coroutine.Yield[int, any](int(v.Uint()))
}
}

func MakeEllipsisClosure(ints ...int) func() {
return func() {
x := ints
for _, v := range x {
coroutine.Yield[int, any](v)
}
}
}

func EllipsisClosure(n int) {
ints := make([]int, n)
for i := range ints {
ints[i] = i
}
c := MakeEllipsisClosure(ints...)
coroutine.Yield[int, any](-1)
c()
}

type innerInterface interface {
Value() int
}

type innerInterfaceImpl int

func (i innerInterfaceImpl) Value() int { return int(i) }

type outerInterface interface {
innerInterface
}

func InterfaceEmbedded() {
var x interface{ outerInterface } = innerInterfaceImpl(1)
coroutine.Yield[int, any](x.Value())
coroutine.Yield[int, any](x.Value())
coroutine.Yield[int, any](x.Value())
}
Loading

0 comments on commit 892d764

Please sign in to comment.