diff --git a/finagle-memcached/src/main/scala/com/twitter/finagle/memcached/Interpreter.scala b/finagle-memcached/src/main/scala/com/twitter/finagle/memcached/Interpreter.scala index 8d194801bed..ffe23ed6a7b 100644 --- a/finagle-memcached/src/main/scala/com/twitter/finagle/memcached/Interpreter.scala +++ b/finagle-memcached/src/main/scala/com/twitter/finagle/memcached/Interpreter.scala @@ -72,7 +72,7 @@ class Interpreter(map: AtomicMap[Buf, Entry]) { data(key) = Entry(value, expiry, flags) Stored } else { - NotStored + Exists } case _ => NotStored @@ -122,17 +122,19 @@ class Interpreter(map: AtomicMap[Buf, Entry]) { existing match { case Some(entry) if entry.valid => val Buf.Utf8(existingString) = entry.value - if (!existingString.isEmpty && !ParserUtils.isDigits(entry.value)) - throw new ClientError("cannot increment or decrement non-numeric value") + if (!existingString.isEmpty && !ParserUtils.isDigits(entry.value)) { + Error(new ClientError("cannot increment or decrement non-numeric value")) + } else { - val existingValue: Long = - if (existingString.isEmpty) 0L - else existingString.toLong + val existingValue: Long = + if (existingString.isEmpty) 0L + else existingString.toLong - val result: Long = existingValue + delta - data(key) = Entry(Buf.Utf8(result.toString), entry.expiry, 0) + val result: Long = existingValue + delta + data(key) = Entry(Buf.Utf8(result.toString), entry.expiry, 0) - Number(result) + Number(result) + } case Some(_) => data.remove(key) // expired NotFound diff --git a/finagle-memcached/src/main/scala/com/twitter/finagle/memcached/protocol/text/server/ResponseToBuf.scala b/finagle-memcached/src/main/scala/com/twitter/finagle/memcached/protocol/text/server/ResponseToBuf.scala index 0081882bcd7..9650162a7d7 100644 --- a/finagle-memcached/src/main/scala/com/twitter/finagle/memcached/protocol/text/server/ResponseToBuf.scala +++ b/finagle-memcached/src/main/scala/com/twitter/finagle/memcached/protocol/text/server/ResponseToBuf.scala @@ -2,7 +2,9 @@ package com.twitter.finagle.memcached.protocol.text.server import com.twitter.finagle.memcached.protocol._ import com.twitter.finagle.memcached.protocol.text.EncodingConstants._ -import com.twitter.io.{Buf, BufByteWriter, ByteWriter} +import com.twitter.io.Buf +import com.twitter.io.BufByteWriter +import com.twitter.io.ByteWriter import java.nio.charset.StandardCharsets /** @@ -22,10 +24,17 @@ private[finagle] object ResponseToBuf { private[this] def encodeResponse(response: Seq[Buf]): Buf = { // + 2 to estimated size for DELIMITER. val bw = BufByteWriter.dynamic(10 * response.size + 2) - response.foreach { token => - bw.writeBytes(token) + var i = 0 + while (i < response.length - 1) { + bw.writeBytes(response(i)) bw.writeBytes(SPACE) + i += 1 + } + + if (response.nonEmpty) { + bw.writeBytes(response(i)) } + bw.writeBytes(DELIMITER) bw.owned() @@ -59,12 +68,8 @@ private[finagle] object ResponseToBuf { // + 5 to estimated size for END + DELIMITER. val bw = BufByteWriter.dynamic(100 * lines.size + 5) - lines.foreach { tokens => - tokens.foreach { token => - bw.writeBytes(token) - bw.writeBytes(SPACE) - } - bw.writeBytes(DELIMITER) + lines.foreach { line => + bw.writeBytes(encodeResponse(line)) } bw.writeBytes(END) bw.writeBytes(DELIMITER) diff --git a/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/integration/MemcachedServerTest.scala b/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/integration/MemcachedServerTest.scala new file mode 100644 index 00000000000..d0c4d93835e --- /dev/null +++ b/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/integration/MemcachedServerTest.scala @@ -0,0 +1,198 @@ +package com.twitter.finagle.memcached.integration + +import com.twitter.conversions.DurationOps._ +import com.twitter.finagle.Address +import com.twitter.finagle.Name +import com.twitter.finagle.Service +import com.twitter.finagle.ServiceFactory +import com.twitter.finagle.Stack +import com.twitter.finagle.client.StackClient +import com.twitter.finagle.client.StdStackClient +import com.twitter.finagle.client.Transporter +import com.twitter.finagle.dispatch.SerialClientDispatcher +import com.twitter.finagle.memcached.integration.external.InternalMemcached +import com.twitter.finagle.memcached.integration.external.TestMemcachedServer +import com.twitter.finagle.memcached.protocol.Add +import com.twitter.finagle.memcached.protocol.Cas +import com.twitter.finagle.memcached.protocol.Command +import com.twitter.finagle.memcached.protocol.Delete +import com.twitter.finagle.memcached.protocol.Get +import com.twitter.finagle.memcached.protocol.Gets +import com.twitter.finagle.memcached.protocol.Incr +import com.twitter.finagle.memcached.protocol.Set +import com.twitter.finagle.memcached.protocol.text.MessageEncoderHandler +import com.twitter.finagle.memcached.protocol.text.client.CommandToBuf +import com.twitter.finagle.netty4.Netty4Transporter +import com.twitter.finagle.netty4.encoder.BufEncoder +import com.twitter.finagle.stats.NullStatsReceiver +import com.twitter.finagle.transport.Transport +import com.twitter.finagle.transport.TransportContext +import com.twitter.io.Buf +import com.twitter.util.Await +import com.twitter.util.Time +import io.netty.channel.ChannelPipeline +import io.netty.handler.codec.string.StringDecoder +import java.net.SocketAddress +import java.nio.charset.StandardCharsets.UTF_8 +import org.scalatest.BeforeAndAfter +import org.scalatest.funsuite.AnyFunSuite + +// Because we use our Memcached server for testing, we need to ensure that it complies to the +// Memcached protocol. +private class MemcachedServerTest extends AnyFunSuite with BeforeAndAfter { + + private[this] var realServer: TestMemcachedServer = _ + private[this] var testServer: TestMemcachedServer = _ + + private[this] var realServerClient: Service[Command, String] = _ + private[this] var testServerClient: Service[Command, String] = _ + + before { + realServer = TestMemcachedServer.start().get + testServer = InternalMemcached.start(None).get + + realServerClient = StringClient + .apply().newService(Name.bound(Address(realServer.address)), "client") + + testServerClient = StringClient + .apply().newService(Name.bound(Address(testServer.address)), "client") + } + + after { + realServer.stop() + testServer.stop() + Await.result(realServerClient.close(), 5.seconds) + Await.result(testServerClient.close(), 5.seconds) + } + + if (Option(System.getProperty("EXTERNAL_MEMCACHED_PATH")).isDefined) { + test("NOT_FOUND") { + assertSameResponses(Incr(Buf.Utf8("key1"), 1), "NOT_FOUND\r\n") + } + + test("STORED") { + assertSameResponses( + Set(Buf.Utf8("key2"), 0, Time.epoch, Buf.Utf8("value")), + "STORED\r\n" + ) + } + + test("NOT_STORED") { + assertSameResponses(Add(Buf.Utf8("key3"), 0, Time.epoch, Buf.Utf8("value")), "STORED\r\n") + assertSameResponses(Add(Buf.Utf8("key3"), 0, Time.epoch, Buf.Utf8("value")), "NOT_STORED\r\n") + } + + test("EXISTS") { + assertSameResponses( + Set(Buf.Utf8("key4"), 0, Time.epoch, Buf.Utf8("value")), + "STORED\r\n" + ) + assertSameResponses(Gets(Seq(Buf.Utf8("key4"))), "VALUE key4 0 5 \\d+\r\nvalue\r\nEND\r\n") + + assertSameResponses( + Cas(Buf.Utf8("key4"), 0, Time.epoch, Buf.Utf8("value2"), Buf.Utf8("9999")), + "EXISTS\r\n") + } + + test("DELETED") { + assertSameResponses( + Set(Buf.Utf8("key5"), 0, Time.epoch, Buf.Utf8("value")), + "STORED\r\n" + ) + assertSameResponses(Delete(Buf.Utf8("key5")), "DELETED\r\n") + } + + test("CLIENT_ERROR") { + assertSameResponses(Set(Buf.Utf8("key6"), 0, Time.epoch, Buf.Utf8("value")), "STORED\r\n") + assertSameResponses( + Incr(Buf.Utf8("key6"), 1), + "CLIENT_ERROR cannot increment or decrement non-numeric value\r\n") + } + + // NO_OP will terminate the connection so can't be tested here. + // STATS not available in the interpreter so can't be tested here. + + test("VALUES (empty)") { + assertSameResponses(Gets(Seq(Buf.Utf8("key7"))), "END\r\n") + } + + test("VALUES without flags without casunique") { + assertSameResponses(Set(Buf.Utf8("key8"), 0, Time.epoch, Buf.Utf8("value")), "STORED\r\n") + // Note how flag 0 is still returned here + assertSameResponses(Get(Seq(Buf.Utf8("key8"))), "VALUE key8 0 5\r\nvalue\r\nEND\r\n") + } + + test("VALUES with flags without casunique") { + assertSameResponses(Set(Buf.Utf8("key9"), 2, Time.epoch, Buf.Utf8("value")), "STORED\r\n") + assertSameResponses(Get(Seq(Buf.Utf8("key9"))), "VALUE key9 2 5\r\nvalue\r\nEND\r\n") + } + + test("VALUES without flags with casunique") { + assertSameResponses(Set(Buf.Utf8("key10"), 0, Time.epoch, Buf.Utf8("value")), "STORED\r\n") + // Note how flag 0 is still returned here + assertSameResponses(Gets(Seq(Buf.Utf8("key10"))), "VALUE key10 0 5 \\d+\r\nvalue\r\nEND\r\n") + } + + test("VALUES with flags with casunique") { + assertSameResponses(Set(Buf.Utf8("key11"), 2, Time.epoch, Buf.Utf8("value")), "STORED\r\n") + assertSameResponses(Gets(Seq(Buf.Utf8("key11"))), "VALUE key11 2 5 \\d+\r\nvalue\r\nEND\r\n") + } + + test("VALUES (multiple lines)") { + assertSameResponses(Set(Buf.Utf8("key12"), 0, Time.epoch, Buf.Utf8("value")), "STORED\r\n") + assertSameResponses(Set(Buf.Utf8("key13"), 0, Time.epoch, Buf.Utf8("value")), "STORED\r\n") + + assertSameResponses( + Get(Seq(Buf.Utf8("key12"), Buf.Utf8("key13"))), + "VALUE key12 0 5\r\nvalue\r\nVALUE key13 0 5\r\nvalue\r\nEND\r\n") + } + + test("NUMBER") { + assertSameResponses(Set(Buf.Utf8("key14"), 0, Time.epoch, Buf.Utf8("1")), "STORED\r\n") + assertSameResponses(Incr(Buf.Utf8("key14"), 2), "3\r\n") + } + } + + private[this] def assertSameResponses(command: Command, response: String): Unit = { + val testServerResponse = Await.result(testServerClient(command), 5.seconds) + val realServerResponse = Await.result(realServerClient(command), 5.seconds) + + assert(testServerResponse.matches(response)) + assert(realServerResponse.matches(response)) + } + + private case class StringClient( + stack: Stack[ServiceFactory[Command, String]] = StackClient.newStack, + params: Stack.Params = Stack.Params.empty) + extends StdStackClient[Command, String, StringClient] { + + override protected type In = Command + override protected type Out = String + override protected type Context = TransportContext + + object PipelineInit extends (ChannelPipeline => Unit) { + override def apply(pipeline: ChannelPipeline): Unit = { + pipeline.addLast("encoder", BufEncoder) + pipeline.addLast("messageToBuf", new MessageEncoderHandler(new CommandToBuf)) + pipeline.addLast("decoder", new StringDecoder(UTF_8)) + } + } + + protected def newDispatcher( + transport: Transport[In, Out] { type Context <: StringClient.this.Context } + ): Service[In, Out] = { + new SerialClientDispatcher(transport, NullStatsReceiver) + } + + override protected def newTransporter( + addr: SocketAddress + ): Transporter[Command, String, TransportContext] = { + Netty4Transporter.raw(PipelineInit, addr, params) + } + + override protected def copy1( + stack: Stack[ServiceFactory[Command, String]], + params: Stack.Params + ): StringClient = copy(stack, params) + } +} diff --git a/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/unit/InterpreterTest.scala b/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/unit/InterpreterTest.scala index 8370b22ff6a..5ace28d4e34 100644 --- a/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/unit/InterpreterTest.scala +++ b/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/unit/InterpreterTest.scala @@ -41,7 +41,11 @@ class InterpreterTest extends AnyFunSuite { assert(interpreter(Gets(Seq(key))) == Values(Seq(Value(key, value1, hashValue1, emptyFlags)))) assert(interpreter(Cas(key, 0, Time.epoch, value2, hashValue1.get)) == Stored) - assert(interpreter(Cas(key, 0, Time.epoch, value3, hashValue1.get)) == NotStored) + assert(interpreter(Cas(key, 0, Time.epoch, value3, hashValue1.get)) == Exists) + + assert( + interpreter( + Cas(Buf.Utf8("non_existant"), 0, Time.epoch, value2, hashValue1.get)) == NotStored) } test("correctly perform the QUIT command") { diff --git a/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/unit/protocol/text/server/ResponseToBufTest.scala b/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/unit/protocol/text/server/ResponseToBufTest.scala index 47851df37a3..7172bea63b7 100644 --- a/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/unit/protocol/text/server/ResponseToBufTest.scala +++ b/finagle-memcached/src/test/scala/com/twitter/finagle/memcached/unit/protocol/text/server/ResponseToBufTest.scala @@ -1,6 +1,9 @@ package com.twitter.finagle.memcached.unit.protocol.text.server -import com.twitter.finagle.memcached.protocol.{ClientError, Error, NonexistentCommand, ServerError} +import com.twitter.finagle.memcached.protocol.ClientError +import com.twitter.finagle.memcached.protocol.Error +import com.twitter.finagle.memcached.protocol.NonexistentCommand +import com.twitter.finagle.memcached.protocol.ServerError import com.twitter.finagle.memcached.protocol.text.server.ResponseToBuf import com.twitter.io.Buf import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks @@ -8,21 +11,29 @@ import org.scalatest.funsuite.AnyFunSuite class ResponseToBufTest extends AnyFunSuite with ScalaCheckDrivenPropertyChecks { - test("encode errors - ERROR") { + test("ERROR") { val error = Error(new NonexistentCommand("No such command")) val res = ResponseToBuf.encode(error) - assert(res == Buf.Utf8("ERROR \r\n")) + assert(res == Buf.Utf8("ERROR\r\n")) } - test("encode errors - CLIENT_ERROR") { - val error = Error(new ClientError("Invalid Input")) - val res = ResponseToBuf.encode(error) - assert(res == Buf.Utf8("CLIENT_ERROR Invalid Input \r\n")) + test("CLIENT_ERROR") { + val errorNoTrailingWhitespace = Error(new ClientError("Invalid Input")) + assert( + ResponseToBuf.encode(errorNoTrailingWhitespace) == Buf.Utf8("CLIENT_ERROR Invalid Input\r\n")) + + val errorTrailingWhitespace = Error(new ClientError("Invalid Input ")) + assert( + ResponseToBuf.encode(errorTrailingWhitespace) == Buf.Utf8("CLIENT_ERROR Invalid Input \r\n")) } - test("encode errors - SERVER_ERROR") { - val error = Error(new ServerError("Out of Memory")) - val res = ResponseToBuf.encode(error) - assert(res == Buf.Utf8("SERVER_ERROR Out of Memory \r\n")) + test("SERVER_ERROR") { + val errorNoTrailingWhitespace = Error(new ServerError("Out of Memory")) + assert( + ResponseToBuf.encode(errorNoTrailingWhitespace) == Buf.Utf8("SERVER_ERROR Out of Memory\r\n")) + + val errorTrailingWhitespace = Error(new ServerError("Out of Memory ")) + assert( + ResponseToBuf.encode(errorTrailingWhitespace) == Buf.Utf8("SERVER_ERROR Out of Memory \r\n")) } }