Skip to content

Commit

Permalink
lmproviders timeout as well
Browse files Browse the repository at this point in the history
  • Loading branch information
PankajBhojwani committed Nov 27, 2024
1 parent f101933 commit 733e123
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 49 deletions.
49 changes: 31 additions & 18 deletions src/cascadia/QueryExtension/AzureLLMProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,35 +146,48 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
try
{
const auto sendRequestOperation = _httpClient.SendRequestAsync(request);

// if the caller cancels this operation, make sure to cancel the http request as well
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});
const auto response{ co_await sendRequestOperation };
// Parse out the suggestion from the response
const auto string{ co_await response.Content().ReadAsStringAsync() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(errorString))
{
const auto errorObject = jsonResult.GetNamedObject(errorString);
message = errorObject.GetNamedString(messageString);
errorType = ErrorTypes::FromProvider;
}
else

if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
{
if (_verifyModelIsValidHelper(jsonResult))
// Parse out the suggestion from the response
const auto response = sendRequestOperation.GetResults();
const auto string{ co_await response.Content().ReadAsStringAsync() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(errorString))
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageString);
message = messageObject.GetNamedString(contentString);
const auto errorObject = jsonResult.GetNamedObject(errorString);
message = errorObject.GetNamedString(messageString);
errorType = ErrorTypes::FromProvider;
}
else
{
message = RS_(L"InvalidModelMessage");
errorType = ErrorTypes::InvalidModel;
if (_verifyModelIsValidHelper(jsonResult))
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageString);
message = messageObject.GetNamedString(contentString);
}
else
{
message = RS_(L"InvalidModelMessage");
errorType = ErrorTypes::InvalidModel;
}
}
}
else
{
// if the http request takes too long, cancel the http request and return an error
sendRequestOperation.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
}
}
catch (...)
{
Expand Down
2 changes: 1 addition & 1 deletion src/cascadia/QueryExtension/ExtensionPalette.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
if (_lmProvider)
{
const auto asyncOperation = _lmProvider.GetResponseAsync(promptCopy);
if (asyncOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
if (asyncOperation.wait_for(std::chrono::seconds(15)) == AsyncStatus::Completed)
{
result = asyncOperation.GetResults();
}
Expand Down
68 changes: 51 additions & 17 deletions src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation

// Make sure we are on the background thread for the http request
auto strongThis = get_strong();

co_await winrt::resume_background();
auto cancellationToken{ co_await winrt::get_cancellation_token() };

for (bool refreshAttempted = false;;)
{
Expand Down Expand Up @@ -276,24 +278,37 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
};

// Send the request
const auto jsonResultOperation = _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post());
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([jsonResultOperation] {
jsonResultOperation.Cancel();
const auto sendRequestOperation = _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post());

// if the caller cancels this operation, make sure to cancel the http request as well
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});
const auto jsonResult = co_await jsonResultOperation;
if (jsonResult.HasKey(errorKey))

if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
{
const auto errorObject = jsonResult.GetNamedObject(errorKey);
message = errorObject.GetNamedString(messageKey);
errorType = ErrorTypes::FromProvider;
// Parse out the suggestion from the response
const auto jsonResult = sendRequestOperation.GetResults();
if (jsonResult.HasKey(errorKey))
{
const auto errorObject = jsonResult.GetNamedObject(errorKey);
message = errorObject.GetNamedString(messageKey);
errorType = ErrorTypes::FromProvider;
}
else
{
const auto choices = jsonResult.GetNamedArray(L"ayy");

Check failure

Code scanning / check-spelling

Unrecognized Spelling Error

ayy is not a recognized word. (unrecognized-spelling)
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageKey);
message = messageObject.GetNamedString(contentKey);
}
}
else
{
const auto choices = jsonResult.GetNamedArray(choicesKey);
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageKey);
message = messageObject.GetNamedString(contentKey);
// if the http request takes too long, cancel the http request and return an error
sendRequestOperation.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
}
break;
}
Expand All @@ -310,8 +325,23 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
break;
}

co_await _refreshAuthTokens();
refreshAttempted = true;
const auto refreshTokensAction = _refreshAuthTokens();
cancellationToken.callback([refreshTokensAction] {
refreshTokensAction.Cancel();
});
// allow up to 10 seconds for reauthentication
if (refreshTokensAction.wait_for(std::chrono::seconds(10)) == AsyncStatus::Completed)
{
refreshAttempted = true;
}
else
{
// if the refresh action takes too long, cancel it and return an error
refreshTokensAction.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
break;
}
}

// Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far
Expand Down Expand Up @@ -339,7 +369,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation

try
{
const auto jsonResult = co_await _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post());
const auto reAuthOperation = _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post());
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([reAuthOperation] {
reAuthOperation.Cancel();
});
const auto jsonResult{ co_await reAuthOperation };

_authToken = jsonResult.GetNamedString(accessTokenKey);
_refreshToken = jsonResult.GetNamedString(refreshTokenKey);
Expand Down Expand Up @@ -371,7 +406,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
sendRequestOperation.Cancel();
});
const auto response{ co_await sendRequestOperation };
_lastRequest = sendRequestOperation;
const auto string{ co_await response.Content().ReadAsStringAsync() };
_lastResponse = string;
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
Expand Down
1 change: 0 additions & 1 deletion src/cascadia/QueryExtension/GithubCopilotLLMProvider.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr };
IBrandingData _brandingData{ winrt::make<GithubCopilotBranding>() };
winrt::hstring _lastResponse;
winrt::Windows::Foundation::IAsyncOperationWithProgress<winrt::Windows::Web::Http::HttpResponseMessage, winrt::Windows::Web::Http::HttpProgress> _lastRequest{ nullptr };

Extension::IContext _context;

Expand Down
37 changes: 25 additions & 12 deletions src/cascadia/QueryExtension/OpenAILLMProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,39 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
try
{
const auto sendRequestOperation = _httpClient.SendRequestAsync(request);

// if the caller cancels this operation, make sure to cancel the http request as well
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});
const auto response{ co_await sendRequestOperation };
// Parse out the suggestion from the response
const auto string{ co_await response.Content().ReadAsStringAsync() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(L"error"))

if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
{
const auto errorObject = jsonResult.GetNamedObject(L"error");
message = errorObject.GetNamedString(L"message");
errorType = ErrorTypes::FromProvider;
// Parse out the suggestion from the response
const auto response = sendRequestOperation.GetResults();
const auto string{ co_await response.Content().ReadAsStringAsync() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(L"error"))
{
const auto errorObject = jsonResult.GetNamedObject(L"error");
message = errorObject.GetNamedString(L"message");
errorType = ErrorTypes::FromProvider;
}
else
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(L"message");
message = messageObject.GetNamedString(L"content");
}
}
else
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(L"message");
message = messageObject.GetNamedString(L"content");
// if the http request takes too long, cancel the http request and return an error
sendRequestOperation.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
}
}
catch (...)
Expand Down

0 comments on commit 733e123

Please sign in to comment.