diff --git a/src/java/org/apache/cassandra/config/Config.java b/src/java/org/apache/cassandra/config/Config.java index 09d0ca709b..a0aab00b69 100644 --- a/src/java/org/apache/cassandra/config/Config.java +++ b/src/java/org/apache/cassandra/config/Config.java @@ -88,6 +88,8 @@ public class Config public long native_transport_idle_timeout_in_ms = 0L; + public volatile Long internode_connect_timeout_in_ms = 1000L; + public volatile Long request_timeout_in_ms = 10000L; public volatile Long read_request_timeout_in_ms = 5000L; diff --git a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java index 88877457dc..250958f315 100644 --- a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java +++ b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java @@ -1077,6 +1077,17 @@ public static int getRpcListenBacklog() return conf.rpc_listen_backlog; } + + public static long getInternodeConnectionTimeout() + { + return conf.internode_connect_timeout_in_ms; + } + + public static long setInternodeConnectionTimeout(long timeoutInMillis) + { + return conf.internode_connect_timeout_in_ms = timeoutInMillis; + } + public static long getRpcTimeout() { return conf.request_timeout_in_ms; diff --git a/src/java/org/apache/cassandra/net/MessagingService.java b/src/java/org/apache/cassandra/net/MessagingService.java index c0ac495dad..f2fb3df728 100644 --- a/src/java/org/apache/cassandra/net/MessagingService.java +++ b/src/java/org/apache/cassandra/net/MessagingService.java @@ -766,17 +766,14 @@ public void shutdown() // We may need to schedule hints on the mutation stage, so it's erroneous to shut down the mutation stage first assert !StageManager.getStage(Stage.MUTATION).isShutdown(); - // the important part - if (!callbacks.shutdownBlocking()) - logger.warn("Failed to wait for messaging service callbacks shutdown"); - // attempt to humor tests that try to stop and restart MS try { clearMessageSinks(); - for (SocketThread th : socketThreads) + for (SocketThread th : socketThreads) { try { + // Close incoming connections th.close(); } catch (IOException e) @@ -784,6 +781,10 @@ public void shutdown() // see https://issues.apache.org/jira/browse/CASSANDRA-10545 handleIOException(e); } + } + // Wait to finish callbacks before closing outbound connections + if (!callbacks.shutdownBlocking()) + logger.warn("Failed to wait for messaging service callbacks shutdown"); connectionManagers.values().forEach(OutboundTcpConnectionPool::close); } diff --git a/src/java/org/apache/cassandra/net/OutboundTcpConnection.java b/src/java/org/apache/cassandra/net/OutboundTcpConnection.java index 828ef0435f..b79d0c5035 100644 --- a/src/java/org/apache/cassandra/net/OutboundTcpConnection.java +++ b/src/java/org/apache/cassandra/net/OutboundTcpConnection.java @@ -37,11 +37,11 @@ import java.util.zip.Checksum; import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.SSLSocket; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.palantir.logsafe.SafeArg; import net.jpountz.lz4.LZ4BlockOutputStream; import net.jpountz.lz4.LZ4Compressor; import net.jpountz.lz4.LZ4Factory; @@ -65,6 +65,7 @@ import org.apache.cassandra.config.Config; import org.apache.cassandra.config.DatabaseDescriptor; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.Uninterruptibles; @@ -216,8 +217,9 @@ public void run() //The timestamp of the first message has already been provided to the coalescing strategy //so skip logging it. inner: - for (QueuedMessage qm : drainedMessages) + for (int i = 0; i < drainedMessages.size(); i++) { + QueuedMessage qm = drainedMessages.get(i); try { MessageOut m = qm.message; @@ -236,8 +238,10 @@ else if (socket != null || connect()) else { // clear out the queue, else gossip messages back up. - drainedMessages.clear(); - backlog.clear(); + int cleared = clearQueueWithFailureCallback(i, drainedMessages, drainedMessageSize, backlog); + logger.warn("Failed to connect to endpoint. Cleared backlog and invoked failure callbacks", + SafeArg.of("clearedMessages", cleared), + SafeArg.of("endpoint", poolReference.endPoint())); currentMsgBufferCount = 0; break inner; } @@ -255,6 +259,24 @@ else if (socket != null || connect()) } } + @VisibleForTesting + int clearQueueWithFailureCallback(int currentMessage, List bufferedMessages, int bufferSize, BlockingQueue queue) { + bufferedMessages.stream().skip(currentMessage).forEach(this::invokeFailureCallback); + int initialCleared = bufferedMessages.size() - currentMessage; + bufferedMessages.clear(); + + int queueSize = queue.size(); + int remaining = queueSize; + while (remaining > 0) { + remaining -= queue.drainTo(bufferedMessages, Math.min(bufferSize, remaining)); + for (QueuedMessage qm : bufferedMessages) { + invokeFailureCallback(qm); + } + bufferedMessages.clear(); + } + return initialCleared + queueSize; + } + public int getPendingMessages() { return backlog.size() + currentMsgBufferCount; @@ -330,6 +352,8 @@ private void writeConnected(QueuedMessage qm, boolean flush) { throw new AssertionError(e1); } + } else { + invokeFailureCallback(qm); } } else @@ -340,6 +364,25 @@ private void writeConnected(QueuedMessage qm, boolean flush) } } + @VisibleForTesting + void invokeFailureCallback(QueuedMessage qm) { + if (qm == null) { + return; + } + CallbackInfo registeredCallbackInfo = MessagingService.instance().getRegisteredCallback(qm.id); + if (registeredCallbackInfo != null && registeredCallbackInfo.isFailureCallback()) { + Optional.ofNullable(MessagingService.instance().removeRegisteredCallback(qm.id)) + .map(info -> info.callback) + .map(callback -> (IAsyncCallbackWithFailure) callback) + .ifPresent(callback -> { + logger.debug("Invoking failure callback for message", + SafeArg.of("endpoint", poolReference.endPoint()), + SafeArg.of("messageId", qm.id), SafeArg.of("verb", qm.message.verb)); + callback.onFailure(poolReference.endPoint()); + }); + } + } + private void writeInternal(MessageOut message, int id, long timestamp) throws IOException { out.writeInt(MessagingService.PROTOCOL_MAGIC); @@ -397,7 +440,7 @@ private boolean connect() logger.trace("attempting to connect to {}", poolReference.endPoint()); long start = System.nanoTime(); - long timeout = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getRpcTimeout()); + long timeout = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getInternodeConnectionTimeout()); while (System.nanoTime() - start < timeout && !isStopped) { targetVersion = MessagingService.instance().getVersion(poolReference.endPoint()); @@ -584,12 +627,13 @@ private void expireMessages() if (!qm.isTimedOut()) return; iter.remove(); + invokeFailureCallback(qm); dropped.incrementAndGet(); } } /** messages that have not been retried yet */ - private static class QueuedMessage implements Coalescable + static class QueuedMessage implements Coalescable { final MessageOut message; final int id; diff --git a/test/unit/org/apache/cassandra/net/MessagingServiceTest.java b/test/unit/org/apache/cassandra/net/MessagingServiceTest.java index df602ef34b..f97fd50156 100644 --- a/test/unit/org/apache/cassandra/net/MessagingServiceTest.java +++ b/test/unit/org/apache/cassandra/net/MessagingServiceTest.java @@ -20,16 +20,55 @@ */ package org.apache.cassandra.net; +import java.net.InetAddress; +import java.time.Duration; +import java.time.Instant; +import java.util.Collection; import java.util.List; +import com.google.common.collect.ImmutableList; +import org.junit.BeforeClass; import org.junit.Test; +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.Util; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.config.KSMetaData; +import org.apache.cassandra.db.ConsistencyLevel; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.Mutation; +import org.apache.cassandra.db.WriteType; +import org.apache.cassandra.exceptions.ConfigurationException; +import org.apache.cassandra.exceptions.WriteFailureException; +import org.apache.cassandra.exceptions.WriteTimeoutException; +import org.apache.cassandra.locator.SimpleStrategy; +import org.apache.cassandra.service.WriteResponseHandler; +import org.apache.cassandra.utils.ByteBufferUtil; +import org.apache.cassandra.utils.FBUtilities; + import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; public class MessagingServiceTest { + private static final String KEYSPACE1 = "MessagingServiceKeyspace"; + private static final String CF_STANDARD1 = "columnfamily"; private final MessagingService messagingService = MessagingService.test(); + @BeforeClass + public static void defineSchema() throws ConfigurationException + { + SchemaLoader.prepareServer(); + SchemaLoader.createKeyspace(KEYSPACE1, + SimpleStrategy.class, + KSMetaData.optsWithRF(1), + SchemaLoader.standardCFMD(KEYSPACE1, CF_STANDARD1)); + } + @Test public void testDroppedMessages() { @@ -53,4 +92,97 @@ public void testDroppedMessages() assertEquals("READ messages were dropped in last 5000 ms: 1250 for internal timeout and 1250 for cross node timeout", logs.get(0)); assertEquals(7500, (int)messagingService.getDroppedMessages().get(verb.toString())); } + + @Test + public void shutdown_refusesNewMessagesWhenInProgress() throws InterruptedException { + Keyspace keyspace = Keyspace.open(KEYSPACE1); + DecoratedKey dk = Util.dk("key1"); + + Mutation mutation = new Mutation(KEYSPACE1, dk.getKey()); + mutation.add(CF_STANDARD1, Util.cellname("Column1"), ByteBufferUtil.bytes("asdf"), 0); + MessageOut message = mutation.createMessage(); + List incomingAcceptThreads; + InetAddress oldBroadcast = FBUtilities.getBroadcastAddress(); + try { + messagingService.listen(); + assertTrue(messagingService.isListening()); + assertFalse(MessagingService.instance().isListening()); + incomingAcceptThreads = messagingService.getSocketThreads(); + assertTrue(incomingAcceptThreads.size() > 0); + + // Spoof broadcast address to avoid WriteCallbackInfo assertion + InetAddress mockBroadcast = mock(InetAddress.class); + doReturn(oldBroadcast.getAddress()).when(mockBroadcast).getAddress(); + FBUtilities.setBroadcastInetAddress(mockBroadcast); + + TestHandler handler = createHandler(keyspace); + MessagingService.instance().sendRR(message, FBUtilities.getLocalAddress(), handler, false); + handler.get(); + assertEquals(1, handler.success); + } finally { + messagingService.shutdown(); + } + incomingAcceptThreads.forEach(thread -> assertFalse(thread.isAlive())); + DatabaseDescriptor.setWriteRpcTimeout(Duration.ofSeconds(1).toMillis()); + DatabaseDescriptor.setInternodeConnectionTimeout(Duration.ofMillis(50).toMillis()); + + Instant start = Instant.now(); + TestHandler handler2 = createHandler(keyspace); + MessagingService.instance().sendRR(message, FBUtilities.getLocalAddress(), handler2, false); + handler2.get(); + Duration handler2Time = Duration.between(start, Instant.now()); + start = Instant.now(); + TestHandler handler3 = createHandler(keyspace); + MessagingService.instance().sendRR(message, FBUtilities.getLocalAddress(), handler3, false); + handler3.get(); + Duration handler3Time = Duration.between(start, Instant.now()); + TestHandler handler4 = createHandler(keyspace); + MessagingService.instance().sendRR(message, FBUtilities.getLocalAddress(), handler4, false); + handler4.get(); + Duration handler4Time = Duration.between(start, Instant.now()); + + // handler2 may or may not fail vs timeout. Likely due to OS level buffering, flushing the socket that has now + // been closed on the receiving end may or may not throw an IOException + int failures = handler2.failures + handler3.failures + handler4.failures; + assertTrue(failures >= 2); + + // Failures should fail in less time than the write timeout, as they hit connect timeout instead + assertTrue(handler3Time.minus(Duration.ofMillis(800)).isNegative()); + assertTrue(handler4Time.minus(Duration.ofMillis(800)).isNegative()); + + FBUtilities.setBroadcastInetAddress(oldBroadcast); + } + + private TestHandler createHandler(Keyspace ks) { + return new TestHandler(ImmutableList.of(FBUtilities.getLocalAddress()), ImmutableList.of(), ConsistencyLevel.ANY, ks, () -> {}, WriteType.SIMPLE); + } + + static class TestHandler extends WriteResponseHandler { + public int success = 0; + public int failures = 0; + public int timeouts = 0; + + + public TestHandler(Collection writeEndpoints, Collection pendingEndpoints, ConsistencyLevel consistencyLevel, Keyspace keyspace, Runnable callback, WriteType writeType) + { + super(writeEndpoints, pendingEndpoints, consistencyLevel, keyspace, callback, writeType); + } + + @Override + protected int totalBlockFor() { + return 1; + } + + @Override + public void get() { + try { + super.get(); + success++; + } catch (WriteTimeoutException e) { + timeouts++; + } catch (WriteFailureException e) { + failures++; + } + } + } } diff --git a/test/unit/org/apache/cassandra/net/OutboundTcpConnectionTest.java b/test/unit/org/apache/cassandra/net/OutboundTcpConnectionTest.java new file mode 100644 index 0000000000..4d9871d3c1 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/OutboundTcpConnectionTest.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.io.IVersionedSerializer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class OutboundTcpConnectionTest +{ + private OutboundTcpConnection connection; + private static OutboundTcpConnectionPool pool; + private static final InetAddress TARGET = mock(InetAddress.class); + private static final OutboundTcpConnection.QueuedMessage QM1 = new OutboundTcpConnection.QueuedMessage(new MessageOut<>(MessagingService.Verb.MUTATION), 1); + private static final OutboundTcpConnection.QueuedMessage QM2 = new OutboundTcpConnection.QueuedMessage(new MessageOut<>(MessagingService.Verb.MUTATION), 2); + private static final OutboundTcpConnection.QueuedMessage QM3 = new OutboundTcpConnection.QueuedMessage(new MessageOut<>(MessagingService.Verb.MUTATION), 3); + + @BeforeClass + public static void beforeClass() throws UnknownHostException + { + pool = mock(OutboundTcpConnectionPool.class); + doReturn(InetAddress.getLocalHost()).when(pool).endPoint(); + } + + @Before + public void before() { + connection = spy(new OutboundTcpConnection(pool, "test")); + MessagingService.instance().clearCallbacksUnsafe(); + } + + @Test + public void invokeFailureCallback_ignoresNonFailureCallbacks() { + TestCallback cb = new TestCallback(); + CallbackInfo nonFailureCallback = new CallbackInfo(TARGET, cb, mock(IVersionedSerializer.class), false); + MessagingService.instance().setCallbackForTests(QM1.id, nonFailureCallback); + connection.invokeFailureCallback(QM1); + assertEquals(0, cb.responses.get()); + } + + @Test + public void invokeFailureCallback_handlesExpiredCallback() { + assertNull(MessagingService.instance().getRegisteredCallback(QM1.id)); + connection.invokeFailureCallback(QM1); + } + + @Test + public void invokeFailureCallback_runsCallback() { + TestFailureCallback cb = registerFailureCallback(QM1); + connection.invokeFailureCallback(QM1); + assertEquals(0, cb.responses.get()); + assertEquals(1, cb.failures.get()); + assertNull(MessagingService.instance().getRegisteredCallback(QM1.id)); + } + + @Test + public void clearQueueWithFailureCallback_handlesInProgressDrainedList() throws InterruptedException + { + List drained = new ArrayList<>(2); + drained.add(QM1); + drained.add(QM2); + BlockingQueue backlog = new LinkedBlockingQueue<>(); + backlog.put(QM3); + + TestFailureCallback cb1 = registerFailureCallback(QM1); + TestFailureCallback cb2 = registerFailureCallback(QM2); + TestFailureCallback cb3 = registerFailureCallback(QM3); + + connection.clearQueueWithFailureCallback(1, drained, 2, backlog); + + assertEquals(0, cb1.failures.get()); + assertEquals(1, cb2.failures.get()); + assertEquals(1, cb3.failures.get()); + + assertTrue(drained.isEmpty()); + assertTrue(backlog.isEmpty()); + } + + @Test + public void clearQueueWithFailureCallback_clearsLargeBacklog() throws InterruptedException + { + List drained = new ArrayList<>(2); + BlockingQueue backlog = spy(new LinkedBlockingQueue<>()); + backlog.put(QM1); + backlog.put(QM2); + backlog.put(QM3); + backlog.put(QM3); + backlog.put(QM3); + + TestFailureCallback cb1 = registerFailureCallback(QM1); + TestFailureCallback cb2 = registerFailureCallback(QM2); + TestFailureCallback cb3 = registerFailureCallback(QM3); + + connection.clearQueueWithFailureCallback(0, drained, 2, backlog); + // With enough elements remaining, drain the buffer size + verify(backlog, times(2)).drainTo(anyCollection(), eq(2)); + // Last call, don't take more off the backlog than needed from when we first called clearQueueWithFailureCallback + verify(backlog, times(1)).drainTo(anyCollection(), eq(1)); + + assertEquals(1, cb1.failures.get()); + assertEquals(1, cb2.failures.get()); + assertEquals(1, cb3.failures.get()); + + assertTrue(drained.isEmpty()); + assertTrue(backlog.isEmpty()); + } + + static class TestCallback implements IAsyncCallback + { + public final AtomicInteger responses = new AtomicInteger(0); + + public void response(MessageIn _msg) + { + responses.incrementAndGet(); + + } + + public boolean isLatencyForSnitch() + { + return false; + } + } + + static class TestFailureCallback extends TestCallback implements IAsyncCallbackWithFailure { + public final AtomicInteger failures = new AtomicInteger(0); + + public void onFailure(InetAddress from) + { + failures.incrementAndGet(); + } + } + + private TestFailureCallback registerFailureCallback(OutboundTcpConnection.QueuedMessage qm) { + TestFailureCallback cb = new TestFailureCallback(); + MessagingService.instance().setCallbackForTests(qm.id, new CallbackInfo(TARGET, cb, mock(IVersionedSerializer.class), true)); + return cb; + } +}