Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store ClientSnapshots in the CapTable. #525

Closed
wants to merge 13 commits into from
Closed
2 changes: 1 addition & 1 deletion answer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestPromiseFulfill(t *testing.T) {
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{PointerCount: 3})
res.SetPtr(1, NewInterface(seg, msg.CapTable().Add(c.AddRef())).ToPtr())
res.SetPtr(1, NewInterface(seg, msg.CapTable().AddClient(c.AddRef())).ToPtr())

p.Fulfill(res.ToPtr())

Expand Down
20 changes: 18 additions & 2 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (i Interface) value(paddr address) rawPointer {
// or nil if the pointer is invalid.
func (i Interface) Client() (c Client) {
if msg := i.Message(); msg != nil {
c = msg.CapTable().Get(i)
c = msg.CapTable().GetClient(i)
}

return
Expand Down Expand Up @@ -550,6 +550,13 @@ func (c Client) AddRef() Client {
})
}

// Steal steals the receiver, and returns a new client for the same capability
// owned by the caller. This can be useful for tracking down ownership bugs.
func (c Client) Steal() Client {
defer c.Release()
return c.AddRef()
}

// WeakRef creates a new WeakClient that refers to the same capability
// as c. If c is nil or has resolved to null, then WeakRef returns nil.
func (c Client) WeakRef() WeakClient {
Expand Down Expand Up @@ -606,6 +613,9 @@ func (cs ClientSnapshot) Recv(ctx context.Context, r Recv) PipelineCaller {

// Client returns a client pointing at the most-resolved version of the snapshot.
func (cs ClientSnapshot) Client() Client {
if !cs.IsValid() {
return Client{}
}
cursor := rc.NewRefInPlace(func(c *clientCursor) func() {
*c = clientCursor{hook: mutex.New(cs.hook.AddRef())}
c.compress()
Expand Down Expand Up @@ -639,6 +649,12 @@ func (cs ClientSnapshot) AddRef() ClientSnapshot {
return cs
}

// Steal is like Client.Steal() but for snapshots.
func (cs ClientSnapshot) Steal() ClientSnapshot {
defer cs.Release()
return cs.AddRef()
}

// Release the reference to the hook.
func (cs ClientSnapshot) Release() {
cs.hook.Release()
Expand Down Expand Up @@ -746,7 +762,7 @@ func (c Client) Release() {
}

func (c Client) EncodeAsPtr(seg *Segment) Ptr {
capId := seg.Message().CapTable().Add(c)
capId := seg.Message().CapTable().AddClient(c)
return NewInterface(seg, capId).ToPtr()
}

Expand Down
2 changes: 1 addition & 1 deletion capnpc-go/templates/structCapabilityField
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ func (s {{.Node.Name}}) Set{{.Field.Name|title}}(c {{.FieldType}}) error {
return capnp.Struct(s).SetPtr({{.Field.Slot.Offset}}, capnp.Ptr{})
}
seg := s.Segment()
in := capnp.NewInterface(seg, seg.Message().CapTable().Add(c))
in := capnp.NewInterface(seg, seg.Message().CapTable().AddClient(c))
return capnp.Struct(s).SetPtr({{.Field.Slot.Offset}}, in.ToPtr())
}
2 changes: 1 addition & 1 deletion capnpc-go/templates/structInterfaceField
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func (s {{.Node.Name}}) Set{{.Field.Name|title}}(v {{.FieldType}}) error {
return capnp.Struct(s).SetPtr({{.Field.Slot.Offset}}, capnp.Ptr{})
}
seg := s.Segment()
in := capnp.NewInterface(seg, seg.Message().CapTable().Add(capnp.Client(v)))
in := capnp.NewInterface(seg, seg.Message().CapTable().AddClient(capnp.Client(v)))
return capnp.Struct(s).SetPtr({{.Field.Slot.Offset}}, in.ToPtr())
}

75 changes: 55 additions & 20 deletions captable.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,40 @@ package capnp
//
// https://capnproto.org/encoding.html#capabilities-interfaces
type CapTable struct {
cs []Client
// We maintain two parallel structurs of clients and corresponding
// snapshots. We need to store both, so that Get*() can hand out
// borrowed references in both cases.
clients []Client
snapshots []ClientSnapshot
}

// Reset the cap table, releasing all capabilities and setting
// the length to zero. Clients passed as arguments are added
// to the table after zeroing, such that ct.Len() == len(cs).
func (ct *CapTable) Reset(cs ...Client) {
for _, c := range ct.cs {
// the length to zero.
func (ct *CapTable) Reset() {
for _, c := range ct.clients {
c.Release()
}
for _, s := range ct.snapshots {
s.Release()
}

ct.cs = append(ct.cs[:0], cs...)
ct.clients = ct.clients[:0]
ct.snapshots = ct.snapshots[:0]
}

// Len returns the number of capabilities in the table.
func (ct CapTable) Len() int {
return len(ct.cs)
return len(ct.clients)
}

// ClientAt returns the client at the given index of the table.
func (ct CapTable) ClientAt(i int) Client {
return ct.clients[i]
}

// At returns the capability at the given index of the table.
func (ct CapTable) At(i int) Client {
return ct.cs[i]
// SnapshotAt is like ClientAt, except that it returns a snapshot.
func (ct CapTable) SnapshotAt(i int) ClientSnapshot {
return ct.snapshots[i]
}

// Contains returns true if the supplied interface corresponds
Expand All @@ -37,28 +49,51 @@ func (ct CapTable) Contains(ifc Interface) bool {
return ifc.IsValid() && ifc.Capability() < CapabilityID(ct.Len())
}

// Get the client corresponding to the supplied interface. It
// returns a null client if the interface's CapabilityID isn't
// GetClient gets the client corresponding to the supplied interface.
// It returns a null client if the interface's CapabilityID isn't
// in the table.
func (ct CapTable) Get(ifc Interface) (c Client) {
func (ct CapTable) GetClient(ifc Interface) (c Client) {
if ct.Contains(ifc) {
c = ct.cs[ifc.Capability()]
c = ct.clients[ifc.Capability()]
}
return
}

// GetSnapshot is like GetClient, except that it returns a snapshot
// instead of a Client.
func (ct CapTable) GetSnapshot(ifc Interface) (s ClientSnapshot) {
if ct.Contains(ifc) {
s = ct.snapshots[ifc.Capability()]
}
return
}

// Set the client for the supplied capability ID. If a client
// SetClient sets the client for the supplied capability ID. If a client
// for the given ID already exists, it will be replaced without
// releasing.
func (ct CapTable) Set(id CapabilityID, c Client) {
ct.cs[id] = c
func (ct CapTable) SetClient(id CapabilityID, c Client) {
ct.snapshots[id] = c.Snapshot()
ct.clients[id] = c.Steal()
}

// SetSnapshot is like SetClient, but takes a snapshot.
func (ct CapTable) SetSnapshot(id CapabilityID, s ClientSnapshot) {
ct.clients[id] = s.Client()
ct.snapshots[id] = s.Steal()
}

// Add appends a capability to the message's capability table and
// AddClient appends a capability to the message's capability table and
// returns its ID. It "steals" c's reference: the Message will release
// the client when calling Reset.
func (ct *CapTable) Add(c Client) CapabilityID {
ct.cs = append(ct.cs, c)
func (ct *CapTable) AddClient(c Client) CapabilityID {
ct.snapshots = append(ct.snapshots, c.Snapshot())
ct.clients = append(ct.clients, c.Steal())
return CapabilityID(ct.Len() - 1)
}

// AddSnapshot is like AddClient, except that it takes a snapshot rather
// than a Client.
func (ct *CapTable) AddSnapshot(s ClientSnapshot) CapabilityID {
defer s.Release()
return ct.AddClient(s.Client())
}
12 changes: 7 additions & 5 deletions captable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@ func TestCapTable(t *testing.T) {

assert.Zero(t, ct.Len(),
"zero-value CapTable should be empty")
assert.Zero(t, ct.Add(capnp.Client{}),
assert.Zero(t, ct.AddClient(capnp.Client{}),
"first entry should have CapabilityID(0)")
assert.Equal(t, 1, ct.Len(),
"should increase length after adding capability")

ct.Reset()
assert.Zero(t, ct.Len(),
"zero-value CapTable should be empty after Reset()")
ct.Reset(capnp.Client{}, capnp.Client{})

ct.AddClient(capnp.Client{})
ct.AddClient(capnp.Client{})
assert.Equal(t, 2, ct.Len(),
"zero-value CapTable should be empty after Reset(c, c)")
"zero-value CapTable should be empty after Reset() & add twice")

errTest := errors.New("test")
ct.Set(capnp.CapabilityID(0), capnp.ErrorClient(errTest))
snapshot := ct.At(0).Snapshot()
ct.SetClient(capnp.CapabilityID(0), capnp.ErrorClient(errTest))
snapshot := ct.ClientAt(0).Snapshot()
defer snapshot.Release()
err := snapshot.Brand().Value.(error)
assert.ErrorIs(t, errTest, err, "should update client at index 0")
Expand Down
8 changes: 4 additions & 4 deletions internal/aircraftlib/aircraft.capnp.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion list.go
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ func (c CapList[T]) At(i int) (T, error) {
func (c CapList[T]) Set(i int, v T) error {
pl := PointerList(c)
seg := pl.Segment()
capId := seg.Message().CapTable().Add(Client(v))
capId := seg.Message().CapTable().AddClient(Client(v))
return pl.Set(i, NewInterface(seg, capId).ToPtr())
}

Expand Down
4 changes: 2 additions & 2 deletions localpromise.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (lp localPromise) String() string {

func (lp localPromise) Fulfill(c Client) {
msg, seg := NewSingleSegmentMessage(nil)
capID := msg.CapTable().Add(c)
capID := msg.CapTable().AddClient(c)
lp.aq.Fulfill(NewInterface(seg, capID).ToPtr())
}

Expand All @@ -60,7 +60,7 @@ type localResolver[C ~ClientKind] struct {
}

func (lf localResolver[C]) Fulfill(c C) {
lf.lp.Fulfill(Client(c))
lf.lp.Fulfill(Client(c).AddRef())
lf.clientResolver.Fulfill(Client(c))
}

Expand Down
16 changes: 8 additions & 8 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,38 +400,38 @@ func TestAddCap(t *testing.T) {
msg := &Message{Arena: SingleSegment(nil)}

// Simple case: distinct non-nil clients.
id1 := msg.CapTable().Add(client1.AddRef())
id1 := msg.CapTable().AddClient(client1.AddRef())
assert.Equal(t, CapabilityID(0), id1,
"first capability ID should be 0")
assert.Equal(t, 1, msg.CapTable().Len(),
"should have exactly one capability in the capTable")
assert.True(t, msg.CapTable().At(0).IsSame(client1),
assert.True(t, msg.CapTable().ClientAt(0).IsSame(client1),
"client does not match entry in cap table")

id2 := msg.CapTable().Add(client2.AddRef())
id2 := msg.CapTable().AddClient(client2.AddRef())
assert.Equal(t, CapabilityID(1), id2,
"second capability ID should be 1")
assert.Equal(t, 2, msg.CapTable().Len(),
"should have exactly two capabilities in the capTable")
assert.True(t, msg.CapTable().At(1).IsSame(client2),
assert.True(t, msg.CapTable().ClientAt(1).IsSame(client2),
"client does not match entry in cap table")

// nil client
id3 := msg.CapTable().Add(Client{})
id3 := msg.CapTable().AddClient(Client{})
assert.Equal(t, CapabilityID(2), id3,
"third capability ID should be 2")
assert.Equal(t, 3, msg.CapTable().Len(),
"should have exactly three capabilities in the capTable")
assert.True(t, msg.CapTable().At(2).IsSame(Client{}),
assert.True(t, msg.CapTable().ClientAt(2).IsSame(Client{}),
"client does not match entry in cap table")

// Add should not attempt to deduplicate.
id4 := msg.CapTable().Add(client1.AddRef())
id4 := msg.CapTable().AddClient(client1.AddRef())
assert.Equal(t, CapabilityID(3), id4,
"fourth capability ID should be 3")
assert.Equal(t, 4, msg.CapTable().Len(),
"should have exactly four capabilities in the capTable")
assert.True(t, msg.CapTable().At(3).IsSame(client1),
assert.True(t, msg.CapTable().ClientAt(3).IsSame(client1),
"client does not match entry in cap table")

// Verify that Add steals the reference: once client1 and client2
Expand Down
4 changes: 2 additions & 2 deletions pogs/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (ins *inserter) insertField(s capnp.Struct, f schema.Field, val reflect.Val
if !c.IsValid() {
return s.SetPtr(off, capnp.Ptr{})
}
id := s.Message().CapTable().Add(c)
id := s.Message().CapTable().AddClient(c)
return s.SetPtr(off, capnp.NewInterface(s.Segment(), id).ToPtr())
default:
panic("unreachable")
Expand All @@ -255,7 +255,7 @@ func capPtr(seg *capnp.Segment, val reflect.Value) capnp.Ptr {
if !client.IsValid() {
return capnp.Ptr{}
}
cap := seg.Message().CapTable().Add(client)
cap := seg.Message().CapTable().AddClient(client)
iface := capnp.NewInterface(seg, cap)
return iface.ToPtr()
}
Expand Down
Loading