From 57a023d90eda695d879a300b22c3281e35c8fe76 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Fri, 27 Oct 2023 21:48:32 +0900 Subject: [PATCH] Add to JWK --- jwk/jwk.go | 21 +++++++++++++++++++++ jwk/jwk_test.go | 37 +++++++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/jwk/jwk.go b/jwk/jwk.go index 4cdf589a2..2c97d7d7f 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -734,6 +734,9 @@ func asnEncode(key Key) (string, []byte, error) { } } +type CustomDecoder = json.CustomDecoder +type CustomDecodeFunc = json.CustomDecodeFunc + // RegisterCustomField allows users to specify that a private field // be decoded as an instance of the specified type. This option has // a global effect. @@ -752,6 +755,24 @@ func asnEncode(key Key) (string, []byte, error) { // // var bday time.Time // _ = key.Get(`x-birthday`, &bday) +// +// If you need a more fine-tuned control over the decoding process, +// you can register a `CustomDecoder`. For example, below shows +// how to register a decoder that can parse RFC1123 format string: +// +// jwk.RegisterCustomField(`x-birthday`, jwk.CustomDecodeFunc(func(data []byte) (interface{}, error) { +// return time.Parse(time.RFC1123, string(data)) +// })) +// +// Please note that use of custom fields can be problematic if you +// are using a library that does not implement MarshalJSON/UnmarshalJSON +// and you try to roundtrip from an object to JSON, and then back to an object. +// For example, in the above example, you can _parse_ time values formatted +// in the format specified in RFC822, but when you convert an object into +// JSON, it will be formatted in RFC3339, because that's what `time.Time` +// likes to do. To avoid this, it's always better to use a custom type +// that wraps your desired type (in this case `time.Time`) and implement +// MarshalJSON and UnmashalJSON. func RegisterCustomField(name string, object interface{}) { registry.Register(name, object) } diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index d711d5649..1c35b9a17 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -1393,16 +1393,34 @@ func TestOKP(t *testing.T) { } func TestCustomField(t *testing.T) { + const rfc3339Key = `x-rfc3339-key` + const rfc1123Key = `x-rfc1123-key` + // XXX has global effect!!! - jwk.RegisterCustomField(`x-birthday`, time.Time{}) - defer jwk.RegisterCustomField(`x-birthday`, nil) + jwk.RegisterCustomField(rfc3339Key, time.Time{}) + jwk.RegisterCustomField(rfc1123Key, jwk.CustomDecodeFunc(func(data []byte) (interface{}, error) { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return time.Parse(time.RFC1123, s) + })) + defer jwk.RegisterCustomField(rfc3339Key, nil) + defer jwk.RegisterCustomField(rfc1123Key, nil) expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC) - bdaybytes, _ := expected.MarshalText() // RFC3339 + rfc3339bytes, _ := expected.MarshalText() // RFC3339 + rfc1123bytes := expected.Format(time.RFC1123) var b strings.Builder - b.WriteString(`{"e":"AQAB", "kty":"RSA", "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw","x-birthday":"`) - b.Write(bdaybytes) + b.WriteString(`{"e":"AQAB", "kty":"RSA", "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw","`) + b.WriteString(rfc3339Key) + b.WriteString(`":"`) + b.Write(rfc3339bytes) + b.WriteString(`","`) + b.WriteString(rfc1123Key) + b.WriteString(`":"`) + b.WriteString(rfc1123bytes) b.WriteString(`"}`) src := b.String() @@ -1412,11 +1430,10 @@ func TestCustomField(t *testing.T) { return } - var v interface{} - require.NoError(t, key.Get(`x-birthday`, &v), `key.Get("x-birthday") should succeed`) - - if !assert.Equal(t, expected, v, `values should match`) { - return + for _, name := range []string{rfc3339Key, rfc1123Key} { + var v time.Time + require.NoError(t, key.Get(name, &v), `key.Get(%q) should succeed`, name) + require.Equal(t, expected, v, `values should match`) } }) }