Skip to content

Commit

Permalink
Tweak MessagingService Shutdown and Add Internode Connect Timeout (#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpear authored Dec 4, 2024
1 parent ad819da commit 27f0c7c
Show file tree
Hide file tree
Showing 6 changed files with 375 additions and 11 deletions.
2 changes: 2 additions & 0 deletions src/java/org/apache/cassandra/config/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public class Config

public long native_transport_idle_timeout_in_ms = 0L;

public volatile Long internode_connect_timeout_in_ms = 1000L;

public volatile Long request_timeout_in_ms = 10000L;

public volatile Long read_request_timeout_in_ms = 5000L;
Expand Down
11 changes: 11 additions & 0 deletions src/java/org/apache/cassandra/config/DatabaseDescriptor.java
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,17 @@ public static int getRpcListenBacklog()
return conf.rpc_listen_backlog;
}


public static long getInternodeConnectionTimeout()
{
return conf.internode_connect_timeout_in_ms;
}

public static long setInternodeConnectionTimeout(long timeoutInMillis)
{
return conf.internode_connect_timeout_in_ms = timeoutInMillis;
}

public static long getRpcTimeout()
{
return conf.request_timeout_in_ms;
Expand Down
11 changes: 6 additions & 5 deletions src/java/org/apache/cassandra/net/MessagingService.java
Original file line number Diff line number Diff line change
Expand Up @@ -766,24 +766,25 @@ public void shutdown()
// We may need to schedule hints on the mutation stage, so it's erroneous to shut down the mutation stage first
assert !StageManager.getStage(Stage.MUTATION).isShutdown();

// the important part
if (!callbacks.shutdownBlocking())
logger.warn("Failed to wait for messaging service callbacks shutdown");

// attempt to humor tests that try to stop and restart MS
try
{
clearMessageSinks();
for (SocketThread th : socketThreads)
for (SocketThread th : socketThreads) {
try
{
// Close incoming connections
th.close();
}
catch (IOException e)
{
// see https://issues.apache.org/jira/browse/CASSANDRA-10545
handleIOException(e);
}
}
// Wait to finish callbacks before closing outbound connections
if (!callbacks.shutdownBlocking())
logger.warn("Failed to wait for messaging service callbacks shutdown");

connectionManagers.values().forEach(OutboundTcpConnectionPool::close);
}
Expand Down
56 changes: 50 additions & 6 deletions src/java/org/apache/cassandra/net/OutboundTcpConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@
import java.util.zip.Checksum;

import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSocket;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.palantir.logsafe.SafeArg;
import net.jpountz.lz4.LZ4BlockOutputStream;
import net.jpountz.lz4.LZ4Compressor;
import net.jpountz.lz4.LZ4Factory;
Expand All @@ -65,6 +65,7 @@
import org.apache.cassandra.config.Config;
import org.apache.cassandra.config.DatabaseDescriptor;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Uninterruptibles;

Expand Down Expand Up @@ -216,8 +217,9 @@ public void run()
//The timestamp of the first message has already been provided to the coalescing strategy
//so skip logging it.
inner:
for (QueuedMessage qm : drainedMessages)
for (int i = 0; i < drainedMessages.size(); i++)
{
QueuedMessage qm = drainedMessages.get(i);
try
{
MessageOut<?> m = qm.message;
Expand All @@ -236,8 +238,10 @@ else if (socket != null || connect())
else
{
// clear out the queue, else gossip messages back up.
drainedMessages.clear();
backlog.clear();
int cleared = clearQueueWithFailureCallback(i, drainedMessages, drainedMessageSize, backlog);
logger.warn("Failed to connect to endpoint. Cleared backlog and invoked failure callbacks",
SafeArg.of("clearedMessages", cleared),
SafeArg.of("endpoint", poolReference.endPoint()));
currentMsgBufferCount = 0;
break inner;
}
Expand All @@ -255,6 +259,24 @@ else if (socket != null || connect())
}
}

@VisibleForTesting
int clearQueueWithFailureCallback(int currentMessage, List<QueuedMessage> bufferedMessages, int bufferSize, BlockingQueue<QueuedMessage> queue) {
bufferedMessages.stream().skip(currentMessage).forEach(this::invokeFailureCallback);
int initialCleared = bufferedMessages.size() - currentMessage;
bufferedMessages.clear();

int queueSize = queue.size();
int remaining = queueSize;
while (remaining > 0) {
remaining -= queue.drainTo(bufferedMessages, Math.min(bufferSize, remaining));
for (QueuedMessage qm : bufferedMessages) {
invokeFailureCallback(qm);
}
bufferedMessages.clear();
}
return initialCleared + queueSize;
}

public int getPendingMessages()
{
return backlog.size() + currentMsgBufferCount;
Expand Down Expand Up @@ -330,6 +352,8 @@ private void writeConnected(QueuedMessage qm, boolean flush)
{
throw new AssertionError(e1);
}
} else {
invokeFailureCallback(qm);
}
}
else
Expand All @@ -340,6 +364,25 @@ private void writeConnected(QueuedMessage qm, boolean flush)
}
}

@VisibleForTesting
void invokeFailureCallback(QueuedMessage qm) {
if (qm == null) {
return;
}
CallbackInfo registeredCallbackInfo = MessagingService.instance().getRegisteredCallback(qm.id);
if (registeredCallbackInfo != null && registeredCallbackInfo.isFailureCallback()) {
Optional.ofNullable(MessagingService.instance().removeRegisteredCallback(qm.id))
.map(info -> info.callback)
.map(callback -> (IAsyncCallbackWithFailure) callback)
.ifPresent(callback -> {
logger.debug("Invoking failure callback for message",
SafeArg.of("endpoint", poolReference.endPoint()),
SafeArg.of("messageId", qm.id), SafeArg.of("verb", qm.message.verb));
callback.onFailure(poolReference.endPoint());
});
}
}

private void writeInternal(MessageOut message, int id, long timestamp) throws IOException
{
out.writeInt(MessagingService.PROTOCOL_MAGIC);
Expand Down Expand Up @@ -397,7 +440,7 @@ private boolean connect()
logger.trace("attempting to connect to {}", poolReference.endPoint());

long start = System.nanoTime();
long timeout = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getRpcTimeout());
long timeout = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getInternodeConnectionTimeout());
while (System.nanoTime() - start < timeout && !isStopped)
{
targetVersion = MessagingService.instance().getVersion(poolReference.endPoint());
Expand Down Expand Up @@ -584,12 +627,13 @@ private void expireMessages()
if (!qm.isTimedOut())
return;
iter.remove();
invokeFailureCallback(qm);
dropped.incrementAndGet();
}
}

/** messages that have not been retried yet */
private static class QueuedMessage implements Coalescable
static class QueuedMessage implements Coalescable
{
final MessageOut<?> message;
final int id;
Expand Down
132 changes: 132 additions & 0 deletions test/unit/org/apache/cassandra/net/MessagingServiceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,55 @@
*/
package org.apache.cassandra.net;

import java.net.InetAddress;
import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.List;

import com.google.common.collect.ImmutableList;
import org.junit.BeforeClass;
import org.junit.Test;

import org.apache.cassandra.SchemaLoader;
import org.apache.cassandra.Util;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.config.KSMetaData;
import org.apache.cassandra.db.ConsistencyLevel;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.Mutation;
import org.apache.cassandra.db.WriteType;
import org.apache.cassandra.exceptions.ConfigurationException;
import org.apache.cassandra.exceptions.WriteFailureException;
import org.apache.cassandra.exceptions.WriteTimeoutException;
import org.apache.cassandra.locator.SimpleStrategy;
import org.apache.cassandra.service.WriteResponseHandler;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.FBUtilities;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;

public class MessagingServiceTest
{
private static final String KEYSPACE1 = "MessagingServiceKeyspace";
private static final String CF_STANDARD1 = "columnfamily";
private final MessagingService messagingService = MessagingService.test();

@BeforeClass
public static void defineSchema() throws ConfigurationException
{
SchemaLoader.prepareServer();
SchemaLoader.createKeyspace(KEYSPACE1,
SimpleStrategy.class,
KSMetaData.optsWithRF(1),
SchemaLoader.standardCFMD(KEYSPACE1, CF_STANDARD1));
}

@Test
public void testDroppedMessages()
{
Expand All @@ -53,4 +92,97 @@ public void testDroppedMessages()
assertEquals("READ messages were dropped in last 5000 ms: 1250 for internal timeout and 1250 for cross node timeout", logs.get(0));
assertEquals(7500, (int)messagingService.getDroppedMessages().get(verb.toString()));
}

@Test
public void shutdown_refusesNewMessagesWhenInProgress() throws InterruptedException {
Keyspace keyspace = Keyspace.open(KEYSPACE1);
DecoratedKey dk = Util.dk("key1");

Mutation mutation = new Mutation(KEYSPACE1, dk.getKey());
mutation.add(CF_STANDARD1, Util.cellname("Column1"), ByteBufferUtil.bytes("asdf"), 0);
MessageOut<Mutation> message = mutation.createMessage();
List<MessagingService.SocketThread> incomingAcceptThreads;
InetAddress oldBroadcast = FBUtilities.getBroadcastAddress();
try {
messagingService.listen();
assertTrue(messagingService.isListening());
assertFalse(MessagingService.instance().isListening());
incomingAcceptThreads = messagingService.getSocketThreads();
assertTrue(incomingAcceptThreads.size() > 0);

// Spoof broadcast address to avoid WriteCallbackInfo assertion
InetAddress mockBroadcast = mock(InetAddress.class);
doReturn(oldBroadcast.getAddress()).when(mockBroadcast).getAddress();
FBUtilities.setBroadcastInetAddress(mockBroadcast);

TestHandler handler = createHandler(keyspace);
MessagingService.instance().sendRR(message, FBUtilities.getLocalAddress(), handler, false);
handler.get();
assertEquals(1, handler.success);
} finally {
messagingService.shutdown();
}
incomingAcceptThreads.forEach(thread -> assertFalse(thread.isAlive()));
DatabaseDescriptor.setWriteRpcTimeout(Duration.ofSeconds(1).toMillis());
DatabaseDescriptor.setInternodeConnectionTimeout(Duration.ofMillis(50).toMillis());

Instant start = Instant.now();
TestHandler handler2 = createHandler(keyspace);
MessagingService.instance().sendRR(message, FBUtilities.getLocalAddress(), handler2, false);
handler2.get();
Duration handler2Time = Duration.between(start, Instant.now());
start = Instant.now();
TestHandler handler3 = createHandler(keyspace);
MessagingService.instance().sendRR(message, FBUtilities.getLocalAddress(), handler3, false);
handler3.get();
Duration handler3Time = Duration.between(start, Instant.now());
TestHandler handler4 = createHandler(keyspace);
MessagingService.instance().sendRR(message, FBUtilities.getLocalAddress(), handler4, false);
handler4.get();
Duration handler4Time = Duration.between(start, Instant.now());

// handler2 may or may not fail vs timeout. Likely due to OS level buffering, flushing the socket that has now
// been closed on the receiving end may or may not throw an IOException
int failures = handler2.failures + handler3.failures + handler4.failures;
assertTrue(failures >= 2);

// Failures should fail in less time than the write timeout, as they hit connect timeout instead
assertTrue(handler3Time.minus(Duration.ofMillis(800)).isNegative());
assertTrue(handler4Time.minus(Duration.ofMillis(800)).isNegative());

FBUtilities.setBroadcastInetAddress(oldBroadcast);
}

private TestHandler createHandler(Keyspace ks) {
return new TestHandler(ImmutableList.of(FBUtilities.getLocalAddress()), ImmutableList.of(), ConsistencyLevel.ANY, ks, () -> {}, WriteType.SIMPLE);
}

static class TestHandler<T> extends WriteResponseHandler<T> {
public int success = 0;
public int failures = 0;
public int timeouts = 0;


public TestHandler(Collection<InetAddress> writeEndpoints, Collection<InetAddress> pendingEndpoints, ConsistencyLevel consistencyLevel, Keyspace keyspace, Runnable callback, WriteType writeType)
{
super(writeEndpoints, pendingEndpoints, consistencyLevel, keyspace, callback, writeType);
}

@Override
protected int totalBlockFor() {
return 1;
}

@Override
public void get() {
try {
super.get();
success++;
} catch (WriteTimeoutException e) {
timeouts++;
} catch (WriteFailureException e) {
failures++;
}
}
}
}
Loading

0 comments on commit 27f0c7c

Please sign in to comment.