Skip to content

Commit

Permalink
Fix bug on checking headscale connection. (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerrit91 authored Jun 10, 2024
1 parent c4ef2d4 commit a8b9325
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 43 deletions.
10 changes: 2 additions & 8 deletions cmd/metal-api/internal/headscale/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,13 @@ func (h *HeadscaleClient) CreatePreAuthKey(ctx context.Context, user string, exp
return resp.PreAuthKey.Key, nil
}

type connectedMap map[string]bool

func (h *HeadscaleClient) MachinesConnected(ctx context.Context) (connectedMap, error) {
func (h *HeadscaleClient) MachinesConnected(ctx context.Context) ([]*headscalev1.Machine, error) {
resp, err := h.client.ListMachines(ctx, &headscalev1.ListMachinesRequest{})
if err != nil || resp == nil {
return nil, fmt.Errorf("failed to list machines: %w", err)
}
result := connectedMap{}
for _, m := range resp.Machines {
result[m.Name] = m.Online
}

return result, nil
return resp.Machines, nil
}

// DeleteMachine removes the node entry from headscale DB
Expand Down
73 changes: 73 additions & 0 deletions cmd/metal-api/internal/service/vpn-service.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
package service

import (
"context"
"fmt"
"log/slog"
"net/http"
"slices"
"time"

restfulspec "github.com/emicklei/go-restful-openapi/v2"

"github.com/emicklei/go-restful/v3"

headscalev1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/metal-stack/metal-api/cmd/metal-api/internal/datastore"
"github.com/metal-stack/metal-api/cmd/metal-api/internal/headscale"
v1 "github.com/metal-stack/metal-api/cmd/metal-api/internal/service/v1"
"github.com/metal-stack/metal-lib/httperrors"
"github.com/metal-stack/metal-lib/pkg/pointer"
)

type vpnResource struct {
Expand Down Expand Up @@ -100,3 +105,71 @@ func (r *vpnResource) getVPNAuthKey(request *restful.Request, response *restful.

r.send(request, response, http.StatusOK, authKeyResp)
}

type headscaleMachineLister interface {
MachinesConnected(ctx context.Context) ([]*headscalev1.Machine, error)
}

func EvaluateVPNConnected(log *slog.Logger, ds *datastore.RethinkStore, lister headscaleMachineLister) error {
ms, err := ds.ListMachines()
if err != nil {
return err
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()

headscaleMachines, err := lister.MachinesConnected(ctx)
if err != nil {
return err
}

var errs []error
for _, m := range ms {
m := m
if m.Allocation == nil || m.Allocation.VPN == nil {
continue
}

index := slices.IndexFunc(headscaleMachines, func(hm *headscalev1.Machine) bool {
if hm.Name != m.ID {
return false
}

if pointer.SafeDeref(hm.User).Name != m.Allocation.Project {
return false
}

return true
})

if index < 0 {
continue
}

connected := headscaleMachines[index].Online

if m.Allocation.VPN.Connected == connected {
log.Info("not updating vpn because already up-to-date", "machine", m.ID, "connected", connected)
continue
}

old := m
m.Allocation.VPN.Connected = connected

err := ds.UpdateMachine(&old, &m)
if err != nil {
errs = append(errs, err)
log.Error("unable to update vpn connected state, continue anyway", "machine", m.ID, "error", err)
continue
}

log.Info("updated vpn connected state", "machine", m.ID, "connected", connected)
}

if len(errs) > 0 {
return fmt.Errorf("errors occurred when evaluating machine vpn connections")
}

return nil
}
112 changes: 112 additions & 0 deletions cmd/metal-api/internal/service/vpn-service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package service

import (
"context"
"log/slog"
"testing"

"github.com/google/go-cmp/cmp"
headscalev1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/metal-stack/metal-api/cmd/metal-api/internal/datastore"
"github.com/metal-stack/metal-api/cmd/metal-api/internal/metal"
"github.com/metal-stack/metal-api/cmd/metal-api/internal/testdata"
"github.com/metal-stack/metal-lib/pkg/testcommon"
r "gopkg.in/rethinkdb/rethinkdb-go.v6"
)

func Test_EvaluateVPNConnected(t *testing.T) {
tests := []struct {
name string
mockFn func(mock *r.Mock)
headscaleMachines []*headscalev1.Machine
wantErr error
}{
{
name: "machines are correctly evaluated",
mockFn: func(mock *r.Mock) {
mock.On(r.DB("mockdb").Table("machine")).Return(metal.Machines{
{
Base: metal.Base{
ID: "toggle",
},
Allocation: &metal.MachineAllocation{
Project: "p1",
VPN: &metal.MachineVPN{
Connected: false,
},
},
},
{
Base: metal.Base{
ID: "already-connected",
},
Allocation: &metal.MachineAllocation{
Project: "p2",
VPN: &metal.MachineVPN{
Connected: true,
},
},
},
{
Base: metal.Base{
ID: "no-vpn",
},
Allocation: &metal.MachineAllocation{
Project: "p3",
},
},
}, nil)

// unfortunately, it's too hard to check the replace exactly for specific fields...
mock.On(r.DB("mockdb").Table("machine").Get("toggle").Replace(r.MockAnything())).Return(testdata.EmptyResult, nil)
},
headscaleMachines: []*headscalev1.Machine{
{
Name: "toggle",
User: &headscalev1.User{
Name: "previous-allocation",
},
Online: false,
},
{
Name: "toggle",
User: &headscalev1.User{
Name: "p1",
},
Online: true,
},
{
Name: "already-connected",
User: &headscalev1.User{
Name: "p2",
},
Online: true,
},
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ds, mock := datastore.InitMockDB(t)
if tt.mockFn != nil {
tt.mockFn(mock)
}

err := EvaluateVPNConnected(slog.Default(), ds, &headscaleTest{ms: tt.headscaleMachines})
if diff := cmp.Diff(tt.wantErr, err, testcommon.ErrorStringComparer()); diff != "" {
t.Errorf("error diff (-want +got):\n%s", diff)
}

mock.AssertExpectations(t)
})
}
}

type headscaleTest struct {
ms []*headscalev1.Machine
}

func (h *headscaleTest) MachinesConnected(ctx context.Context) ([]*headscalev1.Machine, error) {
return h.ms, nil
}
36 changes: 1 addition & 35 deletions cmd/metal-api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -888,41 +888,7 @@ func evaluateVPNConnected() error {
return err
}

ms, err := ds.ListMachines()
if err != nil {
return err
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()

connectedMap, err := headscaleClient.MachinesConnected(ctx)
if err != nil {
return err
}

var errs []error
for _, m := range ms {
m := m
if m.Allocation == nil || m.Allocation.VPN == nil {
continue
}
connected := connectedMap[m.ID]
if m.Allocation.VPN.Connected == connected {
continue
}

old := m
m.Allocation.VPN.Connected = connected
err := ds.UpdateMachine(&old, &m)
if err != nil {
errs = append(errs, err)
logger.Error("unable to update vpn connected state, continue anyway", "machine", m.ID, "error", err)
continue
}
logger.Info("updated vpn connected state", "machine", m.ID, "connected", connected)
}
return errors.Join(errs...)
return service.EvaluateVPNConnected(logger, ds, headscaleClient)
}

// might return (nil, nil) if auditing is disabled!
Expand Down

0 comments on commit a8b9325

Please sign in to comment.