Skip to content

Commit

Permalink
client-cache invalidation message parsing code
Browse files Browse the repository at this point in the history
  • Loading branch information
mhelmich committed May 31, 2024
1 parent 2d8fa02 commit 90cc0d3
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
14 changes: 14 additions & 0 deletions pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,20 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
return &Pong{
Payload: reply[1].(string),
}, nil
case "invalidate":
switch payload := reply[1].(type) {
case []interface{}:
s := make([]string, len(payload))
for idx := range payload {
s[idx] = payload[idx].(string)
}
return &Message{
Channel: "invalidate",
PayloadSlice: s,
}, nil
default:
return nil, fmt.Errorf("redis: unsupported invalidate message payload: %#v", payload)
}
default:
return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
}
Expand Down
97 changes: 97 additions & 0 deletions pubsub_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package redis_test

import (
"context"
"fmt"
"io"
"net"
"sync"
Expand Down Expand Up @@ -567,4 +569,99 @@ var _ = Describe("PubSub", func() {
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal(text))
})

It("supports client-cache invalidation messages", func() {
ch := make(chan []string, 2)
defer close(ch)
client := redis.NewClient(getOptsWithTracking(redisOptions(), func(keys []string) error {
ch <- keys
return nil
}))
defer client.Close()

v1 := client.Get(context.Background(), "foo")
Expect(v1.Val()).To(Equal(""))
s1 := client.Set(context.Background(), "foo", "bar", time.Duration(time.Minute))
Expect(s1.Val()).To(Equal("OK"))
v2 := client.Get(context.Background(), "foo")
Expect(v2.Val()).To(Equal("bar"))
// sleep a little to allow time for the first invalidation message to come through
time.Sleep(time.Second)
s2 := client.Set(context.Background(), "foo", "foobar", time.Duration(time.Minute))
Expect(s2.Val()).To(Equal("OK"))

for i := 0; i < 2; i++ {
select {
case keys := <-ch:
Expect(keys).ToNot(BeEmpty())
Expect(keys[0]).To(Equal("foo"))
case <-time.After(10 * time.Second):
// fail on timeouts
Fail("invalidation message wait timed out")
}
}
})

})

func getOptsWithTracking(opt *redis.Options, processInvalidKeysFunc func([]string) error) *redis.Options {
var mu sync.Mutex
invalidateClientID := int64(-1)
invalidateOpts := *opt
invalidateOpts.OnConnect = func(ctx context.Context, conn *redis.Conn) (err error) {
invalidateClientID, err = conn.ClientID(ctx).Result()
return
}

startBackgroundInvalidationSubscription := func(ctx context.Context) int64 {
mu.Lock()
defer mu.Unlock()

if invalidateClientID != -1 {
return invalidateClientID
}

invalidateClient := redis.NewClient(&invalidateOpts)
invalidations := invalidateClient.Subscribe(ctx, "__redis__:invalidate")

go func() {
defer func() {
invalidations.Close()
invalidateClient.Close()

mu.Lock()
invalidateClientID = -1
mu.Unlock()
}()

for {
msg, err := invalidations.ReceiveMessage(context.Background())
if err == io.EOF || err == context.Canceled {
return
} else if err != nil {
fmt.Printf("warning: subscription on key invalidations aborted: %s\n", err.Error())
// send back empty []string to fail the test
processInvalidKeysFunc([]string{})
return
}

processInvalidKeysFunc(msg.PayloadSlice)
}
}()

return invalidateClientID
}

opt.OnConnect = func(ctx context.Context, conn *redis.Conn) error {
invalidateClientID := startBackgroundInvalidationSubscription(ctx)
return conn.Process(
ctx,
redis.NewBoolCmd(
ctx,
"CLIENT", "TRACKING", "on",
"REDIRECT", fmt.Sprintf("%d", invalidateClientID),
),
)
}
return opt
}

0 comments on commit 90cc0d3

Please sign in to comment.