Skip to content

Commit

Permalink
Merge pull request #613 from imeoer/fix-supervisor
Browse files Browse the repository at this point in the history
supervisor: fix large opaque data handle
  • Loading branch information
changweige authored Sep 2, 2024
2 parents 8fa319b + 4f4dbcd commit 30f3d65
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 117 deletions.
202 changes: 111 additions & 91 deletions pkg/supervisor/supervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package supervisor

import (
"fmt"
"io"
"net"
"os"
"sync"
Expand All @@ -21,49 +22,44 @@ import (
"github.com/pkg/errors"

"golang.org/x/net/context"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"golang.org/x/sys/unix"
)

const MaxOpaqueLen = 1024 * 32 // Bytes

// oobSpace is the size of the oob slice required to store for multiple FDs. Note
// that unix.UnixRights appears to make the assumption that fd is always int32,
// so sizeof(fd) = 4.
// At most can accommodate 64 fds
var oobSpace = unix.CmsgSpace(4) * 64

type StatesStorage interface {
// Appended write states to the storage space.
Write([]byte)
// Read out the previously written states to fill `buf` which should be large enough.
Read(buf []byte) (uint, error)
// Mark all data as stale, the previously written data is cleaned
// Save state to storage space.
Save([]byte)
// Load state from storage space.
Load() ([]byte, error)
// Clean the previously saved state.
Clean()
}

// Store daemon states in memory
type MemStatesStorage struct {
data []byte
head int
}

func newMemStatesStorage() *MemStatesStorage {
return &MemStatesStorage{data: make([]byte, MaxOpaqueLen)}
return &MemStatesStorage{
data: []byte{},
}
}

func (mss *MemStatesStorage) Write(data []byte) {
l := copy(mss.data[mss.head:], data)
mss.head += l
func (mss *MemStatesStorage) Save(data []byte) {
mss.data = make([]byte, len(data))
copy(mss.data, data)
}

func (mss *MemStatesStorage) Read(data []byte) (uint, error) {
l := copy(data, mss.data[:mss.head])
return uint(l), nil
func (mss *MemStatesStorage) Load() ([]byte, error) {
data := make([]byte, len(mss.data))
copy(data, mss.data)
return data, nil
}

func (mss *MemStatesStorage) Clean() {
mss.head = 0
mss.data = []byte{}
}

// Use daemon ID as the supervisor ID
Expand All @@ -88,29 +84,97 @@ func (su *Supervisor) save(data []byte, fd int) {
if fd > 0 {
su.fd = fd
}
su.dataStorage.Write(data)
su.dataStorage.Save(data)
}

// Load resources kept by this supervisor
// 1. daemon runtime states
// 2. file descriptor
//
// Note: the resources should be not be consumed.
func (su *Supervisor) load(data []byte, oob []byte) (nData uint, nOob int, err error) {
func (su *Supervisor) load() ([]byte, int, error) {
su.mu.Lock()
defer su.mu.Unlock()

if su.fd > 0 {
b := syscall.UnixRights(su.fd)
nOob = copy(oob, b)
data, err := su.dataStorage.Load()
if err != nil {
return nil, 0, err
}

return data, su.fd, nil
}

func recv(uc *net.UnixConn) ([]byte, int, error) {
data := make([]byte, 0)
oob := make([]byte, 0)

var dataBufLen = 1024 * 256 // Bytes

// oobSpace is the size of the oob slice required to store for multiple FDs. Note
// that unix.UnixRights appears to make the assumption that fd is always int32,
// so sizeof(fd) = 4.
// At most can accommodate 64 fds
var oobSpace = unix.CmsgSpace(4) * 64

for {
dataBuf := make([]byte, dataBufLen)
oobBuf := make([]byte, oobSpace)

n, oobn, _, _, err := uc.ReadMsgUnix(dataBuf, oobBuf)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return nil, 0, errors.Wrap(err, "receive message")
}
if n == 0 {
break // EOF
}

data = append(data, dataBuf[:n]...)
oob = append(oob, oobBuf[:oobn]...)
}

scms, err := unix.ParseSocketControlMessage(oob)
if err != nil {
return nil, 0, errors.Wrap(err, "parse control message")
}

var fds []int
if len(scms) == 0 {
return nil, 0, fmt.Errorf("received no control file descriptor")
}

nData, err = su.dataStorage.Read(data)
scm := scms[0]
fds, err = unix.ParseUnixRights(&scm)
if err != nil {
return 0, 0, err
return nil, 0, errors.Wrap(err, "extract file descriptors")
}

return nData, nOob, nil
var fd int
if len(fds) > 0 {
fd = fds[0]
} else {
fd = -1
}

return data, fd, nil
}

func send(uc *net.UnixConn, data []byte, fd int) error {
oob := syscall.UnixRights(fd)

for len(data) > 0 || len(oob) > 0 {
n, oobn, err := uc.WriteMsgUnix(data, oob, nil)
if err != nil {
return errors.Wrapf(err, "send message, datan %d oobn %d", n, oobn)
}

data = data[n:]
oob = oob[oobn:]
}

return nil
}

// There are several stages from different goroutines to trigger sending daemon states
Expand Down Expand Up @@ -139,52 +203,16 @@ func (su *Supervisor) waitStatesTimeout(to time.Duration) (func() error, error)
if err != nil {
return errors.Wrapf(err, "Listener is closed")
}

defer conn.Close()

unixConn := conn.(*net.UnixConn)
uf, err := unixConn.File()
data, fd, err := recv(conn.(*net.UnixConn))
if err != nil {
return err
}
log.L.Infof("Supervisor %s receives states. data %d", su.id, len(data))

defer uf.Close()

data := make([]byte, MaxOpaqueLen)
oob := make([]byte, oobSpace) // Out-of-band data

// TODO: Handle EAGAIN EOF and EINTR
n, oobn, _, _, err := unix.Recvmsg(int(uf.Fd()), data, oob, 0)
if err != nil {
return errors.Wrap(err, "receive message")
}

log.L.Infof("Supervisor %s receives states. data %d oob %d", su.id, n, oobn)

scms, err := unix.ParseSocketControlMessage(oob[:oobn])
if err != nil {
return errors.Wrap(err, "parse control message")
}

var fds []int
if len(scms) > 0 {
scm := scms[0]
fds, err = unix.ParseUnixRights(&scm)
if err != nil {
return errors.Wrap(err, "extract file descriptors")
}
} else {
log.L.Warn("received no control file descriptor")
}
su.save(data, fd)

var fd int
if len(fds) > 0 {
fd = fds[0]
} else {
fd = -1
}

su.save(data[:n], fd)
return nil
}

Expand Down Expand Up @@ -240,32 +268,18 @@ func (su *Supervisor) SendStatesTimeout(to time.Duration) error {
if err != nil {
return errors.Wrapf(err, "Listener is closed")
}

defer conn.Close()

unixConn := conn.(*net.UnixConn)
uf, err := unixConn.File()
if err != nil {
return err
}
defer uf.Close()

data := make([]byte, MaxOpaqueLen)
oob := make([]byte, oobSpace)

// FIXME: It's possible that sending states happens before storing state to the storage.

datan, oobn, err := su.load(data, oob)
data, fd, err := su.load()
if err != nil {
return errors.Wrapf(err, "load resources for %s", su.id)
}
// TODO: validate returned length
_, _, err = unixConn.WriteMsgUnix(data[:datan], oob[:oobn], nil)
if err != nil {
return errors.Wrapf(err, "send message, datan %d oobn %d", datan, oobn)
if err := send(conn.(*net.UnixConn), data, fd); err != nil {
return err
}

log.L.Infof("Supervisor %s sends states. data %d oob %d", su.id, datan, oobn)
log.L.Infof("Supervisor %s sends states. data %d", su.id, len(data))

return nil
}
Expand Down Expand Up @@ -311,13 +325,19 @@ func (su *Supervisor) FetchDaemonStates(trigger func() error) error {
return errors.Wrapf(err, "wait states on %s", su.Sock())
}

err = trigger()
if err != nil {
eg := errgroup.Group{}
eg.Go(func() error {
err := trigger()
return errors.Wrapf(err, "trigger on %s", su.Sock())
}
})

eg.Go(func() error {
err := receiver()
return errors.Wrapf(err, "receiver on %s", su.Sock())
})

// FIXME: With Timeout context!
return receiver()
return eg.Wait()
}

// The unix domain socket on which nydus daemon is connected to
Expand Down
62 changes: 36 additions & 26 deletions pkg/supervisor/supervisor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
package supervisor

import (
"crypto/rand"
"net"
"os"
"reflect"
"testing"
"time"

"github.com/stretchr/testify/assert"
"golang.org/x/sys/unix"
)

func TestSupervisor(t *testing.T) {
Expand All @@ -26,49 +26,59 @@ func TestSupervisor(t *testing.T) {
})

supervisorSet, err := NewSupervisorSet(rootDir)
assert.Nil(t, err, "%v", err)
assert.Nil(t, err)

su1 := supervisorSet.NewSupervisor("su1")
assert.NotNil(t, su1)
defer func() {
err = supervisorSet.DestroySupervisor("su1")
assert.NotNil(t, su1)
}()

_, err = su1.waitStatesTimeout(2 * time.Second)
assert.Nil(t, err, "%v", err)
sock := su1.Sock()

addr, err := net.ResolveUnixAddr("unix", sock)
assert.Nil(t, err)

conn, err := net.DialUnix("unix", nil, addr)
assert.Nil(t, err, "%v", err)

sentData := []byte("abcde")
// Build a large data to test the multiple recvmsg / sendmsg
// syscalls can handle all the data.
sentData := make([]byte, 1024*1024*2)
_, err = rand.Read(sentData)
assert.Nil(t, err)

sentLen, err := conn.Write(sentData)
tmpFile, err := os.CreateTemp("", "nydus-supervisor-test")
assert.Nil(t, err)
defer tmpFile.Close()
defer os.Remove(tmpFile.Name())

conn.Close()
nydusdSendFd := func() error {
conn, err := net.DialUnix("unix", nil, addr)
assert.Nil(t, err)
defer conn.Close()

// FIXME: Delay for some time until states are stored
time.Sleep(500 * time.Millisecond)
err = send(conn, sentData, int(tmpFile.Fd()))
assert.Nil(t, err)

// Must set length not only capacity
receivedData := make([]byte, 16, 32)
oob := make([]byte, 16, 32)
err = su1.SendStatesTimeout(0)
assert.Nil(t, err, "%v", err)
return nil
}

conn1, err := net.DialUnix("unix", nil, addr)
assert.Nil(t, err, "%v", err)
err = su1.FetchDaemonStates(nydusdSendFd)
assert.NoError(t, err)

f, _ := conn1.File()
nydusdTakeover := func() {
err = su1.SendStatesTimeout(0)
assert.Nil(t, err)

//nolint:dogsled
receivedLen, _, _, _, err := unix.Recvmsg(int(f.Fd()), receivedData, oob, 0)
assert.Nil(t, err)
conn, err := net.DialUnix("unix", nil, addr)
assert.Nil(t, err)

recvData, _, err := recv(conn)
assert.Nil(t, err)

assert.Equal(t, sentLen, receivedLen)
assert.True(t, reflect.DeepEqual(receivedData[:receivedLen], sentData), "%v", receivedData)
assert.Equal(t, len(sentData), len(recvData))
assert.True(t, reflect.DeepEqual(recvData, sentData))
}

nydusdTakeover()
}

func TestSupervisorTimeout(t *testing.T) {
Expand Down

0 comments on commit 30f3d65

Please sign in to comment.