Skip to content

Commit

Permalink
Rework context handling for headscale client. (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerrit91 authored Jul 28, 2023
1 parent c475c1b commit d023770
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 37 deletions.
36 changes: 16 additions & 20 deletions cmd/metal-api/internal/headscale/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ type HeadscaleClient struct {
address string
controlPlaneAddress string

ctx context.Context
conn *grpc.ClientConn
cancelFunc context.CancelFunc
logger *zap.SugaredLogger
conn *grpc.ClientConn
logger *zap.SugaredLogger
}

func NewHeadscaleClient(addr, controlPlaneAddr, apiKey string, logger *zap.SugaredLogger) (client *HeadscaleClient, err error) {
Expand All @@ -45,7 +43,6 @@ func NewHeadscaleClient(addr, controlPlaneAddr, apiKey string, logger *zap.Sugar

logger: logger,
}
h.ctx, h.cancelFunc = context.WithCancel(context.Background())

if err = h.connect(apiKey); err != nil {
return nil, fmt.Errorf("failed to connect to Headscale server: %w", err)
Expand All @@ -56,7 +53,7 @@ func NewHeadscaleClient(addr, controlPlaneAddr, apiKey string, logger *zap.Sugar

// Connect or reconnect to Headscale server
func (h *HeadscaleClient) connect(apiKey string) (err error) {
ctx, cancel := context.WithTimeout(h.ctx, 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

grpcOptions := []grpc.DialOption{
Expand All @@ -81,22 +78,22 @@ func (h *HeadscaleClient) GetControlPlaneAddress() string {
return h.controlPlaneAddress
}

func (h *HeadscaleClient) UserExists(name string) bool {
func (h *HeadscaleClient) UserExists(ctx context.Context, name string) bool {
req := &headscalev1.GetUserRequest{
Name: name,
}
if _, err := h.client.GetUser(h.ctx, req); err != nil {
if _, err := h.client.GetUser(ctx, req); err != nil {
return false
}

return true
}

func (h *HeadscaleClient) CreateUser(name string) error {
func (h *HeadscaleClient) CreateUser(ctx context.Context, name string) error {
req := &headscalev1.CreateUserRequest{
Name: name,
}
_, err := h.client.CreateUser(h.ctx, req)
_, err := h.client.CreateUser(ctx, req)
// TODO: this error check is pretty rough, but it's not easily possible to compare the proto error directly :/
if err != nil && !strings.Contains(err.Error(), hscontrol.ErrUserExists.Error()) {
return fmt.Errorf("failed to create new VPN user: %w", err)
Expand All @@ -105,13 +102,13 @@ func (h *HeadscaleClient) CreateUser(name string) error {
return nil
}

func (h *HeadscaleClient) CreatePreAuthKey(user string, expiration time.Time, isEphemeral bool) (key string, err error) {
func (h *HeadscaleClient) CreatePreAuthKey(ctx context.Context, user string, expiration time.Time, isEphemeral bool) (key string, err error) {
req := &headscalev1.CreatePreAuthKeyRequest{
User: user,
Expiration: timestamppb.New(expiration),
Ephemeral: isEphemeral,
}
resp, err := h.client.CreatePreAuthKey(h.ctx, req)
resp, err := h.client.CreatePreAuthKey(ctx, req)
if err != nil || resp == nil || resp.PreAuthKey == nil {
return "", fmt.Errorf("failed to create new Auth Key: %w", err)
}
Expand All @@ -121,8 +118,8 @@ func (h *HeadscaleClient) CreatePreAuthKey(user string, expiration time.Time, is

type connectedMap map[string]bool

func (h *HeadscaleClient) MachinesConnected() (connectedMap, error) {
resp, err := h.client.ListMachines(h.ctx, &headscalev1.ListMachinesRequest{})
func (h *HeadscaleClient) MachinesConnected(ctx context.Context) (connectedMap, error) {
resp, err := h.client.ListMachines(ctx, &headscalev1.ListMachinesRequest{})
if err != nil || resp == nil {
return nil, fmt.Errorf("failed to list machines: %w", err)
}
Expand All @@ -135,27 +132,27 @@ func (h *HeadscaleClient) MachinesConnected() (connectedMap, error) {
}

// DeleteMachine removes the node entry from headscale DB
func (h *HeadscaleClient) DeleteMachine(machineID, projectID string) (err error) {
machine, err := h.getMachine(machineID, projectID)
func (h *HeadscaleClient) DeleteMachine(ctx context.Context, machineID, projectID string) (err error) {
machine, err := h.getMachine(ctx, machineID, projectID)
if err != nil || machine == nil {
return err
}

req := &headscalev1.DeleteMachineRequest{
MachineId: machine.Id,
}
if _, err := h.client.DeleteMachine(h.ctx, req); err != nil {
if _, err := h.client.DeleteMachine(ctx, req); err != nil {
return fmt.Errorf("failed to delete machine: %w", err)
}

return nil
}

func (h *HeadscaleClient) getMachine(machineID, projectID string) (machine *headscalev1.Machine, err error) {
func (h *HeadscaleClient) getMachine(ctx context.Context, machineID, projectID string) (machine *headscalev1.Machine, err error) {
req := &headscalev1.ListMachinesRequest{
User: projectID,
}
resp, err := h.client.ListMachines(h.ctx, req)
resp, err := h.client.ListMachines(ctx, req)
if err != nil || resp == nil {
return nil, fmt.Errorf("failed to list machines: %w", err)
}
Expand All @@ -171,6 +168,5 @@ func (h *HeadscaleClient) getMachine(machineID, projectID string) (machine *head

// Close client
func (h *HeadscaleClient) Close() error {
h.cancelFunc()
return h.conn.Close()
}
5 changes: 3 additions & 2 deletions cmd/metal-api/internal/service/async-actor.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package service

import (
"context"
"errors"
"fmt"

Expand Down Expand Up @@ -42,14 +43,14 @@ func newAsyncActor(l *zap.SugaredLogger, ep *bus.Endpoints, ds *datastore.Rethin
return actor, nil
}

func (a *asyncActor) freeMachine(pub bus.Publisher, m *metal.Machine, headscaleClient *headscale.HeadscaleClient, logger *zap.SugaredLogger) error {
func (a *asyncActor) freeMachine(ctx context.Context, pub bus.Publisher, m *metal.Machine, headscaleClient *headscale.HeadscaleClient, logger *zap.SugaredLogger) error {
if m.State.Value == metal.LockedState {
return errors.New("machine is locked")
}

if headscaleClient != nil && m.Allocation != nil {
// always call DeleteMachine, in case machine is not registered it will return nil
if err := headscaleClient.DeleteMachine(m.ID, m.Allocation.Project); err != nil {
if err := headscaleClient.DeleteMachine(ctx, m.ID, m.Allocation.Project); err != nil {
logger.Error("unable to delete Node entry from headscale DB", zap.String("machineID", m.ID), zap.Error(err))
}
}
Expand Down
9 changes: 5 additions & 4 deletions cmd/metal-api/internal/service/firewall-service.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package service

import (
"context"
"errors"
"fmt"
"time"
Expand Down Expand Up @@ -211,7 +212,7 @@ func (r *firewallResource) allocateFirewall(request *restful.Request, response *
return
}

if err := r.setVPNConfigInSpec(spec); err != nil {
if err := r.setVPNConfigInSpec(request.Request.Context(), spec); err != nil {
r.sendError(request, response, defaultError(err))
return
}
Expand All @@ -231,19 +232,19 @@ func (r *firewallResource) allocateFirewall(request *restful.Request, response *
r.send(request, response, http.StatusOK, resp)
}

func (r firewallResource) setVPNConfigInSpec(allocationSpec *machineAllocationSpec) error {
func (r firewallResource) setVPNConfigInSpec(ctx context.Context, allocationSpec *machineAllocationSpec) error {
if r.headscaleClient == nil {
return nil
}

// Try to create user in Headscale DB
projectID := allocationSpec.ProjectID
if err := r.headscaleClient.CreateUser(projectID); err != nil {
if err := r.headscaleClient.CreateUser(ctx, projectID); err != nil {
return fmt.Errorf("failed to create new VPN user for the project: %w", err)
}

expiration := time.Now().Add(2 * time.Hour)
key, err := r.headscaleClient.CreatePreAuthKey(projectID, expiration, false)
key, err := r.headscaleClient.CreatePreAuthKey(ctx, projectID, expiration, false)
if err != nil {
return fmt.Errorf("failed to create new auth key for the firewall: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions cmd/metal-api/internal/service/machine-service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ func (r machineResource) freeMachine(request *restful.Request, response *restful
logger.Error("unable to publish machine command", zap.String("command", string(metal.ChassisIdentifyLEDOffCmd)), zap.String("machineID", m.ID), zap.Error(err))
}

err = r.actor.freeMachine(r.Publisher, m, r.headscaleClient, logger)
err = r.actor.freeMachine(request.Request.Context(), r.Publisher, m, r.headscaleClient, logger)
if err != nil {
r.sendError(request, response, defaultError(err))
return
Expand Down Expand Up @@ -1811,7 +1811,7 @@ func evaluateMachineLiveliness(ds *datastore.RethinkStore, m metal.Machine) (met
}

// ResurrectMachines attempts to resurrect machines that are obviously dead
func ResurrectMachines(ds *datastore.RethinkStore, publisher bus.Publisher, ep *bus.Endpoints, ipamer ipam.IPAMer, headscaleClient *headscale.HeadscaleClient, logger *zap.SugaredLogger) error {
func ResurrectMachines(ctx context.Context, ds *datastore.RethinkStore, publisher bus.Publisher, ep *bus.Endpoints, ipamer ipam.IPAMer, headscaleClient *headscale.HeadscaleClient, logger *zap.SugaredLogger) error {
logger.Info("machine resurrection was requested")

machines, err := ds.ListMachines()
Expand Down Expand Up @@ -1843,7 +1843,7 @@ func ResurrectMachines(ds *datastore.RethinkStore, publisher bus.Publisher, ep *

if provisioningEvents.Liveliness == metal.MachineLivelinessDead && time.Since(*provisioningEvents.LastEventTime) > metal.MachineResurrectAfter {
logger.Infow("resurrecting dead machine", "machineID", m.ID, "liveliness", provisioningEvents.Liveliness, "since", time.Since(*provisioningEvents.LastEventTime).String())
err = act.freeMachine(publisher, &m, headscaleClient, logger)
err = act.freeMachine(ctx, publisher, &m, headscaleClient, logger)
if err != nil {
logger.Errorw("error during machine resurrection", "machineID", m.ID, "error", err)
}
Expand All @@ -1852,7 +1852,7 @@ func ResurrectMachines(ds *datastore.RethinkStore, publisher bus.Publisher, ep *

if provisioningEvents.FailedMachineReclaim {
logger.Infow("resurrecting machine with failed reclaim", "machineID", m.ID, "liveliness", provisioningEvents.Liveliness, "since", time.Since(*provisioningEvents.LastEventTime).String())
err = act.freeMachine(publisher, &m, headscaleClient, logger)
err = act.freeMachine(ctx, publisher, &m, headscaleClient, logger)
if err != nil {
logger.Errorw("error during machine resurrection", "machineID", m.ID, "error", err)
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/metal-api/internal/service/vpn-service.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (r *vpnResource) getVPNAuthKey(request *restful.Request, response *restful.
}

pid := requestPayload.Pid
if ok := r.headscaleClient.UserExists(pid); !ok {
if ok := r.headscaleClient.UserExists(request.Request.Context(), pid); !ok {
r.sendError(
request, response,
httperrors.NotFound(fmt.Errorf("vpn user doesn't exist for project with ID %s", pid)),
Expand All @@ -80,7 +80,7 @@ func (r *vpnResource) getVPNAuthKey(request *restful.Request, response *restful.
} else {
expiration = expiration.Add(time.Hour)
}
key, err := r.headscaleClient.CreatePreAuthKey(pid, expiration, requestPayload.Ephemeral)
key, err := r.headscaleClient.CreatePreAuthKey(request.Request.Context(), pid, expiration, requestPayload.Ephemeral)
if err != nil {
r.sendError(
request, response,
Expand Down
7 changes: 5 additions & 2 deletions cmd/metal-api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ func resurrectDeadMachines() error {
p = nsqer.Publisher
ep = nsqer.Endpoints
}
err = service.ResurrectMachines(ds, p, ep, ipamer, headscaleClient, logger)
err = service.ResurrectMachines(context.Background(), ds, p, ep, ipamer, headscaleClient, logger)
if err != nil {
return fmt.Errorf("unable to resurrect machines: %w", err)
}
Expand Down Expand Up @@ -876,7 +876,10 @@ func evaluateVPNConnected() error {
return err
}

connectedMap, err := headscaleClient.MachinesConnected()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()

connectedMap, err := headscaleClient.MachinesConnected(ctx)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ require (
github.com/deckarep/golang-set/v2 v2.3.0 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/docker/distribution v2.8.2+incompatible // indirect
github.com/docker/docker v24.0.4+incompatible // indirect
github.com/docker/docker v24.0.5+incompatible // indirect
github.com/docker/go-connections v0.4.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ github.com/docker/distribution v2.7.1+incompatible/go.mod h1:J2gT2udsDAN96Uj4Kfc
github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8=
github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
github.com/docker/docker v20.10.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/docker v24.0.4+incompatible h1:s/LVDftw9hjblvqIeTiGYXBCD95nOEEl7qRsRrIOuQI=
github.com/docker/docker v24.0.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/docker v24.0.5+incompatible h1:WmgcE4fxyI6EEXxBRxsHnZXrO1pQ3smi0k/jho4HLeY=
github.com/docker/docker v24.0.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ=
github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec=
github.com/docker/go-events v0.0.0-20170721190031-9461782956ad/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA=
Expand Down

0 comments on commit d023770

Please sign in to comment.