From e18cf8fd06f99797622a0cfe4e8f790bbca96de9 Mon Sep 17 00:00:00 2001 From: "Jonathan A. Sternberg" Date: Tue, 16 Jul 2024 18:56:51 -0500 Subject: [PATCH] Refactor errdefs with an API more similar to the standard library This updates errdefs to have an API similar to the standard library. The `New` method is the equivalent of `fmt.Errorf` and `Join` is the equivalent of `errors.Join`. These methods also support the proper output of stack traces by ensuring the proper handling of the formatting for collapsible errors. Similarly, the `stack` package has been updated to remove the context-based helper and instead directly expose the function that creates the stack error. The stack error is also renamed to `Error` and exposed rather than left unexported. The `stack` package can be used to directly create errors with stacks when one isn't present. It is also possible to add multiple stacks to a single error through `errdefs.Join` and `errdefs.ErrStack` to manually create a stack error. Signed-off-by: Jonathan A. Sternberg --- errors.go | 68 ++++++++++++++ join.go | 74 +++++++++++++++ stack/stack.go | 223 ++++++++++++++------------------------------ stack/stack_test.go | 7 +- 4 files changed, 214 insertions(+), 158 deletions(-) create mode 100644 join.go diff --git a/errors.go b/errors.go index 4827d8c..022628b 100644 --- a/errors.go +++ b/errors.go @@ -26,6 +26,10 @@ package errdefs import ( "context" "errors" + "fmt" + "io" + + "github.com/containerd/errdefs/internal/types" ) // Definitions of common error types used throughout containerd. All containerd @@ -409,3 +413,67 @@ func (c customMessage) As(target any) bool { func (c customMessage) Error() string { return c.msg } + +// errorValue is a general purpose container for errors. +// +// It is constructed through New and is used to ensure +// the proper formatting behavior for the contents of +// the error. +type errorValue struct { + error +} + +func (e errorValue) Format(st fmt.State, verb rune) { + format := fmt.FormatString(st, verb) + fmt.Fprintf(st, format, e.error) + if verb == 'v' && st.Flag('+') { + printStackTraces(st, e.error) + } +} + +func (e errorValue) Unwrap() error { + return e.error +} + +// New constructs a new error with the given format string. +func New(format string, args ...interface{}) error { + return &errorValue{ + error: fmt.Errorf(format, args...), + } +} + +func printStackTraces(w io.Writer, err error) { + // Collect all stack traces from the error. + // Stored in a stack for efficiency and to prevent + // a recursive stack from piling up. + unvisited := []error{err} + for len(unvisited) > 0 { + // Pop the end. + cur := unvisited[len(unvisited)-1] + unvisited = unvisited[:len(unvisited)-1] + + // Print the stack trace if this is one. + if _, ok := cur.(types.CollapsibleError); ok { + fmt.Fprintf(w, "\n%+v", cur) + } + + switch cur := cur.(type) { + case interface{ Unwrap() error }: + if err := cur.Unwrap(); err != nil { + unvisited = append(unvisited, err) + } + case interface{ Unwrap() []error }: + errs := cur.Unwrap() + if len(errs) > 0 { + // Append in the proper order just for + // memory efficiency and then reverse the + // contents since we want to pop the first + // error first. + unvisited = append(unvisited, errs...) + for i, j := len(unvisited)-len(errs), len(unvisited)-1; i < j; i, j = i+1, j-1 { + unvisited[i], unvisited[j] = unvisited[j], unvisited[i] + } + } + } + } +} diff --git a/join.go b/join.go new file mode 100644 index 0000000..b4b0900 --- /dev/null +++ b/join.go @@ -0,0 +1,74 @@ +package errdefs + +import ( + "fmt" + "strings" + + "github.com/containerd/errdefs/internal/types" +) + +type joinError struct { + errs []error +} + +// Join will join the errors together and ensure stack traces +// are appropriately formatted. +func Join(errs ...error) error { + var e error + n := 0 + for _, err := range errs { + if err != nil { + e = err + n++ + } + } + + switch n { + case 0: + return nil + case 1: + switch e.(type) { + case *errorValue, *joinError: + // Don't wrap the types defined by this package + // as that could interfere with the formatting. + return e + } + return &errorValue{e} + } + + joined := make([]error, 0, n) + for _, err := range errs { + if err != nil { + joined = append(joined, err) + } + } + return &joinError{errs: joined} +} + +func (e *joinError) Error() string { + var b strings.Builder + fmt.Fprintf(&b, "%v", e) + return b.String() +} + +func (e *joinError) Format(st fmt.State, verb rune) { + format := fmt.FormatString(st, verb) + collapsed := verb == 'v' && st.Flag('+') + first := true + for _, err := range e.errs { + if !collapsed { + if _, ok := err.(types.CollapsibleError); ok { + continue + } + } + if !first { + fmt.Fprintln(st) + } + fmt.Fprintf(st, format, err) + first = false + } +} + +func (e *joinError) Unwrap() []error { + return e.errs +} diff --git a/stack/stack.go b/stack/stack.go index befbf3c..68df8c7 100644 --- a/stack/stack.go +++ b/stack/stack.go @@ -17,24 +17,22 @@ package stack import ( - "context" "encoding/json" "errors" "fmt" + "io" "os" - "path" "runtime" "strings" "sync/atomic" "unsafe" + "github.com/containerd/errdefs" "github.com/containerd/typeurl/v2" - - "github.com/containerd/errdefs/internal/types" ) func init() { - typeurl.Register((*stack)(nil), "github.com/containerd/errdefs", "stack+json") + typeurl.Register((*Error)(nil), "github.com/containerd/errdefs", "stack+json") } var ( @@ -45,14 +43,13 @@ var ( Revision string = "dirty" ) -type stack struct { +type Error struct { decoded *Trace callers []uintptr helpers []uintptr } -// Trace is a stack trace along with process information about the source type Trace struct { Version string `json:"version,omitempty"` Revision string `json:"revision,omitempty"` @@ -68,20 +65,16 @@ type Frame struct { Line int32 `json:"Line,omitempty"` } -func (f Frame) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - switch { - case s.Flag('+'): - fmt.Fprintf(s, "%s\n\t%s:%d\n", f.Name, f.File, f.Line) - default: - fmt.Fprint(s, f.Name) - } - case 's': - fmt.Fprint(s, path.Base(f.Name)) - case 'q': - fmt.Fprintf(s, "%q", path.Base(f.Name)) - } +func (f Frame) WriteTo(w io.Writer) { + fmt.Fprintf(w, "%s\n\t%s:%d\n", f.Name, f.File, f.Line) +} + +// Callers returns a stack.Error with a customized amount +// of skipped frames. +func Callers(skip int) *Error { + // This function calls two other functions so set the + // default minimum skip to 2. + return callers(skip + 2) } // callers returns the current stack, skipping over the number of frames mentioned @@ -90,23 +83,23 @@ func (f Frame) Format(s fmt.State, verb rune) { // frame[0] runtime.Callers // frame[1] github.com/containerd/errdefs/stack.callers // frame[2] (Use skip=2 to have this be first frame) -func callers(skip int) *stack { +func callers(skip int) *Error { const depth = 32 var pcs [depth]uintptr n := runtime.Callers(skip, pcs[:]) - return &stack{ + return &Error{ callers: pcs[0:n], } } -func (s *stack) getDecoded() *Trace { - if s.decoded == nil { - var unsafeDecoded = (*unsafe.Pointer)(unsafe.Pointer(&s.decoded)) +func (e *Error) getDecoded() *Trace { + if e.decoded == nil { + unsafeDecoded := (*unsafe.Pointer)(unsafe.Pointer(&e.decoded)) var helpers map[string]struct{} - if len(s.helpers) > 0 { + if len(e.helpers) > 0 { helpers = make(map[string]struct{}) - frames := runtime.CallersFrames(s.helpers) + frames := runtime.CallersFrames(e.helpers) for { frame, more := frames.Next() helpers[frame.Function] = struct{}{} @@ -116,9 +109,9 @@ func (s *stack) getDecoded() *Trace { } } - f := make([]Frame, 0, len(s.callers)) - if len(s.callers) > 0 { - frames := runtime.CallersFrames(s.callers) + f := make([]Frame, 0, len(e.callers)) + if len(e.callers) > 0 { + frames := runtime.CallersFrames(e.callers) for { frame, more := frames.Next() if _, ok := helpers[frame.Function]; !ok { @@ -145,20 +138,34 @@ func (s *stack) getDecoded() *Trace { atomic.StorePointer(unsafeDecoded, unsafe.Pointer(&t)) } - return s.decoded + return e.decoded } -func (s *stack) Error() string { - return fmt.Sprintf("%+v", s.getDecoded()) +// Error implements the error interface. This method is rarely +// called because this is a collapsible error and the New/Join +// function will remove this error from non-verbose output. +func (e *Error) Error() string { + return fmt.Sprintf("%+v", e.getDecoded()) } -func (s *stack) MarshalJSON() ([]byte, error) { - return json.Marshal(s.getDecoded()) +func (e *Error) Format(st fmt.State, verb rune) { + if verb == 'v' && st.Flag('+') { + t := e.getDecoded() + fmt.Fprintf(st, "%d %s %s\n", t.Pid, t.Version, strings.Join(t.Cmdline, " ")) + for _, f := range t.Frames { + f.WriteTo(st) + } + fmt.Fprintln(st) + } } -func (s *stack) UnmarshalJSON(b []byte) error { - var unsafeDecoded = (*unsafe.Pointer)(unsafe.Pointer(&s.decoded)) - var t Trace +func (e *Error) MarshalJSON() ([]byte, error) { + return json.Marshal(e.getDecoded()) +} + +func (e *Error) UnmarshalJSON(b []byte) error { + unsafeDecoded := (*unsafe.Pointer)(unsafe.Pointer(&e.decoded)) + var t Error if err := json.Unmarshal(b, &t); err != nil { return err @@ -169,128 +176,38 @@ func (s *stack) UnmarshalJSON(b []byte) error { return nil } -func (s *stack) Format(st fmt.State, verb rune) { - switch verb { - case 'v': - if st.Flag('+') { - t := s.getDecoded() - fmt.Fprintf(st, "%d %s %s\n", t.Pid, t.Version, strings.Join(t.Cmdline, " ")) - for _, f := range t.Frames { - f.Format(st, verb) - } - fmt.Fprintln(st) - return - } - } -} - -func (s *stack) StackTrace() Trace { - return *s.getDecoded() -} - -func (s *stack) CollapseError() {} +func (e *Error) CollapseError() {} -// ErrStack returns a new error for the callers stack, -// this can be wrapped or joined into an existing error. -// NOTE: When joined with errors.Join, the stack -// will show up in the error string output. -// Use with `stack.Join` to force addition of the -// error stack. +// ErrStack is a convenience method for calling Callers +// with the correct skip value for non-helper functions +// directly calling this package. func ErrStack() error { - return callers(3) -} - -// Join adds a stack if there is no stack included to the errors -// and returns a joined error with the stack hidden from the error -// output. The stack error shows up when Unwrapped or formatted -// with `%+v`. -func Join(errs ...error) error { - return joinErrors(nil, errs) -} - -// WithStack will check if the error already has a stack otherwise -// return a new error with the error joined with a stack error -// Any helpers will be skipped. -func WithStack(ctx context.Context, errs ...error) error { - return joinErrors(ctx.Value(helperKey{}), errs) + // Skip the call to Callers and ErrStack. + return Callers(2) } -func joinErrors(helperVal any, errs []error) error { - var filtered []error - var collapsible []error - var hasStack bool - for _, err := range errs { - if err != nil { - if !hasStack && hasLocalStackTrace(err) { - hasStack = true - } - if _, ok := err.(types.CollapsibleError); ok { - collapsible = append(collapsible, err) - } else { - filtered = append(filtered, err) - } - - } - } - if len(filtered) == 0 { - return nil - } - if !hasStack { - s := callers(4) - if helpers, ok := helperVal.([]uintptr); ok { - s.helpers = helpers - } - collapsible = append(collapsible, s) - } - var err error - if len(filtered) > 1 { - err = errors.Join(filtered...) - } else { - err = filtered[0] - } - if len(collapsible) == 0 { - return err +// Errorf creates a new error with the given format and +// arguments and adds a stack trace if one isn't already +// included. +func Errorf(format string, args ...any) error { + err := errdefs.New(format, args...) + if !hasStack(err) { + err = errdefs.Join(err, Callers(2)) } - - return types.CollapsedError(err, collapsible...) + return err } -func hasLocalStackTrace(err error) bool { - switch e := err.(type) { - case *stack: - return true - case interface{ Unwrap() error }: - if hasLocalStackTrace(e.Unwrap()) { - return true - } - case interface{ Unwrap() []error }: - for _, ue := range e.Unwrap() { - if hasLocalStackTrace(ue) { - return true - } - } +// Join joins the errors and adds a stack trace if one +// isn't already present. +func Join(errs ...error) error { + err := errdefs.Join(errs...) + if !hasStack(err) { + err = errdefs.Join(err, Callers(2)) } - - // TODO: Consider if pkg/errors compatibility is needed - // NOTE: This was implemented before the standard error package - // so it may unwrap and have this interface. - //if _, ok := err.(interface{ StackTrace() pkgerrors.StackTrace }); ok { - // return true - //} - - return false + return err } -type helperKey struct{} - -// WithHelper marks the context as from a helper function -// This will add an additional skip to the error stack trace -func WithHelper(ctx context.Context) context.Context { - helpers, _ := ctx.Value(helperKey{}).([]uintptr) - var pcs [1]uintptr - n := runtime.Callers(2, pcs[:]) - if n == 1 { - ctx = context.WithValue(ctx, helperKey{}, append(helpers, pcs[0])) - } - return ctx +func hasStack(err error) bool { + se := &Error{} + return errors.As(err, &se) } diff --git a/stack/stack_test.go b/stack/stack_test.go index 00cb81d..bba53aa 100644 --- a/stack/stack_test.go +++ b/stack/stack_test.go @@ -17,7 +17,6 @@ package stack import ( - "context" "errors" "fmt" "strings" @@ -56,7 +55,6 @@ func TestCollapsed(t *testing.T) { expected := "some error" checkError(Join(errors.New(expected)), expected) checkError(Join(errors.New(expected), ErrStack()), expected) - checkError(WithStack(context.Background(), errors.New(expected)), expected) } func TestHelpers(t *testing.T) { @@ -90,9 +88,8 @@ func TestHelpers(t *testing.T) { func testHelper(msg string, withHelper bool) error { if withHelper { - return WithStack(WithHelper(context.Background()), errors.New(msg)) + return Join(errors.New(msg), Callers(2)) } else { - return WithStack(context.Background(), errors.New(msg)) + return Join(errors.New(msg), ErrStack()) } - }