diff --git a/src/cascadia/QueryExtension/AzureLLMProvider.cpp b/src/cascadia/QueryExtension/AzureLLMProvider.cpp index 5a6a2fd55c9..bc0cabb5e53 100644 --- a/src/cascadia/QueryExtension/AzureLLMProvider.cpp +++ b/src/cascadia/QueryExtension/AzureLLMProvider.cpp @@ -79,6 +79,14 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation AzureLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt) { + auto cancelation_token{ co_await winrt::get_cancellation_token() }; + cancelation_token.callback([=] { + if (_lastRequest) + { + _lastRequest.Cancel(); + } + }); + // Use the ErrorTypes enum to flag whether the response the user receives is an error message // we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event) ErrorTypes errorType{ ErrorTypes::None }; @@ -145,7 +153,9 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation // Send the request try { - const auto response = _httpClient.SendRequestAsync(request).get(); + const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + const auto response{ co_await sendRequestOperation }; + _lastRequest = sendRequestOperation; // Parse out the suggestion from the response const auto string{ response.Content().ReadAsStringAsync().get() }; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; diff --git a/src/cascadia/QueryExtension/AzureLLMProvider.h b/src/cascadia/QueryExtension/AzureLLMProvider.h index 1899bb93099..99139769f82 100644 --- a/src/cascadia/QueryExtension/AzureLLMProvider.h +++ b/src/cascadia/QueryExtension/AzureLLMProvider.h @@ -39,6 +39,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::hstring _azureKey; winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr }; IBrandingData _brandingData{ winrt::make() }; + winrt::Windows::Foundation::IAsyncOperationWithProgress _lastRequest{ nullptr }; Extension::IContext _context; diff --git a/src/cascadia/QueryExtension/ExtensionPalette.cpp b/src/cascadia/QueryExtension/ExtensionPalette.cpp index 52922ffb7b9..77b80e9b142 100644 --- a/src/cascadia/QueryExtension/ExtensionPalette.cpp +++ b/src/cascadia/QueryExtension/ExtensionPalette.cpp @@ -166,6 +166,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation } else { + asyncOperation.Cancel(); result = winrt::make(RS_(L"UnknownErrorMessage"), ErrorTypes::Unknown, winrt::hstring{}); } } diff --git a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp index b72ceba2b2a..bfaa4f29f6b 100644 --- a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp +++ b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp @@ -237,6 +237,14 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation GithubCopilotLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt) { + auto cancelation_token{ co_await winrt::get_cancellation_token() }; + cancelation_token.callback([=] { + if (_lastRequest) + { + _lastRequest.Cancel(); + } + }); + // Use the ErrorTypes enum to flag whether the response the user receives is an error message // we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event) ErrorTypes errorType{ ErrorTypes::None }; @@ -360,7 +368,9 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation WWH::HttpRequestMessage request{ method, Uri{ uri } }; request.Content(content); - const auto response{ co_await _httpClient.SendRequestAsync(request) }; + const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + const auto response{ co_await sendRequestOperation }; + _lastRequest = sendRequestOperation; const auto string{ co_await response.Content().ReadAsStringAsync() }; _lastResponse = string; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; diff --git a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h index 98f69cd6fcc..d711607c131 100644 --- a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h +++ b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h @@ -51,6 +51,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr }; IBrandingData _brandingData{ winrt::make() }; winrt::hstring _lastResponse; + winrt::Windows::Foundation::IAsyncOperationWithProgress _lastRequest{ nullptr }; Extension::IContext _context; diff --git a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp index a8184f72593..e7e25d26333 100644 --- a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp +++ b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp @@ -62,6 +62,14 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation OpenAILLMProvider::GetResponseAsync(const winrt::hstring userPrompt) { + auto cancelation_token{ co_await winrt::get_cancellation_token() }; + cancelation_token.callback([=] { + if (_lastRequest) + { + _lastRequest.Cancel(); + } + }); + // Use the ErrorTypes enum to flag whether the response the user receives is an error message // we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event) ErrorTypes errorType{ ErrorTypes::None }; @@ -100,7 +108,9 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation // Send the request try { - const auto response = co_await _httpClient.SendRequestAsync(request); + const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + const auto response{ co_await sendRequestOperation }; + _lastRequest = sendRequestOperation; // Parse out the suggestion from the response const auto string{ co_await response.Content().ReadAsStringAsync() }; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; diff --git a/src/cascadia/QueryExtension/OpenAILLMProvider.h b/src/cascadia/QueryExtension/OpenAILLMProvider.h index c1f489d310c..5f4f770e97b 100644 --- a/src/cascadia/QueryExtension/OpenAILLMProvider.h +++ b/src/cascadia/QueryExtension/OpenAILLMProvider.h @@ -38,6 +38,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::hstring _AIKey; winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr }; IBrandingData _brandingData{ winrt::make() }; + winrt::Windows::Foundation::IAsyncOperationWithProgress _lastRequest{ nullptr }; Extension::IContext _context;