Skip to content

Commit

Permalink
add FieldHooks (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvndaai authored Feb 28, 2023
1 parent 2e8b227 commit 2817a2f
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 13 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
on: [push, pull_request]
on:
push:
branches:
- main
pull_request:

name: Test
jobs:
test:
Expand Down
68 changes: 56 additions & 12 deletions ctxerr.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,20 @@ import (
"strings"
)

var global Instance

func init() {
global = NewInstance()
}

// Instance creates a local instance so you can have a different setup than global
type Instance struct {
// CreateHooks are functions that run on creation to set fields on context
CreateHooks []func(ctx context.Context, code string, wrapping error) context.Context
// HandleHooks are functions that run on ctxerr.Handle
HandleHooks []func(error)
// FieldHooks are functions that run on ctxerr.SetField(s)
FieldHooks []func(any) any
// FieldsAsSlice are keys that get gathered as a slice in ctxerr.AllFields
FieldsAsSlice []string
}
Expand All @@ -114,11 +122,11 @@ func NewInstance() Instance {
in.AddCreateHook(SetLocationHook)
// Gather keys like location as slice instead of just the deepest value
in.FieldsAsSlice = []string{FieldKeyLocation}
// No built in hooks
in.FieldHooks = []func(any) any{}
return in
}

var global = NewInstance()

const (
// FieldKeyCode should be unique to the error
FieldKeyCode = "error_code"
Expand Down Expand Up @@ -175,6 +183,16 @@ func (in *Instance) AddHandleHook(f func(error)) {
in.HandleHooks = append(in.HandleHooks, f)
}

// AddFieldHooks adds a hook to be run on handling of an error
func AddFieldHook(f func(any) any) { global.AddFieldHook(f) }
func (in *Instance) AddFieldHook(f func(any) any) {
if in == nil {
// cannot return an error so adding info to panic
panic("cannot call AddFieldHooks because ctxerr.Instance is nil")
}
in.FieldHooks = append(in.FieldHooks, f)
}

// CtxErr is the interface that should be checked in a errors.As function
type CtxErr interface {
error
Expand Down Expand Up @@ -207,7 +225,6 @@ func (in Instance) New(ctx context.Context, code string, message ...interface{})
func Newf(ctx context.Context, code, message string, messageArgs ...interface{}) error {
return global.Newf(ctx, code, message, messageArgs...)
}

func (in Instance) Newf(ctx context.Context, code, message string, messageArgs ...interface{}) error {
for _, hook := range in.CreateHooks {
ctx = hook(ctx, code, nil)
Expand Down Expand Up @@ -289,6 +306,12 @@ func Fields(ctx context.Context) map[string]interface{} {

// SetField adds a field onto the context
func SetField(ctx context.Context, key string, value interface{}) context.Context {
return global.SetField(ctx, key, value)
}
func (in Instance) SetField(ctx context.Context, key string, value interface{}) context.Context {
for _, f := range in.FieldHooks {
value = f(value)
}
f := map[string]interface{}{}
for k, v := range Fields(ctx) {
f[k] = v
Expand All @@ -299,11 +322,17 @@ func SetField(ctx context.Context, key string, value interface{}) context.Contex

// SetFields can add multiple fields onto the context
func SetFields(ctx context.Context, fields map[string]interface{}) context.Context {
return global.SetFields(ctx, fields)
}
func (in Instance) SetFields(ctx context.Context, fields map[string]interface{}) context.Context {
f := map[string]interface{}{}
for k, v := range Fields(ctx) {
f[k] = v
}
for k, v := range fields {
for _, f := range in.FieldHooks {
v = f(v)
}
f[k] = v
}
return context.WithValue(ctx, FieldsKey, f)
Expand Down Expand Up @@ -459,17 +488,26 @@ func (im *impl) WithContext(ctx context.Context) { im.ctx = ctx }

// SetHTTPStatusCode is equivelent to ctxerr.SetField(ctx, FieldKeyStatusCode, code)
func SetHTTPStatusCode(ctx context.Context, code int) context.Context {
return SetField(ctx, FieldKeyStatusCode, code)
return global.SetHTTPStatusCode(ctx, code)
}
func (in Instance) SetHTTPStatusCode(ctx context.Context, code int) context.Context {
return in.SetField(ctx, FieldKeyStatusCode, code)
}

// SetAction is equivelent to ctxerr.SetField(ctx, FieldKeyAction, action)
func SetAction(ctx context.Context, action string) context.Context {
return SetField(ctx, FieldKeyAction, action)
return global.SetAction(ctx, action)
}
func (in Instance) SetAction(ctx context.Context, action string) context.Context {
return in.SetField(ctx, FieldKeyAction, action)
}

// SetCategory is equivelent to ctxerr.SetField(ctx, FieldKeyStatusCode, category)
func SetCategory(ctx context.Context, category interface{}) context.Context {
return SetField(ctx, FieldKeyCategory, category)
return global.SetCategory(ctx, category)
}
func (in Instance) SetCategory(ctx context.Context, category interface{}) context.Context {
return in.SetField(ctx, FieldKeyCategory, category)
}

// ** Hooks ** //
Expand All @@ -489,6 +527,9 @@ func (in Instance) DefaultLogHook(err error) {

// SetCodeHook takes the code and adds it to the context
func SetCodeHook(ctx context.Context, code string, wrapping error) context.Context {
return global.SetCodeHook(ctx, code, wrapping)
}
func (in Instance) SetCodeHook(ctx context.Context, code string, wrapping error) context.Context {
if code != "" {
ctx = SetField(ctx, FieldKeyCode, code)
}
Expand All @@ -497,6 +538,9 @@ func SetCodeHook(ctx context.Context, code string, wrapping error) context.Conte

// SetLocationHook get the location of where the error happened and adds it to the context
func SetLocationHook(ctx context.Context, code string, wrapping error) context.Context {
return global.SetLocationHook(ctx, code, wrapping)
}
func (in Instance) SetLocationHook(ctx context.Context, code string, wrapping error) context.Context {
ctx = SetField(ctx, FieldKeyLocation, CallerFunc(2))
return ctx
}
Expand All @@ -523,10 +567,10 @@ func NewHTTPf(ctx context.Context, code, action string, statusCode int, message
}
func (in Instance) NewHTTPf(ctx context.Context, code, action string, statusCode int, message string, messageArgs ...interface{}) error {
if action != "" {
ctx = SetAction(ctx, action)
ctx = in.SetAction(ctx, action)
}
if statusCode != 0 {
ctx = SetHTTPStatusCode(ctx, statusCode)
ctx = in.SetHTTPStatusCode(ctx, statusCode)
}
return in.Newf(ctx, code, message, messageArgs...)
}
Expand All @@ -537,10 +581,10 @@ func WrapHTTP(ctx context.Context, err error, code, action string, statusCode in
}
func (in Instance) WrapHTTP(ctx context.Context, err error, code, action string, statusCode int, message ...interface{}) error {
if action != "" {
ctx = SetAction(ctx, action)
ctx = in.SetAction(ctx, action)
}
if statusCode != 0 {
ctx = SetHTTPStatusCode(ctx, statusCode)
ctx = in.SetHTTPStatusCode(ctx, statusCode)
}
return in.Wrap(ctx, err, code, message...)
}
Expand All @@ -551,10 +595,10 @@ func WrapHTTPf(ctx context.Context, err error, code, action string, statusCode i
}
func (in Instance) WrapHTTPf(ctx context.Context, err error, code, action string, statusCode int, message string, messageArgs ...interface{}) error {
if action != "" {
ctx = SetAction(ctx, action)
ctx = in.SetAction(ctx, action)
}
if statusCode != 0 {
ctx = SetHTTPStatusCode(ctx, statusCode)
ctx = in.SetHTTPStatusCode(ctx, statusCode)
}
return in.Wrapf(ctx, err, code, message, messageArgs...)
}
85 changes: 85 additions & 0 deletions ctxerr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ func TestNil(t *testing.T) {
var in *ctxerr.Instance
in.AddHandleHook(ctxerr.DefaultLogHook)
}()

func() {
defer func() {
if r := recover(); r != nil {
if !strings.HasSuffix(fmt.Sprint(r), "ctxerr.Instance is nil") {
t.Error("recovered with wrong message:", r)
}
} else {
t.Error("expected to recover")
}
}()
var in *ctxerr.Instance
in.AddFieldHook(func(v any) any { return v })
}()
}

func TestOverall(t *testing.T) {
Expand Down Expand Up @@ -252,6 +266,14 @@ func TestQuickWrap(t *testing.T) {
},
expectedMessage: "ctxerr",
expectedCode: "code",
}, {
name: "ctxerr instance",
err: func(ctx context.Context) error {
in := ctxerr.NewInstance()
return in.QuickWrap(ctx, ctxerr.New(ctx, "code", "ctxerr"))
},
expectedMessage: "ctxerr",
expectedCode: "code",
},
{
name: "triple wrap",
Expand Down Expand Up @@ -841,3 +863,66 @@ func TestFeildsWithNilCtx(t *testing.T) {
t.Error("expected a nil map")
}
}

type redactable string

func (r redactable) Redact() any {
return "redacted"
}

type IRedactable interface {
Redact() any
}

func RedactItem(a any) any {
if v, ok := a.(IRedactable); ok {
return v.Redact()
}
return a
}

func TestFieldHook(t *testing.T) {
in := ctxerr.Instance{}
in.AddFieldHook(RedactItem)

var key = "key"
tests := []struct {
name string
f func(context.Context, any) context.Context
}{
{
name: "in.SetField",
f: func(ctx context.Context, v any) context.Context {
return in.SetField(ctx, key, v)
},
},
{
name: "in.SetFields",
f: func(ctx context.Context, v any) context.Context {
return in.SetFields(ctx, map[string]any{key: v})
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := redactable("hello")

ctx := context.Background()
ctx = tt.f(ctx, r)

v, ok := ctxerr.Fields(ctx)[key]
if !ok {
t.Fatal("key not set")
}

if v != "redacted" {
t.Error("hook did not work", v)
}
})
}
}

func TestGlobalFieldsHook(t *testing.T) {
ctxerr.AddFieldHook(func(a any) any { return a })
}

0 comments on commit 2817a2f

Please sign in to comment.