Skip to content

Commit

Permalink
handle panic
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrre committed Aug 5, 2024
1 parent 4c47133 commit ddffcdb
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 88 deletions.
10 changes: 10 additions & 0 deletions _assertauto/TestConfig/PanicError.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"entries": [
{
"equal": "(string) <panic>: error\n"
},
{
"equal": 0
}
]
}
10 changes: 10 additions & 0 deletions _assertauto/TestConfig/PanicOther.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"entries": [
{
"equal": "(string) <panic>: 123\n"
},
{
"equal": 0
}
]
}
10 changes: 10 additions & 0 deletions _assertauto/TestConfig/PanicString.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"entries": [
{
"equal": "(string) <panic>: string\n"
},
{
"equal": 0
}
]
}
216 changes: 128 additions & 88 deletions pretty.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,18 @@ var (
indentCache = map[string][]byte{}
)

func getIndent(s string, n int) []byte {
indentCacheLock.Lock()
defer indentCacheLock.Unlock()
b := indentCache[s]
l := len(s) * n
if len(b) < l {
b = bytes.Repeat([]byte(s), n)
indentCache[s] = b
}
return b[:l]
}

// WriteIndent writes the indentation to the writer.
func (c *Config) WriteIndent(w io.Writer, st *State) {
if st.Indent <= 0 {
Expand All @@ -134,35 +146,25 @@ func (c *Config) WriteIndent(w io.Writer, st *State) {
_, _ = WriteString(w, c.Indent)
return
}
indentCacheLock.Lock()
b := indentCache[c.Indent]
l := len(c.Indent) * st.Indent
if len(b) < l {
b = bytes.Repeat([]byte(c.Indent), st.Indent)
indentCache[c.Indent] = b
}
indentCacheLock.Unlock()
_, _ = w.Write(b[:l])
_, _ = w.Write(getIndent(c.Indent, st.Indent))
}

func (c *Config) checkRecursion(w io.Writer, st *State, v reflect.Value) bool {
func (c *Config) runRecursion(w io.Writer, st *State, v reflect.Value, f func(st *State)) {
vp := v.Pointer()
if slices.Contains(st.Visited, vp) {
_, _ = WriteString(w, "<recursion>")
return true
return
}
st.Visited = append(st.Visited, vp)
return false
}

func (c *Config) endRecursion(st *State) {
st.Visited = st.Visited[:len(st.Visited)-1]
st.RunVisited(vp, f)
}

// WriteTypeAndValue writes the type and value to the writer.
//
// It writes "(TYPE) VALUE".
func (c *Config) WriteTypeAndValue(w io.Writer, st *State, v reflect.Value) {
defer func() {
c.checkRecover(w, recover())
}()
if !v.IsValid() {
WriteNil(w)
return
Expand All @@ -175,12 +177,28 @@ func (c *Config) WriteTypeAndValue(w io.Writer, st *State, v reflect.Value) {
_, _ = WriteString(w, "<max depth>")
return
}
st.Depth++
_, _ = WriteString(w, "(")
c.WriteType(w, v.Type())
_, _ = WriteString(w, ") ")
c.WriteValue(w, st, v)
st.Depth--
st.RunDepth(func(st *State) {
_, _ = WriteString(w, "(")
c.WriteType(w, v.Type())
_, _ = WriteString(w, ") ")
c.WriteValue(w, st, v)
})
}

func (c *Config) checkRecover(w io.Writer, r any) {
if r == nil {
return
}
_, _ = writeString(w, "<panic>: ")
switch r := r.(type) {
case string:
_, _ = WriteString(w, r)
case error:
_, _ = WriteString(w, r.Error())
default:
_, _ = fmt.Fprint(w, r)
}
_, _ = WriteString(w, "\n")
}

// WriteType writes the type to the writer.
Expand Down Expand Up @@ -321,12 +339,10 @@ func (c *Config) writeFunc(w io.Writer, v reflect.Value) {
}

func (c *Config) writePointer(w io.Writer, st *State, v reflect.Value) {
if c.checkRecursion(w, st, v) {
return
}
WriteArrow(w)
c.WriteTypeAndValue(w, st, v.Elem())
c.endRecursion(st)
c.runRecursion(w, st, v, func(st *State) {
WriteArrow(w)
c.WriteTypeAndValue(w, st, v.Elem())
})
}

func (c *Config) writeUnsafePointer(w io.Writer, v reflect.Value) {
Expand All @@ -343,18 +359,18 @@ func (c *Config) writeArray(w io.Writer, st *State, v reflect.Value) {
}
_, _ = WriteString(w, "{\n")
if v.Len() > 0 {
st.Indent++
for i := range l {
c.WriteIndent(w, st)
c.WriteTypeAndValue(w, st, v.Index(i))
_, _ = WriteString(w, ",\n")
}
if truncated {
c.WriteIndent(w, st)
writeTruncated(w)
_, _ = WriteString(w, "\n")
}
st.Indent--
st.RunIndent(func(st *State) {
for i := range l {
c.WriteIndent(w, st)
c.WriteTypeAndValue(w, st, v.Index(i))
_, _ = WriteString(w, ",\n")
}
if truncated {
c.WriteIndent(w, st)
writeTruncated(w)
_, _ = WriteString(w, "\n")
}
})
}
c.WriteIndent(w, st)
_, _ = WriteString(w, "}")
Expand All @@ -365,38 +381,34 @@ func (c *Config) writeSlice(w io.Writer, st *State, v reflect.Value) {
WriteNil(w)
return
}
if c.checkRecursion(w, st, v) {
return
}
writeLenCapReflect(w, v)
_, _ = WriteString(w, " ")
c.writeArray(w, st, v)
c.endRecursion(st)
c.runRecursion(w, st, v, func(st *State) {
writeLenCapReflect(w, v)
_, _ = WriteString(w, " ")
c.writeArray(w, st, v)
})
}

func (c *Config) writeMap(w io.Writer, st *State, v reflect.Value) {
if v.IsNil() {
WriteNil(w)
return
}
if c.checkRecursion(w, st, v) {
return
}
_, _ = WriteString(w, "(len=")
_, _ = strconvio.WriteInt(w, int64(v.Len()), 10)
_, _ = WriteString(w, ") {\n")
if v.Len() > 0 {
st.Indent++
if c.MapSortKeys {
c.writeMapSorted(w, st, v)
} else {
c.writeMapUnsorted(w, st, v)
c.runRecursion(w, st, v, func(st *State) {
_, _ = WriteString(w, "(len=")
_, _ = strconvio.WriteInt(w, int64(v.Len()), 10)
_, _ = WriteString(w, ") {\n")
if v.Len() > 0 {
st.RunIndent(func(st *State) {
if c.MapSortKeys {
c.writeMapSorted(w, st, v)
} else {
c.writeMapUnsorted(w, st, v)
}
})
}
st.Indent--
}
c.WriteIndent(w, st)
_, _ = WriteString(w, "}")
c.endRecursion(st)
c.WriteIndent(w, st)
_, _ = WriteString(w, "}")
})
}

func (c *Config) writeMapSorted(w io.Writer, st *State, v reflect.Value) {
Expand Down Expand Up @@ -516,19 +528,19 @@ func (c *Config) writeMapEntry(w io.Writer, st *State, key reflect.Value, value

func (c *Config) writeStruct(w io.Writer, st *State, v reflect.Value) {
_, _ = WriteString(w, "{\n")
st.Indent++
fields := getStructFields(v.Type())
for i, field := range fields {
if !c.StructUnexported && !field.IsExported() {
continue
st.RunIndent(func(st *State) {
fields := getStructFields(v.Type())
for i, field := range fields {
if !c.StructUnexported && !field.IsExported() {
continue
}
c.WriteIndent(w, st)
_, _ = WriteString(w, field.Name)
_, _ = WriteString(w, ": ")
c.WriteTypeAndValue(w, st, v.Field(i))
_, _ = WriteString(w, ",\n")
}
c.WriteIndent(w, st)
_, _ = WriteString(w, field.Name)
_, _ = WriteString(w, ": ")
c.WriteTypeAndValue(w, st, v.Field(i))
_, _ = WriteString(w, ",\n")
}
st.Indent--
})
c.WriteIndent(w, st)
_, _ = WriteString(w, "}")
}
Expand Down Expand Up @@ -567,6 +579,34 @@ type State struct {
Visited []uintptr
}

// RunDepth runs the function with increased depth and restores the original depth after.
func (st *State) RunDepth(f func(st *State)) {
st.Depth++
defer func() {
st.Depth--
}()
f(st)
}

// RunIndent runs the function with increased indentation and restores the original indentation after.
func (st *State) RunIndent(f func(st *State)) {
st.Indent++
defer func() {
st.Indent--
}()
f(st)
}

// RunVisited runs the function with the visited pointer and restores the original visited pointers after.
func (st *State) RunVisited(p uintptr, f func(st *State)) {
l := len(st.Visited)
st.Visited = append(st.Visited, p)
defer func() {
st.Visited = st.Visited[:l]
}()
f(st)
}

func (st *State) reset() {
st.Depth = 0
st.Indent = 0
Expand Down Expand Up @@ -689,17 +729,17 @@ func writeBytesCommon(c *Config, w io.Writer, st *State, b []byte, maxLen int) {
truncated = true
}
_, _ = WriteString(w, "\n")
st.Indent++
iw := GetIndentWriter(w, c, st, false)
d := hex.Dumper(iw)
_, _ = d.Write(b)
_ = d.Close()
iw.Release()
if truncated {
c.WriteIndent(w, st)
writeTruncated(w)
}
st.Indent--
st.RunIndent(func(st *State) {
iw := GetIndentWriter(w, c, st, false)
defer iw.Release()
d := hex.Dumper(iw)
_, _ = d.Write(b)
_ = d.Close()
if truncated {
c.WriteIndent(w, st)
writeTruncated(w)
}
})
}

var typeStringer = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
Expand Down
28 changes: 28 additions & 0 deletions pretty_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,34 @@ var testCases = []testCase{
name: "Nil",
value: nil,
},
{
name: "PanicString",
value: "test",
configure: func(c *Config) {
c.ValueWriters = []ValueWriter{func(c *Config, w io.Writer, st *State, v reflect.Value) bool {
panic("string")
}}
},
},
{
name: "PanicError",
value: "test",
configure: func(c *Config) {
err := errors.New("error")
c.ValueWriters = []ValueWriter{func(c *Config, w io.Writer, st *State, v reflect.Value) bool {
panic(err)
}}
},
},
{
name: "PanicOther",
value: "test",
configure: func(c *Config) {
c.ValueWriters = []ValueWriter{func(c *Config, w io.Writer, st *State, v reflect.Value) bool {
panic(123)
}}
},
},
{
name: "Bool",
value: true,
Expand Down

0 comments on commit ddffcdb

Please sign in to comment.