Compare commits

...

7 Commits

Author SHA1 Message Date
Pankaj Bhojwani
2a90fb94f2 things 2024-11-26 16:23:29 -08:00
Pankaj Bhojwani
733e12362b lmproviders timeout as well 2024-11-26 16:17:49 -08:00
Pankaj Bhojwani
f1019334e6 same for capi 2024-11-25 16:21:59 -08:00
Pankaj Bhojwani
7affcb6f76 remove as member 2024-11-25 15:59:14 -08:00
Pankaj Bhojwani
2fece1350c camel 2024-11-22 12:00:29 -08:00
Pankaj Bhojwani
48debd9463 cancel the http request too 2024-11-22 11:18:15 -08:00
Pankaj Bhojwani
ddb864d14a add 5s timeout 2024-11-21 16:17:34 -08:00
6 changed files with 136 additions and 45 deletions

View File

@@ -145,31 +145,49 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
// Send the request
try
{
const auto response = _httpClient.SendRequestAsync(request).get();
// Parse out the suggestion from the response
const auto string{ response.Content().ReadAsStringAsync().get() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(errorString))
const auto sendRequestOperation = _httpClient.SendRequestAsync(request);
// if the caller cancels this operation, make sure to cancel the http request as well
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});
if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
{
const auto errorObject = jsonResult.GetNamedObject(errorString);
message = errorObject.GetNamedString(messageString);
errorType = ErrorTypes::FromProvider;
}
else
{
if (_verifyModelIsValidHelper(jsonResult))
// Parse out the suggestion from the response
const auto response = sendRequestOperation.GetResults();
const auto string{ co_await response.Content().ReadAsStringAsync() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(errorString))
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageString);
message = messageObject.GetNamedString(contentString);
const auto errorObject = jsonResult.GetNamedObject(errorString);
message = errorObject.GetNamedString(messageString);
errorType = ErrorTypes::FromProvider;
}
else
{
message = RS_(L"InvalidModelMessage");
errorType = ErrorTypes::InvalidModel;
if (_verifyModelIsValidHelper(jsonResult))
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageString);
message = messageObject.GetNamedString(contentString);
}
else
{
message = RS_(L"InvalidModelMessage");
errorType = ErrorTypes::InvalidModel;
}
}
}
else
{
// if the http request takes too long, cancel the http request and return an error
sendRequestOperation.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
}
}
catch (...)
{

View File

@@ -159,7 +159,16 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
if (_lmProvider)
{
result = _lmProvider.GetResponseAsync(promptCopy).get();
const auto asyncOperation = _lmProvider.GetResponseAsync(promptCopy);
if (asyncOperation.wait_for(std::chrono::seconds(15)) == AsyncStatus::Completed)
{
result = asyncOperation.GetResults();
}
else
{
asyncOperation.Cancel();
result = winrt::make<SystemResponse>(RS_(L"UnknownErrorMessage"), ErrorTypes::Unknown, winrt::hstring{});
}
}
else
{

View File

@@ -248,6 +248,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
// Make sure we are on the background thread for the http request
auto strongThis = get_strong();
co_await winrt::resume_background();
auto cancellationToken{ co_await winrt::get_cancellation_token() };
for (bool refreshAttempted = false;;)
{
@@ -276,19 +277,37 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
};
// Send the request
const auto jsonResult = co_await _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post());
if (jsonResult.HasKey(errorKey))
const auto sendRequestOperation = _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post());
// if the caller cancels this operation, make sure to cancel the http request as well
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});
if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
{
const auto errorObject = jsonResult.GetNamedObject(errorKey);
message = errorObject.GetNamedString(messageKey);
errorType = ErrorTypes::FromProvider;
// Parse out the suggestion from the response
const auto jsonResult = sendRequestOperation.GetResults();
if (jsonResult.HasKey(errorKey))
{
const auto errorObject = jsonResult.GetNamedObject(errorKey);
message = errorObject.GetNamedString(messageKey);
errorType = ErrorTypes::FromProvider;
}
else
{
const auto choices = jsonResult.GetNamedArray(choicesKey);
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageKey);
message = messageObject.GetNamedString(contentKey);
}
}
else
{
const auto choices = jsonResult.GetNamedArray(choicesKey);
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageKey);
message = messageObject.GetNamedString(contentKey);
// if the http request takes too long, cancel the http request and return an error
sendRequestOperation.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
}
break;
}
@@ -305,8 +324,23 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
break;
}
co_await _refreshAuthTokens();
refreshAttempted = true;
const auto refreshTokensAction = _refreshAuthTokens();
cancellationToken.callback([refreshTokensAction] {
refreshTokensAction.Cancel();
});
// allow up to 10 seconds for reauthentication
if (refreshTokensAction.wait_for(std::chrono::seconds(10)) == AsyncStatus::Completed)
{
refreshAttempted = true;
}
else
{
// if the refresh action takes too long, cancel it and return an error
refreshTokensAction.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
break;
}
}
// Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far
@@ -334,7 +368,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
try
{
const auto jsonResult = co_await _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post());
const auto reAuthOperation = _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post());
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([reAuthOperation] {
reAuthOperation.Cancel();
});
const auto jsonResult{ co_await reAuthOperation };
_authToken = jsonResult.GetNamedString(accessTokenKey);
_refreshToken = jsonResult.GetNamedString(refreshTokenKey);
@@ -360,7 +399,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
WWH::HttpRequestMessage request{ method, Uri{ uri } };
request.Content(content);
const auto response{ co_await _httpClient.SendRequestAsync(request) };
const auto sendRequestOperation = _httpClient.SendRequestAsync(request);
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});
const auto response{ co_await sendRequestOperation };
const auto string{ co_await response.Content().ReadAsStringAsync() };
_lastResponse = string;
const auto jsonResult{ WDJ::JsonObject::Parse(string) };

View File

@@ -100,22 +100,40 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
// Send the request
try
{
const auto response = co_await _httpClient.SendRequestAsync(request);
// Parse out the suggestion from the response
const auto string{ co_await response.Content().ReadAsStringAsync() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(L"error"))
const auto sendRequestOperation = _httpClient.SendRequestAsync(request);
// if the caller cancels this operation, make sure to cancel the http request as well
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});
if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
{
const auto errorObject = jsonResult.GetNamedObject(L"error");
message = errorObject.GetNamedString(L"message");
errorType = ErrorTypes::FromProvider;
// Parse out the suggestion from the response
const auto response = sendRequestOperation.GetResults();
const auto string{ co_await response.Content().ReadAsStringAsync() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(L"error"))
{
const auto errorObject = jsonResult.GetNamedObject(L"error");
message = errorObject.GetNamedString(L"message");
errorType = ErrorTypes::FromProvider;
}
else
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(L"message");
message = messageObject.GetNamedString(L"content");
}
}
else
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(L"message");
message = messageObject.GetNamedString(L"content");
// if the http request takes too long, cancel the http request and return an error
sendRequestOperation.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
}
}
catch (...)

View File

@@ -126,7 +126,7 @@
<comment>The message presented to the user when they attempt to use the AI chat feature without providing an AI endpoint and key.</comment>
</data>
<data name="UnknownErrorMessage" xml:space="preserve">
<value>An error occurred. Your AI provider might not be correctly configured, or the service might be temporarily unavailable.</value>
<value>An error occurred. The service might be temporarily unavailable or there might be network connectivity issues.</value>
<comment>The error message presented to the user when we were unable to query the provided endpoint.</comment>
</data>
<data name="InvalidModelMessage" xml:space="preserve">

View File

@@ -53,6 +53,8 @@ TRACELOGGING_DECLARE_PROVIDER(g_hQueryExtensionProvider);
#include <winrt/Windows.Data.Json.h>
#include <chrono>
// Manually include til after we include Windows.Foundation to give it winrt superpowers
#include "til.h"