diff --git a/coroutine.go b/coroutine.go index c51a86e..4a27c6a 100644 --- a/coroutine.go +++ b/coroutine.go @@ -104,6 +104,12 @@ func LoadContext[R, S any]() *Context[R, S] { } } -// ErrNotDurable is an error that occurs when attempting to -// serialize a coroutine that is not durable. -var ErrNotDurable = errors.New("only durable coroutines can be serialized") +var ( + // ErrNotDurable is an error that occurs when attempting to + // serialize a coroutine that is not durable. + ErrNotDurable = errors.New("only durable coroutines can be serialized") + + // ErrInvalidState is an error that occurs when attempting to + // deserialize a coroutine that was serialized in another build. + ErrInvalidState = errors.New("durable coroutine was serialized in another build") +) diff --git a/coroutine_durable.go b/coroutine_durable.go index e565e4d..d1cfa44 100644 --- a/coroutine_durable.go +++ b/coroutine_durable.go @@ -3,6 +3,7 @@ package coroutine import ( + "errors" "runtime" "unsafe" @@ -93,7 +94,13 @@ func (c *Context[R, S]) Marshal() ([]byte, error) { // context. func (c *Context[R, S]) Unmarshal(b []byte) (int, error) { start := len(b) - v, b := types.Deserialize(b) + v, b, err := types.Deserialize(b) + if err != nil { + if errors.Is(err, types.ErrBuildIDMismatch) { + return 0, ErrInvalidState + } + return 0, err + } s := v.(*serializedCoroutine) c.entry = s.entry c.Stack = s.stack diff --git a/examples/scrape/go.mod b/examples/scrape/go.mod index d0e001d..d849268 100644 --- a/examples/scrape/go.mod +++ b/examples/scrape/go.mod @@ -3,3 +3,5 @@ module scrape go 1.21.0 require github.com/stealthrocket/coroutine v0.0.0-20230927150141-7c62a3508ce8 + +replace github.com/stealthrocket/coroutine => ../../ diff --git a/examples/scrape/main.go b/examples/scrape/main.go index d72330f..12920ae 100644 --- a/examples/scrape/main.go +++ b/examples/scrape/main.go @@ -28,7 +28,11 @@ func main() { log.Fatal(err) } } else if _, err := coro.Context().Unmarshal(state); err != nil { - log.Fatal(err) + if errors.Is(err, coroutine.ErrInvalidState) { + log.Println("warning: coroutine state is no longer valid. Starting fresh") + } else { + log.Fatal(err) + } } } diff --git a/types/buildid.go b/types/buildid.go new file mode 100644 index 0000000..ea97546 --- /dev/null +++ b/types/buildid.go @@ -0,0 +1,4 @@ +package types + +// buildID is the build identifier for the binary. +var buildID string diff --git a/types/func.go b/types/func.go index 2d08929..7dca674 100644 --- a/types/func.go +++ b/types/func.go @@ -4,6 +4,7 @@ package types import ( "debug/gosym" + "errors" "io" "reflect" "runtime" @@ -150,7 +151,10 @@ func initFunctionTables(pclntab, symtab []byte) { } } -func readAll(r io.ReaderAt, size uint64) ([]byte, error) { +func readSection(r io.ReaderAt, size uint64) ([]byte, error) { + if r == nil { + return nil, errors.New("section missing") + } b := make([]byte, size) n, err := r.ReadAt(b, 0) if err != nil && n < len(b) { diff --git a/types/func_darwin.go b/types/func_darwin.go deleted file mode 100644 index 9016e5f..0000000 --- a/types/func_darwin.go +++ /dev/null @@ -1,27 +0,0 @@ -package types - -import ( - "debug/macho" - "os" -) - -func init() { - f, err := macho.Open(os.Args[0]) - if err != nil { - panic("cannot read Mach-O binary: " + err.Error()) - } - - pclntab := f.Section("__gopclntab") - pclntabData, err := readAll(pclntab, pclntab.Size) - if err != nil { - panic("cannot read pclntab: " + err.Error()) - } - - symtab := f.Section("__gosymtab") - symtabData, err := readAll(symtab, symtab.Size) - if err != nil { - panic("cannot read symtab: " + err.Error()) - } - - initFunctionTables(pclntabData, symtabData) -} diff --git a/types/func_linux.go b/types/func_linux.go deleted file mode 100644 index a0f6759..0000000 --- a/types/func_linux.go +++ /dev/null @@ -1,27 +0,0 @@ -package types - -import ( - "debug/elf" - "os" -) - -func init() { - f, err := elf.Open(os.Args[0]) - if err != nil { - panic("cannot read elf binary: " + err.Error()) - } - - pclntab := f.Section(".gopclntab") - pclntabData, err := readAll(pclntab, pclntab.Size) - if err != nil { - panic("cannot read pclntab: " + err.Error()) - } - - symtab := f.Section(".gosymtab") - symtabData, err := readAll(symtab, symtab.Size) - if err != nil { - panic("cannot read symtab: " + err.Error()) - } - - initFunctionTables(pclntabData, symtabData) -} diff --git a/types/obj_darwin.go b/types/obj_darwin.go new file mode 100644 index 0000000..755164e --- /dev/null +++ b/types/obj_darwin.go @@ -0,0 +1,65 @@ +package types + +import ( + "bytes" + "debug/macho" + "os" + "strconv" +) + +func init() { + f, err := macho.Open(os.Args[0]) + if err != nil { + panic("cannot read Mach-O binary: " + err.Error()) + } + defer f.Close() + + initMachOFunctionTables(f) + initMachOBuildID(f) +} + +func initMachOFunctionTables(f *macho.File) { + pclntab := f.Section("__gopclntab") + pclntabData, err := readSection(pclntab, pclntab.Size) + if err != nil { + panic("cannot read pclntab: " + err.Error()) + } + symtab := f.Section("__gosymtab") + symtabData, err := readSection(symtab, symtab.Size) + if err != nil { + panic("cannot read symtab: " + err.Error()) + } + initFunctionTables(pclntabData, symtabData) +} + +func initMachOBuildID(f *macho.File) { + text := f.Section("__text") + + // Read up to 32KB from the text section. + // See https://github.com/golang/go/blob/3803c858/src/cmd/internal/buildid/note.go#L199 + data, err := readSection(text, min(text.Size, 32*1024)) + if err != nil { + panic("cannot read __text: " + err.Error()) + } + + // From https://github.com/golang/go/blob/3803c858/src/cmd/internal/buildid/buildid.go#L300 + i := bytes.Index(data, buildIDPrefix) + if i < 0 { + panic("build ID not found") + } + j := bytes.Index(data[i+len(buildIDPrefix):], buildIDEnd) + if j < 0 { + panic("build ID not found") + } + quoted := data[i+len(buildIDPrefix)-1 : i+len(buildIDPrefix)+j+1] + id, err := strconv.Unquote(string(quoted)) + if err != nil { + panic("build ID not found") + } + buildID = id +} + +var ( + buildIDPrefix = []byte("\xff Go build ID: \"") + buildIDEnd = []byte("\"\n \xff") +) diff --git a/types/obj_linux.go b/types/obj_linux.go new file mode 100644 index 0000000..29559c7 --- /dev/null +++ b/types/obj_linux.go @@ -0,0 +1,56 @@ +package types + +import ( + "bytes" + "debug/elf" + "os" +) + +func init() { + f, err := elf.Open(os.Args[0]) + if err != nil { + panic("cannot read elf binary: " + err.Error()) + } + defer f.Close() + + initELFFunctionTables(f) + initELFBuildID(f) +} + +func initELFFunctionTables(f *elf.File) { + pclntab := f.Section(".gopclntab") + pclntabData, err := readSection(pclntab, pclntab.Size) + if err != nil { + panic("cannot read pclntab: " + err.Error()) + } + symtab := f.Section(".gosymtab") + symtabData, err := readSection(symtab, symtab.Size) + if err != nil { + panic("cannot read symtab: " + err.Error()) + } + initFunctionTables(pclntabData, symtabData) +} + +func initELFBuildID(f *elf.File) { + noteSection := f.Section(".note.go.buildid") + note, err := readSection(noteSection, noteSection.Size) + if err != nil { + panic("cannot read build ID: " + err.Error()) + } + + // See https://github.com/golang/go/blob/3803c858/src/cmd/internal/buildid/note.go#L135C3-L135C3 + nameSize := f.ByteOrder.Uint32(note) + valSize := f.ByteOrder.Uint32(note[4:]) + tag := f.ByteOrder.Uint32(note[8:]) + nname := note[12:16] + if nameSize == 4 && 16+valSize <= uint32(len(note)) && tag == buildIDTag && bytes.Equal(nname, buildIDNote) { + buildID = string(note[16 : 16+valSize]) + } else { + panic("build ID not found") + } +} + +var ( + buildIDNote = []byte("Go\x00\x00") + buildIDTag = uint32(4) +) diff --git a/types/serde.go b/types/serde.go index a550c1c..3ec13a1 100644 --- a/types/serde.go +++ b/types/serde.go @@ -7,6 +7,7 @@ package types import ( "encoding/binary" + "errors" "fmt" "reflect" "unsafe" @@ -15,6 +16,10 @@ import ( // sID is the unique sID of a pointer or type in the serialized format. type sID int64 +// ErrBuildIDMismatch is an error that occurs when a program attempts +// to deserialize objects from another build. +var ErrBuildIDMismatch = errors.New("build ID mismatch") + // Serialize x. // // The output of Serialize can be reconstructed back to a Go value using @@ -35,14 +40,17 @@ func Serialize(x any) []byte { } // Deserialize value from b. Return left over bytes. -func Deserialize(b []byte) (interface{}, []byte) { - d := newDeserializer(b) +func Deserialize(b []byte) (interface{}, []byte, error) { + d, err := newDeserializer(b) + if err != nil { + return nil, nil, err + } var x interface{} px := &x t := reflect.TypeOf(px).Elem() p := unsafe.Pointer(px) deserializeInterface(d, t, p) - return x, d.b + return x, d.b, nil } type Deserializer struct { @@ -54,11 +62,22 @@ type Deserializer struct { b []byte } -func newDeserializer(b []byte) *Deserializer { +func newDeserializer(b []byte) (*Deserializer, error) { + buildIDLength, n := binary.Varint(b) + if n <= 0 || buildIDLength <= 0 || buildIDLength > int64(len(buildID)) || int64(len(b)-n) < buildIDLength { + return nil, fmt.Errorf("missing or invalid build ID") + } + b = b[n:] + serializedBuildID := string(b[:buildIDLength]) + b = b[buildIDLength:] + if serializedBuildID != buildID { + return nil, fmt.Errorf("%w: got %v, expect %v", ErrBuildIDMismatch, serializedBuildID, buildID) + } + return &Deserializer{ ptrs: make(map[sID]unsafe.Pointer), b: b, - } + }, nil } func (d *Deserializer) readPtr() (unsafe.Pointer, sID) { @@ -123,9 +142,14 @@ type Serializer struct { } func newSerializer() *Serializer { + b := make([]byte, 0, 128) + b = binary.AppendVarint(b, int64(len(buildID))) + b = append(b, buildID...) + return &Serializer{ ptrs: make(map[unsafe.Pointer]sID), scanptrs: make(map[reflect.Value]struct{}), + b: b, } } diff --git a/types/serde_test.go b/types/serde_test.go index 4e08de4..5ac3094 100644 --- a/types/serde_test.go +++ b/types/serde_test.go @@ -40,7 +40,10 @@ func TestSerdeTime(t *testing.T) { func testSerdeTime(t *testing.T, x time.Time) { b := Serialize(x) - out, _ := Deserialize(b) + out, _, err := Deserialize(b) + if err != nil { + t.Fatal(err) + } if !x.Equal(out.(time.Time)) { t.Errorf("expected %v, got %v", x, out) @@ -120,7 +123,10 @@ func TestReflect(t *testing.T) { typ := reflect.TypeOf(x) t.Run(fmt.Sprintf("%d-%s", i, typ), func(t *testing.T) { b := Serialize(x) - out, b := Deserialize(b) + out, b, err := Deserialize(b) + if err != nil { + t.Fatal(err) + } assertEqual(t, x, out) @@ -302,7 +308,10 @@ func TestReflectCustom(t *testing.T) { // unserializable function in CheckRedirect. b := Serialize(x) - out, b := Deserialize(b) + out, b, err := Deserialize(b) + if err != nil { + t.Fatal(err) + } assertEqual(t, x.Timeout, out.(http.Client).Timeout) @@ -615,7 +624,10 @@ func assertRoundTrip[T any](t *testing.T, orig T) T { t.Helper() b := Serialize(orig) - out, b := Deserialize(b) + out, b, err := Deserialize(b) + if err != nil { + t.Fatal(err) + } assertEqual(t, orig, out)