Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add token to Unix Socket API message #90

Merged
merged 3 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions internal/api/message.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
package api

import (
"crypto/rand"
"encoding/base64"
"encoding/binary"
"errors"
"io"
)

const (
// MaxPayloadLength is the maximum allowed length of a message payload.
MaxPayloadLength = 2097152
// TokenLength is the length of the message token in bytes.
TokenLength = 16
)

var (
// token is the message token.
token [TokenLength]byte
)

// Message types.
Expand All @@ -24,6 +31,7 @@ const (
type Header struct {
Type uint16
Length uint32
Token [TokenLength]byte
}

// Message is an API message.
Expand All @@ -34,13 +42,11 @@ type Message struct {

// NewMessage returns a new message with type t and payload p.
func NewMessage(t uint16, p []byte) *Message {
if len(p) > MaxPayloadLength {
return nil
}
return &Message{
Header: Header{
Type: t,
Length: uint32(len(p)),
Token: token,
},
Value: p,
}
Expand Down Expand Up @@ -69,8 +75,8 @@ func ReadMessage(r io.Reader) (*Message, error) {
if h.Type == TypeNone || h.Type >= TypeUndefined {
return nil, errors.New("invalid message type")
}
if h.Length > MaxPayloadLength {
return nil, errors.New("invalid message length")
if h.Token != token {
return nil, errors.New("invalid message token")
}

// read payload
Expand Down Expand Up @@ -107,3 +113,26 @@ func WriteMessage(w io.Writer, m *Message) error {

return nil
}

// GetToken generates and returns the message token as string. This should be
// used once on the server side before the server is started. Token must be
// passed to the client side.
func GetToken() (string, error) {
_, err := rand.Read(token[:])
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(token[:]), nil
}

// SetToken sets the message token from string. This should be used on the
// client side before sending requests to the server. Token must match token on
// the server side.
func SetToken(s string) error {
b, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
return err
}
copy(token[:], b)
return nil
}
47 changes: 37 additions & 10 deletions internal/api/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"bytes"
"encoding/base64"
"errors"
"log"
"reflect"
Expand All @@ -24,12 +25,6 @@ func TestNewMessage(t *testing.T) {
t.Errorf("got %d, want %d", msg.Type, typ)
}
}

// invalid payload length
p := [MaxPayloadLength + 1]byte{}
if NewMessage(TypeOK, p[:]) != nil {
t.Error("should not create message with invalid payload length")
}
}

// TestNewOK tests NewOK.
Expand Down Expand Up @@ -62,11 +57,11 @@ func TestReadMessageErrors(t *testing.T) {
// invalid type
{Header: Header{Type: TypeUndefined}},

// invalid length
{Header: Header{Type: TypeOK, Length: MaxPayloadLength + 1}},

// short message
{Header: Header{Type: TypeOK, Length: MaxPayloadLength}},
{Header: Header{Type: TypeOK, Length: 4096}},

// invalid token
{Header: Header{Type: TypeOK, Token: [16]byte{1}}},
} {
if err := WriteMessage(buf, msg); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -132,3 +127,35 @@ func TestReadWriteMessage(t *testing.T) {
t.Errorf("got %v, want %v", got, want)
}
}

// TestGetSetToken tests GetToken and SetToken.
func TestGetSetToken(t *testing.T) {
// reset token after tests
defer func() { token = [TokenLength]byte{} }()

// get new test token
testToken, err := GetToken()
if err != nil {
t.Fatal(err)
}
s := base64.RawURLEncoding.EncodeToString(token[:])
if testToken != s {
t.Fatal("encoded token should match internal token")
}

// set token
if err := SetToken(testToken); err != nil {
t.Fatal(err)
}

// check token
s = base64.RawURLEncoding.EncodeToString(token[:])
if s != testToken {
t.Fatal("internal token should match encoded token")
}

// setting invalid token
if err := SetToken("not a valid encoded token!"); err == nil {
t.Fatal("invalid token should return error")
}
}
15 changes: 2 additions & 13 deletions internal/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package daemon

import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net"
"reflect"
Expand Down Expand Up @@ -344,13 +342,6 @@ func (d *Daemon) updateVPNConfig(request *api.Request) {
return
}

// check token
if configUpdate.Token != d.token {
log.Error("Daemon got invalid token in vpn config update")
request.Error("invalid token in config update message")
return
}

// handle config update for vpn (dis)connect
if configUpdate.Reason == "disconnect" {
d.updateVPNConfigDown()
Expand Down Expand Up @@ -508,13 +499,11 @@ func (d *Daemon) cleanup(ctx context.Context) {

// initToken creates the daemon token for client authentication.
func (d *Daemon) initToken() error {
// TODO: is this good enough for us?
b := make([]byte, 16)
_, err := rand.Read(b)
token, err := api.GetToken()
if err != nil {
return err
}
d.token = base64.RawURLEncoding.EncodeToString(b)
d.token = token
return nil
}

Expand Down
12 changes: 4 additions & 8 deletions internal/daemon/vpnconfigupdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,20 @@ import (
// VPNConfigUpdate is a VPN configuration update.
type VPNConfigUpdate struct {
Reason string
Token string
Config *vpnconfig.Config
}

// Valid returns whether the config update is valid.
func (c *VPNConfigUpdate) Valid() bool {
switch c.Reason {
case "disconnect":
// token must be valid and config nil
if c.Token == "" || c.Config != nil {
// config must be nil
if c.Config != nil {
return false
}
case "connect":
// token and config must be valid
if c.Token == "" || c.Config == nil {
return false
}
if !c.Config.Valid() {
// config must be valid
if c.Config == nil || !c.Config.Valid() {
return false
}
default:
Expand Down
8 changes: 2 additions & 6 deletions internal/daemon/vpnconfigupdate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ func TestVPNConfigUpdateValid(t *testing.T) {
// test invalid disconnect
u = NewVPNConfigUpdate()
u.Reason = "disconnect"
u.Config = vpnconfig.New()

got = u.Valid()
want = false
if got != want {
t.Errorf("got %t, want %t", got, want)
}

// test invalid connect, no token and no config
// test invalid connect, no config
u = NewVPNConfigUpdate()
u.Reason = "connect"

Expand All @@ -42,7 +43,6 @@ func TestVPNConfigUpdateValid(t *testing.T) {
// test invalid connect, invalid config
u = NewVPNConfigUpdate()
u.Reason = "connect"
u.Token = "some test token"
u.Config = vpnconfig.New()
u.Config.Device.Name = "name is too long for a network device"

Expand All @@ -55,7 +55,6 @@ func TestVPNConfigUpdateValid(t *testing.T) {
// test valid disconnect
u = NewVPNConfigUpdate()
u.Reason = "disconnect"
u.Token = "some test token"

got = u.Valid()
want = true
Expand All @@ -66,7 +65,6 @@ func TestVPNConfigUpdateValid(t *testing.T) {
// test valid connect
u = NewVPNConfigUpdate()
u.Reason = "connect"
u.Token = "some test token"
u.Config = vpnconfig.New()

got = u.Valid()
Expand All @@ -87,13 +85,11 @@ func TestVPNConfigUpdateJSON(t *testing.T) {
// valid disconnect
u = NewVPNConfigUpdate()
u.Reason = "disconnect"
u.Token = "some test token"
updates = append(updates, u)

// valid connect
u = NewVPNConfigUpdate()
u.Reason = "connect"
u.Token = "some test token"
u.Config = vpnconfig.New()
updates = append(updates, u)

Expand Down
23 changes: 6 additions & 17 deletions internal/vpncscript/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestRunClient(t *testing.T) {
return confUpdate
}

// test with maximum payload length
// test with varying payload lengths
server = api.NewServer(config)
go func() {
for r := range server.Requests() {
Expand All @@ -79,23 +79,12 @@ func TestRunClient(t *testing.T) {
if err := server.Start(); err != nil {
t.Fatal(err)
}
if err := runClient(sockfile, getConfUpdate(api.MaxPayloadLength)); err != nil {
t.Fatal(err)
}
server.Stop()

// test with more than maximum payload length
server = api.NewServer(config)
go func() {
for r := range server.Requests() {
r.Close()
for _, length := range []int{
2048, 4096, 8192, 65536, 2097152,
} {
if err := runClient(sockfile, getConfUpdate(length)); err != nil {
t.Errorf("length %d returned error: %v", length, err)
}
}()
if err := server.Start(); err != nil {
t.Fatal(err)
}
if err := runClient(sockfile, getConfUpdate(api.MaxPayloadLength+1)); err == nil {
t.Fatal("too long message should return error")
}
server.Stop()
}
6 changes: 6 additions & 0 deletions internal/vpncscript/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"

log "github.com/sirupsen/logrus"
"github.com/telekom-mms/oc-daemon/internal/api"
"github.com/telekom-mms/oc-daemon/internal/daemon"
)

Expand Down Expand Up @@ -47,6 +48,11 @@ func run(args []string) error {
socketFile = e.socketFile
}

// set token from environemt
if err := api.SetToken(e.token); err != nil {
return fmt.Errorf("VPNCScript could not set token: %w", err)
}

printDebugEnvironment()
log.WithField("env", e).Debug("VPNCScript parsed environment")

Expand Down
6 changes: 6 additions & 0 deletions internal/vpncscript/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ func TestRun(t *testing.T) {
t.Errorf("help should return ErrHelp, got: %v", err)
}

// test with invalid token
t.Setenv("oc_daemon_token", "this is not a valid encoded token!")
if err := run([]string{"test"}); err == nil {
t.Errorf("invalid token should return error")
}

// prepare environment with not existing sockfile
os.Clearenv()
sockfile := filepath.Join(t.TempDir(), "sockfile")
Expand Down
1 change: 0 additions & 1 deletion internal/vpncscript/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ func createConfig(env *env) (*vpnconfig.Config, error) {
func createConfigUpdate(env *env) (*daemon.VPNConfigUpdate, error) {
update := daemon.NewVPNConfigUpdate()
update.Reason = env.reason
update.Token = env.token
if env.reason == "connect" {
c, err := createConfig(env)
if err != nil {
Expand Down
4 changes: 0 additions & 4 deletions internal/vpncscript/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ func TestCreateConfigUpdate(t *testing.T) {

// create expected values based on test environment
reason := "connect"
token := "some token"
config := &vpnconfig.Config{
Gateway: net.IPv4(10, 1, 1, 1),
PID: 12345,
Expand Down Expand Up @@ -129,9 +128,6 @@ func TestCreateConfigUpdate(t *testing.T) {
if got.Reason != reason {
t.Errorf("got %s, want %s", got.Reason, reason)
}
if got.Token != token {
t.Errorf("got %s, want %s", got.Token, token)
}
if !reflect.DeepEqual(got.Config, config) {
t.Errorf("got:\n%#v\nwant:\n%#v", got.Config, config)
}
Expand Down
Loading