Skip to content

Commit

Permalink
Merge pull request #519 from zenhack/clienthook-snapshot
Browse files Browse the repository at this point in the history
Clienthook snapshot
  • Loading branch information
zenhack authored May 26, 2023
2 parents 2ad05d6 + b57e496 commit 59cfdc9
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 92 deletions.
4 changes: 3 additions & 1 deletion answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,9 @@ func (pc PipelineClient) Brand() Brand {
r := mutex.With1(&pc.p.state, func(p *promiseState) resolution {
return p.resolution(pc.p.method)
})
return r.client(pc.transform).State().Brand
snapshot := r.client(pc.transform).Snapshot()
defer snapshot.Release()
return snapshot.Brand()
default:
return Brand{Value: pc}
}
Expand Down
190 changes: 139 additions & 51 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,32 +519,17 @@ func (c Client) IsSame(c2 Client) bool {
// Resolve only returns an error if the context is canceled; it returns nil even
// if the capability resolves to an error.
func (c Client) Resolve(ctx context.Context) error {
for {
h, resolved, released := c.startCall()
defer h.Release()
if released {
return errors.New("cannot resolve released client")
}

if resolved {
return nil
}

r, ok := h.Value().resolution.Get()
if !ok {
return nil
}

resolvedCh := mutex.With1(r, func(s *resolveState) <-chan struct{} {
return s.resolved
})

select {
case <-resolvedCh:
case <-ctx.Done():
return ctx.Err()
}
h, resolved, released := c.startCall()
defer h.Release()
if released {
return errors.New("cannot resolve released client")
}
if resolved {
return nil
}
h, err := resolveClientHook(ctx, h)
h.Release()
return err
}

// AddRef creates a new Client that refers to the same capability as c.
Expand Down Expand Up @@ -577,39 +562,142 @@ func (c Client) WeakRef() WeakClient {
return WeakClient{r: cursor}
}

// State reads the current state of the client. It returns the zero
// ClientState if c is nil, has resolved to null, or has been released.
func (c Client) State() ClientState {
h, resolved, _ := c.startCall()
defer h.Release()
if h == nil {
return ClientState{}
}
return ClientState{
Brand: h.Value().Brand(),
IsPromise: !resolved,
Metadata: &h.Value().metadata,
}
// Snapshot reads the current state of the client. It returns the zero
// ClientSnapshot if c is nil, has resolved to null, or has been released.
func (c Client) Snapshot() ClientSnapshot {
h, _, _ := c.startCall()
return ClientSnapshot{hook: h}
}

// A Brand is an opaque value used to identify a capability.
type Brand struct {
Value any
}

// ClientState is a snapshot of a client's identity.
type ClientState struct {
// Brand is the value returned from the hook's Brand method.
Brand Brand
// IsPromise is true if the client has not resolved yet.
IsPromise bool
// Arbitrary metadata. Note that, if a Client is a promise,
// when it resolves its metadata will be replaced with that
// of its resolution.
//
// TODO: this might change before the v3 API is stabilized;
// we are not sure the above is the correct semantics.
Metadata *Metadata
// ClientSnapshot is a snapshot of a client's identity. If the Client
// is a promise, then the corresponding ClientSnapshot will *not*
// redirect to point at the resolution.
type ClientSnapshot struct {
hook *rc.Ref[clientHook]
}

func (cs ClientSnapshot) IsValid() bool {
return cs.hook.IsValid()
}

// IsPromise returns true if the snapshot is a promise.
func (cs ClientSnapshot) IsPromise() bool {
if cs.hook == nil {
return false
}
_, ret := cs.hook.Value().resolution.Get()
return ret
}

// Send implements ClientHook.Send
func (cs ClientSnapshot) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) {
return cs.hook.Value().Send(ctx, s)
}

// Recv implements ClientHook.Recv
func (cs ClientSnapshot) Recv(ctx context.Context, r Recv) PipelineCaller {
return cs.hook.Value().Recv(ctx, r)
}

// Client returns a client pointing at the most-resolved version of the snapshot.
func (cs ClientSnapshot) Client() Client {
cursor := rc.NewRefInPlace(func(c *clientCursor) func() {
*c = clientCursor{hook: mutex.New(cs.hook.AddRef())}
c.compress()
return c.Release
})
c := Client{client: &client{
state: mutex.New(clientState{cursor: cursor}),
}}
setupLeakReporting(c)
return c
}

// Brand is the value returned from the ClientHook's Brand method.
// Returns the zero Brand if the receiver is the zero ClientSnapshot.
func (cs ClientSnapshot) Brand() Brand {
if cs.hook == nil {
return Brand{}
}
return cs.hook.Value().Brand()
}

// Return a the reference to the Metadata associated with this client hook.
// Callers may store whatever they need here.
func (cs ClientSnapshot) Metadata() *Metadata {
return &cs.hook.Value().metadata
}

// Create a copy of the snapshot, with its own underlying reference.
func (cs ClientSnapshot) AddRef() ClientSnapshot {
cs.hook = cs.hook.AddRef()
return cs
}

// Release the reference to the hook.
func (cs ClientSnapshot) Release() {
cs.hook.Release()
}

func (cs *ClientSnapshot) Resolve1(ctx context.Context) error {
var err error
cs.hook, _, err = resolve1ClientHook(ctx, cs.hook)
return err
}

func (cs *ClientSnapshot) resolve1(ctx context.Context) (more bool, err error) {
cs.hook, more, err = resolve1ClientHook(ctx, cs.hook)
return
}

func (cs *ClientSnapshot) Resolve(ctx context.Context) error {
var err error
cs.hook, err = resolveClientHook(ctx, cs.hook)
return err
}

func resolveClientHook(ctx context.Context, h *rc.Ref[clientHook]) (_ *rc.Ref[clientHook], err error) {
for {
var more bool
h, more, err = resolve1ClientHook(ctx, h)
if !more || err != nil {
return h, err
}
}
}

func resolve1ClientHook(ctx context.Context, h *rc.Ref[clientHook]) (_ *rc.Ref[clientHook], more bool, err error) {
if !h.IsValid() {
return h, false, nil
}
defer h.Release()

r, ok := h.Value().resolution.Get()
if !ok {
return h.AddRef(), false, nil
}

resolvedCh := mutex.With1(r, func(s *resolveState) <-chan struct{} {
return s.resolved
})

select {
case <-resolvedCh:
rh := mutex.With1(r, func(r *resolveState) *rc.Ref[clientHook] {
return r.resolvedHook
})
if rh == nil {
return nil, false, nil
}
return rh.AddRef(), true, nil
case <-ctx.Done():
return h.AddRef(), true, ctx.Err()
}
}

// String returns a string that identifies this capability for debugging
Expand Down
97 changes: 73 additions & 24 deletions capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestClient(t *testing.T) {
Expand All @@ -23,13 +24,14 @@ func TestClient(t *testing.T) {
if !c.IsValid() {
t.Error("new client is not valid")
}
state := c.State()
if state.IsPromise {
state := c.Snapshot()
if state.IsPromise() {
t.Error("c.State().IsPromise = true; want false")
}
if state.Brand.Value != int(42) {
t.Errorf("c.State().Brand.Value = %#v; want 42", state.Brand.Value)
if state.Brand().Value != int(42) {
t.Errorf("c.State().Brand().Value = %#v; want 42", state.Brand().Value)
}
state.Release()
ans, finish := c.SendCall(ctx, Send{})
if _, err := ans.Struct(); err != nil {
t.Error("SendCall:", err)
Expand Down Expand Up @@ -78,13 +80,14 @@ func TestReleasedClient(t *testing.T) {
if c.IsValid() {
t.Error("released client is valid")
}
state := c.State()
if state.Brand.Value != nil {
t.Errorf("c.State().Brand.Value = %#v; want <nil>", state.Brand.Value)
state := c.Snapshot()
if state.Brand().Value != nil {
t.Errorf("c.Snapshot().Brand().Value = %#v; want <nil>", state.Brand().Value)
}
if state.IsPromise {
t.Error("c.State().IsPromise = true; want false")
if state.IsPromise() {
t.Error("c.Snapshot().IsPromise = true; want false")
}
state.Release()
ans, finish := c.SendCall(ctx, Send{})
if _, err := ans.Struct(); err == nil {
t.Error("SendCall did not return error")
Expand Down Expand Up @@ -116,6 +119,49 @@ func TestReleasedClient(t *testing.T) {
t.Error("second Release made more calls to ClientHook.Shutdown")
}
}
func TestResolve(t *testing.T) {
test := func(t *testing.T, name string, f func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client])) {
t.Run(name, func(t *testing.T) {
t.Parallel()
p1, r1 := NewLocalPromise[Client]()
p2, r2 := NewLocalPromise[Client]()
defer p1.Release()
defer p2.Release()
f(t, p1, p2, r1, r2)
})
}
t.Run("Clients", func(t *testing.T) {
test(t, "Waits for the full chain", func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client]) {
r1.Fulfill(p2)
ctx, cancel := context.WithTimeout(context.Background(), time.Second/10)
defer cancel()
require.NotNil(t, p1.Resolve(ctx), "blocks on second promise")
r2.Fulfill(Client{})
require.NoError(t, p1.Resolve(context.Background()), "resolves after second resolution")
assert.True(t, p1.IsSame(Client{}), "p1 resolves to null")
assert.True(t, p2.IsSame(Client{}), "p2 resolves to null")
assert.True(t, p1.IsSame(p2), "p1 & p2 are the same")
})
})
t.Run("Snapshots", func(t *testing.T) {
test(t, "Resolve1 only waits for one link", func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client]) {
s1 := p1.Snapshot()
defer s1.Release()
r1.Fulfill(p2)
require.NoError(t, s1.Resolve1(context.Background()), "Resolve1 returns after first resolution")
})
test(t, "Resolve waits for the full chain", func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client]) {
s1 := p1.Snapshot()
defer s1.Release()
r1.Fulfill(p2)
ctx, cancel := context.WithTimeout(context.Background(), time.Second/10)
defer cancel()
require.NotNil(t, s1.Resolve(ctx), "blocks on second promise")
r2.Fulfill(Client{})
require.NoError(t, s1.Resolve(context.Background()), "resolves after second resolution")
})
})
}

func TestNullClient(t *testing.T) {
ctx := context.Background()
Expand All @@ -141,13 +187,14 @@ func TestNullClient(t *testing.T) {
if c.IsValid() {
t.Error("null client is valid")
}
state := c.State()
if state.Brand.Value != nil {
t.Errorf("c.State().Brand = %#v; want <nil>", state.Brand)
state := c.Snapshot()
if state.Brand().Value != nil {
t.Errorf("c.Snapshot().Brand() = %#v; want <nil>", state.Brand())
}
if state.IsPromise {
t.Error("c.State().IsPromise = true; want false")
if state.IsPromise() {
t.Error("c.Snapshot().IsPromise = true; want false")
}
state.Release()
ans, finish := c.SendCall(ctx, Send{})
if _, err := ans.Struct(); err == nil {
t.Error("SendCall did not return error")
Expand Down Expand Up @@ -186,13 +233,14 @@ func TestPromisedClient(t *testing.T) {
if ca.IsSame(cb) {
t.Error("before resolution, ca == cb")
}
state := ca.State()
if state.Brand.Value != int(111) {
t.Errorf("before resolution, ca.State().Brand.Value = %#v; want 111", state.Brand.Value)
state := ca.Snapshot()
if state.Brand().Value != int(111) {
t.Errorf("before resolution, ca.Snapshot().Brand().Value = %#v; want 111", state.Brand().Value)
}
if !state.IsPromise {
t.Error("before resolution, ca.State().IsPromise = false; want true")
if !state.IsPromise() {
t.Error("before resolution, ca.Snapshot().IsPromise = false; want true")
}
state.Release()
_, finish := ca.SendCall(ctx, Send{})
finish()
pa.Fulfill(cb)
Expand All @@ -207,13 +255,14 @@ func TestPromisedClient(t *testing.T) {
if !ca.IsSame(cb) {
t.Errorf("after resolution, ca != cb (%v vs. %v)", ca, cb)
}
state = ca.State()
if state.Brand.Value != int(222) {
t.Errorf("after resolution, ca.State().Brand.Value = %#v; want 222", state.Brand.Value)
state = ca.Snapshot()
if state.Brand().Value != int(222) {
t.Errorf("after resolution, ca.Snapshot().Brand().Value = %#v; want 222", state.Brand().Value)
}
if state.IsPromise {
t.Error("after resolution, ca.State().IsPromise = true; want false")
if state.IsPromise() {
t.Error("after resolution, ca.Snapshot().IsPromise = true; want false")
}
state.Release()

if b.shutdowns > 0 {
t.Error("b shut down before clients released")
Expand Down
4 changes: 3 additions & 1 deletion captable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ func TestCapTable(t *testing.T) {

errTest := errors.New("test")
ct.Set(capnp.CapabilityID(0), capnp.ErrorClient(errTest))
err := ct.At(0).State().Brand.Value.(error)
snapshot := ct.At(0).Snapshot()
defer snapshot.Release()
err := snapshot.Brand().Value.(error)
assert.ErrorIs(t, errTest, err, "should update client at index 0")
}
Loading

0 comments on commit 59cfdc9

Please sign in to comment.