Skip to content

Commit

Permalink
Allow registering a constructor for RegisterCustomField (#1006)
Browse files Browse the repository at this point in the history
* s{/v2/}{/v3}g

* Add v3 to workflows

* more s{v2}{v3}g

* a few more v2 -> v3

* tweak for v3

* tweaks for v3

* Change Get() API (#990)

* Chage Get(string) (interface{}, bool) to Get(string, interface{}) error

* fix example code

* run go mod tidy

* more bazel tweaks

* Documentation fixes

* Generate jws docs

* fix go.mod

* [WIP] Make JWK Key Parsing pluggable (#991)

* s{/v2/}{/v3}g

* Make the probe phase generic, and allow multiple parsers

* panic if default probe registration fails

* Fix documentation

* tweak

* tweak go.mod

* rip out ECDSA specific stuff into its own package

* appease linter

* Add missing file

* Reinstate the relevant changes from 81ba77f

* Update bazel files

* Add missing bazel file

* docs and locks around keyParsers

* docs

* Remove x25519, use crypto/ecdh

* Update go versions to use

* Update version

* remove toolchain directive

* Run make tidy + gazelle-update-repos

* Remove iterators  (#999)

* First pass removing iterators from jwk

* Remove iterators from jwe, remove Range

* remove iterate from jws

* Remove iterate from jwt

* Remove remaining iterator from jws

* Remove iterators from jwk

* remove more references to iterators and makePairs

* fix lint

* Fix jwk.Set example

* deterministic token serialization

* remove iterate from cmd

* Rip out iterator library

* do away with context.Context

* appease linter

* Remove ctx from jws

* Add incomplete list of changes

* Fix after rebase

* Allow registering a constructor for RegisterCustomField

* appease linter and add docs

* Add tests in JWE, tweak docs

* Add to JWK

* appease linter

* Fix to use "portable" versions

* Add it to jws

* Add example
  • Loading branch information
lestrrat authored Oct 31, 2023
1 parent cab40bb commit 12bba39
Show file tree
Hide file tree
Showing 10 changed files with 348 additions and 99 deletions.
22 changes: 22 additions & 0 deletions examples/jwt_get_claims_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"time"

"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
)

Expand All @@ -15,6 +16,7 @@ func ExampleJWT_GetClaims() {
Subject(`example`).
Claim(`claim1`, `value1`).
Claim(`claim2`, `2022-05-16T07:35:56+00:00`).
Claim(`claim3`, `{"kty": "oct", "alg":"A128KW", "k":"GawgguFyGrWKav7AX4VKUg"}`).
Build()
if err != nil {
fmt.Printf("failed to build token: %s\n", err)
Expand Down Expand Up @@ -43,6 +45,7 @@ func ExampleJWT_GetClaims() {
var dummy interface{}
_ = tok.Get(`claim1`, &dummy)
_ = tok.Get(`claim2`, &dummy)
_ = tok.Get(`claim3`, &dummy)

// However, it is possible to globally specify that a private
// claim should be parsed into a custom type.
Expand All @@ -62,5 +65,24 @@ func ExampleJWT_GetClaims() {
return
}

// It's also possible to specify a custom decoder for a private claim.
// For example, in the case of `claim3`, it needs to call `jwk.ParseKey`
// which returns an interface that can't be instantiated like the
// `time.Time` value for `claim2`.
jwt.RegisterCustomField(`claim3`, jwt.CustomDecodeFunc(func(data []byte) (interface{}, error) {
return jwk.ParseKey(data)
}))

tok = jwt.New()
if err := json.Unmarshal([]byte(`{"claim3": {"kty": "oct", "alg":"A128KW", "k":"GawgguFyGrWKav7AX4VKUg"}}`), tok); err != nil {
fmt.Printf(`failed to parse token: %s`, err)
return
}
var claim3 jwk.Key
if err := tok.Get(`claim3`, &claim3); err != nil {
fmt.Printf("failed to get private claim \"claim3\": %s\n", err)
return
}

// OUTPUT:
}
56 changes: 47 additions & 9 deletions internal/json/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,80 @@ import (
"sync"
)

// CustomDecoder is the interface we expect from RegisterCustomField in jws, jwe, jwk, and jwt packages.
type CustomDecoder interface {
// Decode takes a JSON encoded byte slice and returns the desired
// decoded value,which will be used as the value for that field
// registered through RegisterCustomField
Decode([]byte) (interface{}, error)
}

// CustomDecodeFunc is a stateless, function-based implementation of CustomDecoder
type CustomDecodeFunc func([]byte) (interface{}, error)

func (fn CustomDecodeFunc) Decode(data []byte) (interface{}, error) {
return fn(data)
}

type objectTypeDecoder struct {
typ reflect.Type
name string
}

func (dec *objectTypeDecoder) Decode(data []byte) (interface{}, error) {
ptr := reflect.New(dec.typ).Interface()
if err := Unmarshal(data, ptr); err != nil {
return nil, fmt.Errorf(`failed to decode field %s: %w`, dec.name, err)
}
return reflect.ValueOf(ptr).Elem().Interface(), nil
}

type Registry struct {
mu *sync.RWMutex
data map[string]reflect.Type
ctrs map[string]CustomDecoder
}

func NewRegistry() *Registry {
return &Registry{
mu: &sync.RWMutex{},
data: make(map[string]reflect.Type),
ctrs: make(map[string]CustomDecoder),
}
}

func (r *Registry) Register(name string, object interface{}) {
if object == nil {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.data, name)
delete(r.ctrs, name)
return
}

typ := reflect.TypeOf(object)
r.mu.Lock()
defer r.mu.Unlock()
r.data[name] = typ
if ctr, ok := object.(CustomDecoder); ok {
r.ctrs[name] = ctr
} else {
r.ctrs[name] = &objectTypeDecoder{
typ: reflect.TypeOf(object),
name: name,
}
}
}

func (r *Registry) Decode(dec *Decoder, name string) (interface{}, error) {
r.mu.RLock()
defer r.mu.RUnlock()

if typ, ok := r.data[name]; ok {
ptr := reflect.New(typ).Interface()
if err := dec.Decode(ptr); err != nil {
if ctr, ok := r.ctrs[name]; ok {
var raw RawMessage
if err := dec.Decode(&raw); err != nil {
return nil, fmt.Errorf(`failed to decode field %s: %w`, name, err)
}
v, err := ctr.Decode([]byte(raw))
if err != nil {
return nil, fmt.Errorf(`failed to decode field %s: %w`, name, err)
}
return reflect.ValueOf(ptr).Elem().Interface(), nil
return v, nil
}

var decoded interface{}
Expand Down
21 changes: 21 additions & 0 deletions jwe/jwe.go
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,9 @@ func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) {
return m, nil
}

type CustomDecoder = json.CustomDecoder
type CustomDecodeFunc = json.CustomDecodeFunc

// RegisterCustomField allows users to specify that a private field
// be decoded as an instance of the specified type. This option has
// a global effect.
Expand All @@ -803,6 +806,24 @@ func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) {
//
// var bday time.Time
// _ = hdr.Get(`x-birthday`, &bday)
//
// If you need a more fine-tuned control over the decoding process,
// you can register a `CustomDecoder`. For example, below shows
// how to register a decoder that can parse RFC1123 format string:
//
// jwe.RegisterCustomField(`x-birthday`, jwe.CustomDecodeFunc(func(data []byte) (interface{}, error) {
// return time.Parse(time.RFC1123, string(data))
// }))
//
// Please note that use of custom fields can be problematic if you
// are using a library that does not implement MarshalJSON/UnmarshalJSON
// and you try to roundtrip from an object to JSON, and then back to an object.
// For example, in the above example, you can _parse_ time values formatted
// in the format specified in RFC822, but when you convert an object into
// JSON, it will be formatted in RFC3339, because that's what `time.Time`
// likes to do. To avoid this, it's always better to use a custom type
// that wraps your desired type (in this case `time.Time`) and implement
// MarshalJSON and UnmashalJSON.
func RegisterCustomField(name string, object interface{}) {
registry.Register(name, object)
}
77 changes: 43 additions & 34 deletions jwe/jwe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,57 +689,66 @@ func TestReadFile(t *testing.T) {

func TestCustomField(t *testing.T) {
// XXX has global effect!!!
jwe.RegisterCustomField(`x-birthday`, time.Time{})
defer jwe.RegisterCustomField(`x-birthday`, nil)
const rfc3339Key = `x-test-rfc3339`
const rfc1123Key = `x-test-rfc1123`
jwe.RegisterCustomField(rfc3339Key, time.Time{})
jwe.RegisterCustomField(rfc1123Key, jwe.CustomDecodeFunc(func(data []byte) (interface{}, error) {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return nil, err
}
return time.Parse(time.RFC1123, s)
}))

defer jwe.RegisterCustomField(rfc3339Key, nil)
defer jwe.RegisterCustomField(rfc1123Key, nil)

expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC)
bdaybytes, _ := expected.MarshalText() // RFC3339
rfc3339bytes, _ := expected.MarshalText() // RFC3339
rfc1123bytes := expected.Format(time.RFC1123)

plaintext := []byte("Hello, World!")
rsakey, err := jwxtest.GenerateRsaJwk()
if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`) {
return
}
pubkey, err := jwk.PublicKeyOf(rsakey)
if !assert.NoError(t, err, `jwk.PublicKeyOf() should succeed`) {
return
}
require.NoError(t, err, `jwxtest.GenerateRsaJwk() should succeed`)

protected := jwe.NewHeaders()
protected.Set(`x-birthday`, string(bdaybytes))
pubkey, err := jwk.PublicKeyOf(rsakey)
require.NoError(t, err, `jwk.PublicKeyOf() should succeed`)

encrypted, err := jwe.Encrypt(plaintext, jwe.WithKey(jwa.RSA_OAEP, pubkey), jwe.WithProtectedHeaders(protected))
if !assert.NoError(t, err, `jwe.Encrypt should succeed`) {
return
}
t.Run("jwe.Parse", func(t *testing.T) {
protected := jwe.NewHeaders()
protected.Set(rfc3339Key, string(rfc3339bytes))
protected.Set(rfc1123Key, rfc1123bytes)

t.Run("jwe.Parse + json.Unmarshal", func(t *testing.T) {
encrypted, err := jwe.Encrypt(plaintext, jwe.WithKey(jwa.RSA_OAEP, pubkey), jwe.WithProtectedHeaders(protected))
require.NoError(t, err, `jwe.Encrypt should succeed`)
msg, err := jwe.Parse(encrypted)
if !assert.NoError(t, err, `jwe.Parse should succeed`) {
t.Logf("%q", encrypted)
return
}

var v time.Time
require.NoError(t, msg.ProtectedHeaders().Get(`x-birthday`, &v), `msg.ProtectedHeaders().Get("x-birthday") should succeed`)
if !assert.Equal(t, expected, v, `values should match`) {
return
}

// Create JSON from jwe.Message
buf, err := json.Marshal(msg)
if !assert.NoError(t, err, `json.Marshal should succeed`) {
return
for _, key := range []string{rfc3339Key, rfc1123Key} {
var v time.Time
require.NoError(t, msg.ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key)
require.Equal(t, expected, v, `values should match`)
}

var msg2 jwe.Message
if !assert.NoError(t, json.Unmarshal(buf, &msg2), `json.Unmarshal should succeed`) {
})
t.Run("json.Unmarshal", func(t *testing.T) {
protected := jwe.NewHeaders()
protected.Set(rfc3339Key, string(rfc3339bytes))
protected.Set(rfc1123Key, rfc1123bytes)

encrypted, err := jwe.Encrypt(plaintext, jwe.WithKey(jwa.RSA_OAEP, pubkey), jwe.WithProtectedHeaders(protected), jwe.WithJSON())
require.NoError(t, err, `jwe.Encrypt should succeed`)
msg := jwe.NewMessage()
if !assert.NoError(t, json.Unmarshal(encrypted, msg), `json.Unmarshal should succeed`) {
return
}

v = time.Time{} // reset
require.NoError(t, msg2.ProtectedHeaders().Get(`x-birthday`, &v), `msg2.ProtectedHeaders().Get("x-birthday") should succeed`)
if !assert.Equal(t, expected, v, `values should match`) {
return
for _, key := range []string{rfc3339Key, rfc1123Key} {
var v time.Time
require.NoError(t, msg.ProtectedHeaders().Get(key, &v), `msg.Get(%q) should succeed`, key)
require.Equal(t, expected, v, `values should match`)
}
})
}
Expand Down
21 changes: 21 additions & 0 deletions jwk/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,9 @@ func asnEncode(key Key) (string, []byte, error) {
}
}

type CustomDecoder = json.CustomDecoder
type CustomDecodeFunc = json.CustomDecodeFunc

// RegisterCustomField allows users to specify that a private field
// be decoded as an instance of the specified type. This option has
// a global effect.
Expand All @@ -752,6 +755,24 @@ func asnEncode(key Key) (string, []byte, error) {
//
// var bday time.Time
// _ = key.Get(`x-birthday`, &bday)
//
// If you need a more fine-tuned control over the decoding process,
// you can register a `CustomDecoder`. For example, below shows
// how to register a decoder that can parse RFC1123 format string:
//
// jwk.RegisterCustomField(`x-birthday`, jwk.CustomDecodeFunc(func(data []byte) (interface{}, error) {
// return time.Parse(time.RFC1123, string(data))
// }))
//
// Please note that use of custom fields can be problematic if you
// are using a library that does not implement MarshalJSON/UnmarshalJSON
// and you try to roundtrip from an object to JSON, and then back to an object.
// For example, in the above example, you can _parse_ time values formatted
// in the format specified in RFC822, but when you convert an object into
// JSON, it will be formatted in RFC3339, because that's what `time.Time`
// likes to do. To avoid this, it's always better to use a custom type
// that wraps your desired type (in this case `time.Time`) and implement
// MarshalJSON and UnmashalJSON.
func RegisterCustomField(name string, object interface{}) {
registry.Register(name, object)
}
Expand Down
37 changes: 27 additions & 10 deletions jwk/jwk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1393,16 +1393,34 @@ func TestOKP(t *testing.T) {
}

func TestCustomField(t *testing.T) {
const rfc3339Key = `x-rfc3339-key`
const rfc1123Key = `x-rfc1123-key`

// XXX has global effect!!!
jwk.RegisterCustomField(`x-birthday`, time.Time{})
defer jwk.RegisterCustomField(`x-birthday`, nil)
jwk.RegisterCustomField(rfc3339Key, time.Time{})
jwk.RegisterCustomField(rfc1123Key, jwk.CustomDecodeFunc(func(data []byte) (interface{}, error) {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return nil, err
}
return time.Parse(time.RFC1123, s)
}))
defer jwk.RegisterCustomField(rfc3339Key, nil)
defer jwk.RegisterCustomField(rfc1123Key, nil)

expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC)
bdaybytes, _ := expected.MarshalText() // RFC3339
rfc3339bytes, _ := expected.MarshalText() // RFC3339
rfc1123bytes := expected.Format(time.RFC1123)

var b strings.Builder
b.WriteString(`{"e":"AQAB", "kty":"RSA", "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw","x-birthday":"`)
b.Write(bdaybytes)
b.WriteString(`{"e":"AQAB", "kty":"RSA", "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw","`)
b.WriteString(rfc3339Key)
b.WriteString(`":"`)
b.Write(rfc3339bytes)
b.WriteString(`","`)
b.WriteString(rfc1123Key)
b.WriteString(`":"`)
b.WriteString(rfc1123bytes)
b.WriteString(`"}`)
src := b.String()

Expand All @@ -1412,11 +1430,10 @@ func TestCustomField(t *testing.T) {
return
}

var v interface{}
require.NoError(t, key.Get(`x-birthday`, &v), `key.Get("x-birthday") should succeed`)

if !assert.Equal(t, expected, v, `values should match`) {
return
for _, name := range []string{rfc3339Key, rfc1123Key} {
var v time.Time
require.NoError(t, key.Get(name, &v), `key.Get(%q) should succeed`, name)
require.Equal(t, expected, v, `values should match`)
}
})
}
Expand Down
Loading

0 comments on commit 12bba39

Please sign in to comment.