Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TaskResource fixes #24296

Merged
merged 6 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.Duration;

import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;

import static com.google.common.util.concurrent.Futures.catching;
import static com.google.common.util.concurrent.Futures.withTimeout;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

public class AsyncResponseUtils
{
private AsyncResponseUtils() {}

public static <V> ListenableFuture<V> withFallbackAfterTimeout(ListenableFuture<V> future, Duration timeout, Supplier<V> fallback, Executor responseExecutor, ScheduledExecutorService timeoutExecutor)
public static <V> ListenableFuture<V> withFallbackAfterTimeout(ListenableFuture<V> future, Duration timeout, Supplier<V> fallback, ScheduledExecutorService timeoutExecutor)
{
return catching(withTimeout(future, timeout.toMillis(), MILLISECONDS, timeoutExecutor), TimeoutException.class, _ -> fallback.get(), responseExecutor);
return catching(withTimeout(future, timeout.toMillis(), MILLISECONDS, timeoutExecutor), TimeoutException.class, _ -> fallback.get(), directExecutor());
}
}
22 changes: 17 additions & 5 deletions core/trino-main/src/main/java/io/trino/server/TaskResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
import java.util.concurrent.ThreadLocalRandom;

import static com.google.common.collect.Iterables.transform;
import static com.google.common.util.concurrent.Futures.withTimeout;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.concurrent.MoreFutures.addTimeout;
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
Expand All @@ -83,6 +82,7 @@
import static io.trino.server.InternalHeaders.TRINO_TASK_FAILED;
import static io.trino.server.InternalHeaders.TRINO_TASK_INSTANCE_ID;
import static io.trino.server.security.ResourceSecurity.AccessType.INTERNAL_ONLY;
import static jakarta.ws.rs.core.Response.status;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
Expand Down Expand Up @@ -221,9 +221,11 @@ public void getTaskInfo(
futureTaskInfo = Futures.transform(futureTaskInfo, TaskInfo::summarize, directExecutor());
}

ListenableFuture<Response> response = Futures.transform(futureTaskInfo, taskInfo ->
Response.ok(taskInfo).build(), directExecutor());
// For hard timeout, add an additional time to max wait for thread scheduling contention and GC
Duration timeout = new Duration(waitTime.toMillis() + ADDITIONAL_WAIT_TIME.toMillis(), MILLISECONDS);
bindAsyncResponse(asyncResponse, withTimeout(futureTaskInfo, timeout.toMillis(), MILLISECONDS, timeoutExecutor), responseExecutor);
bindAsyncResponse(asyncResponse, withFallbackAfterTimeout(response, timeout, () -> serviceUnavailable(timeout), timeoutExecutor), responseExecutor);
}

@ResourceSecurity(INTERNAL_ONLY)
Expand Down Expand Up @@ -266,7 +268,9 @@ public void getTaskStatus(

// For hard timeout, add an additional time to max wait for thread scheduling contention and GC
Duration timeout = new Duration(waitTime.toMillis() + ADDITIONAL_WAIT_TIME.toMillis(), MILLISECONDS);
bindAsyncResponse(asyncResponse, withTimeout(futureTaskStatus, timeout.toMillis(), MILLISECONDS, timeoutExecutor), responseExecutor);

ListenableFuture<Response> response = Futures.transform(futureTaskStatus, taskStatus -> Response.ok(taskStatus).build(), directExecutor());
bindAsyncResponse(asyncResponse, withFallbackAfterTimeout(response, timeout, () -> serviceUnavailable(timeout), timeoutExecutor), responseExecutor);
}

@ResourceSecurity(INTERNAL_ONLY)
Expand Down Expand Up @@ -367,14 +371,14 @@ public void getResults(
// For hard timeout, add an additional time to max wait for thread scheduling contention and GC
Duration timeout = new Duration(waitTime.toMillis() + ADDITIONAL_WAIT_TIME.toMillis(), MILLISECONDS);
bindAsyncResponse(asyncResponse,
withFallbackAfterTimeout(responseFuture, timeout, () -> createBufferResultResponse(pagesInputStreamFactory, taskWithResults, emptyBufferResults), responseExecutor, timeoutExecutor), responseExecutor);
withFallbackAfterTimeout(responseFuture, timeout, () -> createBufferResultResponse(pagesInputStreamFactory, taskWithResults, emptyBufferResults), timeoutExecutor), responseExecutor);
responseFuture.addListener(() -> readFromOutputBufferTime.add(Duration.nanosSince(start)), directExecutor());
}

@ResourceSecurity(INTERNAL_ONLY)
@GET
@Path("{taskId}/results/{bufferId}/{token}/acknowledge")
public void acknowledgeResults(
public Response acknowledgeResults(
@PathParam("taskId") TaskId taskId,
@PathParam("bufferId") PipelinedOutputBuffers.OutputBufferId bufferId,
@PathParam("token") long token)
Expand All @@ -383,6 +387,7 @@ public void acknowledgeResults(
requireNonNull(bufferId, "bufferId is null");

taskManager.acknowledgeTaskResults(taskId, bufferId, token);
return Response.ok().build();
}

@ResourceSecurity(INTERNAL_ONLY)
Expand Down Expand Up @@ -561,4 +566,11 @@ private static Response createBufferResultResponse(PagesInputStreamFactory pages
pagesInputStreamFactory.write(output, serializedPages))
.build();
}

private static Response serviceUnavailable(Duration timeout)
{
return status(Response.Status.SERVICE_UNAVAILABLE)
.entity("Timed out after waiting for " + timeout.convertToMostSuccinctTimeUnit())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.base.Throwables;
import com.google.inject.Inject;
import io.airlift.jaxrs.ParsingException;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.ForbiddenException;
import jakarta.ws.rs.InternalServerErrorException;
Expand Down Expand Up @@ -83,6 +84,9 @@ public Response toResponse(Throwable throwable)
case TimeoutException timeoutException -> plainTextError(Response.Status.REQUEST_TIMEOUT)
.entity("Error 408 Timeout: " + timeoutException.getMessage())
.build();
case ParsingException parsingException -> Response.status(Response.Status.BAD_REQUEST)
.entity(Throwables.getStackTraceAsString(parsingException))
.build();
case WebApplicationException webApplicationException -> webApplicationException.getResponse();
default -> {
ResponseBuilder responseBuilder = plainTextError(Response.Status.INTERNAL_SERVER_ERROR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ public void fail(Throwable throwable)
public synchronized void dispose()
{
exchangeDataSource.close();
lastResult = null;
}

public QueryId getQueryId()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.concurrent.ScheduledExecutorService;

import static com.google.common.util.concurrent.Futures.transform;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
import static io.trino.server.AsyncResponseUtils.withFallbackAfterTimeout;
import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC;
Expand Down Expand Up @@ -97,8 +98,8 @@ public void getAuthenticationToken(@PathParam("authId") UUID authId, @Suspended
// hang if the client retries the request. The response will timeout eventually.
ListenableFuture<TokenPoll> tokenFuture = tokenExchange.getTokenPoll(authId);
ListenableFuture<Response> responseFuture = withFallbackAfterTimeout(
transform(tokenFuture, OAuth2TokenExchangeResource::toResponse, responseExecutor),
MAX_POLL_TIME, () -> pendingResponse(request), responseExecutor, timeoutExecutor);
transform(tokenFuture, OAuth2TokenExchangeResource::toResponse, directExecutor()),
MAX_POLL_TIME, () -> pendingResponse(request), timeoutExecutor);
bindAsyncResponse(asyncResponse, responseFuture, responseExecutor);
}

Expand Down
Loading