Skip to content

Commit aa98754

Browse files
authored
.Net: Add activities to MistralClient (#6297)
Replicates the ModelDiagnostics stuff to the MistralAI chat completion service implementation. I still need to test it. Best I can say now is it compiles :) cc: @markwallace-microsoft, @TaoChenOSU
1 parent a136cd4 commit aa98754

File tree

5 files changed

+136
-55
lines changed

5 files changed

+136
-55
lines changed

dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ public async Task<IReadOnlyList<ChatMessageContent>> GenerateChatMessageAsync(
175175
.ConfigureAwait(false);
176176
chatResponses = this.ProcessChatResponse(geminiResponse);
177177
}
178-
catch (Exception ex)
178+
catch (Exception ex) when (activity is not null)
179179
{
180-
activity?.SetError(ex);
180+
activity.SetError(ex);
181181
throw;
182182
}
183183

@@ -259,9 +259,9 @@ public async IAsyncEnumerable<StreamingChatMessageContent> StreamGenerateChatMes
259259
break;
260260
}
261261
}
262-
catch (Exception ex)
262+
catch (Exception ex) when (activity is not null)
263263
{
264-
activity?.SetError(ex);
264+
activity.SetError(ex);
265265
throw;
266266
}
267267

dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceClient.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ public async Task<IReadOnlyList<TextContent>> GenerateTextAsync(
147147

148148
response = DeserializeResponse<TextGenerationResponse>(body);
149149
}
150-
catch (Exception ex)
150+
catch (Exception ex) when (activity is not null)
151151
{
152-
activity?.SetError(ex);
152+
activity.SetError(ex);
153153
throw;
154154
}
155155

@@ -204,9 +204,9 @@ public async IAsyncEnumerable<StreamingTextContent> StreamGenerateTextAsync(
204204
break;
205205
}
206206
}
207-
catch (Exception ex)
207+
catch (Exception ex) when (activity is not null)
208208
{
209-
activity?.SetError(ex);
209+
activity.SetError(ex);
210210
throw;
211211
}
212212

dotnet/src/Connectors/Connectors.HuggingFace/Core/HuggingFaceMessageApiClient.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> StreamCompleteChatM
120120
break;
121121
}
122122
}
123-
catch (Exception ex)
123+
catch (Exception ex) when (activity is not null)
124124
{
125-
activity?.SetError(ex);
125+
activity.SetError(ex);
126126
throw;
127127
}
128128

@@ -162,9 +162,9 @@ internal async Task<IReadOnlyList<ChatMessageContent>> CompleteChatMessageAsync(
162162

163163
response = HuggingFaceClient.DeserializeResponse<ChatCompletionResponse>(body);
164164
}
165-
catch (Exception ex)
165+
catch (Exception ex) when (activity is not null)
166166
{
167-
activity?.SetError(ex);
167+
activity.SetError(ex);
168168
throw;
169169
}
170170

dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs

Lines changed: 110 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using Microsoft.Extensions.Logging;
1616
using Microsoft.Extensions.Logging.Abstractions;
1717
using Microsoft.SemanticKernel.ChatCompletion;
18+
using Microsoft.SemanticKernel.Diagnostics;
1819
using Microsoft.SemanticKernel.Http;
1920
using Microsoft.SemanticKernel.Text;
2021

@@ -25,6 +26,8 @@ namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client;
2526
/// </summary>
2627
internal sealed class MistralClient
2728
{
29+
private const string ModelProvider = "mistralai";
30+
2831
internal MistralClient(
2932
string modelId,
3033
HttpClient httpClient,
@@ -56,18 +59,56 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
5659

5760
for (int requestIndex = 1; ; requestIndex++)
5861
{
59-
using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: false);
60-
var responseData = await this.SendRequestAsync<ChatCompletionResponse>(httpRequestMessage, cancellationToken).ConfigureAwait(false);
61-
if (responseData is null || responseData.Choices is null || responseData.Choices.Count == 0)
62+
ChatCompletionResponse? responseData = null;
63+
List<ChatMessageContent> responseContent;
64+
using (var activity = ModelDiagnostics.StartCompletionActivity(this._endpoint, this._modelId, ModelProvider, chatHistory, mistralExecutionSettings))
6265
{
63-
throw new KernelException("Chat completions not found");
66+
try
67+
{
68+
using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: false);
69+
responseData = await this.SendRequestAsync<ChatCompletionResponse>(httpRequestMessage, cancellationToken).ConfigureAwait(false);
70+
if (responseData is null || responseData.Choices is null || responseData.Choices.Count == 0)
71+
{
72+
throw new KernelException("Chat completions not found");
73+
}
74+
}
75+
catch (Exception ex) when (activity is not null)
76+
{
77+
activity.SetError(ex);
78+
79+
// Capture available metadata even if the operation failed.
80+
if (responseData is not null)
81+
{
82+
if (responseData.Id is string id)
83+
{
84+
activity.SetResponseId(id);
85+
}
86+
87+
if (responseData.Usage is MistralUsage usage)
88+
{
89+
if (usage.PromptTokens is int promptTokens)
90+
{
91+
activity.SetPromptTokenUsage(promptTokens);
92+
}
93+
if (usage.CompletionTokens is int completionTokens)
94+
{
95+
activity.SetCompletionTokenUsage(completionTokens);
96+
}
97+
}
98+
}
99+
100+
throw;
101+
}
102+
103+
responseContent = this.ToChatMessageContent(modelId, responseData);
104+
activity?.SetCompletionResponse(responseContent, responseData.Usage?.PromptTokens, responseData.Usage?.CompletionTokens);
64105
}
65106

66107
// If we don't want to attempt to invoke any functions, just return the result.
67108
// Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail.
68109
if (!autoInvoke || responseData.Choices.Count != 1)
69110
{
70-
return this.ToChatMessageContent(modelId, responseData);
111+
return responseContent;
71112
}
72113

73114
// Get our single result and extract the function call information. If this isn't a function call, or if it is
@@ -78,7 +119,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
78119
MistralChatChoice chatChoice = responseData.Choices[0]; // TODO Handle multiple choices
79120
if (!chatChoice.IsToolCall)
80121
{
81-
return this.ToChatMessageContent(modelId, responseData);
122+
return responseContent;
82123
}
83124

84125
if (this._logger.IsEnabled(LogLevel.Debug))
@@ -237,35 +278,75 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes
237278
toolCalls?.Clear();
238279

239280
// Stream the responses
240-
var response = this.StreamChatMessageContentsAsync(chatHistory, mistralExecutionSettings, chatRequest, modelId, cancellationToken);
241-
string? streamedRole = null;
242-
await foreach (var update in response.ConfigureAwait(false))
281+
using (var activity = ModelDiagnostics.StartCompletionActivity(this._endpoint, this._modelId, ModelProvider, chatHistory, mistralExecutionSettings))
243282
{
244-
// If we're intending to invoke function calls, we need to consume that function call information.
245-
if (autoInvoke)
283+
// Make the request.
284+
IAsyncEnumerable<StreamingChatMessageContent> response;
285+
try
246286
{
247-
if (update.InnerContent is not MistralChatCompletionChunk completionChunk || completionChunk.Choices is null || completionChunk.Choices?.Count == 0)
248-
{
249-
continue;
250-
}
287+
response = this.StreamChatMessageContentsAsync(chatHistory, mistralExecutionSettings, chatRequest, modelId, cancellationToken);
288+
}
289+
catch (Exception e) when (activity is not null)
290+
{
291+
activity.SetError(e);
292+
throw;
293+
}
251294

252-
MistralChatCompletionChoice chatChoice = completionChunk!.Choices![0]; // TODO Handle multiple choices
253-
streamedRole ??= chatChoice.Delta!.Role;
254-
if (chatChoice.IsToolCall)
295+
var responseEnumerator = response.ConfigureAwait(false).GetAsyncEnumerator();
296+
List<StreamingKernelContent>? streamedContents = activity is not null ? [] : null;
297+
string? streamedRole = null;
298+
try
299+
{
300+
while (true)
255301
{
256-
// Create a copy of the tool calls to avoid modifying the original list
257-
toolCalls = new List<MistralToolCall>(chatChoice.ToolCalls!);
258-
259-
// Add the original assistant message to the chatRequest; this is required for the service
260-
// to understand the tool call responses. Also add the result message to the caller's chat
261-
// history: if they don't want it, they can remove it, but this makes the data available,
262-
// including metadata like usage.
263-
chatRequest.AddMessage(new MistralChatMessage(streamedRole, completionChunk.GetContent(0)) { ToolCalls = chatChoice.ToolCalls });
264-
chatHistory.Add(this.ToChatMessageContent(modelId, streamedRole!, completionChunk, chatChoice));
302+
try
303+
{
304+
if (!await responseEnumerator.MoveNextAsync())
305+
{
306+
break;
307+
}
308+
}
309+
catch (Exception ex) when (activity is not null)
310+
{
311+
activity.SetError(ex);
312+
throw;
313+
}
314+
315+
StreamingChatMessageContent update = responseEnumerator.Current;
316+
317+
// If we're intending to invoke function calls, we need to consume that function call information.
318+
if (autoInvoke)
319+
{
320+
if (update.InnerContent is not MistralChatCompletionChunk completionChunk || completionChunk.Choices is null || completionChunk.Choices?.Count == 0)
321+
{
322+
continue;
323+
}
324+
325+
MistralChatCompletionChoice chatChoice = completionChunk!.Choices![0]; // TODO Handle multiple choices
326+
streamedRole ??= chatChoice.Delta!.Role;
327+
if (chatChoice.IsToolCall)
328+
{
329+
// Create a copy of the tool calls to avoid modifying the original list
330+
toolCalls = new List<MistralToolCall>(chatChoice.ToolCalls!);
331+
332+
// Add the original assistant message to the chatRequest; this is required for the service
333+
// to understand the tool call responses. Also add the result message to the caller's chat
334+
// history: if they don't want it, they can remove it, but this makes the data available,
335+
// including metadata like usage.
336+
chatRequest.AddMessage(new MistralChatMessage(streamedRole, completionChunk.GetContent(0)) { ToolCalls = chatChoice.ToolCalls });
337+
chatHistory.Add(this.ToChatMessageContent(modelId, streamedRole!, completionChunk, chatChoice));
338+
}
339+
}
340+
341+
streamedContents?.Add(update);
342+
yield return update;
265343
}
266344
}
267-
268-
yield return update;
345+
finally
346+
{
347+
activity?.EndStreaming(streamedContents);
348+
await responseEnumerator.DisposeAsync();
349+
}
269350
}
270351

271352
// If we don't have a function to invoke, we're done.

dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,13 @@ internal async Task<IReadOnlyList<TextContent>> GetTextResultsAsync(
148148
throw new KernelException("Text completions not found");
149149
}
150150
}
151-
catch (Exception ex)
151+
catch (Exception ex) when (activity is not null)
152152
{
153-
activity?.SetError(ex);
153+
activity.SetError(ex);
154154
if (responseData != null)
155155
{
156156
// Capture available metadata even if the operation failed.
157-
activity?
157+
activity
158158
.SetResponseId(responseData.Id)
159159
.SetPromptTokenUsage(responseData.Usage.PromptTokens)
160160
.SetCompletionTokenUsage(responseData.Usage.CompletionTokens);
@@ -190,9 +190,9 @@ internal async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAs
190190
{
191191
response = await RunRequestAsync(() => this.Client.GetCompletionsStreamingAsync(options, cancellationToken)).ConfigureAwait(false);
192192
}
193-
catch (Exception ex)
193+
catch (Exception ex) when (activity is not null)
194194
{
195-
activity?.SetError(ex);
195+
activity.SetError(ex);
196196
throw;
197197
}
198198

@@ -209,9 +209,9 @@ internal async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAs
209209
break;
210210
}
211211
}
212-
catch (Exception ex)
212+
catch (Exception ex) when (activity is not null)
213213
{
214-
activity?.SetError(ex);
214+
activity.SetError(ex);
215215
throw;
216216
}
217217

@@ -402,13 +402,13 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
402402
throw new KernelException("Chat completions not found");
403403
}
404404
}
405-
catch (Exception ex)
405+
catch (Exception ex) when (activity is not null)
406406
{
407-
activity?.SetError(ex);
407+
activity.SetError(ex);
408408
if (responseData != null)
409409
{
410410
// Capture available metadata even if the operation failed.
411-
activity?
411+
activity
412412
.SetResponseId(responseData.Id)
413413
.SetPromptTokenUsage(responseData.Usage.PromptTokens)
414414
.SetCompletionTokenUsage(responseData.Usage.CompletionTokens);
@@ -671,9 +671,9 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
671671
{
672672
response = await RunRequestAsync(() => this.Client.GetChatCompletionsStreamingAsync(chatOptions, cancellationToken)).ConfigureAwait(false);
673673
}
674-
catch (Exception ex)
674+
catch (Exception ex) when (activity is not null)
675675
{
676-
activity?.SetError(ex);
676+
activity.SetError(ex);
677677
throw;
678678
}
679679

@@ -690,9 +690,9 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
690690
break;
691691
}
692692
}
693-
catch (Exception ex)
693+
catch (Exception ex) when (activity is not null)
694694
{
695-
activity?.SetError(ex);
695+
activity.SetError(ex);
696696
throw;
697697
}
698698

0 commit comments

Comments
 (0)