diff --git a/pubsub.go b/pubsub.go index aea96241f..00a25f903 100644 --- a/pubsub.go +++ b/pubsub.go @@ -409,6 +409,20 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { return &Pong{ Payload: reply[1].(string), }, nil + case "invalidate": + switch payload := reply[1].(type) { + case []interface{}: + s := make([]string, len(payload)) + for idx := range payload { + s[idx] = payload[idx].(string) + } + return &Message{ + Channel: "invalidate", + PayloadSlice: s, + }, nil + default: + return nil, fmt.Errorf("redis: unsupported invalidate message payload: %#v", payload) + } default: return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind) } diff --git a/pubsub_test.go b/pubsub_test.go index a76100659..17d29fba7 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -1,6 +1,8 @@ package redis_test import ( + "context" + "fmt" "io" "net" "sync" @@ -567,4 +569,99 @@ var _ = Describe("PubSub", func() { Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Payload).To(Equal(text)) }) + + It("supports client-cache invalidation messages", func() { + ch := make(chan []string, 2) + defer close(ch) + client := redis.NewClient(getOptsWithTracking(redisOptions(), func(keys []string) error { + ch <- keys + return nil + })) + defer client.Close() + + v1 := client.Get(context.Background(), "foo") + Expect(v1.Val()).To(Equal("")) + s1 := client.Set(context.Background(), "foo", "bar", time.Duration(time.Minute)) + Expect(s1.Val()).To(Equal("OK")) + v2 := client.Get(context.Background(), "foo") + Expect(v2.Val()).To(Equal("bar")) + // sleep a little to allow time for the first invalidation message to come through + time.Sleep(time.Second) + s2 := client.Set(context.Background(), "foo", "foobar", time.Duration(time.Minute)) + Expect(s2.Val()).To(Equal("OK")) + + for i := 0; i < 2; i++ { + select { + case keys := <-ch: + Expect(keys).ToNot(BeEmpty()) + Expect(keys[0]).To(Equal("foo")) + case <-time.After(10 * time.Second): + // fail on timeouts + Fail("invalidation message wait timed out") + } + } + }) + }) + +func getOptsWithTracking(opt *redis.Options, processInvalidKeysFunc func([]string) error) *redis.Options { + var mu sync.Mutex + invalidateClientID := int64(-1) + invalidateOpts := *opt + invalidateOpts.OnConnect = func(ctx context.Context, conn *redis.Conn) (err error) { + invalidateClientID, err = conn.ClientID(ctx).Result() + return + } + + startBackgroundInvalidationSubscription := func(ctx context.Context) int64 { + mu.Lock() + defer mu.Unlock() + + if invalidateClientID != -1 { + return invalidateClientID + } + + invalidateClient := redis.NewClient(&invalidateOpts) + invalidations := invalidateClient.Subscribe(ctx, "__redis__:invalidate") + + go func() { + defer func() { + invalidations.Close() + invalidateClient.Close() + + mu.Lock() + invalidateClientID = -1 + mu.Unlock() + }() + + for { + msg, err := invalidations.ReceiveMessage(context.Background()) + if err == io.EOF || err == context.Canceled { + return + } else if err != nil { + fmt.Printf("warning: subscription on key invalidations aborted: %s\n", err.Error()) + // send back empty []string to fail the test + processInvalidKeysFunc([]string{}) + return + } + + processInvalidKeysFunc(msg.PayloadSlice) + } + }() + + return invalidateClientID + } + + opt.OnConnect = func(ctx context.Context, conn *redis.Conn) error { + invalidateClientID := startBackgroundInvalidationSubscription(ctx) + return conn.Process( + ctx, + redis.NewBoolCmd( + ctx, + "CLIENT", "TRACKING", "on", + "REDIRECT", fmt.Sprintf("%d", invalidateClientID), + ), + ) + } + return opt +}