Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Handle incoming resolve messages, take 2 #530

Merged
merged 16 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type Promise struct {
// - Resolved. Fulfill or Reject has finished.

state mutex.Mutex[promiseState]

resolver Resolver[Ptr]
}

type promiseState struct {
Expand Down Expand Up @@ -64,11 +66,13 @@ type clientAndPromise struct {
}

// NewPromise creates a new unresolved promise. The PipelineCaller will
// be used to make pipelined calls before the promise resolves.
func NewPromise(m Method, pc PipelineCaller) *Promise {
// be used to make pipelined calls before the promise resolves. If resolver
// is not nil, calls to Fulfill will be forwarded to it.
func NewPromise(m Method, pc PipelineCaller, resolver Resolver[Ptr]) *Promise {
if pc == nil {
panic("NewPromise(nil)")
}

resolved := make(chan struct{})
p := &Promise{
method: m,
Expand All @@ -77,6 +81,7 @@ func NewPromise(m Method, pc PipelineCaller) *Promise {
signals: []func(){func() { close(resolved) }},
caller: pc,
}),
resolver: resolver,
}
p.ans.f.promise = p
p.ans.metadata = *NewMetadata()
Expand Down Expand Up @@ -152,6 +157,14 @@ func (p *Promise) Resolve(r Ptr, e error) {
return p.clients
})

if p.resolver != nil {
if e == nil {
p.resolver.Fulfill(r)
} else {
p.resolver.Reject(e)
}
}

// Pending resolution state: wait for clients to be fulfilled
// and calls to have answers.
res := resolution{p.method, r, e}
Expand Down
12 changes: 6 additions & 6 deletions answer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var dummyMethod = Method{

func TestPromiseReject(t *testing.T) {
t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
done := p.Answer().Done()
p.Reject(errors.New("omg bbq"))
select {
Expand All @@ -27,7 +27,7 @@ func TestPromiseReject(t *testing.T) {
}
})
t.Run("Struct", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
ans := p.Answer()
p.Reject(errors.New("omg bbq"))
Expand All @@ -36,7 +36,7 @@ func TestPromiseReject(t *testing.T) {
}
})
t.Run("Client", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
pc := p.Answer().Field(1, nil).Client()
p.Reject(errors.New("omg bbq"))
Expand All @@ -57,7 +57,7 @@ func TestPromiseFulfill(t *testing.T) {
t.Parallel()

t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
done := p.Answer().Done()
msg, seg, _ := NewMessage(SingleSegment(nil))
defer msg.Release()
Expand All @@ -72,7 +72,7 @@ func TestPromiseFulfill(t *testing.T) {
}
})
t.Run("Struct", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
ans := p.Answer()
msg, seg, _ := NewMessage(SingleSegment(nil))
Expand All @@ -92,7 +92,7 @@ func TestPromiseFulfill(t *testing.T) {
}
})
t.Run("Client", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
pc := p.Answer().Field(1, nil).Client()

Expand Down
2 changes: 1 addition & 1 deletion answerqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea
}
}
}
sr.p = NewPromise(m, pcall)
sr.p = NewPromise(m, pcall, nil)
ans := sr.p.Answer()
return ans, func() {
<-ans.Done()
Expand Down
19 changes: 19 additions & 0 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,22 @@ func (cs ClientSnapshot) IsPromise() bool {
return ret
}

// IsResolved returns true if the snapshot has resolved to its final value.
// If IsPromise() returns false, then this will also return false. Otherwise,
// it returns false before resolution and true afterwards.
func (cs ClientSnapshot) IsResolved() bool {
if cs.hook == nil {
return false
}
res, ok := cs.hook.Value().resolution.Get()
if !ok {
return false
}
return mutex.With1(res, func(s *resolveState) bool {
return s.isResolved()
})
}

// Send implements ClientHook.Send
func (cs ClientSnapshot) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) {
if cs.hook == nil {
Expand Down Expand Up @@ -805,6 +821,9 @@ func SetClientLeakFunc(clientLeakFunc func(msg string)) {
clientLeakFunc("leaked client created at:\n\n" + stack)
})
case ClientSnapshot:
if !c.IsValid() {
return
}
runtime.SetFinalizer(c.hook, func(c *rc.Ref[clientHook]) {
if !c.IsValid() {
return
Expand Down
2 changes: 1 addition & 1 deletion capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func TestResolve(t *testing.T) {
}
t.Run("Clients", func(t *testing.T) {
test(t, "Waits for the full chain", func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client]) {
r1.Fulfill(p2)
r1.Fulfill(p2.AddRef())
ctx, cancel := context.WithTimeout(context.Background(), time.Second/10)
defer cancel()
require.NotNil(t, p1.Resolve(ctx), "blocks on second promise")
Expand Down
61 changes: 14 additions & 47 deletions localpromise.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package capnp

import (
"context"
)

// ClientHook for a promise that will be resolved to some other capability
// at some point. Buffers calls in a queue until the promsie is fulfilled,
// then forwards them.
Expand All @@ -12,59 +8,30 @@ type localPromise struct {
}

// NewLocalPromise returns a client that will eventually resolve to a capability,
// supplied via the fulfiller.
// supplied via the resolver.
func NewLocalPromise[C ~ClientKind]() (C, Resolver[C]) {
lp := newLocalPromise()
p, f := NewPromisedClient(lp)
aq := NewAnswerQueue(Method{})
f := NewPromise(Method{}, aq, aq)
p := f.Answer().Client().AddRef()
return C(p), localResolver[C]{
lp: lp,
clientResolver: f,
p: f,
}
}

func newLocalPromise() localPromise {
return localPromise{aq: NewAnswerQueue(Method{})}
}

func (lp localPromise) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) {
return lp.aq.PipelineSend(ctx, nil, s)
}

func (lp localPromise) Recv(ctx context.Context, r Recv) PipelineCaller {
return lp.aq.PipelineRecv(ctx, nil, r)
}

func (lp localPromise) Brand() Brand {
return Brand{}
}

func (lp localPromise) Shutdown() {}

func (lp localPromise) String() string {
return "localPromise{...}"
}

func (lp localPromise) Fulfill(c Client) {
msg, seg := NewSingleSegmentMessage(nil)
capID := msg.CapTable().Add(c)
lp.aq.Fulfill(NewInterface(seg, capID).ToPtr())
}

func (lp localPromise) Reject(err error) {
lp.aq.Reject(err)
}

type localResolver[C ~ClientKind] struct {
lp localPromise
clientResolver Resolver[Client]
p *Promise
}

func (lf localResolver[C]) Fulfill(c C) {
lf.lp.Fulfill(Client(c))
lf.clientResolver.Fulfill(Client(c))
msg, seg := NewSingleSegmentMessage(nil)
capID := msg.CapTable().Add(Client(c))
iface := NewInterface(seg, capID)
lf.p.Fulfill(iface.ToPtr())
lf.p.ReleaseClients()
msg.Release()
}

func (lf localResolver[C]) Reject(err error) {
lf.lp.Reject(err)
lf.clientResolver.Reject(err)
lf.p.Reject(err)
lf.p.ReleaseClients()
}
2 changes: 1 addition & 1 deletion rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ er
func (ans *ansent) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) {
if !ans.flags.Contains(resultsReady) {
ans.pcall = pcall
ans.promise = capnp.NewPromise(m, pcall)
ans.promise = capnp.NewPromise(m, pcall, nil)
}
}

Expand Down
51 changes: 29 additions & 22 deletions rpc/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ type exportID uint32

// expent is an entry in a Conn's export table.
type expent struct {
snapshot capnp.ClientSnapshot
wireRefs uint32
isPromise bool
snapshot capnp.ClientSnapshot
wireRefs uint32

// Should be called when removing this entry from the exports table:
cancel context.CancelFunc
Expand Down Expand Up @@ -74,9 +73,11 @@ func (c *lockedConn) releaseExport(id exportID, count uint32) (capnp.ClientSnaps
c.lk.exports[id] = nil
c.lk.exportID.remove(uint32(id))
metadata := snapshot.Metadata()
syncutil.With(metadata, func() {
c.clearExportID(metadata)
})
if metadata != nil {
syncutil.With(metadata, func() {
c.clearExportID(metadata)
})
}
return snapshot, nil
case count > ent.wireRefs:
return capnp.ClientSnapshot{}, rpcerr.Failed(errors.New("export ID " + str.Utod(id) + " released too many references"))
Expand Down Expand Up @@ -203,7 +204,7 @@ func (c *lockedConn) sendSenderPromise(id exportID, d rpccp.CapDescriptor) {
// Conn before trying to use it again:
unlockedConn := (*Conn)(c)

waitErr := waitRef.Resolve(ctx)
waitErr := waitRef.Resolve1(ctx)
unlockedConn.withLocked(func(c *lockedConn) {
if len(c.lk.exports) <= int(id) || c.lk.exports[id] != ee {
// Export was removed from the table at some point;
Expand Down Expand Up @@ -366,33 +367,39 @@ func (e *embargo) Shutdown() {
// senderLoopback holds the salient information for a sender-loopback
// Disembargo message.
type senderLoopback struct {
id embargoID
question questionID
transform []capnp.PipelineOp
id embargoID
target parsedMessageTarget
}

func (sl *senderLoopback) buildDisembargo(msg rpccp.Message) error {
d, err := msg.NewDisembargo()
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}
d.Context().SetSenderLoopback(uint32(sl.id))
tgt, err := d.NewTarget()
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}
pa, err := tgt.NewPromisedAnswer()
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}
oplist, err := pa.NewTransform(int32(len(sl.transform)))
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}
switch sl.target.which {
case rpccp.MessageTarget_Which_promisedAnswer:
pa, err := tgt.NewPromisedAnswer()
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}
oplist, err := pa.NewTransform(int32(len(sl.target.transform)))
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}

d.Context().SetSenderLoopback(uint32(sl.id))
pa.SetQuestionId(uint32(sl.question))
for i, op := range sl.transform {
oplist.At(i).SetGetPointerField(op.Field)
pa.SetQuestionId(uint32(sl.target.promisedAnswer))
for i, op := range sl.target.transform {
oplist.At(i).SetGetPointerField(op.Field)
}
case rpccp.MessageTarget_Which_importedCap:
tgt.SetImportedCap(uint32(sl.target.importedCap))
default:
return errors.New("unknown variant for MessageTarget: " + str.Utod(sl.target.which))
}
return nil
}
21 changes: 18 additions & 3 deletions rpc/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,19 @@ type impent struct {
// importClient's generation matches the entry's generation before
// removing the entry from the table and sending a release message.
generation uint64

// If resolver is non-nil, then this is a promise (received as
// CapDescriptor_Which_senderPromise), and when a resolve message
// arrives we should use this to fulfill the promise locally.
resolver capnp.Resolver[capnp.Client]
}

// addImport returns a client that represents the given import,
// incrementing the number of references to this import from this vat.
// This is separate from the reference counting that capnp.Client does.
//
// The caller must be holding onto c.mu.
func (c *lockedConn) addImport(id importID) capnp.Client {
func (c *lockedConn) addImport(id importID, isPromise bool) capnp.Client {
if ent := c.lk.imports[id]; ent != nil {
ent.wireRefs++
client, ok := ent.wc.AddRef()
Expand All @@ -67,13 +72,23 @@ func (c *lockedConn) addImport(id importID) capnp.Client {
}
return client
}
client := capnp.NewClient(&importClient{
hook := &importClient{
c: (*Conn)(c),
id: id,
})
}
var (
client capnp.Client
resolver capnp.Resolver[capnp.Client]
)
if isPromise {
client, resolver = capnp.NewPromisedClient(hook)
} else {
client = capnp.NewClient(hook)
}
c.lk.imports[id] = &impent{
wc: client.WeakRef(),
wireRefs: 1,
resolver: resolver,
}
return client
}
Expand Down
Loading
Loading