Skip to content

Commit

Permalink
jwk.Parse
Browse files Browse the repository at this point in the history
  • Loading branch information
Daisuke Maki committed Oct 25, 2024
1 parent c5278c6 commit 911467a
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 30 deletions.
1 change: 1 addition & 0 deletions jwk/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ go_library(
"convert.go",
"ecdsa.go",
"ecdsa_gen.go",
"errors.go",
"fetch.go",
"interface.go",
"interface_gen.go",
Expand Down
84 changes: 84 additions & 0 deletions jwk/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package jwk

import (
"errors"
"fmt"
)

var cpe = &continueError{}

// ContinueError returns an opaque error that can be returned
// when a `KeyParser`, `KeyImporter`, or `KeyExporter` cannot handle the given payload,
// but would like the process to continue with the next handler.
func ContinueError() error {
return cpe
}

type continueError struct{}

func (e *continueError) Error() string {
return "continue parsing"
}

// IsContinueError returns true if the given error is a ContinueError.
func IsContinueError(err error) bool {
return errors.Is(err, &continueError{})
}

type importError struct {
error
}

func (e importError) Unwrap() error {
return e.error
}

func (importError) Is(err error) bool {
_, ok := err.(importError)
return ok
}

func importerr(f string, args ...any) error {
return importError{fmt.Errorf(`jwk.Import: `+f, args...)}
}

var errDefaultImportError = importError{errors.New(`import error`)}

func ImportError() error {
return errDefaultImportError
}

type parseError struct {
error
}

func (e parseError) Unwrap() error {
return e.error
}

func (parseError) Is(err error) bool {
_, ok := err.(parseError)
return ok
}

func bparseerr(prefix string, f string, args ...any) error {
return parseError{fmt.Errorf(prefix+`: `+f, args...)}
}

func parseerr(f string, args ...any) error {
return bparseerr(`jwk.Parse`, f, args...)
}

func rparseerr(f string, args ...any) error {
return bparseerr(`jwk.ParseReader`, f, args...)
}

func sparseerr(f string, args ...any) error {
return bparseerr(`jwk.ParseString`, f, args...)
}

var errDefaultParseError = parseError{errors.New(`parse error`)}

func ParseError() error {
return errDefaultParseError
}
48 changes: 18 additions & 30 deletions jwk/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,6 @@ func init() {
}
}

var cpe = &continueError{}

// ContinueError returns an opaque error that can be returned
// when a `KeyParser`, `KeyImporter`, or `KeyExporter` cannot handle the given payload,
// but would like the process to continue with the next handler.
func ContinueError() error {
return cpe
}

type continueError struct{}

func (e *continueError) Error() string {
return "continue parsing"
}

// IsContinueError returns true if the given error is a ContinueError.
func IsContinueError(err error) bool {
return errors.Is(err, &continueError{})
}

// Import creates a jwk.Key from the given key (RSA/ECDSA/symmetric keys).
//
// The constructor auto-detects the type of key to be instantiated
Expand All @@ -76,14 +56,14 @@ func IsContinueError(err error) bool {
// - []byte creates a symmetric key
func Import(raw interface{}) (Key, error) {
if raw == nil {
return nil, fmt.Errorf(`jwk.Import requires a non-nil key`)
return nil, importerr(`a non-nil key is required`)
}

muKeyImporters.RLock()
conv, ok := keyImporters[reflect.TypeOf(raw)]
muKeyImporters.RUnlock()
if !ok {
return nil, fmt.Errorf(`jwk.Import: failed to convert %T to jwk.Key: no converters were able to convert`, raw)
return nil, importerr(`failed to convert %T to jwk.Key: no converters were able to convert`, raw)
}

return conv.Import(raw)
Expand Down Expand Up @@ -328,14 +308,14 @@ func Parse(src []byte, options ...ParseOption) (Set, error) {
for len(src) > 0 {
raw, rest, err := pemDecoder.Decode(src)
if err != nil {
return nil, fmt.Errorf(`failed to parse PEM encoded key: %w`, err)
return nil, parseerr(`failed to parse PEM encoded key: %w`, err)
}
key, err := Import(raw)
if err != nil {
return nil, fmt.Errorf(`failed to create jwk.Key from %T: %w`, raw, err)
return nil, parseerr(`failed to create jwk.Key from %T: %w`, raw, err)
}
if err := s.AddKey(key); err != nil {
return nil, fmt.Errorf(`failed to add jwk.Key to set: %w`, err)
return nil, parseerr(`failed to add jwk.Key to set: %w`, err)
}
src = bytes.TrimSpace(rest)
}
Expand All @@ -345,7 +325,7 @@ func Parse(src []byte, options ...ParseOption) (Set, error) {
if localReg != nil || ignoreParseError {
dcKs, ok := s.(KeyWithDecodeCtx)
if !ok {
return nil, fmt.Errorf(`typed field was requested, but the key set (%T) does not support DecodeCtx`, s)
return nil, parseerr(`typed field was requested, but the key set (%T) does not support DecodeCtx`, s)
}
dc := &setDecodeCtx{
DecodeCtx: json.NewDecodeCtx(localReg),
Expand All @@ -356,7 +336,7 @@ func Parse(src []byte, options ...ParseOption) (Set, error) {
}

if err := json.Unmarshal(src, s); err != nil {
return nil, fmt.Errorf(`failed to unmarshal JWK set: %w`, err)
return nil, parseerr(`failed to unmarshal JWK set: %w`, err)
}

return s, nil
Expand All @@ -368,15 +348,23 @@ func ParseReader(src io.Reader, options ...ParseOption) (Set, error) {
// JWKs except when we encounter an EOF, so just... ReadAll
buf, err := io.ReadAll(src)
if err != nil {
return nil, fmt.Errorf(`failed to read from io.Reader: %w`, err)
return nil, rparseerr(`failed to read from io.Reader: %w`, err)
}

return Parse(buf, options...)
set, err := Parse(buf, options...)
if err != nil {
return nil, rparseerr(`failed to parse reader: %w`, err)
}
return set, nil
}

// ParseString parses a JWK set from the incoming string.
func ParseString(s string, options ...ParseOption) (Set, error) {
return Parse([]byte(s), options...)
set, err := Parse([]byte(s), options...)
if err != nil {
return nil, sparseerr(`failed to parse string: %w`, err)
}
return set, nil
}

// AssignKeyID is a convenience function to automatically assign the "kid"
Expand Down
28 changes: 28 additions & 0 deletions jwk/jwk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1968,3 +1968,31 @@ func TestValidation(t *testing.T) {
require.Error(t, key.Validate(), `key.Validate should fail`)
}
}

func TestParse_fail(t *testing.T) {

Check failure on line 1972 in jwk/jwk_test.go

View workflow job for this annotation

GitHub Actions / lint

TestParse_fail should call t.Parallel on the top level as well as its subtests (tparallel)
t.Run(`malformed json`, func(t *testing.T) {
t.Parallel()
const src = `{blah}`
t.Run("string", func(t *testing.T) {
t.Parallel()
_, err := jwk.ParseString(src)
require.Error(t, err, `jwk.ParseString should fail`)
require.ErrorIs(t, err, jwk.ParseError(), `error should be ParseError`)
require.True(t, strings.HasPrefix(err.Error(), `jwk.ParseString: `))
})
t.Run("[]byte", func(t *testing.T) {
t.Parallel()
_, err := jwk.Parse([]byte(src))
require.Error(t, err, `jwk.Parse should fail`)
require.ErrorIs(t, err, jwk.ParseError(), `error should be ParseError`)
require.True(t, strings.HasPrefix(err.Error(), `jwk.Parse: `))
})
t.Run("io.Reader", func(t *testing.T) {
t.Parallel()
_, err := jwk.ParseReader(strings.NewReader(src))
require.Error(t, err, `jwk.ParseReader should fail`)
require.ErrorIs(t, err, jwk.ParseError(), `error should be ParseError`)
require.True(t, strings.HasPrefix(err.Error(), `jwk.ParseReader: `))
})
})
}

0 comments on commit 911467a

Please sign in to comment.