-
Notifications
You must be signed in to change notification settings - Fork 19
/
auth_test.go
78 lines (60 loc) · 2.07 KB
/
auth_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package wire
import (
"bytes"
"context"
"fmt"
"strconv"
"testing"
"github.com/jeroenrinzema/psql-wire/pkg/buffer"
"github.com/jeroenrinzema/psql-wire/pkg/types"
"github.com/neilotoole/slogt"
"github.com/stretchr/testify/require"
)
func TestDefaultHandleAuth(t *testing.T) {
input := bytes.NewBuffer([]byte{})
sink := bytes.NewBuffer([]byte{})
ctx := context.Background()
reader := buffer.NewReader(slogt.New(t), input, buffer.DefaultBufferSize)
writer := buffer.NewWriter(slogt.New(t), sink)
server := &Server{logger: slogt.New(t)}
_, err := server.handleAuth(ctx, reader, writer)
require.NoError(t, err)
result := buffer.NewReader(slogt.New(t), sink, buffer.DefaultBufferSize)
ty, ln, err := result.ReadTypedMsg()
require.NoError(t, err)
if ln == 0 {
t.Error("unexpected length, expected typed message length to be greater then 0")
}
if ty != 'R' {
t.Errorf("unexpected message type %s, expected 'R'", strconv.QuoteRune(rune(ty)))
}
status, err := result.GetUint32()
require.NoError(t, err)
if authType(status) != authOK {
t.Errorf("unexpected auth status %d, expected OK", status)
}
}
func TestClearTextPassword(t *testing.T) {
expected := "password"
input := bytes.NewBuffer([]byte{})
incoming := buffer.NewWriter(slogt.New(t), input)
// NOTE: we could reuse the server buffered writer to write client messages
incoming.Start(types.ServerMessage(types.ClientPassword))
incoming.AddString(expected)
incoming.AddNullTerminate()
incoming.End() //nolint:errcheck
validate := func(ctx context.Context, username, password string) (context.Context, bool, error) {
if password != expected {
return ctx, false, fmt.Errorf("unexpected password: %s", password)
}
return ctx, true, nil
}
sink := bytes.NewBuffer([]byte{})
ctx := context.Background()
reader := buffer.NewReader(slogt.New(t), input, buffer.DefaultBufferSize)
writer := buffer.NewWriter(slogt.New(t), sink)
server := &Server{logger: slogt.New(t), Auth: ClearTextPassword(validate)}
out, err := server.handleAuth(ctx, reader, writer)
require.NoError(t, err)
require.Equal(t, ctx, out)
}