diff --git a/finagle-core/src/main/scala/com/twitter/finagle/context/MarshalledContext.scala b/finagle-core/src/main/scala/com/twitter/finagle/context/MarshalledContext.scala index c90d1855ba9..2dba2414df2 100644 --- a/finagle-core/src/main/scala/com/twitter/finagle/context/MarshalledContext.scala +++ b/finagle-core/src/main/scala/com/twitter/finagle/context/MarshalledContext.scala @@ -127,6 +127,11 @@ final class MarshalledContext private[context] extends Context { letLocal(next)(fn) } + private[finagle] def retainIds[R](ids: Set[String])(fn: => R): R = { + val next = env.filter { case (id, _) => ids.contains(id) } + letLocal(next)(fn) + } + def letClearAll[R](fn: => R): R = local.letClear(fn) /** diff --git a/finagle-core/src/test/scala/com/twitter/finagle/context/MarshalledContextTest.scala b/finagle-core/src/test/scala/com/twitter/finagle/context/MarshalledContextTest.scala index 7bca8410c27..db1145e5a3f 100644 --- a/finagle-core/src/test/scala/com/twitter/finagle/context/MarshalledContextTest.scala +++ b/finagle-core/src/test/scala/com/twitter/finagle/context/MarshalledContextTest.scala @@ -62,6 +62,55 @@ class MarshalledContextTest extends AbstractContextTest { } } + test("retainIds") { + val ctx = new MarshalledContext + + def stringKey(id: String): ctx.Key[String] = new ctx.Key[String](id) { + def marshal(value: String): Buf = Buf.Utf8(value) + def tryUnmarshal(buf: Buf): Return[String] = buf match { + case Buf.Utf8(value) => Return(value) + } + } + + val fooKey = stringKey("foo") + val barKey = stringKey("bar") + val bazKey = stringKey("baz") + + ctx.let( + Seq( + ctx.KeyValuePair(fooKey, "foo-value"), + ctx.KeyValuePair(barKey, "bar-value"), + ctx.KeyValuePair(bazKey, "baz-value"))) { + + assert( + ctx.marshal() == Map( + fooKey.marshalId -> Buf.Utf8("foo-value"), + barKey.marshalId -> Buf.Utf8("bar-value"), + bazKey.marshalId -> Buf.Utf8("baz-value"), + )) + + ctx.retainIds(Set("foo", "baz")) { + assert( + ctx.marshal() == Map( + fooKey.marshalId -> Buf.Utf8("foo-value"), + bazKey.marshalId -> Buf.Utf8("baz-value") + )) + + ctx.retainIds(Set("foo")) { + assert( + ctx.marshal() == Map( + fooKey.marshalId -> Buf.Utf8("foo-value") + )) + + // qux doesn't exist + ctx.retainIds(Set("qux")) { + assert(ctx.marshal().isEmpty) + } + } + } + } + } + test("key lookups are case insensitive") { val ctx = new MarshalledContext val lowerKey = new ctx.Key[String]("foo") {