From 745f71118462894a0599996881ccd5d81f877184 Mon Sep 17 00:00:00 2001 From: Jillian Crossley Date: Fri, 25 Oct 2024 15:00:26 +0000 Subject: [PATCH] finagle/finagle-core: Add retaintIds method to MarshalledContext Problem A service may want to remove extraneous contexts from being broadcasted downstream. Solution Add method `retainIds` to remove all contexts except those specified. Differential Revision: https://phabricator.twitter.biz/D1178842 --- .../finagle/context/MarshalledContext.scala | 5 ++ .../context/MarshalledContextTest.scala | 49 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/finagle-core/src/main/scala/com/twitter/finagle/context/MarshalledContext.scala b/finagle-core/src/main/scala/com/twitter/finagle/context/MarshalledContext.scala index c90d1855ba9..2dba2414df2 100644 --- a/finagle-core/src/main/scala/com/twitter/finagle/context/MarshalledContext.scala +++ b/finagle-core/src/main/scala/com/twitter/finagle/context/MarshalledContext.scala @@ -127,6 +127,11 @@ final class MarshalledContext private[context] extends Context { letLocal(next)(fn) } + private[finagle] def retainIds[R](ids: Set[String])(fn: => R): R = { + val next = env.filter { case (id, _) => ids.contains(id) } + letLocal(next)(fn) + } + def letClearAll[R](fn: => R): R = local.letClear(fn) /** diff --git a/finagle-core/src/test/scala/com/twitter/finagle/context/MarshalledContextTest.scala b/finagle-core/src/test/scala/com/twitter/finagle/context/MarshalledContextTest.scala index 7bca8410c27..db1145e5a3f 100644 --- a/finagle-core/src/test/scala/com/twitter/finagle/context/MarshalledContextTest.scala +++ b/finagle-core/src/test/scala/com/twitter/finagle/context/MarshalledContextTest.scala @@ -62,6 +62,55 @@ class MarshalledContextTest extends AbstractContextTest { } } + test("retainIds") { + val ctx = new MarshalledContext + + def stringKey(id: String): ctx.Key[String] = new ctx.Key[String](id) { + def marshal(value: String): Buf = Buf.Utf8(value) + def tryUnmarshal(buf: Buf): Return[String] = buf match { + case Buf.Utf8(value) => Return(value) + } + } + + val fooKey = stringKey("foo") + val barKey = stringKey("bar") + val bazKey = stringKey("baz") + + ctx.let( + Seq( + ctx.KeyValuePair(fooKey, "foo-value"), + ctx.KeyValuePair(barKey, "bar-value"), + ctx.KeyValuePair(bazKey, "baz-value"))) { + + assert( + ctx.marshal() == Map( + fooKey.marshalId -> Buf.Utf8("foo-value"), + barKey.marshalId -> Buf.Utf8("bar-value"), + bazKey.marshalId -> Buf.Utf8("baz-value"), + )) + + ctx.retainIds(Set("foo", "baz")) { + assert( + ctx.marshal() == Map( + fooKey.marshalId -> Buf.Utf8("foo-value"), + bazKey.marshalId -> Buf.Utf8("baz-value") + )) + + ctx.retainIds(Set("foo")) { + assert( + ctx.marshal() == Map( + fooKey.marshalId -> Buf.Utf8("foo-value") + )) + + // qux doesn't exist + ctx.retainIds(Set("qux")) { + assert(ctx.marshal().isEmpty) + } + } + } + } + } + test("key lookups are case insensitive") { val ctx = new MarshalledContext val lowerKey = new ctx.Key[String]("foo") {