diff --git a/p2p/http/auth/auth.go b/p2p/http/auth/auth.go new file mode 100644 index 0000000000..692464fa22 --- /dev/null +++ b/p2p/http/auth/auth.go @@ -0,0 +1,11 @@ +package httppeeridauth + +import ( + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/p2p/http/auth/internal/handshake" +) + +const PeerIDAuthScheme = handshake.PeerIDAuthScheme +const ProtocolID = "/http-peer-id-auth/1.0.0" + +var log = logging.Logger("http-peer-id-auth") diff --git a/p2p/http/auth/auth_test.go b/p2p/http/auth/auth_test.go new file mode 100644 index 0000000000..d080b19511 --- /dev/null +++ b/p2p/http/auth/auth_test.go @@ -0,0 +1,243 @@ +package httppeeridauth + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "hash" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestMutualAuth tests that we can do a mutually authenticated round trip +func TestMutualAuth(t *testing.T) { + logging.SetLogLevel("httppeeridauth", "DEBUG") + + zeroBytes := make([]byte, 64) + serverKey, _, err := crypto.GenerateEd25519Key(bytes.NewReader(zeroBytes)) + require.NoError(t, err) + + type clientTestCase struct { + name string + clientKeyGen func(t *testing.T) crypto.PrivKey + } + + clientTestCases := []clientTestCase{ + { + name: "ED25519", + clientKeyGen: func(t *testing.T) crypto.PrivKey { + t.Helper() + clientKey, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + return clientKey + }, + }, + { + name: "RSA", + clientKeyGen: func(t *testing.T) crypto.PrivKey { + t.Helper() + clientKey, _, err := crypto.GenerateRSAKeyPair(2048, rand.Reader) + require.NoError(t, err) + return clientKey + }, + }, + } + + type serverTestCase struct { + name string + serverGen func(t *testing.T) (*httptest.Server, *ServerPeerIDAuth) + } + + serverTestCases := []serverTestCase{ + { + name: "no TLS", + serverGen: func(t *testing.T) (*httptest.Server, *ServerPeerIDAuth) { + t.Helper() + auth := ServerPeerIDAuth{ + PrivKey: serverKey, + ValidHostnameFn: func(s string) bool { + return s == "example.com" + }, + TokenTTL: time.Hour, + NoTLS: true, + } + + ts := httptest.NewServer(&auth) + t.Cleanup(ts.Close) + return ts, &auth + }, + }, + { + name: "TLS", + serverGen: func(t *testing.T) (*httptest.Server, *ServerPeerIDAuth) { + t.Helper() + auth := ServerPeerIDAuth{ + PrivKey: serverKey, + ValidHostnameFn: func(s string) bool { + return s == "example.com" + }, + TokenTTL: time.Hour, + } + + ts := httptest.NewTLSServer(&auth) + t.Cleanup(ts.Close) + return ts, &auth + }, + }, + } + + for _, ctc := range clientTestCases { + for _, stc := range serverTestCases { + t.Run(ctc.name+"+"+stc.name, func(t *testing.T) { + ts, server := stc.serverGen(t) + client := ts.Client() + roundTripper := instrumentedRoundTripper{client.Transport, 0} + client.Transport = &roundTripper + requestsSent := func() int { + defer func() { roundTripper.timesRoundtripped = 0 }() + return roundTripper.timesRoundtripped + } + + tlsClientConfig := roundTripper.TLSClientConfig() + if tlsClientConfig != nil { + // If we're using TLS, we need to set the SNI so that the + // server can verify the request Host matches it. + tlsClientConfig.ServerName = "example.com" + } + clientKey := ctc.clientKeyGen(t) + clientAuth := ClientPeerIDAuth{PrivKey: clientKey} + + expectedServerID, err := peer.IDFromPrivateKey(serverKey) + require.NoError(t, err) + + req, err := http.NewRequest("POST", ts.URL, nil) + require.NoError(t, err) + req.Host = "example.com" + serverID, resp, err := clientAuth.AuthenticatedDo(client, req) + require.NoError(t, err) + require.Equal(t, expectedServerID, serverID) + require.NotZero(t, clientAuth.tm.tokenMap["example.com"]) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, 2, requestsSent()) + + // Once more with the auth token + req, err = http.NewRequest("POST", ts.URL, nil) + require.NoError(t, err) + req.Host = "example.com" + serverID, resp, err = clientAuth.AuthenticatedDo(client, req) + require.NotEmpty(t, req.Header.Get("Authorization")) + require.NoError(t, err) + require.Equal(t, expectedServerID, serverID) + require.NotZero(t, clientAuth.tm.tokenMap["example.com"]) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, 1, requestsSent(), "should only call newRequest once since we have a token") + + t.Run("Tokens Expired", func(t *testing.T) { + // Clear the auth token on the server side + server.TokenTTL = 1 // Small TTL + time.Sleep(100 * time.Millisecond) + resetServerTokenTTL := sync.OnceFunc(func() { + server.TokenTTL = time.Hour + }) + + req, err := http.NewRequest("POST", ts.URL, nil) + require.NoError(t, err) + req.Host = "example.com" + req.GetBody = func() (io.ReadCloser, error) { + resetServerTokenTTL() + return nil, nil + } + serverID, resp, err = clientAuth.AuthenticatedDo(client, req) + require.NoError(t, err) + require.NotEmpty(t, req.Header.Get("Authorization")) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, expectedServerID, serverID) + require.NotZero(t, clientAuth.tm.tokenMap["example.com"]) + require.Equal(t, 3, requestsSent(), "should call newRequest 3x since our token expired") + }) + + t.Run("Tokens Invalidated", func(t *testing.T) { + // Clear the auth token on the server side + server.Hmac = func() hash.Hash { + key := make([]byte, 32) + _, err := rand.Read(key) + if err != nil { + panic(err) + } + return hmac.New(sha256.New, key) + }() + + req, err := http.NewRequest("POST", ts.URL, nil) + req.GetBody = func() (io.ReadCloser, error) { + return nil, nil + } + require.NoError(t, err) + req.Host = "example.com" + serverID, resp, err = clientAuth.AuthenticatedDo(client, req) + require.NoError(t, err) + require.NotEmpty(t, req.Header.Get("Authorization")) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, expectedServerID, serverID) + require.NotZero(t, clientAuth.tm.tokenMap["example.com"]) + require.Equal(t, 3, requestsSent(), "should call have sent 3 reqs since our token expired") + }) + + }) + } + } +} + +func TestBodyNotSentDuringRedirect(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Empty(t, string(b)) + if r.URL.Path != "/redirected" { + w.Header().Set("Location", "/redirected") + w.WriteHeader(http.StatusTemporaryRedirect) + return + } + })) + t.Cleanup(ts.Close) + client := ts.Client() + clientKey, _, _ := crypto.GenerateEd25519Key(rand.Reader) + clientAuth := ClientPeerIDAuth{PrivKey: clientKey} + + req, err := + http.NewRequest( + "POST", + ts.URL, + strings.NewReader("Only for authenticated servers"), + ) + req.Host = "example.com" + require.NoError(t, err) + _, _, err = clientAuth.AuthenticatedDo(client, req) + require.ErrorContains(t, err, "signature not set") // server doesn't actually handshake +} + +type instrumentedRoundTripper struct { + http.RoundTripper + timesRoundtripped int +} + +func (irt *instrumentedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + irt.timesRoundtripped++ + return irt.RoundTripper.RoundTrip(req) +} + +func (irt *instrumentedRoundTripper) TLSClientConfig() *tls.Config { + return irt.RoundTripper.(*http.Transport).TLSClientConfig +} diff --git a/p2p/http/auth/client.go b/p2p/http/auth/client.go new file mode 100644 index 0000000000..a6bdece61a --- /dev/null +++ b/p2p/http/auth/client.go @@ -0,0 +1,194 @@ +package httppeeridauth + +import ( + "errors" + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/http/auth/internal/handshake" +) + +type ClientPeerIDAuth struct { + PrivKey crypto.PrivKey + TokenTTL time.Duration + + tm tokenMap +} + +// AuthenticatedDo is like http.Client.Do, but it does the libp2p peer ID auth +// handshake if needed. +// +// It is recommended to pass in an http.Request with `GetBody` set, so that this +// method can retry sending the request in case a previously used token has +// expired. +func (a *ClientPeerIDAuth) AuthenticatedDo(client *http.Client, req *http.Request) (peer.ID, *http.Response, error) { + hostname := req.Host + ti, hasToken := a.tm.get(hostname, a.TokenTTL) + handshake := handshake.PeerIDAuthHandshakeClient{ + Hostname: hostname, + PrivKey: a.PrivKey, + } + + if hasToken { + // We have a token. Attempt to use that, but fallback to server initiated challenge if it fails. + peer, resp, err := a.doWithToken(client, req, ti) + switch { + case err == nil: + return peer, resp, nil + case errors.Is(err, errTokenRejected): + // Token was rejected, we need to re-authenticate + break + default: + return "", nil, err + } + + // Token didn't work, we need to re-authenticate. + // Run the server-initiated handshake + req = req.Clone(req.Context()) + req.Body, err = req.GetBody() + if err != nil { + return "", nil, err + } + + handshake.ParseHeader(resp.Header) + } else { + // We didn't have a handshake token, so we initiate the handshake. + // If our token was rejected, the server initiates the handshake. + handshake.SetInitiateChallenge() + } + + serverPeerID, resp, err := a.runHandshake(client, req, clearBody(req), &handshake) + if err != nil { + return "", nil, fmt.Errorf("failed to run handshake: %w", err) + } + a.tm.set(hostname, tokenInfo{ + token: handshake.BearerToken(), + insertedAt: time.Now(), + peerID: serverPeerID, + }) + return serverPeerID, resp, nil +} + +func (a *ClientPeerIDAuth) runHandshake(client *http.Client, req *http.Request, b bodyMeta, hs *handshake.PeerIDAuthHandshakeClient) (peer.ID, *http.Response, error) { + maxSteps := 5 // Avoid infinite loops in case of buggy handshake. Shouldn't happen. + var resp *http.Response + + err := hs.Run() + if err != nil { + return "", nil, err + } + + sentBody := false + for !hs.HandshakeDone() || !sentBody { + req = req.Clone(req.Context()) + hs.AddHeader(req.Header) + if hs.ServerAuthenticated() { + sentBody = true + b.setBody(req) + } + + resp, err = client.Do(req) + if err != nil { + return "", nil, err + } + + hs.ParseHeader(resp.Header) + err = hs.Run() + if err != nil { + resp.Body.Close() + return "", nil, err + } + + if maxSteps--; maxSteps == 0 { + return "", nil, errors.New("handshake took too many steps") + } + } + + p, err := hs.PeerID() + if err != nil { + resp.Body.Close() + return "", nil, err + } + return p, resp, nil +} + +var errTokenRejected = errors.New("token rejected") + +func (a *ClientPeerIDAuth) doWithToken(client *http.Client, req *http.Request, ti tokenInfo) (peer.ID, *http.Response, error) { + // Try to make the request with the token + req.Header.Set("Authorization", ti.token) + resp, err := client.Do(req) + if err != nil { + return "", nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + // our token is still valid + return ti.peerID, resp, nil + } + if req.GetBody == nil { + // We can't retry this request even if we wanted to. + // Return the response and an error + return "", resp, errors.New("expired token. Couldn't run handshake because req.GetBody is nil") + } + resp.Body.Close() + + return "", resp, errTokenRejected +} + +type bodyMeta struct { + body io.ReadCloser + contentLength int64 + getBody func() (io.ReadCloser, error) +} + +func clearBody(req *http.Request) bodyMeta { + defer func() { + req.Body = nil + req.ContentLength = 0 + req.GetBody = nil + }() + return bodyMeta{body: req.Body, contentLength: req.ContentLength, getBody: req.GetBody} +} + +func (b *bodyMeta) setBody(req *http.Request) { + req.Body = b.body + req.ContentLength = b.contentLength + req.GetBody = b.getBody +} + +type tokenInfo struct { + token string + insertedAt time.Time + peerID peer.ID +} + +type tokenMap struct { + tokenMapMu sync.Mutex + tokenMap map[string]tokenInfo +} + +func (tm *tokenMap) get(hostname string, ttl time.Duration) (tokenInfo, bool) { + tm.tokenMapMu.Lock() + defer tm.tokenMapMu.Unlock() + + ti, ok := tm.tokenMap[hostname] + if ok && ttl != 0 && time.Since(ti.insertedAt) > ttl { + delete(tm.tokenMap, hostname) + return tokenInfo{}, false + } + return ti, ok +} + +func (tm *tokenMap) set(hostname string, ti tokenInfo) { + tm.tokenMapMu.Lock() + defer tm.tokenMapMu.Unlock() + if tm.tokenMap == nil { + tm.tokenMap = make(map[string]tokenInfo) + } + tm.tokenMap[hostname] = ti +} diff --git a/p2p/http/auth/internal/handshake/alloc_test.go b/p2p/http/auth/internal/handshake/alloc_test.go new file mode 100644 index 0000000000..333bad4f0d --- /dev/null +++ b/p2p/http/auth/internal/handshake/alloc_test.go @@ -0,0 +1,20 @@ +//go:build nocover + +package handshake + +import "testing" + +func TestParsePeerIDAuthSchemeParamsNoAllocNoCover(t *testing.T) { + str := []byte(`libp2p-PeerID peer-id="", sig="", public-key="", bearer=""`) + + allocs := testing.AllocsPerRun(1000, func() { + p := params{} + err := p.parsePeerIDAuthSchemeParams(str) + if err != nil { + t.Fatal(err) + } + }) + if allocs > 0 { + t.Fatalf("alloc test failed expected 0 received %0.2f", allocs) + } +} diff --git a/p2p/http/auth/internal/handshake/client.go b/p2p/http/auth/internal/handshake/client.go new file mode 100644 index 0000000000..f8d39e9c14 --- /dev/null +++ b/p2p/http/auth/internal/handshake/client.go @@ -0,0 +1,247 @@ +package handshake + +import ( + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" +) + +type peerIDAuthClientState int + +const ( + peerIDAuthClientStateSignChallenge peerIDAuthClientState = iota + peerIDAuthClientStateVerifyChallenge + peerIDAuthClientStateDone // We have the bearer token, and there's nothing left to do + + // Client initiated handshake + peerIDAuthClientInitiateChallenge + peerIDAuthClientStateVerifyAndSignChallenge + peerIDAuthClientStateWaitingForBearer +) + +type PeerIDAuthHandshakeClient struct { + Hostname string + PrivKey crypto.PrivKey + + serverPeerID peer.ID + serverPubKey crypto.PubKey + state peerIDAuthClientState + p params + hb headerBuilder + challengeServer []byte + buf [128]byte +} + +var errMissingChallenge = errors.New("missing challenge") + +func (h *PeerIDAuthHandshakeClient) SetInitiateChallenge() { + h.state = peerIDAuthClientInitiateChallenge +} + +func (h *PeerIDAuthHandshakeClient) ParseHeader(header http.Header) error { + if h.state == peerIDAuthClientStateDone || h.state == peerIDAuthClientInitiateChallenge { + return nil + } + h.p = params{} + + var headerVal []byte + switch h.state { + case peerIDAuthClientStateSignChallenge, peerIDAuthClientStateVerifyAndSignChallenge: + headerVal = []byte(header.Get("WWW-Authenticate")) + case peerIDAuthClientStateVerifyChallenge, peerIDAuthClientStateWaitingForBearer: + headerVal = []byte(header.Get("Authentication-Info")) + } + + if len(headerVal) == 0 { + return errMissingChallenge + } + + err := h.p.parsePeerIDAuthSchemeParams(headerVal) + if err != nil { + return err + } + + if h.serverPubKey == nil && len(h.p.publicKeyB64) > 0 { + serverPubKeyBytes, err := base64.URLEncoding.AppendDecode(nil, h.p.publicKeyB64) + if err != nil { + return err + } + h.serverPubKey, err = crypto.UnmarshalPublicKey(serverPubKeyBytes) + if err != nil { + return err + } + h.serverPeerID, err = peer.IDFromPublicKey(h.serverPubKey) + if err != nil { + return err + } + } + + return err +} + +func (h *PeerIDAuthHandshakeClient) Run() error { + if h.state == peerIDAuthClientStateDone { + return nil + } + + h.hb.clear() + clientPubKeyBytes, err := crypto.MarshalPublicKey(h.PrivKey.GetPublic()) + if err != nil { + return err + } + switch h.state { + case peerIDAuthClientInitiateChallenge: + h.hb.writeScheme(PeerIDAuthScheme) + h.addChallengeServerParam() + h.hb.writeParamB64(nil, "public-key", clientPubKeyBytes) + h.state = peerIDAuthClientStateVerifyAndSignChallenge + return nil + case peerIDAuthClientStateVerifyAndSignChallenge: + if len(h.p.sigB64) == 0 && len(h.p.challengeClient) != 0 { + // The server refused a client initiated handshake, so we need run the server initiated handshake + h.state = peerIDAuthClientStateSignChallenge + return h.Run() + } + if err := h.verifySig(clientPubKeyBytes); err != nil { + return err + } + + h.hb.writeScheme(PeerIDAuthScheme) + h.hb.writeParam("opaque", h.p.opaqueB64) + h.addSigParam() + h.state = peerIDAuthClientStateWaitingForBearer + return nil + + case peerIDAuthClientStateWaitingForBearer: + h.hb.writeScheme(PeerIDAuthScheme) + h.hb.writeParam("bearer", h.p.bearerTokenB64) + h.state = peerIDAuthClientStateDone + return nil + + case peerIDAuthClientStateSignChallenge: + if len(h.p.challengeClient) < challengeLen { + return errors.New("challenge too short") + } + + h.hb.writeScheme(PeerIDAuthScheme) + h.hb.writeParamB64(nil, "public-key", clientPubKeyBytes) + if err := h.addChallengeServerParam(); err != nil { + return err + } + if err := h.addSigParam(); err != nil { + return err + } + h.hb.writeParam("opaque", h.p.opaqueB64) + + h.state = peerIDAuthClientStateVerifyChallenge + return nil + case peerIDAuthClientStateVerifyChallenge: + if err := h.verifySig(clientPubKeyBytes); err != nil { + return err + } + + h.hb.writeScheme(PeerIDAuthScheme) + h.hb.writeParam("bearer", h.p.bearerTokenB64) + h.state = peerIDAuthClientStateDone + + return nil + } + + return errors.New("unhandled state") +} + +func (h *PeerIDAuthHandshakeClient) addChallengeServerParam() error { + _, err := io.ReadFull(randReader, h.buf[:challengeLen]) + if err != nil { + return err + } + h.challengeServer = base64.URLEncoding.AppendEncode(nil, h.buf[:challengeLen]) + clear(h.buf[:challengeLen]) + h.hb.writeParam("challenge-server", h.challengeServer) + return nil +} + +func (h *PeerIDAuthHandshakeClient) verifySig(clientPubKeyBytes []byte) error { + if len(h.p.sigB64) == 0 { + return errors.New("signature not set") + } + sig, err := base64.URLEncoding.AppendDecode(nil, h.p.sigB64) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } + err = verifySig(h.serverPubKey, PeerIDAuthScheme, []sigParam{ + {"challenge-server", h.challengeServer}, + {"client-public-key", clientPubKeyBytes}, + {"hostname", []byte(h.Hostname)}, + }, sig) + return err +} + +func (h *PeerIDAuthHandshakeClient) addSigParam() error { + if h.serverPubKey == nil { + return errors.New("server public key not set") + } + serverPubKeyBytes, err := crypto.MarshalPublicKey(h.serverPubKey) + if err != nil { + return err + } + clientSig, err := sign(h.PrivKey, PeerIDAuthScheme, []sigParam{ + {"challenge-client", h.p.challengeClient}, + {"server-public-key", serverPubKeyBytes}, + {"hostname", []byte(h.Hostname)}, + }) + if err != nil { + return fmt.Errorf("failed to sign challenge: %w", err) + } + h.hb.writeParamB64(nil, "sig", clientSig) + return nil + +} + +// PeerID returns the peer ID of the authenticated client. +func (h *PeerIDAuthHandshakeClient) PeerID() (peer.ID, error) { + switch h.state { + case peerIDAuthClientStateDone: + case peerIDAuthClientStateWaitingForBearer: + default: + return "", errors.New("server not authenticated yet") + } + + if h.serverPeerID == "" { + return "", errors.New("peer ID not set") + } + return h.serverPeerID, nil +} + +func (h *PeerIDAuthHandshakeClient) AddHeader(hdr http.Header) { + hdr.Set("Authorization", h.hb.b.String()) +} + +// BearerToken returns the server given bearer token for the client. Set this on +// the Authorization header in the client's request. +func (h *PeerIDAuthHandshakeClient) BearerToken() string { + if h.state != peerIDAuthClientStateDone { + return "" + } + return h.hb.b.String() +} + +func (h *PeerIDAuthHandshakeClient) ServerAuthenticated() bool { + switch h.state { + case peerIDAuthClientStateDone: + case peerIDAuthClientStateWaitingForBearer: + default: + return false + } + + return h.serverPeerID != "" +} + +func (h *PeerIDAuthHandshakeClient) HandshakeDone() bool { + return h.state == peerIDAuthClientStateDone +} diff --git a/p2p/http/auth/internal/handshake/handshake.go b/p2p/http/auth/internal/handshake/handshake.go new file mode 100644 index 0000000000..1c237ae3a3 --- /dev/null +++ b/p2p/http/auth/internal/handshake/handshake.go @@ -0,0 +1,218 @@ +package handshake + +import ( + "bufio" + "bytes" + "crypto/rand" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "slices" + "strings" + "time" + + "github.com/libp2p/go-libp2p/core/crypto" + + pool "github.com/libp2p/go-buffer-pool" +) + +const PeerIDAuthScheme = "libp2p-PeerID" +const challengeLen = 32 +const maxHeaderSize = 2048 + +var peerIDAuthSchemeBytes = []byte(PeerIDAuthScheme) + +var errTooBig = errors.New("header value too big") +var errInvalid = errors.New("invalid header value") +var errNotRan = errors.New("not ran. call Run() first") + +var randReader = rand.Reader // A var so it can be changed in tests +var nowFn = time.Now // A var so it can be changed in tests + +// params represent params passed in via headers. All []byte fields to avoid allocations. +type params struct { + bearerTokenB64 []byte + challengeClient []byte + challengeServer []byte + opaqueB64 []byte + publicKeyB64 []byte + sigB64 []byte +} + +// parsePeerIDAuthSchemeParams parses the parameters of the PeerID auth scheme +// from the header string. zero alloc. +func (p *params) parsePeerIDAuthSchemeParams(headerVal []byte) error { + if len(headerVal) > maxHeaderSize { + return errTooBig + } + startIdx := bytes.Index(headerVal, peerIDAuthSchemeBytes) + if startIdx == -1 { + return nil + } + + headerVal = headerVal[startIdx+len(PeerIDAuthScheme):] + advance, token, err := splitAuthHeaderParams(headerVal, true) + for ; err == nil; advance, token, err = splitAuthHeaderParams(headerVal, true) { + headerVal = headerVal[advance:] + bs := token + splitAt := bytes.Index(bs, []byte("=")) + if splitAt == -1 { + return errInvalid + } + kB := bs[:splitAt] + v := bs[splitAt+1:] + if len(v) < 2 || v[0] != '"' || v[len(v)-1] != '"' { + return errInvalid + } + v = v[1 : len(v)-1] // drop quotes + switch string(kB) { + case "bearer": + p.bearerTokenB64 = v + case "challenge-client": + p.challengeClient = v + case "challenge-server": + p.challengeServer = v + case "opaque": + p.opaqueB64 = v + case "public-key": + p.publicKeyB64 = v + case "sig": + p.sigB64 = v + } + } + if err == bufio.ErrFinalToken { + err = nil + } + return err +} + +func splitAuthHeaderParams(data []byte, atEOF bool) (advance int, token []byte, err error) { + if len(data) == 0 && atEOF { + return 0, nil, bufio.ErrFinalToken + } + + start := 0 + for start < len(data) && (data[start] == ' ' || data[start] == ',') { + // Ignore leading spaces and commas + start++ + } + if start == len(data) { + return len(data), nil, nil + } + end := start + 1 + for end < len(data) && data[end] != ' ' && data[end] != ',' { + // Consume until we hit a space or comma + end++ + } + token = data[start:end] + if !bytes.ContainsAny(token, "=") { + // This isn't a param. It's likely the next scheme. We're done + return len(data), nil, bufio.ErrFinalToken + } + + return end, token, nil +} + +type headerBuilder struct { + b strings.Builder + pastFirstField bool +} + +func (h *headerBuilder) clear() { + h.b.Reset() + h.pastFirstField = false +} + +func (h *headerBuilder) writeScheme(scheme string) { + h.b.WriteString(scheme) + h.b.WriteByte(' ') +} + +func (h *headerBuilder) maybeAddComma() { + if !h.pastFirstField { + h.pastFirstField = true + return + } + h.b.WriteString(", ") +} + +// writeParam writes a key value pair to the header. It first b64 encodes the +// value. It uses buf as scratch space. +func (h *headerBuilder) writeParamB64(buf []byte, key string, val []byte) { + if buf == nil { + buf = make([]byte, base64.URLEncoding.EncodedLen(len(val))) + } + encodedVal := base64.URLEncoding.AppendEncode(buf[:0], val) + h.writeParam(key, encodedVal) +} + +// writeParam writes a key value pair to the header. It writes the val as-is. +func (h *headerBuilder) writeParam(key string, val []byte) { + if len(val) == 0 { + return + } + h.maybeAddComma() + + h.b.Grow(len(key) + len(`="`) + len(val) + 1) + // Not doing fmt.Fprintf here to avoid one allocation + h.b.WriteString(key) + h.b.WriteString(`="`) + h.b.Write(val) + h.b.WriteByte('"') +} + +type sigParam struct { + k string + v []byte +} + +func verifySig(publicKey crypto.PubKey, prefix string, signedParts []sigParam, sig []byte) error { + if publicKey == nil { + return fmt.Errorf("no public key to verify signature") + } + + b := pool.Get(4096) + defer pool.Put(b) + buf, err := genDataToSign(b[:0], prefix, signedParts) + if err != nil { + return fmt.Errorf("failed to generate signed data: %w", err) + } + ok, err := publicKey.Verify(buf, sig) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("signature verification failed") + } + + return nil +} + +func sign(privKey crypto.PrivKey, prefix string, partsToSign []sigParam) ([]byte, error) { + if privKey == nil { + return nil, fmt.Errorf("no private key available to sign") + } + b := pool.Get(4096) + defer pool.Put(b) + buf, err := genDataToSign(b[:0], prefix, partsToSign) + if err != nil { + return nil, fmt.Errorf("failed to generate data to sign: %w", err) + } + return privKey.Sign(buf) +} + +func genDataToSign(buf []byte, prefix string, parts []sigParam) ([]byte, error) { + // Sort the parts in lexicographic order + slices.SortFunc(parts, func(a, b sigParam) int { + return strings.Compare(a.k, b.k) + }) + buf = append(buf, prefix...) + for _, p := range parts { + buf = binary.AppendUvarint(buf, uint64(len(p.k)+1+len(p.v))) // +1 for '=' + buf = append(buf, p.k...) + buf = append(buf, '=') + buf = append(buf, p.v...) + } + return buf, nil +} diff --git a/p2p/http/auth/internal/handshake/handshake_test.go b/p2p/http/auth/internal/handshake/handshake_test.go new file mode 100644 index 0000000000..0579b95163 --- /dev/null +++ b/p2p/http/auth/internal/handshake/handshake_test.go @@ -0,0 +1,652 @@ +package handshake + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "net/url" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/require" +) + +func TestHandshake(t *testing.T) { + for _, clientInitiated := range []bool{true, false} { + t.Run(fmt.Sprintf("clientInitiated=%t", clientInitiated), func(t *testing.T) { + hostname := "example.com" + serverPriv, _, _ := crypto.GenerateEd25519Key(rand.Reader) + clientPriv, _, _ := crypto.GenerateEd25519Key(rand.Reader) + + serverHandshake := PeerIDAuthHandshakeServer{ + Hostname: hostname, + PrivKey: serverPriv, + TokenTTL: time.Hour, + Hmac: hmac.New(sha256.New, make([]byte, 32)), + } + + clientHandshake := PeerIDAuthHandshakeClient{ + Hostname: hostname, + PrivKey: clientPriv, + } + if clientInitiated { + clientHandshake.state = peerIDAuthClientInitiateChallenge + } + + headers := make(http.Header) + + // Start the handshake + if !clientInitiated { + require.NoError(t, serverHandshake.ParseHeaderVal(nil)) + require.NoError(t, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + } + + // Server Inititated: Client receives the challenge and signs it. Also sends the challenge server + // Client Inititated: Client forms the challenge and sends it + require.NoError(t, clientHandshake.ParseHeader(headers)) + clear(headers) + require.NoError(t, clientHandshake.Run()) + clientHandshake.AddHeader(headers) + + // Server Inititated: Server receives the sig and verifies it. Also signs the challenge-server (client authenticated) + // Client Inititated: Server receives the challenge and signs it. Also sends the challenge-client + serverHandshake.Reset() + require.NoError(t, serverHandshake.ParseHeaderVal([]byte(headers.Get("Authorization")))) + clear(headers) + require.NoError(t, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + + // Server Inititated: Client verifies sig and sets the bearer token for future requests (server authenticated) + // Client Inititated: Client verifies sig, and signs challenge. Sends it along with any application data (server authenticated) + require.NoError(t, clientHandshake.ParseHeader(headers)) + clear(headers) + require.NoError(t, clientHandshake.Run()) + clientHandshake.AddHeader(headers) + + // Server Inititated: Server verifies the bearer token + // Client Inititated: Server verifies the sig, sets the bearer token (client authenticated) + // and processes any application data + serverHandshake.Reset() + require.NoError(t, serverHandshake.ParseHeaderVal([]byte(headers.Get("Authorization")))) + clear(headers) + require.NoError(t, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + + expectedClientPeerID, _ := peer.IDFromPrivateKey(clientPriv) + expectedServerPeerID, _ := peer.IDFromPrivateKey(serverPriv) + clientPeerID, err := serverHandshake.PeerID() + require.NoError(t, err) + require.Equal(t, expectedClientPeerID, clientPeerID) + + serverPeerID, err := clientHandshake.PeerID() + require.NoError(t, err) + require.Equal(t, expectedServerPeerID, serverPeerID) + }) + } +} + +func TestServerRefusesClientInitiatedHandshake(t *testing.T) { + hostname := "example.com" + serverPriv, _, _ := crypto.GenerateEd25519Key(rand.Reader) + clientPriv, _, _ := crypto.GenerateEd25519Key(rand.Reader) + + serverHandshake := PeerIDAuthHandshakeServer{ + Hostname: hostname, + PrivKey: serverPriv, + TokenTTL: time.Hour, + Hmac: hmac.New(sha256.New, make([]byte, 32)), + } + + clientHandshake := PeerIDAuthHandshakeClient{ + Hostname: hostname, + PrivKey: clientPriv, + } + clientHandshake.SetInitiateChallenge() + + headers := make(http.Header) + // Client initiates the handshake + require.NoError(t, clientHandshake.Run()) + clientHandshake.AddHeader(headers) + + // Server receives the challenge-server, but chooses to reject it (simulating this by not passing the challenge) + serverHandshake.Reset() + require.NoError(t, serverHandshake.ParseHeaderVal(nil)) + clear(headers) + require.NoError(t, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + + // Client now runs the server-initiated handshake. Signs challenge-client; sends challenge-server + require.NoError(t, clientHandshake.ParseHeader(headers)) + clear(headers) + require.NoError(t, clientHandshake.Run()) + clientHandshake.AddHeader(headers) + + // Server verifies the challenge-client and signs the challenge-server + serverHandshake.Reset() + require.NoError(t, serverHandshake.ParseHeaderVal([]byte(headers.Get("Authorization")))) + clear(headers) + require.NoError(t, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + + // Client verifies the challenge-server and sets the bearer token + require.NoError(t, clientHandshake.ParseHeader(headers)) + clear(headers) + require.NoError(t, clientHandshake.Run()) + clientHandshake.AddHeader(headers) + + expectedClientPeerID, _ := peer.IDFromPrivateKey(clientPriv) + expectedServerPeerID, _ := peer.IDFromPrivateKey(serverPriv) + clientPeerID, err := serverHandshake.PeerID() + require.NoError(t, err) + require.Equal(t, expectedClientPeerID, clientPeerID) + + serverPeerID, err := clientHandshake.PeerID() + require.NoError(t, err) + require.True(t, clientHandshake.HandshakeDone()) + require.Equal(t, expectedServerPeerID, serverPeerID) +} + +func BenchmarkServerHandshake(b *testing.B) { + clientHeader1 := make(http.Header) + clientHeader2 := make(http.Header) + headers := make(http.Header) + + hostname := "example.com" + serverPriv, _, _ := crypto.GenerateEd25519Key(rand.Reader) + clientPriv, _, _ := crypto.GenerateEd25519Key(rand.Reader) + + serverHandshake := PeerIDAuthHandshakeServer{ + Hostname: hostname, + PrivKey: serverPriv, + TokenTTL: time.Hour, + Hmac: hmac.New(sha256.New, make([]byte, 32)), + } + + clientHandshake := PeerIDAuthHandshakeClient{ + Hostname: hostname, + PrivKey: clientPriv, + } + require.NoError(b, serverHandshake.ParseHeaderVal(nil)) + require.NoError(b, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + + // Client receives the challenge and signs it. Also sends the challenge server + require.NoError(b, clientHandshake.ParseHeader(headers)) + clear(headers) + require.NoError(b, clientHandshake.Run()) + clientHandshake.AddHeader(clientHeader1) + + // Server receives the sig and verifies it. Also signs the challenge server + serverHandshake.Reset() + require.NoError(b, serverHandshake.ParseHeaderVal([]byte(clientHeader1.Get("Authorization")))) + clear(headers) + require.NoError(b, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + + // Client verifies sig and sets the bearer token for future requests + require.NoError(b, clientHandshake.ParseHeader(headers)) + clear(headers) + require.NoError(b, clientHandshake.Run()) + clientHandshake.AddHeader(clientHeader2) + + // Server verifies the bearer token + serverHandshake.Reset() + require.NoError(b, serverHandshake.ParseHeaderVal([]byte(clientHeader2.Get("Authorization")))) + clear(headers) + require.NoError(b, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + + initialClientAuth := []byte(clientHeader1.Get("Authorization")) + bearerClientAuth := []byte(clientHeader2.Get("Authorization")) + _ = initialClientAuth + _ = bearerClientAuth + + b.ResetTimer() + for i := 0; i < b.N; i++ { + serverHandshake.Reset() + serverHandshake.ParseHeaderVal(nil) + serverHandshake.Run() + + serverHandshake.Reset() + serverHandshake.ParseHeaderVal(initialClientAuth) + serverHandshake.Run() + + serverHandshake.Reset() + serverHandshake.ParseHeaderVal(bearerClientAuth) + serverHandshake.Run() + } + +} + +func TestParsePeerIDAuthSchemeParams(t *testing.T) { + str := `libp2p-PeerID sig="", public-key="", bearer=""` + p := params{} + expectedParam := params{ + sigB64: []byte(``), + publicKeyB64: []byte(``), + bearerTokenB64: []byte(``), + } + err := p.parsePeerIDAuthSchemeParams([]byte(str)) + require.NoError(t, err) + require.Equal(t, expectedParam, p) +} + +func BenchmarkParsePeerIDAuthSchemeParams(b *testing.B) { + str := []byte(`libp2p-PeerID peer-id="", sig="", public-key="", bearer=""`) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + p := params{} + err := p.parsePeerIDAuthSchemeParams(str) + if err != nil { + b.Fatal(err) + } + } +} + +func TestHeaderBuilder(t *testing.T) { + hb := headerBuilder{} + hb.writeScheme(PeerIDAuthScheme) + hb.writeParam("peer-id", []byte("foo")) + hb.writeParam("challenge-client", []byte("something-else")) + hb.writeParam("hostname", []byte("example.com")) + + expected := `libp2p-PeerID peer-id="foo", challenge-client="something-else", hostname="example.com"` + require.Equal(t, expected, hb.b.String()) +} + +func BenchmarkHeaderBuilder(b *testing.B) { + h := headerBuilder{} + scratch := make([]byte, 256) + scratch = scratch[:0] + + b.ResetTimer() + for i := 0; i < b.N; i++ { + h.b.Grow(256) + h.writeParamB64(scratch, "foo", []byte("bar")) + h.clear() + } +} + +// Test Vectors +var zeroBytes = make([]byte, 64) +var zeroKey, _, _ = crypto.GenerateEd25519Key(bytes.NewReader(zeroBytes)) + +// Peer ID derived from the zero key +var zeroID, _ = peer.IDFromPublicKey(zeroKey.GetPublic()) + +func TestOpaqueStateRoundTrip(t *testing.T) { + zeroBytes := [32]byte{} + + // To drop the monotonic clock reading + timeAfterUnmarshal := time.Now() + b, err := json.Marshal(timeAfterUnmarshal) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(b, &timeAfterUnmarshal)) + hmac := hmac.New(sha256.New, zeroBytes[:]) + + o := opaqueState{ + ChallengeClient: "foo-bar", + CreatedTime: timeAfterUnmarshal, + IsToken: true, + PeerID: zeroID, + Hostname: "example.com", + } + + hmac.Reset() + b, err = o.Marshal(hmac, nil) + require.NoError(t, err) + + o2 := opaqueState{} + + hmac.Reset() + err = o2.Unmarshal(hmac, b) + require.NoError(t, err) + require.EqualValues(t, o, o2) +} + +func FuzzServerHandshakeNoPanic(f *testing.F) { + zeroBytes := [32]byte{} + hmac := hmac.New(sha256.New, zeroBytes[:]) + + f.Fuzz(func(t *testing.T, data []byte) { + hmac.Reset() + h := PeerIDAuthHandshakeServer{ + Hostname: "example.com", + PrivKey: zeroKey, + Hmac: hmac, + } + err := h.ParseHeaderVal(data) + if err != nil { + return + } + err = h.Run() + if err != nil { + return + } + h.PeerID() + }) +} + +func BenchmarkOpaqueStateWrite(b *testing.B) { + zeroBytes := [32]byte{} + hmac := hmac.New(sha256.New, zeroBytes[:]) + o := opaqueState{ + ChallengeClient: "foo-bar", + CreatedTime: time.Now(), + } + d := make([]byte, 512) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + hmac.Reset() + _, err := o.Marshal(hmac, d[:0]) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkOpaqueStateRead(b *testing.B) { + zeroBytes := [32]byte{} + hmac := hmac.New(sha256.New, zeroBytes[:]) + o := opaqueState{ + ChallengeClient: "foo-bar", + CreatedTime: time.Now(), + } + d := make([]byte, 256) + d, err := o.Marshal(hmac, d[:0]) + require.NoError(b, err) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + hmac.Reset() + err := o.Unmarshal(hmac, d) + if err != nil { + b.Fatal(err) + } + } +} + +func FuzzParsePeerIDAuthSchemeParamsNoPanic(f *testing.F) { + p := params{} + // Just check that we don't panic + f.Fuzz(func(t *testing.T, data []byte) { + p.parsePeerIDAuthSchemeParams(data) + }) +} + +type specsExampleParameters struct { + hostname string + serverPriv crypto.PrivKey + serverHmacKey [32]byte + clientPriv crypto.PrivKey +} + +func TestSpecsExample(t *testing.T) { + originalRandReader := randReader + originalNowFn := nowFn + randReader = bytes.NewReader(append( + bytes.Repeat([]byte{0x11}, 32), + bytes.Repeat([]byte{0x33}, 32)..., + )) + nowFn = func() time.Time { + return time.Unix(0, 0) + } + defer func() { + randReader = originalRandReader + nowFn = originalNowFn + }() + + parameters := specsExampleParameters{ + hostname: "example.com", + } + serverPrivBytes, err := hex.AppendDecode(nil, []byte("0801124001010101010101010101010101010101010101010101010101010101010101018a88e3dd7409f195fd52db2d3cba5d72ca6709bf1d94121bf3748801b40f6f5c")) + require.NoError(t, err) + clientPrivBytes, err := hex.AppendDecode(nil, []byte("0801124002020202020202020202020202020202020202020202020202020202020202028139770ea87d175f56a35466c34c7ecccb8d8a91b4ee37a25df60f5b8fc9b394")) + require.NoError(t, err) + + parameters.serverPriv, err = crypto.UnmarshalPrivateKey(serverPrivBytes) + require.NoError(t, err) + + parameters.clientPriv, err = crypto.UnmarshalPrivateKey(clientPrivBytes) + require.NoError(t, err) + + serverHandshake := PeerIDAuthHandshakeServer{ + Hostname: parameters.hostname, + PrivKey: parameters.serverPriv, + TokenTTL: time.Hour, + Hmac: hmac.New(sha256.New, parameters.serverHmacKey[:]), + } + + clientHandshake := PeerIDAuthHandshakeClient{ + Hostname: parameters.hostname, + PrivKey: parameters.clientPriv, + } + + headers := make(http.Header) + + // Start the handshake + require.NoError(t, serverHandshake.ParseHeaderVal(nil)) + require.NoError(t, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + initialWWWAuthenticate := headers.Get("WWW-Authenticate") + + // Client receives the challenge and signs it. Also sends the challenge server + require.NoError(t, clientHandshake.ParseHeader(headers)) + clear(headers) + require.NoError(t, clientHandshake.Run()) + clientHandshake.AddHeader(headers) + clientAuthentication := headers.Get("Authorization") + + // Server receives the sig and verifies it. Also signs the challenge server + serverHandshake.Reset() + require.NoError(t, serverHandshake.ParseHeaderVal([]byte(headers.Get("Authorization")))) + clear(headers) + require.NoError(t, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + serverAuthentication := headers.Get("Authentication-Info") + + // Client verifies sig and sets the bearer token for future requests + require.NoError(t, clientHandshake.ParseHeader(headers)) + clear(headers) + require.NoError(t, clientHandshake.Run()) + clientHandshake.AddHeader(headers) + clientBearerToken := headers.Get("Authorization") + + params := params{} + params.parsePeerIDAuthSchemeParams([]byte(initialWWWAuthenticate)) + challengeClient := params.challengeClient + params.parsePeerIDAuthSchemeParams([]byte(clientAuthentication)) + challengeServer := params.challengeServer + + fmt.Println("### Parameters") + fmt.Println("| Parameter | Value |") + fmt.Println("| --- | --- |") + fmt.Printf("| hostname | %s |\n", parameters.hostname) + fmt.Printf("| Server Private Key (pb encoded as hex) | %s |\n", hex.EncodeToString(serverPrivBytes)) + fmt.Printf("| Server HMAC Key (hex) | %s |\n", hex.EncodeToString(parameters.serverHmacKey[:])) + fmt.Printf("| Challenge Client | %s |\n", string(challengeClient)) + fmt.Printf("| Client Private Key (pb encoded as hex) | %s |\n", hex.EncodeToString(clientPrivBytes)) + fmt.Printf("| Challenge Server | %s |\n", string(challengeServer)) + fmt.Printf("| \"Now\" time | %s |\n", nowFn()) + fmt.Println() + fmt.Println("### Handshake Diagram") + + fmt.Println("```mermaid") + fmt.Printf(`sequenceDiagram +Client->>Server: Initial request +Server->>Client: WWW-Authenticate=%s +Client->>Server: Authorization=%s +Note left of Server: Server has authenticated Client +Server->>Client: Authentication-Info=%s +Note right of Client: Client has authenticated Server + +Note over Client: Future requests use the bearer token +Client->>Server: Authorization=%s +`, initialWWWAuthenticate, clientAuthentication, serverAuthentication, clientBearerToken) + fmt.Println("```") + +} + +func TestSpecsClientInitiatedExample(t *testing.T) { + originalRandReader := randReader + originalNowFn := nowFn + randReader = bytes.NewReader(append( + bytes.Repeat([]byte{0x33}, 32), + bytes.Repeat([]byte{0x11}, 32)..., + )) + nowFn = func() time.Time { + return time.Unix(0, 0) + } + defer func() { + randReader = originalRandReader + nowFn = originalNowFn + }() + + parameters := specsExampleParameters{ + hostname: "example.com", + } + serverPrivBytes, err := hex.AppendDecode(nil, []byte("0801124001010101010101010101010101010101010101010101010101010101010101018a88e3dd7409f195fd52db2d3cba5d72ca6709bf1d94121bf3748801b40f6f5c")) + require.NoError(t, err) + clientPrivBytes, err := hex.AppendDecode(nil, []byte("0801124002020202020202020202020202020202020202020202020202020202020202028139770ea87d175f56a35466c34c7ecccb8d8a91b4ee37a25df60f5b8fc9b394")) + require.NoError(t, err) + + parameters.serverPriv, err = crypto.UnmarshalPrivateKey(serverPrivBytes) + require.NoError(t, err) + + parameters.clientPriv, err = crypto.UnmarshalPrivateKey(clientPrivBytes) + require.NoError(t, err) + + serverHandshake := PeerIDAuthHandshakeServer{ + Hostname: parameters.hostname, + PrivKey: parameters.serverPriv, + TokenTTL: time.Hour, + Hmac: hmac.New(sha256.New, parameters.serverHmacKey[:]), + } + + clientHandshake := PeerIDAuthHandshakeClient{ + Hostname: parameters.hostname, + PrivKey: parameters.clientPriv, + } + + headers := make(http.Header) + + // Start the handshake + clientHandshake.SetInitiateChallenge() + require.NoError(t, clientHandshake.Run()) + clientHandshake.AddHeader(headers) + clientChallenge := headers.Get("Authorization") + + // Server receives the challenge and signs it. Also sends challenge-client + serverHandshake.Reset() + require.NoError(t, serverHandshake.ParseHeaderVal([]byte(headers.Get("Authorization")))) + clear(headers) + require.NoError(t, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + serverAuthentication := headers.Get("WWW-Authenticate") + params := params{} + params.parsePeerIDAuthSchemeParams([]byte(serverAuthentication)) + challengeClient := params.challengeClient + + // Client verifies sig and signs the challenge-client + require.NoError(t, clientHandshake.ParseHeader(headers)) + clear(headers) + require.NoError(t, clientHandshake.Run()) + clientHandshake.AddHeader(headers) + clientAuthentication := headers.Get("Authorization") + + // Server verifies sig and sets the bearer token + serverHandshake.Reset() + require.NoError(t, serverHandshake.ParseHeaderVal([]byte(headers.Get("Authorization")))) + clear(headers) + require.NoError(t, serverHandshake.Run()) + serverHandshake.SetHeader(headers) + serverReplayWithBearer := headers.Get("Authentication-Info") + + params.parsePeerIDAuthSchemeParams([]byte(clientChallenge)) + challengeServer := params.challengeServer + + fmt.Println("### Parameters") + fmt.Println("| Parameter | Value |") + fmt.Println("| --- | --- |") + fmt.Printf("| hostname | %s |\n", parameters.hostname) + fmt.Printf("| Server Private Key (pb encoded as hex) | %s |\n", hex.EncodeToString(serverPrivBytes)) + fmt.Printf("| Server HMAC Key (hex) | %s |\n", hex.EncodeToString(parameters.serverHmacKey[:])) + fmt.Printf("| Challenge Client | %s |\n", string(challengeClient)) + fmt.Printf("| Client Private Key (pb encoded as hex) | %s |\n", hex.EncodeToString(clientPrivBytes)) + fmt.Printf("| Challenge Server | %s |\n", string(challengeServer)) + fmt.Printf("| \"Now\" time | %s |\n", nowFn()) + fmt.Println() + fmt.Println("### Handshake Diagram") + + fmt.Println("```mermaid") + fmt.Printf(`sequenceDiagram +Client->>Server: Authorization=%s +Server->>Client: WWW-Authenticate=%s +Note right of Client: Client has authenticated Server + +Client->>Server: Authorization=%s +Note left of Server: Server has authenticated Client +Server->>Client: Authentication-Info=%s +Note over Client: Future requests use the bearer token +`, clientChallenge, serverAuthentication, clientAuthentication, serverReplayWithBearer) + fmt.Println("```") + +} + +func TestSigningExample(t *testing.T) { + serverPrivBytes, err := hex.AppendDecode(nil, []byte("0801124001010101010101010101010101010101010101010101010101010101010101018a88e3dd7409f195fd52db2d3cba5d72ca6709bf1d94121bf3748801b40f6f5c")) + require.NoError(t, err) + serverPriv, err := crypto.UnmarshalPrivateKey(serverPrivBytes) + require.NoError(t, err) + clientPrivBytes, err := hex.AppendDecode(nil, []byte("0801124002020202020202020202020202020202020202020202020202020202020202028139770ea87d175f56a35466c34c7ecccb8d8a91b4ee37a25df60f5b8fc9b394")) + require.NoError(t, err) + clientPriv, err := crypto.UnmarshalPrivateKey(clientPrivBytes) + require.NoError(t, err) + clientPubKeyBytes, err := crypto.MarshalPublicKey(clientPriv.GetPublic()) + require.NoError(t, err) + + require.NoError(t, err) + challenge := "ERERERERERERERERERERERERERERERERERERERERERE=" + + hostname := "example.com" + dataToSign, err := genDataToSign(nil, PeerIDAuthScheme, []sigParam{ + {"challenge-server", []byte(challenge)}, + {"client-public-key", clientPubKeyBytes}, + {"hostname", []byte(hostname)}, + }) + require.NoError(t, err) + + sig, err := sign(serverPriv, PeerIDAuthScheme, []sigParam{ + {"challenge-server", []byte(challenge)}, + {"client-public-key", clientPubKeyBytes}, + {"hostname", []byte(hostname)}, + }) + require.NoError(t, err) + + fmt.Println("### Signing Example") + + fmt.Println("| Parameter | Value |") + fmt.Println("| --- | --- |") + fmt.Printf("| hostname | %s |\n", hostname) + fmt.Printf("| Server Private Key (pb encoded as hex) | %s |\n", hex.EncodeToString(serverPrivBytes)) + fmt.Printf("| challenge-server | %s |\n", string(challenge)) + fmt.Printf("| Client Public Key (pb encoded as hex) | %s |\n", hex.EncodeToString(clientPubKeyBytes)) + fmt.Printf("| data to sign ([percent encoded](https://datatracker.ietf.org/doc/html/rfc3986#section-2.1)) | %s |\n", url.PathEscape(string(dataToSign))) + fmt.Printf("| data to sign (hex encoded) | %s |\n", hex.EncodeToString(dataToSign)) + fmt.Printf("| signature (base64 encoded) | %s |\n", base64.URLEncoding.EncodeToString(sig)) + fmt.Println() + + fmt.Println("Note that the `=` after the libp2p-PeerID scheme is actually the varint length of the challenge-server parameter.") + +} diff --git a/p2p/http/auth/internal/handshake/server.go b/p2p/http/auth/internal/handshake/server.go new file mode 100644 index 0000000000..6b84038d93 --- /dev/null +++ b/p2p/http/auth/internal/handshake/server.go @@ -0,0 +1,373 @@ +package handshake + +import ( + "crypto/hmac" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "hash" + "io" + "net/http" + "time" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" +) + +var ( + ErrExpiredChallenge = errors.New("challenge expired") + ErrExpiredToken = errors.New("token expired") + ErrInvalidHMAC = errors.New("invalid HMAC") +) + +const challengeTTL = 5 * time.Minute + +type peerIDAuthServerState int + +const ( + // Server initiated + peerIDAuthServerStateChallengeClient peerIDAuthServerState = iota + peerIDAuthServerStateVerifyChallenge + peerIDAuthServerStateVerifyBearer + + // Client initiated + peerIDAuthServerStateSignChallenge +) + +type opaqueState struct { + IsToken bool `json:"is-token,omitempty"` + ClientPublicKey []byte `json:"client-public-key,omitempty"` + PeerID peer.ID `json:"peer-id,omitempty"` + ChallengeClient string `json:"challenge-client,omitempty"` + Hostname string `json:"hostname"` + CreatedTime time.Time `json:"created-time"` +} + +// Marshal serializes the state by appending it to the byte slice. +func (o *opaqueState) Marshal(hmac hash.Hash, b []byte) ([]byte, error) { + hmac.Reset() + fieldsMarshalled, err := json.Marshal(o) + if err != nil { + return b, err + } + _, err = hmac.Write(fieldsMarshalled) + if err != nil { + return b, err + } + b = hmac.Sum(b) + b = append(b, fieldsMarshalled...) + return b, nil +} + +func (o *opaqueState) Unmarshal(hmacImpl hash.Hash, d []byte) error { + hmacImpl.Reset() + if len(d) < hmacImpl.Size() { + return ErrInvalidHMAC + } + hmacVal := d[:hmacImpl.Size()] + fields := d[hmacImpl.Size():] + _, err := hmacImpl.Write(fields) + if err != nil { + return err + } + expectedHmac := hmacImpl.Sum(nil) + if !hmac.Equal(hmacVal, expectedHmac) { + return ErrInvalidHMAC + } + + err = json.Unmarshal(fields, &o) + if err != nil { + return err + } + return nil +} + +type PeerIDAuthHandshakeServer struct { + Hostname string + PrivKey crypto.PrivKey + TokenTTL time.Duration + // used to authenticate opaque blobs and tokens + Hmac hash.Hash + + ran bool + buf [1024]byte + + state peerIDAuthServerState + p params + hb headerBuilder + + opaque opaqueState +} + +var errInvalidHeader = errors.New("invalid header") + +func (h *PeerIDAuthHandshakeServer) Reset() { + h.Hmac.Reset() + h.ran = false + clear(h.buf[:]) + h.state = 0 + h.p = params{} + h.hb.clear() + h.opaque = opaqueState{} +} + +func (h *PeerIDAuthHandshakeServer) ParseHeaderVal(headerVal []byte) error { + if len(headerVal) == 0 { + // We are in the initial state. Nothing to parse. + return nil + } + err := h.p.parsePeerIDAuthSchemeParams(headerVal) + if err != nil { + return err + } + switch { + case h.p.sigB64 != nil && h.p.opaqueB64 != nil: + h.state = peerIDAuthServerStateVerifyChallenge + case h.p.bearerTokenB64 != nil: + h.state = peerIDAuthServerStateVerifyBearer + case h.p.challengeServer != nil && h.p.publicKeyB64 != nil: + h.state = peerIDAuthServerStateSignChallenge + default: + return errInvalidHeader + + } + return nil +} + +func (h *PeerIDAuthHandshakeServer) Run() error { + h.ran = true + switch h.state { + case peerIDAuthServerStateSignChallenge: + h.hb.writeScheme(PeerIDAuthScheme) + if err := h.addChallengeClientParam(); err != nil { + return err + } + if err := h.addPublicKeyParam(); err != nil { + return err + } + + publicKeyBytes, err := base64.URLEncoding.AppendDecode(nil, h.p.publicKeyB64) + if err != nil { + return err + } + h.opaque.ClientPublicKey = publicKeyBytes + if err := h.addServerSigParam(publicKeyBytes); err != nil { + return err + } + if err := h.addOpaqueParam(); err != nil { + return err + } + case peerIDAuthServerStateChallengeClient: + h.hb.writeScheme(PeerIDAuthScheme) + if err := h.addChallengeClientParam(); err != nil { + return err + } + if err := h.addPublicKeyParam(); err != nil { + return err + } + if err := h.addOpaqueParam(); err != nil { + return err + } + case peerIDAuthServerStateVerifyChallenge: + opaque, err := base64.URLEncoding.AppendDecode(h.buf[:0], h.p.opaqueB64) + if err != nil { + return err + } + err = h.opaque.Unmarshal(h.Hmac, opaque) + if err != nil { + return err + } + + if nowFn().After(h.opaque.CreatedTime.Add(challengeTTL)) { + return ErrExpiredChallenge + } + if h.opaque.IsToken { + return errors.New("expected challenge, got token") + } + + if h.Hostname != h.opaque.Hostname { + return errors.New("hostname in opaque mismatch") + } + + var publicKeyBytes []byte + clientInitiatedHandshake := h.opaque.ClientPublicKey != nil + + if clientInitiatedHandshake { + publicKeyBytes = h.opaque.ClientPublicKey + } else { + if len(h.p.publicKeyB64) == 0 { + return errors.New("missing public key") + } + var err error + publicKeyBytes, err = base64.URLEncoding.AppendDecode(nil, h.p.publicKeyB64) + if err != nil { + return err + } + } + pubKey, err := crypto.UnmarshalPublicKey(publicKeyBytes) + if err != nil { + return err + } + if err := h.verifySig(pubKey); err != nil { + return err + } + + peerID, err := peer.IDFromPublicKey(pubKey) + if err != nil { + return err + } + + // And create a bearer token for the client + h.opaque = opaqueState{ + IsToken: true, + PeerID: peerID, + Hostname: h.Hostname, + CreatedTime: nowFn(), + } + + h.hb.writeScheme(PeerIDAuthScheme) + + if !clientInitiatedHandshake { + if err := h.addServerSigParam(publicKeyBytes); err != nil { + return err + } + } + if err := h.addBearerParam(); err != nil { + return err + } + case peerIDAuthServerStateVerifyBearer: + bearerToken, err := base64.URLEncoding.AppendDecode(h.buf[:0], h.p.bearerTokenB64) + if err != nil { + return err + } + err = h.opaque.Unmarshal(h.Hmac, bearerToken) + if err != nil { + return err + } + + if !h.opaque.IsToken { + return errors.New("expected token, got challenge") + } + + if nowFn().After(h.opaque.CreatedTime.Add(h.TokenTTL)) { + return ErrExpiredToken + } + + return nil + default: + return errors.New("unhandled state") + } + + return nil +} + +func (h *PeerIDAuthHandshakeServer) addChallengeClientParam() error { + _, err := io.ReadFull(randReader, h.buf[:challengeLen]) + if err != nil { + return err + } + encodedChallenge := base64.URLEncoding.AppendEncode(h.buf[challengeLen:challengeLen], h.buf[:challengeLen]) + h.opaque.ChallengeClient = string(encodedChallenge) + h.opaque.Hostname = h.Hostname + h.opaque.CreatedTime = nowFn() + h.hb.writeParam("challenge-client", encodedChallenge) + return nil +} + +func (h *PeerIDAuthHandshakeServer) addOpaqueParam() error { + opaqueVal, err := h.opaque.Marshal(h.Hmac, h.buf[:0]) + if err != nil { + return err + } + h.hb.writeParamB64(h.buf[len(opaqueVal):], "opaque", opaqueVal) + return nil +} + +func (h *PeerIDAuthHandshakeServer) addServerSigParam(clientPublicKeyBytes []byte) error { + if len(h.p.challengeServer) < challengeLen { + return errors.New("challenge too short") + } + serverSig, err := sign(h.PrivKey, PeerIDAuthScheme, []sigParam{ + {"challenge-server", h.p.challengeServer}, + {"client-public-key", clientPublicKeyBytes}, + {"hostname", []byte(h.Hostname)}, + }) + if err != nil { + return fmt.Errorf("failed to sign challenge: %w", err) + } + h.hb.writeParamB64(h.buf[:], "sig", serverSig) + return nil +} + +func (h *PeerIDAuthHandshakeServer) addBearerParam() error { + bearerToken, err := h.opaque.Marshal(h.Hmac, h.buf[:0]) + if err != nil { + return err + } + h.hb.writeParamB64(h.buf[len(bearerToken):], "bearer", bearerToken) + return nil +} + +func (h *PeerIDAuthHandshakeServer) addPublicKeyParam() error { + serverPubKey := h.PrivKey.GetPublic() + pubKeyBytes, err := crypto.MarshalPublicKey(serverPubKey) + if err != nil { + return err + } + h.hb.writeParamB64(h.buf[:], "public-key", pubKeyBytes) + return nil +} + +func (h *PeerIDAuthHandshakeServer) verifySig(clientPubKey crypto.PubKey) error { + serverPubKey := h.PrivKey.GetPublic() + serverPubKeyBytes, err := crypto.MarshalPublicKey(serverPubKey) + if err != nil { + return err + } + sig, err := base64.URLEncoding.AppendDecode(h.buf[:0], h.p.sigB64) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } + err = verifySig(clientPubKey, PeerIDAuthScheme, []sigParam{ + {k: "challenge-client", v: []byte(h.opaque.ChallengeClient)}, + {k: "server-public-key", v: serverPubKeyBytes}, + {k: "hostname", v: []byte(h.Hostname)}, + }, sig) + if err != nil { + return err + } + return nil +} + +// PeerID returns the peer ID of the authenticated client. +func (h *PeerIDAuthHandshakeServer) PeerID() (peer.ID, error) { + if !h.ran { + return "", errNotRan + } + switch h.state { + case peerIDAuthServerStateVerifyChallenge: + case peerIDAuthServerStateVerifyBearer: + default: + return "", errors.New("not in proper state") + } + if h.opaque.PeerID == "" { + return "", errors.New("peer ID not set") + } + return h.opaque.PeerID, nil +} + +func (h *PeerIDAuthHandshakeServer) SetHeader(hdr http.Header) { + if !h.ran { + return + } + defer h.hb.clear() + switch h.state { + case peerIDAuthServerStateChallengeClient, peerIDAuthServerStateSignChallenge: + hdr.Set("WWW-Authenticate", h.hb.b.String()) + case peerIDAuthServerStateVerifyChallenge: + hdr.Set("Authentication-Info", h.hb.b.String()) + case peerIDAuthServerStateVerifyBearer: + // For completeness. Nothing to do + } +} diff --git a/p2p/http/auth/server.go b/p2p/http/auth/server.go new file mode 100644 index 0000000000..3ee4f96dc8 --- /dev/null +++ b/p2p/http/auth/server.go @@ -0,0 +1,128 @@ +package httppeeridauth + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "errors" + "hash" + "net/http" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/http/auth/internal/handshake" +) + +type ServerPeerIDAuth struct { + PrivKey crypto.PrivKey + TokenTTL time.Duration + Next func(peer peer.ID, w http.ResponseWriter, r *http.Request) + // NoTLS is a flag that allows the server to accept requests without a TLS + // ServerName. Used when something else is terminating the TLS connection. + NoTLS bool + // Required when NoTLS is true. The server will only accept requests for + // which the Host header returns true. + ValidHostnameFn func(hostname string) bool + + Hmac hash.Hash + initHmac sync.Once +} + +// ServeHTTP implements the http.Handler interface for PeerIDAuth. It will +// attempt to authenticate the request using using the libp2p peer ID auth +// scheme. If a Next handler is set, it will be called on authenticated +// requests. +func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) { + a.initHmac.Do(func() { + if a.Hmac == nil { + key := make([]byte, 32) + _, err := rand.Read(key) + if err != nil { + panic(err) + } + a.Hmac = hmac.New(sha256.New, key) + } + }) + + hostname := r.Host + if a.NoTLS { + if a.ValidHostnameFn == nil { + log.Error("No ValidHostnameFn set. Required for NoTLS") + w.WriteHeader(http.StatusInternalServerError) + return + } + if !a.ValidHostnameFn(hostname) { + log.Debugf("Unauthorized request for host %s: hostname returned false for ValidHostnameFn", hostname) + w.WriteHeader(http.StatusBadRequest) + return + } + } else { + if r.TLS == nil { + log.Warn("No TLS connection, and NoTLS is false") + w.WriteHeader(http.StatusBadRequest) + return + } + if hostname != r.TLS.ServerName { + log.Debugf("Unauthorized request for host %s: hostname mismatch. Expected %s", hostname, r.TLS.ServerName) + w.WriteHeader(http.StatusBadRequest) + return + } + if a.ValidHostnameFn != nil && !a.ValidHostnameFn(hostname) { + log.Debugf("Unauthorized request for host %s: hostname returned false for ValidHostnameFn", hostname) + w.WriteHeader(http.StatusBadRequest) + return + } + } + + hs := handshake.PeerIDAuthHandshakeServer{ + Hostname: hostname, + PrivKey: a.PrivKey, + TokenTTL: a.TokenTTL, + Hmac: a.Hmac, + } + err := hs.ParseHeaderVal([]byte(r.Header.Get("Authorization"))) + if err != nil { + log.Debugf("Failed to parse header: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + err = hs.Run() + if err != nil { + switch { + case errors.Is(err, handshake.ErrInvalidHMAC), + errors.Is(err, handshake.ErrExpiredChallenge), + errors.Is(err, handshake.ErrExpiredToken): + + hs := handshake.PeerIDAuthHandshakeServer{ + Hostname: hostname, + PrivKey: a.PrivKey, + TokenTTL: a.TokenTTL, + Hmac: a.Hmac, + } + hs.Run() + hs.SetHeader(w.Header()) + w.WriteHeader(http.StatusUnauthorized) + + return + } + + log.Debugf("Failed to run handshake: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + hs.SetHeader(w.Header()) + + peer, err := hs.PeerID() + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + if a.Next == nil { + w.WriteHeader(http.StatusOK) + return + } + a.Next(peer, w, r) +} diff --git a/test-plans/go.mod b/test-plans/go.mod index 6bace7c9a6..d1842149a9 100644 --- a/test-plans/go.mod +++ b/test-plans/go.mod @@ -2,6 +2,8 @@ module github.com/libp2p/go-libp2p/test-plans/m/v2 go 1.22 +toolchain go1.22.1 + require ( github.com/go-redis/redis/v8 v8.11.5 github.com/libp2p/go-libp2p v0.0.0