diff --git a/internal/client/error.go b/internal/client/error.go index ade7496f..d4662275 100644 --- a/internal/client/error.go +++ b/internal/client/error.go @@ -23,9 +23,7 @@ import ( "io" "net" "strconv" - "strings" "syscall" - "unicode/utf8" "github.com/edgedb/edgedb-go/internal/buff" ) @@ -73,47 +71,41 @@ const ( lineStart = 0xfff3 ) -type position struct { - lineNo int - byteNo int -} - -func positionFromHeaders(headers map[uint16]string) (position, bool, error) { +func positionFromHeaders(headers map[uint16]string) (*int, *int, error) { lineNoRaw, ok := headers[lineStart] if !ok { - return position{}, false, nil + return nil, nil, nil } byteNoRaw, ok := headers[positionStart] if !ok { - return position{}, false, nil + return nil, nil, nil } lineNo, err := strconv.Atoi(lineNoRaw) if err != nil { - return position{}, false, &binaryProtocolError{ + return nil, nil, &binaryProtocolError{ err: fmt.Errorf("decode lineNo: %q: %w", lineNoRaw, err), } } byteNo, err := strconv.Atoi(byteNoRaw) if err != nil { - return position{}, false, &binaryProtocolError{ + return nil, nil, &binaryProtocolError{ err: fmt.Errorf("decode byteNo: %q: %w", byteNoRaw, err), } } - return position{ - lineNo: lineNo - 1, - byteNo: byteNo, - }, true, nil + return &lineNo, &byteNo, nil } // decodeErrorResponseMsg decodes an error response // https://www.edgedb.com/docs/internals/protocol/messages#errorresponse func decodeErrorResponseMsg(r *buff.Reader, query string) error { r.Discard(1) // severity - code := r.PopUint32() - msg := r.PopString() + w := Warning{ + Code: r.PopUint32(), + Message: r.PopString(), + } n := int(r.PopUint16()) headers := make(map[uint16]string, n) @@ -121,49 +113,14 @@ func decodeErrorResponseMsg(r *buff.Reader, query string) error { headers[r.PopUint16()] = r.PopString() } - pos, ok, err := positionFromHeaders(headers) + var err error + w.Line, w.Start, err = positionFromHeaders(headers) if err != nil { - return err - } - if !ok { - return errorFromCode(code, msg) + return errors.Join(w.Err(query), err) } - hintmsg, ok := headers[hint] - if !ok { - hintmsg = "error" - } - - lines := strings.Split(query, "\n") - if pos.lineNo >= len(lines) { - return errorFromCode(code, msg) - } - - // replace tabs with a single space - // because we don't know how they will be printed. - line := strings.ReplaceAll(lines[pos.lineNo], "\t", " ") - - for i := 0; i < pos.lineNo; i++ { - pos.byteNo -= 1 + len(lines[i]) - } - - if pos.byteNo >= len(line) { - pos.byteNo = 0 - } - - runeCount := utf8.RuneCountInString(line[:pos.byteNo]) - padding := strings.Repeat(" ", runeCount) - - msg += fmt.Sprintf( - "\nquery:%v:%v\n\n%v\n%v^ %v", - 1+pos.lineNo, - 1+runeCount, - line, - padding, - hintmsg, - ) - - return errorFromCode(code, msg) + w.Hint = headers[hint] + return w.Err(query) } type wrappedManyError struct { diff --git a/internal/client/granularflow1pX.go b/internal/client/granularflow1pX.go index 08dd9464..df90e687 100644 --- a/internal/client/granularflow1pX.go +++ b/internal/client/granularflow1pX.go @@ -125,7 +125,7 @@ func (c *protocolConnection) decodeCommandDataDescriptionMsg1pX( r *buff.Reader, q *query, ) (*CommandDescription, error) { - _, err := decodeHeaders1pX(r, q.warningHandler) + _, err := decodeHeaders1pX(r, q.cmd, q.warningHandler) if err != nil { return nil, err } diff --git a/internal/client/granularflow2pX.go b/internal/client/granularflow2pX.go index cb2aae3f..df7e38e6 100644 --- a/internal/client/granularflow2pX.go +++ b/internal/client/granularflow2pX.go @@ -131,7 +131,7 @@ func (c *protocolConnection) decodeCommandDataDescriptionMsg2pX( r *buff.Reader, q *query, ) (*CommandDescriptionV2, error) { - _, err := decodeHeaders2pX(r, q.warningHandler) + _, err := decodeHeaders2pX(r, q.cmd, q.warningHandler) if err != nil { return nil, err } diff --git a/internal/client/scriptflow.go b/internal/client/scriptflow.go index 959f9f38..c392ccf7 100644 --- a/internal/client/scriptflow.go +++ b/internal/client/scriptflow.go @@ -57,6 +57,7 @@ func discardHeaders0pX(r *buff.Reader) { func decodeHeaders1pX( r *buff.Reader, + query string, warningHandler WarningHandler, ) (header.Header1pX, error) { n := int(r.PopUint16()) @@ -75,7 +76,7 @@ func decodeHeaders1pX( errors := make([]error, len(warnings)) for i, warning := range warnings { - errors[i] = errorFromCode(warning.Code, warning.Message) + errors[i] = warning.Err(query) } err = warningHandler(errors) diff --git a/internal/client/warning.go b/internal/client/warning.go index eb3c8f46..f9d09e07 100644 --- a/internal/client/warning.go +++ b/internal/client/warning.go @@ -18,13 +18,62 @@ package edgedb import ( "errors" + "fmt" "log" + "strings" + "unicode/utf8" ) // Warning is used to decode warnings in the protocol. type Warning struct { Code uint32 `json:"code"` Message string `json:"message"` + Hint string `json:"hint,omitempty"` + Line *int `json:"line,omitempty"` + Start *int `json:"start,omitempty"` +} + +func (w *Warning) Err(query string) error { + if w.Line == nil || w.Start == nil { + return errorFromCode(w.Code, w.Message) + } + + lineNo := *w.Line - 1 + byteNo := *w.Start + lines := strings.Split(query, "\n") + if lineNo >= len(lines) { + return errorFromCode(w.Code, w.Message) + } + + // replace tabs with a single space + // because we don't know how they will be printed. + line := strings.ReplaceAll(lines[lineNo], "\t", " ") + + for i := 0; i < lineNo; i++ { + byteNo -= 1 + len(lines[i]) + } + + if byteNo >= len(line) { + byteNo = 0 + } + + hint := w.Hint + if hint == "" { + hint = "error" + } + + runeCount := utf8.RuneCountInString(line[:byteNo]) + padding := strings.Repeat(" ", runeCount) + msg := w.Message + fmt.Sprintf( + "\nquery:%v:%v\n\n%v\n%v^ %v", + 1+lineNo, + 1+runeCount, + line, + padding, + hint, + ) + + return errorFromCode(w.Code, msg) } // LogWarnings is an edgedb.WarningHandler that logs warnings.