From 911467a7a29005b2a99d431e37b43c3e60960431 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 25 Oct 2024 11:44:18 +0900 Subject: [PATCH] jwk.Parse --- jwk/BUILD.bazel | 1 + jwk/errors.go | 84 +++++++++++++++++++++++++++++++++++++++++++++++++ jwk/jwk.go | 48 +++++++++++----------------- jwk/jwk_test.go | 28 +++++++++++++++++ 4 files changed, 131 insertions(+), 30 deletions(-) create mode 100644 jwk/errors.go diff --git a/jwk/BUILD.bazel b/jwk/BUILD.bazel index 2079c5d1..851aa603 100644 --- a/jwk/BUILD.bazel +++ b/jwk/BUILD.bazel @@ -7,6 +7,7 @@ go_library( "convert.go", "ecdsa.go", "ecdsa_gen.go", + "errors.go", "fetch.go", "interface.go", "interface_gen.go", diff --git a/jwk/errors.go b/jwk/errors.go new file mode 100644 index 00000000..0e20b111 --- /dev/null +++ b/jwk/errors.go @@ -0,0 +1,84 @@ +package jwk + +import ( + "errors" + "fmt" +) + +var cpe = &continueError{} + +// ContinueError returns an opaque error that can be returned +// when a `KeyParser`, `KeyImporter`, or `KeyExporter` cannot handle the given payload, +// but would like the process to continue with the next handler. +func ContinueError() error { + return cpe +} + +type continueError struct{} + +func (e *continueError) Error() string { + return "continue parsing" +} + +// IsContinueError returns true if the given error is a ContinueError. +func IsContinueError(err error) bool { + return errors.Is(err, &continueError{}) +} + +type importError struct { + error +} + +func (e importError) Unwrap() error { + return e.error +} + +func (importError) Is(err error) bool { + _, ok := err.(importError) + return ok +} + +func importerr(f string, args ...any) error { + return importError{fmt.Errorf(`jwk.Import: `+f, args...)} +} + +var errDefaultImportError = importError{errors.New(`import error`)} + +func ImportError() error { + return errDefaultImportError +} + +type parseError struct { + error +} + +func (e parseError) Unwrap() error { + return e.error +} + +func (parseError) Is(err error) bool { + _, ok := err.(parseError) + return ok +} + +func bparseerr(prefix string, f string, args ...any) error { + return parseError{fmt.Errorf(prefix+`: `+f, args...)} +} + +func parseerr(f string, args ...any) error { + return bparseerr(`jwk.Parse`, f, args...) +} + +func rparseerr(f string, args ...any) error { + return bparseerr(`jwk.ParseReader`, f, args...) +} + +func sparseerr(f string, args ...any) error { + return bparseerr(`jwk.ParseString`, f, args...) +} + +var errDefaultParseError = parseError{errors.New(`parse error`)} + +func ParseError() error { + return errDefaultParseError +} diff --git a/jwk/jwk.go b/jwk/jwk.go index 14532f26..55e73d3c 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -44,26 +44,6 @@ func init() { } } -var cpe = &continueError{} - -// ContinueError returns an opaque error that can be returned -// when a `KeyParser`, `KeyImporter`, or `KeyExporter` cannot handle the given payload, -// but would like the process to continue with the next handler. -func ContinueError() error { - return cpe -} - -type continueError struct{} - -func (e *continueError) Error() string { - return "continue parsing" -} - -// IsContinueError returns true if the given error is a ContinueError. -func IsContinueError(err error) bool { - return errors.Is(err, &continueError{}) -} - // Import creates a jwk.Key from the given key (RSA/ECDSA/symmetric keys). // // The constructor auto-detects the type of key to be instantiated @@ -76,14 +56,14 @@ func IsContinueError(err error) bool { // - []byte creates a symmetric key func Import(raw interface{}) (Key, error) { if raw == nil { - return nil, fmt.Errorf(`jwk.Import requires a non-nil key`) + return nil, importerr(`a non-nil key is required`) } muKeyImporters.RLock() conv, ok := keyImporters[reflect.TypeOf(raw)] muKeyImporters.RUnlock() if !ok { - return nil, fmt.Errorf(`jwk.Import: failed to convert %T to jwk.Key: no converters were able to convert`, raw) + return nil, importerr(`failed to convert %T to jwk.Key: no converters were able to convert`, raw) } return conv.Import(raw) @@ -328,14 +308,14 @@ func Parse(src []byte, options ...ParseOption) (Set, error) { for len(src) > 0 { raw, rest, err := pemDecoder.Decode(src) if err != nil { - return nil, fmt.Errorf(`failed to parse PEM encoded key: %w`, err) + return nil, parseerr(`failed to parse PEM encoded key: %w`, err) } key, err := Import(raw) if err != nil { - return nil, fmt.Errorf(`failed to create jwk.Key from %T: %w`, raw, err) + return nil, parseerr(`failed to create jwk.Key from %T: %w`, raw, err) } if err := s.AddKey(key); err != nil { - return nil, fmt.Errorf(`failed to add jwk.Key to set: %w`, err) + return nil, parseerr(`failed to add jwk.Key to set: %w`, err) } src = bytes.TrimSpace(rest) } @@ -345,7 +325,7 @@ func Parse(src []byte, options ...ParseOption) (Set, error) { if localReg != nil || ignoreParseError { dcKs, ok := s.(KeyWithDecodeCtx) if !ok { - return nil, fmt.Errorf(`typed field was requested, but the key set (%T) does not support DecodeCtx`, s) + return nil, parseerr(`typed field was requested, but the key set (%T) does not support DecodeCtx`, s) } dc := &setDecodeCtx{ DecodeCtx: json.NewDecodeCtx(localReg), @@ -356,7 +336,7 @@ func Parse(src []byte, options ...ParseOption) (Set, error) { } if err := json.Unmarshal(src, s); err != nil { - return nil, fmt.Errorf(`failed to unmarshal JWK set: %w`, err) + return nil, parseerr(`failed to unmarshal JWK set: %w`, err) } return s, nil @@ -368,15 +348,23 @@ func ParseReader(src io.Reader, options ...ParseOption) (Set, error) { // JWKs except when we encounter an EOF, so just... ReadAll buf, err := io.ReadAll(src) if err != nil { - return nil, fmt.Errorf(`failed to read from io.Reader: %w`, err) + return nil, rparseerr(`failed to read from io.Reader: %w`, err) } - return Parse(buf, options...) + set, err := Parse(buf, options...) + if err != nil { + return nil, rparseerr(`failed to parse reader: %w`, err) + } + return set, nil } // ParseString parses a JWK set from the incoming string. func ParseString(s string, options ...ParseOption) (Set, error) { - return Parse([]byte(s), options...) + set, err := Parse([]byte(s), options...) + if err != nil { + return nil, sparseerr(`failed to parse string: %w`, err) + } + return set, nil } // AssignKeyID is a convenience function to automatically assign the "kid" diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index d4f474a5..b9a48d5c 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -1968,3 +1968,31 @@ func TestValidation(t *testing.T) { require.Error(t, key.Validate(), `key.Validate should fail`) } } + +func TestParse_fail(t *testing.T) { + t.Run(`malformed json`, func(t *testing.T) { + t.Parallel() + const src = `{blah}` + t.Run("string", func(t *testing.T) { + t.Parallel() + _, err := jwk.ParseString(src) + require.Error(t, err, `jwk.ParseString should fail`) + require.ErrorIs(t, err, jwk.ParseError(), `error should be ParseError`) + require.True(t, strings.HasPrefix(err.Error(), `jwk.ParseString: `)) + }) + t.Run("[]byte", func(t *testing.T) { + t.Parallel() + _, err := jwk.Parse([]byte(src)) + require.Error(t, err, `jwk.Parse should fail`) + require.ErrorIs(t, err, jwk.ParseError(), `error should be ParseError`) + require.True(t, strings.HasPrefix(err.Error(), `jwk.Parse: `)) + }) + t.Run("io.Reader", func(t *testing.T) { + t.Parallel() + _, err := jwk.ParseReader(strings.NewReader(src)) + require.Error(t, err, `jwk.ParseReader should fail`) + require.ErrorIs(t, err, jwk.ParseError(), `error should be ParseError`) + require.True(t, strings.HasPrefix(err.Error(), `jwk.ParseReader: `)) + }) + }) +}