diff --git a/ctriface/iface.go b/ctriface/iface.go index 30fb0bc5a..6b19f445f 100644 --- a/ctriface/iface.go +++ b/ctriface/iface.go @@ -527,38 +527,26 @@ func (o *Orchestrator) LoadSnapshot(ctx context.Context, originVmID string, vmID os.Remove(conf.MemBackend.BackendPath) } - // connChan := make(chan *net.UnixConn, 1) - errChan := make(chan error, 1) go func() { listener, err := net.Listen("unix", conf.MemBackend.BackendPath) if err != nil { - errChan <- errors.Wrapf(err, "failed to listen to uffd socket") + logger.Error("failed to listen to uffd socket") return - // return nil, nil, errors.Wrapf(err, "failed to listen to uffd socket") } defer listener.Close() - + logger.Debug("Listening ...") conn, err := listener.Accept() if err != nil { - errChan <- errors.Wrapf(err, "failed to accept connection") - return - // return nil, nil, errors.Wrapf(err, "failed to accept connection") + logger.Error("failed to accept connection to uffd socket") + return } - sendfdConn, _ = conn.(*net.UnixConn) + sendfdConn, _ = conn.(*net.UnixConn) close(uffdListenerCh) - - // connChan <- sendfdConn }() - // select { - // case sendfdConn = <-connChan: - // logger.Debug("Connection accepted and type-asserted to *net.UnixConn") - // case err := <-errChan: - // logger.Errorf("Error occurred: %v\n", err) - // } - time.Sleep(10 * time.Second) + time.Sleep(10 * time.Second) // TODO: sleep for 10 seconds to wait for the uffd socket to be ready } tStart = time.Now() @@ -622,7 +610,7 @@ func (o *Orchestrator) LoadSnapshot(ctx context.Context, originVmID string, vmID } logger.Debug("TEST: activate VM in mm") - if activateErr = o.memoryManager.Activate(vmID, sendfdConn); activateErr != nil { + if activateErr = o.memoryManager.Activate(originVmID, sendfdConn); activateErr != nil { logger.Warn("Failed to activate VM in the memory manager", activateErr) } } diff --git a/memory/manager/snapshot_state.go b/memory/manager/snapshot_state.go index 1c70720e8..a0bb7d076 100644 --- a/memory/manager/snapshot_state.go +++ b/memory/manager/snapshot_state.go @@ -29,6 +29,7 @@ import "C" import ( "encoding/binary" + "encoding/json" "errors" "fmt" "net" @@ -39,7 +40,6 @@ import ( "syscall" "time" - "github.com/ftrvxmtrx/fd" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" @@ -133,44 +133,49 @@ func (s *SnapshotState) setupStateOnActivate() { } } -func (s *SnapshotState) getUFFD(sendfdConn *net.UnixConn) error { - // var d net.Dialer - // ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - // defer cancel() - - // for { - // c, err := d.DialContext(ctx, "unix", s.InstanceSockAddr) - // if err != nil { - // if ctx.Err() != nil { - // log.Error("Failed to dial within the context timeout") - // return err - // } - // time.Sleep(1 * time.Millisecond) - // continue - // } - - // defer c.Close() - - // sendfdConn := c.(*net.UnixConn) +type GuestRegionUffdMapping struct { + BaseHostVirtAddr uint64 `json:"base_host_virt_addr"` + Size uint64 `json:"size"` + Offset uint64 `json:"offset"` + PageSizeKiB uint64 `json:"page_size_kib"` +} - // fs, err := fd.Get(sendfdConn, 1, []string{"a file"}) - // if err != nil { - // log.Error("Failed to receive the uffd") - // return err - // } +func (s *SnapshotState) getUFFD(sendfdConn *net.UnixConn) error { + buff := make([]byte, 256) // set a maximum buffer size + oobBuff := make([]byte, unix.CmsgSpace(4)) - // s.userFaultFD = fs[0] + n, oobn, _, _, err := sendfdConn.ReadMsgUnix(buff, oobBuff) + if err != nil { + return fmt.Errorf("error reading message: %w", err) + } + buff = buff[:n] - // return nil - // } + var fd int + if oobn > 0 { + scms, err := unix.ParseSocketControlMessage(oobBuff[:oobn]) + if err != nil { + return fmt.Errorf("error parsing socket control message: %w", err) + } + for _, scm := range scms { + fds, err := unix.ParseUnixRights(&scm) + if err != nil { + return fmt.Errorf("error parsing unix rights: %w", err) + } + if len(fds) > 0 { + fd = fds[0] // Assuming only one fd is sent. + break + } + } + } + userfaultFD := os.NewFile(uintptr(fd), "userfaultfd") - fs, err := fd.Get(sendfdConn, 1, []string{"a file"}) - if err != nil { - log.Error("Failed to receive the uffd") - return err + var mapping []GuestRegionUffdMapping + if err := json.Unmarshal(buff, &mapping); err != nil { + return fmt.Errorf("error unmarshaling data: %w", err) } - s.userFaultFD = fs[0] + s.startAddress = mapping[0].BaseHostVirtAddr + s.userFaultFD = userfaultFD return nil } @@ -401,7 +406,6 @@ func (s *SnapshotState) servePageFault(fd int, address uint64) error { s.firstPageFaultOnce.Do( func() { - s.startAddress = address log.Debugf("TEST: first page fault address %d", address) if s.isRecordReady && !s.IsLazyMode {