Skip to content

Commit

Permalink
Merge pull request #520 from zenhack/export-snapshot
Browse files Browse the repository at this point in the history
Use ClientSnapshot in the export table, rather than Client
  • Loading branch information
zenhack authored Jun 17, 2023
2 parents 6d3f260 + cc72fcb commit b6db31e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 62 deletions.
41 changes: 29 additions & 12 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ func (c Client) WeakRef() WeakClient {
// ClientSnapshot if c is nil, has resolved to null, or has been released.
func (c Client) Snapshot() ClientSnapshot {
h, _, _ := c.startCall()
return ClientSnapshot{hook: h}
s := ClientSnapshot{hook: h}
setupLeakReporting(s)
return s
}

// A Brand is an opaque value used to identify a capability.
Expand Down Expand Up @@ -643,17 +645,19 @@ func (cs ClientSnapshot) Metadata() *Metadata {
// Create a copy of the snapshot, with its own underlying reference.
func (cs ClientSnapshot) AddRef() ClientSnapshot {
cs.hook = cs.hook.AddRef()
setupLeakReporting(cs)
return cs
}

// Release the reference to the hook.
func (cs ClientSnapshot) Release() {
func (cs *ClientSnapshot) Release() {
cs.hook.Release()
}

func (cs *ClientSnapshot) Resolve1(ctx context.Context) error {
var err error
cs.hook, _, err = resolve1ClientHook(ctx, cs.hook)
setupLeakReporting(*cs)
return err
}

Expand All @@ -665,6 +669,7 @@ func (cs *ClientSnapshot) resolve1(ctx context.Context) (more bool, err error) {
func (cs *ClientSnapshot) Resolve(ctx context.Context) error {
var err error
cs.hook, err = resolveClientHook(ctx, cs.hook)
setupLeakReporting(*cs)
return err
}

Expand Down Expand Up @@ -773,7 +778,7 @@ func (s *resolveState) isResolved() bool {
}
}

var setupLeakReporting func(Client) = func(Client) {}
var setupLeakReporting func(any) = func(any) {}

// SetClientLeakFunc sets a callback for reporting Clients that went
// out of scope without being released. The callback is not guaranteed
Expand All @@ -783,20 +788,32 @@ var setupLeakReporting func(Client) = func(Client) {}
// SetClientLeakFunc must not be called after any calls to NewClient or
// NewPromisedClient.
func SetClientLeakFunc(clientLeakFunc func(msg string)) {
setupLeakReporting = func(c Client) {
setupLeakReporting = func(v any) {
buf := bufferpool.Default.Get(1e6)
n := runtime.Stack(buf, false)
stack := string(buf[:n])
bufferpool.Default.Put(buf)
runtime.SetFinalizer(c.client, func(c *client) {
released := mutex.With1(&c.state, func(c *clientState) bool {
return c.released
switch c := v.(type) {
case Client:
runtime.SetFinalizer(c.client, func(c *client) {
released := mutex.With1(&c.state, func(c *clientState) bool {
return c.released
})
if released {
return
}
clientLeakFunc("leaked client created at:\n\n" + stack)
})
if released {
return
}
clientLeakFunc("leaked client created at:\n\n" + stack)
})
case ClientSnapshot:
runtime.SetFinalizer(c.hook, func(c *rc.Ref[clientHook]) {
if !c.IsValid() {
return
}
clientLeakFunc("leaked client snapshot created at:\n\n" + stack)
})
default:
panic("setupLeakReporting called on unrecognized type!")
}
}
}

Expand Down
71 changes: 33 additions & 38 deletions rpc/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type exportID uint32

// expent is an entry in a Conn's export table.
type expent struct {
client capnp.Client
snapshot capnp.ClientSnapshot
wireRefs uint32
isPromise bool

Expand Down Expand Up @@ -60,68 +60,64 @@ func (c *lockedConn) findExport(id exportID) *expent {
// releaseExport decreases the number of wire references to an export
// by a given number. If the export's reference count reaches zero,
// then releaseExport will pop export from the table and return the
// export's client. The caller must be holding onto c.mu, and the
// caller is responsible for releasing the client once the caller is no
// longer holding onto c.mu.
func (c *lockedConn) releaseExport(id exportID, count uint32) (capnp.Client, error) {
// export's ClientSnapshot. The caller is responsible for releasing
// the snapshot once the caller is no longer holding onto c.mu.
func (c *lockedConn) releaseExport(id exportID, count uint32) (capnp.ClientSnapshot, error) {
ent := c.findExport(id)
if ent == nil {
return capnp.Client{}, rpcerr.Failed(errors.New("unknown export ID " + str.Utod(id)))
return capnp.ClientSnapshot{}, rpcerr.Failed(errors.New("unknown export ID " + str.Utod(id)))
}
switch {
case count == ent.wireRefs:
defer ent.cancel()
client := ent.client
snapshot := ent.snapshot
c.lk.exports[id] = nil
c.lk.exportID.remove(uint32(id))
snapshot := client.Snapshot()
defer snapshot.Release()
metadata := snapshot.Metadata()
syncutil.With(metadata, func() {
c.clearExportID(metadata)
})
return client, nil
return snapshot, nil
case count > ent.wireRefs:
return capnp.Client{}, rpcerr.Failed(errors.New("export ID " + str.Utod(id) + " released too many references"))
return capnp.ClientSnapshot{}, rpcerr.Failed(errors.New("export ID " + str.Utod(id) + " released too many references"))
default:
ent.wireRefs -= count
return capnp.Client{}, nil
return capnp.ClientSnapshot{}, nil
}
}

func (c *lockedConn) releaseExportRefs(dq *deferred.Queue, refs map[exportID]uint32) error {
n := len(refs)
var firstErr error
for id, count := range refs {
client, err := c.releaseExport(id, count)
snapshot, err := c.releaseExport(id, count)
if err != nil {
if firstErr == nil {
firstErr = err
}
n--
continue
}
if (client == capnp.Client{}) {
if (snapshot == capnp.ClientSnapshot{}) {
n--
continue
}
dq.Defer(client.Release)
dq.Defer(snapshot.Release)
n--
}
return firstErr
}

// sendCap writes a capability descriptor, returning an export ID if
// this vat is hosting the capability.
func (c *lockedConn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID, isExport bool, _ error) {
if !client.IsValid() {
// this vat is hosting the capability. Steals the snapshot.
func (c *lockedConn) sendCap(d rpccp.CapDescriptor, snapshot capnp.ClientSnapshot) (_ exportID, isExport bool, _ error) {
if !snapshot.IsValid() {
d.SetNone()
return 0, false, nil
}

state := client.Snapshot()
defer state.Release()
bv := state.Brand().Value
defer snapshot.Release()
bv := snapshot.Brand().Value
if ic, ok := bv.(*importClient); ok {
if ic.c == (*Conn)(c) {
if ent := c.lk.imports[ic.id]; ent != nil && ent.generation == ic.generation {
Expand Down Expand Up @@ -159,7 +155,7 @@ func (c *lockedConn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ expo
}

// Default to export.
metadata := state.Metadata()
metadata := snapshot.Metadata()
metadata.Lock()
defer metadata.Unlock()
id, ok := c.findExportID(metadata)
Expand All @@ -170,10 +166,9 @@ func (c *lockedConn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ expo
} else {
// Not already present; allocate an export id for it:
ee = &expent{
client: client.AddRef(),
wireRefs: 1,
isPromise: state.IsPromise(),
cancel: func() {},
snapshot: snapshot.AddRef(),
wireRefs: 1,
cancel: func() {},
}
id = exportID(c.lk.exportID.next())
if int64(id) == int64(len(c.lk.exports)) {
Expand All @@ -183,23 +178,23 @@ func (c *lockedConn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ expo
}
c.setExportID(metadata, id)
}
if ee.isPromise {
c.sendSenderPromise(id, client, d)
if ee.snapshot.IsPromise() {
c.sendSenderPromise(id, d)
} else {
d.SetSenderHosted(uint32(id))
}
return id, true, nil
}

// sendSenderPromise is a helper for sendCap that handles the senderPromise case.
func (c *lockedConn) sendSenderPromise(id exportID, client capnp.Client, d rpccp.CapDescriptor) {
func (c *lockedConn) sendSenderPromise(id exportID, d rpccp.CapDescriptor) {
// Send a promise, wait for the resolution asynchronously, then send
// a resolve message:
ee := c.lk.exports[id]
d.SetSenderPromise(uint32(id))
ctx, cancel := context.WithCancel(c.bgctx)
ee.cancel = cancel
waitRef := client.AddRef()
waitRef := ee.snapshot.AddRef()
go func() {
defer cancel()
defer waitRef.Release()
Expand All @@ -210,10 +205,10 @@ func (c *lockedConn) sendSenderPromise(id exportID, client capnp.Client, d rpccp

waitErr := waitRef.Resolve(ctx)
unlockedConn.withLocked(func(c *lockedConn) {
// Export was removed from the table at some point;
// remote peer is uninterested in the resolution, so
// drop the reference and we're done:
if c.lk.exports[id] != ee {
if len(c.lk.exports) <= int(id) || c.lk.exports[id] != ee {
// Export was removed from the table at some point;
// remote peer is uninterested in the resolution, so
// drop the reference and we're done
return
}

Expand Down Expand Up @@ -245,9 +240,9 @@ func (c *lockedConn) sendSenderPromise(id exportID, client capnp.Client, d rpccp
sendRef.Release()
if err != nil && isExport {
// release 1 ref of the thing it resolved to.
client, err := withLockedConn2(
snapshot, err := withLockedConn2(
unlockedConn,
func(c *lockedConn) (capnp.Client, error) {
func(c *lockedConn) (capnp.ClientSnapshot, error) {
return c.releaseExport(resolvedID, 1)
},
)
Expand All @@ -256,7 +251,7 @@ func (c *lockedConn) sendSenderPromise(id exportID, client capnp.Client, d rpccp
exc.WrapError("releasing export due to failure to send resolve", err),
)
} else {
client.Release()
snapshot.Release()
}
}
})
Expand All @@ -281,7 +276,7 @@ func (c *lockedConn) fillPayloadCapTable(payload rpccp.Payload) (map[exportID]ui
}
var refs map[exportID]uint32
for i := 0; i < clients.Len(); i++ {
id, isExport, err := c.sendCap(list.At(i), clients.At(i))
id, isExport, err := c.sendCap(list.At(i), clients.At(i).Snapshot())
if err != nil {
return nil, rpcerr.WrapFailed("Serializing capability", err)
}
Expand Down
21 changes: 9 additions & 12 deletions rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,11 @@ func (c *lockedConn) releaseBootstrap(dq *deferred.Queue) {
func (c *lockedConn) releaseExports(dq *deferred.Queue, exports []*expent) {
for _, e := range exports {
if e != nil {
snapshot := e.client.Snapshot()
metadata := snapshot.Metadata()
metadata := e.snapshot.Metadata()
syncutil.With(metadata, func() {
c.clearExportID(metadata)
})
snapshot.Release()
dq.Defer(e.client.Release)
dq.Defer(e.snapshot.Release)
}
}
}
Expand Down Expand Up @@ -665,13 +663,13 @@ func (c *Conn) handleUnimplemented(in transport.IncomingMessage) error {
default:
return nil
}
client, err := withLockedConn2(c, func(c *lockedConn) (capnp.Client, error) {
snapshot, err := withLockedConn2(c, func(c *lockedConn) (capnp.ClientSnapshot, error) {
return c.releaseExport(id, 1)
})
if err != nil {
return err
}
client.Release()
snapshot.Release()
return nil
}
}
Expand Down Expand Up @@ -859,7 +857,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err
pcall := newPromisedPipelineCaller()
ans.setPipelineCaller(p.method, pcall)
dq.Defer(func() {
pcall.resolve(ent.client.RecvCall(callCtx, recv))
pcall.resolve(ent.snapshot.Recv(callCtx, recv))
})
return nil
case rpccp.MessageTarget_Which_promisedAnswer:
Expand Down Expand Up @@ -1359,7 +1357,7 @@ func (c *lockedConn) recvCap(d rpccp.CapDescriptor) (capnp.Client, error) {
"receive capability: invalid export " + str.Utod(id),
))
}
return ent.client.AddRef(), nil
return ent.snapshot.Client(), nil
case rpccp.CapDescriptor_Which_receiverAnswer:
promisedAnswer, err := d.ReceiverAnswer()
if err != nil {
Expand Down Expand Up @@ -1535,14 +1533,13 @@ func (c *Conn) handleRelease(ctx context.Context, in transport.IncomingMessage)
id := exportID(rel.Id())
count := rel.ReferenceCount()

var client capnp.Client
c.withLocked(func(c *lockedConn) {
client, err = c.releaseExport(id, count)
snapshot, err := withLockedConn2(c, func(c *lockedConn) (capnp.ClientSnapshot, error) {
return c.releaseExport(id, count)
})
if err != nil {
return rpcerr.Annotate(err, "incoming release")
}
client.Release() // no-ops for nil
snapshot.Release() // no-ops for nil
return nil
}

Expand Down

0 comments on commit b6db31e

Please sign in to comment.