Skip to content

Commit

Permalink
fix: Improve error types and their visibility
Browse files Browse the repository at this point in the history
Signed-off-by: Steffen Vogel <[email protected]>
  • Loading branch information
stv0g committed Aug 8, 2023
1 parent 9b096f5 commit 885c82a
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 41 deletions.
4 changes: 3 additions & 1 deletion conn_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"golang.org/x/exp/slog"
)

var errInvalidEndpoint = errors.New("invalid endpoint type")

type udpEndpoint struct {
*net.UDPAddr
}
Expand Down Expand Up @@ -59,7 +61,7 @@ func (s *udpConn) Close() error {
func (s *udpConn) Send(pl payload, spkt spk, ep endpoint) error {
uep, ok := ep.(*udpEndpoint)
if !ok {
return errors.New("invalid endpoint type")
return errInvalidEndpoint
}

e := envelope{
Expand Down
14 changes: 7 additions & 7 deletions handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (

var (
// TODO: Only expose errors which are need on the public API.
ErrUnexpectedMsgType = errors.New("received unexpected message type")
ErrPeerNotFound = errors.New("peer not found")
ErrSessionNotFound = errors.New("session not found")
ErrInvalidAuthTag = errors.New("invalid authentication tag")
ErrReplayDetected = errors.New("detected replay")
ErrStaleNonce = errors.New("stale nonce")
ErrInvalidBiscuit = errors.New("failed decrypt biscuit")
errUnexpectedMsgType = errors.New("received unexpected message type")
errPeerNotFound = errors.New("peer not found")
errSessionNotFound = errors.New("session not found")
errInvalidAuthTag = errors.New("invalid authentication tag")
errReplayDetected = errors.New("detected replay")
errStaleNonce = errors.New("stale nonce")
errInvalidBiscuit = errors.New("failed decrypt biscuit")
)

type handshake struct {
Expand Down
6 changes: 3 additions & 3 deletions handshake_initiator.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (hs *initiatorHandshake) handleRespHello(r *respHello) error {

// RHI7: Add a message authentication code for the same reason as above.
if _, err := hs.decryptAndMix(r.auth[:]); err != nil {
return fmt.Errorf("%w (RHI7): %w", ErrInvalidAuthTag, err)
return fmt.Errorf("%w (RHI7): %w", errInvalidAuthTag, err)
}

return nil
Expand Down Expand Up @@ -149,7 +149,7 @@ func (hs *initiatorHandshake) handleEmptyData(e *emptyData) error {

// TODO: Check nonce counter
if txnt < hs.txnt {
return ErrStaleNonce
return errStaleNonce
}
hs.txnt = txnt

Expand All @@ -159,7 +159,7 @@ func (hs *initiatorHandshake) handleEmptyData(e *emptyData) error {
}

if _, err := aead.Open(nil, n, e.auth[:], []byte{}); err != nil {
return ErrInvalidAuthTag
return errInvalidAuthTag
}

return nil
Expand Down
14 changes: 7 additions & 7 deletions handshake_responder.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (hs *responderHandshake) handleInitHello(h *initHello) error {

var ok bool
if hs.peer, ok = hs.server.peers[pid(pidi)]; !ok {
return fmt.Errorf("failed to lookup peer %s (IHR6): %w", pid(pidi), ErrPeerNotFound)
return fmt.Errorf("failed to lookup peer %s (IHR6): %w", pid(pidi), errPeerNotFound)
}

// IHR7: Ensure the responder has the correct view on spki. Mix in the PSK as optional
Expand All @@ -50,7 +50,7 @@ func (hs *responderHandshake) handleInitHello(h *initHello) error {
// IHR8: Add a message authentication code to ensure both participants agree on the
// session state and protocol transcript at this point.
if _, err := hs.decryptAndMix(h.auth[:]); err != nil {
return fmt.Errorf("%w (IHR8): %w", ErrInvalidAuthTag, err)
return fmt.Errorf("%w (IHR8): %w", errInvalidAuthTag, err)
}

return nil
Expand Down Expand Up @@ -127,12 +127,12 @@ func (hs *responderHandshake) handleInitConf(i *initConf) error {
// ICR4: Message authentication code for the same reason as above, which in particular
// ensures that both participants agree on the final chaining key.
if _, err := hs.decryptAndMix(i.auth[:]); err != nil {
return fmt.Errorf("%w (ICR4): %w", ErrInvalidAuthTag, err)
return fmt.Errorf("%w (ICR4): %w", errInvalidAuthTag, err)
}

// ICR5: Biscuit replay detection.
if !bNo.Larger(hs.peer.biscuitUsed) {
return fmt.Errorf("%w (ICR5)", ErrReplayDetected)
return fmt.Errorf("%w (ICR5)", errReplayDetected)
} else if bNo.Equal(hs.peer.biscuitUsed) {
// This is a retransmitted InitConf message.
// We skip ICR6 & ICR6 and just reply with EmptyData
Expand Down Expand Up @@ -233,12 +233,12 @@ func (hs *responderHandshake) loadBiscuit(sb sealedBiscuit) (biscuitNo, error) {
// Find the peer and apply retransmission protection
var ok bool
if hs.peer, ok = hs.server.peers[b.pidi]; !ok {
return biscuitNo{}, ErrPeerNotFound
return biscuitNo{}, errPeerNotFound
}

// assert(pt.biscuit_no ≤ peer.biscuit_used);
if hs.peer.biscuitUsed.LargerOrEqual(b.biscuitNo) {
return biscuitNo{}, ErrReplayDetected
return biscuitNo{}, errReplayDetected
}

// Restore the chaining key
Expand All @@ -252,7 +252,7 @@ func (hs *responderHandshake) loadBiscuit(sb sealedBiscuit) (biscuitNo, error) {
return b.biscuitNo, nil
}

return biscuitNo{}, ErrInvalidBiscuit
return biscuitNo{}, errInvalidBiscuit
}

func (hs *responderHandshake) enterLive() {
Expand Down
26 changes: 13 additions & 13 deletions messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (
)

var (
ErrMsgTruncated = errors.New("message is truncated")
ErrInvalidLen = errors.New("invalid message length")
ErrInvalidMsgType = errors.New("invalid message type")
ErrInvalidMAC = errors.New("invalid mac")
errMsgTruncated = errors.New("message is truncated")
errInvalidLen = errors.New("invalid message length")
errInvalidMsgType = errors.New("invalid message type")
errInvalidMAC = errors.New("invalid mac")
)

type payload interface {
Expand Down Expand Up @@ -40,7 +40,7 @@ func (e *envelope) MarshalBinaryAndSeal(spkt spk) []byte {

func (e *envelope) CheckAndUnmarshalBinary(buf []byte, spkm spk) (int, error) {
if len(buf) < envelopeSize {
return -1, ErrMsgTruncated
return -1, errMsgTruncated
}

macOffset := len(buf) - macSize - cookieSize
Expand All @@ -49,7 +49,7 @@ func (e *envelope) CheckAndUnmarshalBinary(buf []byte, spkm spk) (int, error) {
macCalc := macKey[:macSize]

if subtle.ConstantTimeCompare(macWire, macCalc) != 1 {
return -1, ErrInvalidMAC
return -1, errInvalidMAC
}

return e.UnmarshalBinary(buf)
Expand All @@ -71,7 +71,7 @@ func (e *envelope) MarshalBinary() []byte {
func (e *envelope) UnmarshalBinary(buf []byte) (int, error) {
lenPayload := len(buf) - envelopeSize
if lenPayload <= 0 {
return -1, ErrMsgTruncated
return -1, errMsgTruncated
}

mtype := msgType(buf[0])
Expand All @@ -86,7 +86,7 @@ func (e *envelope) UnmarshalBinary(buf []byte) (int, error) {
case msgTypeEmptyData:
e.payload = &emptyData{}
default:
return -1, ErrInvalidMsgType
return -1, errInvalidMsgType
}

o := 4
Expand Down Expand Up @@ -118,7 +118,7 @@ func (b *biscuit) MarshalBinary() []byte {

func (b *biscuit) UnmarshalBinary(buf []byte) (int, error) {
if len(buf) != biscuitSize {
return -1, ErrInvalidLen
return -1, errInvalidLen
}

o := copy(b.pidi[:], buf)
Expand Down Expand Up @@ -147,7 +147,7 @@ func (m *initHello) MarshalBinary() []byte {

func (m *initHello) UnmarshalBinary(buf []byte) (int, error) {
if len(buf) != initHelloMsgSize {
return -1, ErrInvalidLen
return -1, errInvalidLen
}

o := copy(m.sidi[:], buf)
Expand Down Expand Up @@ -184,7 +184,7 @@ func (m *respHello) MarshalBinary() []byte {

func (m *respHello) UnmarshalBinary(buf []byte) (int, error) {
if len(buf) != respHelloMsgSize {
return -1, ErrInvalidLen
return -1, errInvalidLen
}

o := copy(m.sidr[:], buf)
Expand Down Expand Up @@ -218,7 +218,7 @@ func (m *initConf) MarshalBinary() []byte {

func (m *initConf) UnmarshalBinary(buf []byte) (int, error) {
if len(buf) != initConfMsgSize {
return -1, ErrInvalidLen
return -1, errInvalidLen
}

o := copy(m.sidi[:], buf)
Expand All @@ -244,7 +244,7 @@ func (m *emptyData) MarshalBinary() []byte {

func (m *emptyData) UnmarshalBinary(buf []byte) (int, error) {
if len(buf) != emptyDataMsgSize {
return -1, ErrInvalidLen
return -1, errInvalidLen
}

o := copy(m.sid[:], buf)
Expand Down
4 changes: 3 additions & 1 deletion output.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ const (
KeyOutputReasonStale KeyOutputReason = "stale"
)

var errInvalidOutputFormat = errors.New("invalid output format")

// Output format:
// output-key peer {} key-file {of:?} {why}.
type KeyOutput struct {
Expand All @@ -31,7 +33,7 @@ func ParseKeyOutput(str string) (o KeyOutput, err error) {
if tokens[0] != "output-key" ||
tokens[1] != "peer" ||
tokens[3] != "key-file" {
return o, errors.New("invalid output format")
return o, errInvalidOutputFormat
}

if o.Peer, err = ParsePeerID(tokens[2]); err != nil {
Expand Down
9 changes: 6 additions & 3 deletions peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ import (
"golang.org/x/exp/slog"
)

var ErrMissingEndpoint = errors.New("missing endpoint")
var (
errMissingEndpoint = errors.New("missing endpoint")
errMissingPublicKey = errors.New("missing public key")
)

type PeerConfig struct {
PublicKey spk // The peer’s public key
Expand Down Expand Up @@ -43,7 +46,7 @@ type peer struct {

func (s *Server) newPeer(cfg PeerConfig) (*peer, error) {
if cfg.PublicKey == nil {
return nil, errors.New("missing public key")
return nil, errMissingPublicKey
}

p := &peer{
Expand Down Expand Up @@ -73,7 +76,7 @@ func (p *peer) PID() pid {

func (p *peer) initiateHandshake() (*initiatorHandshake, error) {
if p.endpoint == nil {
return nil, ErrMissingEndpoint
return nil, errMissingEndpoint
}

hs := &initiatorHandshake{
Expand Down
12 changes: 6 additions & 6 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ func (s *Server) handle(pl payload, from endpoint) (err error) {
case *respHello:
hs, ok := s.getHandshake(req.sidi)
if !ok {
return fmt.Errorf("%s: %s", ErrSessionNotFound, req.sidi)
return fmt.Errorf("%s: %s", errSessionNotFound, req.sidi)
}

if hs.nextMsg != msgTypeRespHello {
return fmt.Errorf("%w: %s", ErrUnexpectedMsgType, mTyp)
return fmt.Errorf("%w: %s", errUnexpectedMsgType, mTyp)
}

if err = hs.handleRespHello(req); err != nil {
Expand Down Expand Up @@ -245,11 +245,11 @@ func (s *Server) handle(pl payload, from endpoint) (err error) {
case *emptyData:
hs, ok := s.getHandshake(req.sid)
if !ok {
return fmt.Errorf("%s: %s", ErrSessionNotFound, req.sid)
return fmt.Errorf("%s: %s", errSessionNotFound, req.sid)
}

if hs.nextMsg != msgTypeEmptyData {
return fmt.Errorf("%w: %s", ErrUnexpectedMsgType, mTyp)
return fmt.Errorf("%w: %s", errUnexpectedMsgType, mTyp)
}

if err = hs.handleEmptyData(req); err != nil {
Expand All @@ -264,7 +264,7 @@ func (s *Server) handle(pl payload, from endpoint) (err error) {
s.completeHandshake(&hs.handshake, from, RekeyAfterTimeInitiator)

default:
return ErrInvalidMsgType
return errInvalidMsgType
}

return nil
Expand Down Expand Up @@ -292,7 +292,7 @@ func (s *Server) removeHandshake(hs *initiatorHandshake) {

func (s *Server) initiateHandshake(p *peer) {
if hs, err := p.initiateHandshake(); err != nil {
if errors.Is(err, ErrMissingEndpoint) {
if errors.Is(err, errMissingEndpoint) {
p.logger.Debug("Skipping handshake due to missing endpoint")
} else {
p.logger.Error("Failed to initiate handshake for peer", slog.Any("error", err))
Expand Down

0 comments on commit 885c82a

Please sign in to comment.