diff --git a/bench/performance/go.mod b/bench/performance/go.mod index 824617a2b..ebc7c6118 100644 --- a/bench/performance/go.mod +++ b/bench/performance/go.mod @@ -11,7 +11,6 @@ require ( github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect github.com/lestrrat-go/httprc v1.0.4 // indirect - github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/option v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/segmentio/asm v1.2.0 // indirect diff --git a/bench/performance/go.sum b/bench/performance/go.sum index 7d75a5090..7fb294b48 100644 --- a/bench/performance/go.sum +++ b/bench/performance/go.sum @@ -11,8 +11,6 @@ github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZ github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8= github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= -github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= -github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/cmd/jwx/go.mod b/cmd/jwx/go.mod index 56b575fc1..0c38a1b82 100644 --- a/cmd/jwx/go.mod +++ b/cmd/jwx/go.mod @@ -15,7 +15,6 @@ require ( github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect github.com/lestrrat-go/httprc v1.0.4 // indirect - github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/option v1.0.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/segmentio/asm v1.2.0 // indirect diff --git a/cmd/jwx/go.sum b/cmd/jwx/go.sum index 95f412b8d..5cd43505f 100644 --- a/cmd/jwx/go.sum +++ b/cmd/jwx/go.sum @@ -12,8 +12,6 @@ github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZ github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8= github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= -github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= -github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/deps.bzl b/deps.bzl index e9d82e69d..f39a1df77 100644 --- a/deps.bzl +++ b/deps.bzl @@ -50,13 +50,6 @@ def go_dependencies(): sum = "h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8=", version = "v1.0.4", ) - go_repository( - name = "com_github_lestrrat_go_iter", - build_file_proto_mode = "disable_global", - importpath = "github.com/lestrrat-go/iter", - sum = "h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=", - version = "v1.0.2", - ) go_repository( name = "com_github_lestrrat_go_option", diff --git a/examples/go.mod b/examples/go.mod index 251fb00f0..5d2ebcaf4 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -14,7 +14,6 @@ require ( github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect github.com/lestrrat-go/httprc v1.0.4 // indirect - github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/option v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/segmentio/asm v1.2.0 // indirect diff --git a/examples/go.sum b/examples/go.sum index a6b37ec99..db79a400b 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -13,8 +13,6 @@ github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZ github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8= github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= -github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= -github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/go.mod b/go.mod index f768e4a51..0c26ca64b 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/goccy/go-json v0.10.2 github.com/lestrrat-go/blackmagic v1.0.2 github.com/lestrrat-go/httprc v1.0.4 - github.com/lestrrat-go/iter v1.0.2 github.com/lestrrat-go/option v1.0.1 github.com/segmentio/asm v1.2.0 github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index 7d75a5090..7fb294b48 100644 --- a/go.sum +++ b/go.sum @@ -11,8 +11,6 @@ github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZ github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8= github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= -github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= -github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/internal/iter/BUILD.bazel b/internal/iter/BUILD.bazel deleted file mode 100644 index 0987e4a99..000000000 --- a/internal/iter/BUILD.bazel +++ /dev/null @@ -1,15 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") - -go_library( - name = "iter", - srcs = ["mapiter.go"], - importpath = "github.com/lestrrat-go/jwx/v3/internal/iter", - visibility = ["//:__subpackages__"], - deps = ["@com_github_lestrrat_go_iter//mapiter"], -) - -alias( - name = "go_default_library", - actual = ":iter", - visibility = ["//:__subpackages__"], -) diff --git a/internal/iter/mapiter.go b/internal/iter/mapiter.go deleted file mode 100644 index c98fd46c3..000000000 --- a/internal/iter/mapiter.go +++ /dev/null @@ -1,36 +0,0 @@ -package iter - -import ( - "context" - "fmt" - - "github.com/lestrrat-go/iter/mapiter" -) - -// MapVisitor is a specialized visitor for our purposes. -// Whereas mapiter.Visitor supports any type of key, this -// visitor assumes the key is a string -type MapVisitor interface { - Visit(string, interface{}) error -} - -type MapVisitorFunc func(string, interface{}) error - -func (fn MapVisitorFunc) Visit(s string, v interface{}) error { - return fn(s, v) -} - -func WalkMap(ctx context.Context, src mapiter.Source, visitor MapVisitor) error { - return mapiter.Walk(ctx, src, mapiter.VisitorFunc(func(k, v interface{}) error { - //nolint:forcetypeassert - return visitor.Visit(k.(string), v) - })) -} - -func AsMap(ctx context.Context, src mapiter.Source) (map[string]interface{}, error) { - var m map[string]interface{} - if err := mapiter.AsMap(ctx, src, &m); err != nil { - return nil, fmt.Errorf(`mapiter.AsMap failed: %w`, err) - } - return m, nil -} diff --git a/jwe/BUILD.bazel b/jwe/BUILD.bazel index ebf015a24..d072f4588 100644 --- a/jwe/BUILD.bazel +++ b/jwe/BUILD.bazel @@ -20,7 +20,6 @@ go_library( deps = [ "//cert", "//internal/base64", - "//internal/iter", "//internal/json", "//internal/keyconv", "//internal/pool", @@ -31,7 +30,6 @@ go_library( "//jwe/internal/keygen", "//jwk", "@com_github_lestrrat_go_blackmagic//:blackmagic", - "@com_github_lestrrat_go_iter//mapiter", "@com_github_lestrrat_go_option//:option", "@org_golang_x_crypto//pbkdf2", ], diff --git a/jwe/headers.go b/jwe/headers.go index ff7b23a4d..c4890736b 100644 --- a/jwe/headers.go +++ b/jwe/headers.go @@ -41,11 +41,14 @@ func (h *stdHeaders) Clone(ctx context.Context) (Headers, error) { } func (h *stdHeaders) Copy(_ context.Context, dst Headers) error { - for _, pair := range h.makePairs() { - //nolint:forcetypeassert - key := pair.Key.(string) - if err := dst.Set(key, pair.Value); err != nil { - return fmt.Errorf(`failed to set header %q: %w`, key, err) + for _, key := range h.Keys() { + var v interface{} + if err := h.Get(key, &v); err != nil { + return fmt.Errorf(`jwe.Headers: Copy: failed to get header %q: %w`, key, err) + } + + if err := dst.Set(key, v); err != nil { + return fmt.Errorf(`jwe.Headers: Copy: failed to set header %q: %w`, key, err) } } return nil diff --git a/jwe/headers_gen.go b/jwe/headers_gen.go index 1b909610d..9ed5c4aa2 100644 --- a/jwe/headers_gen.go +++ b/jwe/headers_gen.go @@ -241,64 +241,6 @@ func (h *stdHeaders) X509URL() string { return *(h.x509URL) } -func (h *stdHeaders) makePairs() []*HeaderPair { - h.mu.RLock() - defer h.mu.RUnlock() - var pairs []*HeaderPair - if h.agreementPartyUInfo != nil { - pairs = append(pairs, &HeaderPair{Key: AgreementPartyUInfoKey, Value: h.agreementPartyUInfo}) - } - if h.agreementPartyVInfo != nil { - pairs = append(pairs, &HeaderPair{Key: AgreementPartyVInfoKey, Value: h.agreementPartyVInfo}) - } - if h.algorithm != nil { - pairs = append(pairs, &HeaderPair{Key: AlgorithmKey, Value: *(h.algorithm)}) - } - if h.compression != nil { - pairs = append(pairs, &HeaderPair{Key: CompressionKey, Value: *(h.compression)}) - } - if h.contentEncryption != nil { - pairs = append(pairs, &HeaderPair{Key: ContentEncryptionKey, Value: *(h.contentEncryption)}) - } - if h.contentType != nil { - pairs = append(pairs, &HeaderPair{Key: ContentTypeKey, Value: *(h.contentType)}) - } - if h.critical != nil { - pairs = append(pairs, &HeaderPair{Key: CriticalKey, Value: h.critical}) - } - if h.ephemeralPublicKey != nil { - pairs = append(pairs, &HeaderPair{Key: EphemeralPublicKeyKey, Value: h.ephemeralPublicKey}) - } - if h.jwk != nil { - pairs = append(pairs, &HeaderPair{Key: JWKKey, Value: h.jwk}) - } - if h.jwkSetURL != nil { - pairs = append(pairs, &HeaderPair{Key: JWKSetURLKey, Value: *(h.jwkSetURL)}) - } - if h.keyID != nil { - pairs = append(pairs, &HeaderPair{Key: KeyIDKey, Value: *(h.keyID)}) - } - if h.typ != nil { - pairs = append(pairs, &HeaderPair{Key: TypeKey, Value: *(h.typ)}) - } - if h.x509CertChain != nil { - pairs = append(pairs, &HeaderPair{Key: X509CertChainKey, Value: h.x509CertChain}) - } - if h.x509CertThumbprint != nil { - pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintKey, Value: *(h.x509CertThumbprint)}) - } - if h.x509CertThumbprintS256 != nil { - pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintS256Key, Value: *(h.x509CertThumbprintS256)}) - } - if h.x509URL != nil { - pairs = append(pairs, &HeaderPair{Key: X509URLKey, Value: *(h.x509URL)}) - } - for k, v := range h.privateParams { - pairs = append(pairs, &HeaderPair{Key: k, Value: v}) - } - return pairs -} - func (h *stdHeaders) PrivateParams() map[string]interface{} { h.mu.RLock() defer h.mu.RUnlock() @@ -828,25 +770,91 @@ func (h *stdHeaders) Keys() []string { func (h stdHeaders) MarshalJSON() ([]byte, error) { data := make(map[string]interface{}) - fields := make([]string, 0, 16) - for _, pair := range h.makePairs() { - fields = append(fields, pair.Key.(string)) - data[pair.Key.(string)] = pair.Value + keys := make([]string, 0, 16+len(h.privateParams)) + h.mu.RLock() + if h.agreementPartyUInfo != nil { + data[AgreementPartyUInfoKey] = h.agreementPartyUInfo + keys = append(keys, AgreementPartyUInfoKey) + } + if h.agreementPartyVInfo != nil { + data[AgreementPartyVInfoKey] = h.agreementPartyVInfo + keys = append(keys, AgreementPartyVInfoKey) + } + if h.algorithm != nil { + data[AlgorithmKey] = *(h.algorithm) + keys = append(keys, AlgorithmKey) + } + if h.compression != nil { + data[CompressionKey] = *(h.compression) + keys = append(keys, CompressionKey) + } + if h.contentEncryption != nil { + data[ContentEncryptionKey] = *(h.contentEncryption) + keys = append(keys, ContentEncryptionKey) + } + if h.contentType != nil { + data[ContentTypeKey] = *(h.contentType) + keys = append(keys, ContentTypeKey) + } + if h.critical != nil { + data[CriticalKey] = h.critical + keys = append(keys, CriticalKey) + } + if h.ephemeralPublicKey != nil { + data[EphemeralPublicKeyKey] = h.ephemeralPublicKey + keys = append(keys, EphemeralPublicKeyKey) + } + if h.jwk != nil { + data[JWKKey] = h.jwk + keys = append(keys, JWKKey) + } + if h.jwkSetURL != nil { + data[JWKSetURLKey] = *(h.jwkSetURL) + keys = append(keys, JWKSetURLKey) } + if h.keyID != nil { + data[KeyIDKey] = *(h.keyID) + keys = append(keys, KeyIDKey) + } + if h.typ != nil { + data[TypeKey] = *(h.typ) + keys = append(keys, TypeKey) + } + if h.x509CertChain != nil { + data[X509CertChainKey] = h.x509CertChain + keys = append(keys, X509CertChainKey) + } + if h.x509CertThumbprint != nil { + data[X509CertThumbprintKey] = *(h.x509CertThumbprint) + keys = append(keys, X509CertThumbprintKey) + } + if h.x509CertThumbprintS256 != nil { + data[X509CertThumbprintS256Key] = *(h.x509CertThumbprintS256) + keys = append(keys, X509CertThumbprintS256Key) + } + if h.x509URL != nil { + data[X509URLKey] = *(h.x509URL) + keys = append(keys, X509URLKey) + } + for k, v := range h.privateParams { + data[k] = v + keys = append(keys, k) + } + h.mu.RUnlock() - sort.Strings(fields) + sort.Strings(keys) buf := pool.GetBytesBuffer() defer pool.ReleaseBytesBuffer(buf) - buf.WriteByte('{') enc := json.NewEncoder(buf) - for i, f := range fields { + buf.WriteByte('{') + for i, k := range keys { if i > 0 { buf.WriteRune(',') } buf.WriteRune('"') - buf.WriteString(f) + buf.WriteString(k) buf.WriteString(`":`) - v := data[f] + v := data[k] switch v := v.(type) { case []byte: buf.WriteRune('"') @@ -854,7 +862,7 @@ func (h stdHeaders) MarshalJSON() ([]byte, error) { buf.WriteRune('"') default: if err := enc.Encode(v); err != nil { - return nil, fmt.Errorf(`failed to encode value for field %s`, f) + return nil, fmt.Errorf(`failed to encode value for field %s`, k) } buf.Truncate(buf.Len() - 1) } diff --git a/jwe/interface.go b/jwe/interface.go index 3606808c9..d68742f16 100644 --- a/jwe/interface.go +++ b/jwe/interface.go @@ -1,8 +1,6 @@ package jwe import ( - "github.com/lestrrat-go/iter/mapiter" - "github.com/lestrrat-go/jwx/v3/internal/iter" "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwe/internal/keygen" ) @@ -207,8 +205,3 @@ type Message struct { type populater interface { Populate(keygen.Setter) error } - -type Visitor = iter.MapVisitor -type VisitorFunc = iter.MapVisitorFunc -type HeaderPair = mapiter.Pair -type Iterator = mapiter.Iterator diff --git a/jwk/BUILD.bazel b/jwk/BUILD.bazel index 50058b47a..bb5efa8e2 100644 --- a/jwk/BUILD.bazel +++ b/jwk/BUILD.bazel @@ -32,15 +32,12 @@ go_library( "//cert", "//internal/base64", "//internal/ecutil", - "//internal/iter", "//internal/json", "//internal/pool", "//jwa", "//jwk/ecdsa", "@com_github_lestrrat_go_blackmagic//:blackmagic", "@com_github_lestrrat_go_httprc//:httprc", - "@com_github_lestrrat_go_iter//arrayiter", - "@com_github_lestrrat_go_iter//mapiter", "@com_github_lestrrat_go_option//:option", ], ) diff --git a/jws/BUILD.bazel b/jws/BUILD.bazel index 631a9f315..13dbf9f6b 100644 --- a/jws/BUILD.bazel +++ b/jws/BUILD.bazel @@ -24,14 +24,12 @@ go_library( deps = [ "//cert", "//internal/base64", - "//internal/iter", "//internal/json", "//internal/keyconv", "//internal/pool", "//jwa", "//jwk", "@com_github_lestrrat_go_blackmagic//:blackmagic", - "@com_github_lestrrat_go_iter//mapiter", "@com_github_lestrrat_go_option//:option", ], ) diff --git a/jws/headers_gen.go b/jws/headers_gen.go index 23d65a90a..6cf67d567 100644 --- a/jws/headers_gen.go +++ b/jws/headers_gen.go @@ -216,52 +216,6 @@ func (h *stdHeaders) rawBuffer() []byte { return h.raw } -func (h *stdHeaders) makePairs() []*HeaderPair { - h.mu.RLock() - defer h.mu.RUnlock() - var pairs []*HeaderPair - if h.algorithm != nil { - pairs = append(pairs, &HeaderPair{Key: AlgorithmKey, Value: *(h.algorithm)}) - } - if h.contentType != nil { - pairs = append(pairs, &HeaderPair{Key: ContentTypeKey, Value: *(h.contentType)}) - } - if h.critical != nil { - pairs = append(pairs, &HeaderPair{Key: CriticalKey, Value: h.critical}) - } - if h.jwk != nil { - pairs = append(pairs, &HeaderPair{Key: JWKKey, Value: h.jwk}) - } - if h.jwkSetURL != nil { - pairs = append(pairs, &HeaderPair{Key: JWKSetURLKey, Value: *(h.jwkSetURL)}) - } - if h.keyID != nil { - pairs = append(pairs, &HeaderPair{Key: KeyIDKey, Value: *(h.keyID)}) - } - if h.typ != nil { - pairs = append(pairs, &HeaderPair{Key: TypeKey, Value: *(h.typ)}) - } - if h.x509CertChain != nil { - pairs = append(pairs, &HeaderPair{Key: X509CertChainKey, Value: h.x509CertChain}) - } - if h.x509CertThumbprint != nil { - pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintKey, Value: *(h.x509CertThumbprint)}) - } - if h.x509CertThumbprintS256 != nil { - pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintS256Key, Value: *(h.x509CertThumbprintS256)}) - } - if h.x509URL != nil { - pairs = append(pairs, &HeaderPair{Key: X509URLKey, Value: *(h.x509URL)}) - } - for k, v := range h.privateParams { - pairs = append(pairs, &HeaderPair{Key: k, Value: v}) - } - sort.Slice(pairs, func(i, j int) bool { - return pairs[i].Key.(string) < pairs[j].Key.(string) - }) - return pairs -} - func (h *stdHeaders) PrivateParams() map[string]interface{} { h.mu.RLock() defer h.mu.RUnlock() @@ -660,26 +614,78 @@ func (h *stdHeaders) Keys() []string { } func (h stdHeaders) MarshalJSON() ([]byte, error) { + h.mu.RLock() + data := make(map[string]interface{}) + keys := make([]string, 0, 11+len(h.privateParams)) + if h.algorithm != nil { + data[AlgorithmKey] = *(h.algorithm) + keys = append(keys, AlgorithmKey) + } + if h.contentType != nil { + data[ContentTypeKey] = *(h.contentType) + keys = append(keys, ContentTypeKey) + } + if h.critical != nil { + data[CriticalKey] = h.critical + keys = append(keys, CriticalKey) + } + if h.jwk != nil { + data[JWKKey] = h.jwk + keys = append(keys, JWKKey) + } + if h.jwkSetURL != nil { + data[JWKSetURLKey] = *(h.jwkSetURL) + keys = append(keys, JWKSetURLKey) + } + if h.keyID != nil { + data[KeyIDKey] = *(h.keyID) + keys = append(keys, KeyIDKey) + } + if h.typ != nil { + data[TypeKey] = *(h.typ) + keys = append(keys, TypeKey) + } + if h.x509CertChain != nil { + data[X509CertChainKey] = h.x509CertChain + keys = append(keys, X509CertChainKey) + } + if h.x509CertThumbprint != nil { + data[X509CertThumbprintKey] = *(h.x509CertThumbprint) + keys = append(keys, X509CertThumbprintKey) + } + if h.x509CertThumbprintS256 != nil { + data[X509CertThumbprintS256Key] = *(h.x509CertThumbprintS256) + keys = append(keys, X509CertThumbprintS256Key) + } + if h.x509URL != nil { + data[X509URLKey] = *(h.x509URL) + keys = append(keys, X509URLKey) + } + for k, v := range h.privateParams { + data[k] = v + keys = append(keys, k) + } + h.mu.RUnlock() + sort.Strings(keys) buf := pool.GetBytesBuffer() defer pool.ReleaseBytesBuffer(buf) - buf.WriteByte('{') enc := json.NewEncoder(buf) - for i, p := range h.makePairs() { + buf.WriteByte('{') + for i, k := range keys { if i > 0 { buf.WriteRune(',') } buf.WriteRune('"') - buf.WriteString(p.Key.(string)) + buf.WriteString(k) buf.WriteString(`":`) - v := p.Value - switch v := v.(type) { + switch v := data[k].(type) { case []byte: buf.WriteRune('"') buf.WriteString(base64.EncodeToString(v)) buf.WriteRune('"') default: if err := enc.Encode(v); err != nil { - return nil, fmt.Errorf(`failed to encode value for field %s: %w`, p.Key, err) + return nil, fmt.Errorf(`failed to encode value for field %s: %w`, k, err) } buf.Truncate(buf.Len() - 1) } diff --git a/jws/interface.go b/jws/interface.go index 556d80ed5..1f5292fdb 100644 --- a/jws/interface.go +++ b/jws/interface.go @@ -1,8 +1,6 @@ package jws import ( - "github.com/lestrrat-go/iter/mapiter" - "github.com/lestrrat-go/jwx/v3/internal/iter" "github.com/lestrrat-go/jwx/v3/jwa" ) @@ -64,11 +62,6 @@ type Signature struct { detached bool } -type Visitor = iter.MapVisitor -type VisitorFunc = iter.MapVisitorFunc -type HeaderPair = mapiter.Pair -type Iterator = mapiter.Iterator - // Signer generates the signature for a given payload. type Signer interface { // Sign creates a signature for the given payload. diff --git a/jwt/BUILD.bazel b/jwt/BUILD.bazel index 955aa5375..a7cddae89 100644 --- a/jwt/BUILD.bazel +++ b/jwt/BUILD.bazel @@ -21,7 +21,6 @@ go_library( deps = [ "//:jwx", "//internal/base64", - "//internal/iter", "//internal/json", "//internal/pool", "//jwa", @@ -30,7 +29,6 @@ go_library( "//jws", "//jwt/internal/types", "@com_github_lestrrat_go_blackmagic//:blackmagic", - "@com_github_lestrrat_go_iter//mapiter", "@com_github_lestrrat_go_option//:option", ], ) diff --git a/jwt/openid/BUILD.bazel b/jwt/openid/BUILD.bazel index bbfacef19..51cb1dd9b 100644 --- a/jwt/openid/BUILD.bazel +++ b/jwt/openid/BUILD.bazel @@ -14,13 +14,11 @@ go_library( visibility = ["//visibility:public"], deps = [ "//internal/base64", - "//internal/iter", "//internal/json", "//internal/pool", "//jwt", "//jwt/internal/types", "@com_github_lestrrat_go_blackmagic//:blackmagic", - "@com_github_lestrrat_go_iter//mapiter", ], ) diff --git a/tools/cmd/genjwe/main.go b/tools/cmd/genjwe/main.go index 0f46667fb..55f65b0dd 100644 --- a/tools/cmd/genjwe/main.go +++ b/tools/cmd/genjwe/main.go @@ -164,28 +164,6 @@ func generateHeaders(obj *codegen.Object) error { o.L("}") // func (h *stdHeaders) %s() %s } - // Generate a function that iterates through all of the keys - // in this header. - o.LL("func (h *stdHeaders) makePairs() []*HeaderPair {") - o.L("h.mu.RLock()") - o.L("defer h.mu.RUnlock()") - // NOTE: building up an array is *slow*? - o.L("var pairs []*HeaderPair") - for _, f := range obj.Fields() { - o.L("if h.%s != nil {", f.Name(false)) - if fieldStorageTypeIsIndirect(f.Type()) { - o.L("pairs = append(pairs, &HeaderPair{Key: %sKey, Value: *(h.%s)})", f.Name(true), f.Name(false)) - } else { - o.L("pairs = append(pairs, &HeaderPair{Key: %sKey, Value: h.%s})", f.Name(true), f.Name(false)) - } - o.L("}") - } - o.L("for k, v := range h.privateParams {") - o.L("pairs = append(pairs, &HeaderPair{Key: k, Value: v})") - o.L("}") - o.L("return pairs") - o.L("}") // end of (h *stdHeaders) iterate(...) - o.LL("func (h *stdHeaders) PrivateParams() map[string]interface{} {") o.L("h.mu.RLock()") o.L("defer h.mu.RUnlock()") @@ -396,24 +374,36 @@ func generateHeaders(obj *codegen.Object) error { o.LL("func (h stdHeaders) MarshalJSON() ([]byte, error) {") o.L("data := make(map[string]interface{})") - o.L("fields := make([]string, 0, %d)", len(obj.Fields())) - o.L("for _, pair := range h.makePairs() {") - o.L("fields = append(fields, pair.Key.(string))") - o.L("data[pair.Key.(string)] = pair.Value") + o.L("keys := make([]string, 0, %d+len(h.privateParams))", len(obj.Fields())) + o.L("h.mu.RLock()") + for _, f := range obj.Fields() { + o.L("if h.%s != nil {", f.Name(false)) + if fieldStorageTypeIsIndirect(f.Type()) { + o.L("data[%sKey] = *(h.%s)", f.Name(true), f.Name(false)) + } else { + o.L("data[%sKey] = h.%s", f.Name(true), f.Name(false)) + } + o.L("keys = append(keys, %sKey)", f.Name(true)) + o.L("}") + } + o.L("for k, v := range h.privateParams {") + o.L("data[k] = v") + o.L("keys = append(keys, k)") o.L("}") - o.LL("sort.Strings(fields)") + o.L("h.mu.RUnlock()") + o.LL("sort.Strings(keys)") o.L("buf := pool.GetBytesBuffer()") o.L("defer pool.ReleaseBytesBuffer(buf)") - o.L("buf.WriteByte('{')") o.L("enc := json.NewEncoder(buf)") - o.L("for i, f := range fields {") + o.L("buf.WriteByte('{')") + o.L("for i, k := range keys {") o.L("if i > 0 {") o.L("buf.WriteRune(',')") o.L("}") o.L("buf.WriteRune('\"')") - o.L("buf.WriteString(f)") + o.L("buf.WriteString(k)") o.L("buf.WriteString(`\":`)") - o.L("v := data[f]") + o.L("v := data[k]") o.L("switch v := v.(type) {") o.L("case []byte:") o.L("buf.WriteRune('\"')") @@ -421,7 +411,7 @@ func generateHeaders(obj *codegen.Object) error { o.L("buf.WriteRune('\"')") o.L("default:") o.L("if err := enc.Encode(v); err != nil {") - o.L("return nil, fmt.Errorf(`failed to encode value for field %%s`, f)") + o.L("return nil, fmt.Errorf(`failed to encode value for field %%s`, k)") o.L("}") o.L("buf.Truncate(buf.Len()-1)") o.L("}") diff --git a/tools/cmd/genjws/main.go b/tools/cmd/genjws/main.go index 13600e00c..c06ef0b9c 100644 --- a/tools/cmd/genjws/main.go +++ b/tools/cmd/genjws/main.go @@ -185,31 +185,6 @@ func generateHeaders(obj *codegen.Object) error { o.L("return h.raw") o.L("}") - // Generate a function that iterates through all of the keys - // in this header. - o.LL("func (h *stdHeaders) makePairs() []*HeaderPair {") - o.L("h.mu.RLock()") - o.L("defer h.mu.RUnlock()") - // NOTE: building up an array is *slow*? - o.L("var pairs []*HeaderPair") - for _, f := range obj.Fields() { - o.L("if h.%s != nil {", f.Name(false)) - if fieldStorageTypeIsIndirect(f.Type()) { - o.L("pairs = append(pairs, &HeaderPair{Key: %sKey, Value: *(h.%s)})", f.Name(true), f.Name(false)) - } else { - o.L("pairs = append(pairs, &HeaderPair{Key: %sKey, Value: h.%s})", f.Name(true), f.Name(false)) - } - o.L("}") - } - o.L("for k, v := range h.privateParams {") - o.L("pairs = append(pairs, &HeaderPair{Key: k, Value: v})") - o.L("}") - o.L("sort.Slice(pairs, func(i, j int) bool {") - o.L("return pairs[i].Key.(string) < pairs[j].Key.(string)") - o.L("})") - o.L("return pairs") - o.L("}") // end of (h *stdHeaders) iterate(...) - o.LL("func (h *stdHeaders) PrivateParams() map[string]interface{} {") o.L("h.mu.RLock()") o.L("defer h.mu.RUnlock()") @@ -414,26 +389,46 @@ func generateHeaders(obj *codegen.Object) error { o.L("}") o.LL("func (h stdHeaders) MarshalJSON() ([]byte, error) {") + o.L("h.mu.RLock()") + o.L("data := make(map[string]interface{})") + o.L("keys := make([]string, 0, %d+len(h.privateParams))", len(obj.Fields())) + for _, f := range obj.Fields() { + o.L("if h.%s != nil {", f.Name(false)) + if fieldStorageTypeIsIndirect(f.Type()) { + o.L("data[%sKey] = *(h.%s)", f.Name(true), f.Name(false)) + } else { + o.L("data[%sKey] = h.%s", f.Name(true), f.Name(false)) + } + o.L("keys = append(keys, %sKey)", f.Name(true)) + o.L("}") + } + o.L("for k, v := range h.privateParams {") + o.L("data[k] = v") + o.L("keys = append(keys, k)") + o.L("}") + o.L("h.mu.RUnlock()") + o.L("sort.Strings(keys)") + o.L("buf := pool.GetBytesBuffer()") o.L("defer pool.ReleaseBytesBuffer(buf)") - o.L("buf.WriteByte('{')") o.L("enc := json.NewEncoder(buf)") - o.L("for i, p := range h.makePairs() {") + + o.L("buf.WriteByte('{')") + o.L("for i, k := range keys {") o.L("if i > 0 {") o.L("buf.WriteRune(',')") o.L("}") o.L("buf.WriteRune('\"')") - o.L("buf.WriteString(p.Key.(string))") + o.L("buf.WriteString(k)") o.L("buf.WriteString(`\":`)") - o.L("v := p.Value") - o.L("switch v := v.(type) {") + o.L("switch v := data[k].(type) {") o.L("case []byte:") o.L("buf.WriteRune('\"')") o.L("buf.WriteString(base64.EncodeToString(v))") o.L("buf.WriteRune('\"')") o.L("default:") o.L("if err := enc.Encode(v); err != nil {") - o.L("return nil, fmt.Errorf(`failed to encode value for field %%s: %%w`, p.Key, err)") + o.L("return nil, fmt.Errorf(`failed to encode value for field %%s: %%w`, k, err)") o.L("}") o.L("buf.Truncate(buf.Len()-1)") o.L("}")