Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Add prompt execution settings to AutoFunctionInvocationContext #10551

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,65 @@ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming)
Assert.Equal(isStreaming, actualStreamingFlag);
}

[Fact]
public async Task PromptExecutionSettingsArePropagatedFromInvokePromptToFilterContextAsync()
{
// Arrange
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();

var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);

AutoFunctionInvocationContext? actualContext = null;

var kernel = this.GetKernelWithFilter(plugin, (context, next) =>
{
actualContext = context;
return Task.CompletedTask;
});

var expectedExecutionSettings = new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
};

// Act
var result = await kernel.InvokePromptAsync("Test prompt", new(expectedExecutionSettings));

// Assert
Assert.NotNull(actualContext);
Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings);
}

[Fact]
public async Task PromptExecutionSettingsArePropagatedFromInvokePromptStreamingToFilterContextAsync()
{
// Arrange
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();

var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);

AutoFunctionInvocationContext? actualContext = null;

var kernel = this.GetKernelWithFilter(plugin, (context, next) =>
{
actualContext = context;
return Task.CompletedTask;
});

var expectedExecutionSettings = new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
};

// Act
await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(expectedExecutionSettings)))
{ }

// Assert
Assert.NotNull(actualContext);
Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings);
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
// In such cases, we'll return the last message in the chat history.
var lastMessage = await this.FunctionCallsProcessor.ProcessFunctionCallsAsync(
chatMessageContent,
chatExecutionSettings,
chatHistory,
requestIndex,
(FunctionCallContent content) => IsRequestableTool(chatOptions.Tools, content),
Expand Down Expand Up @@ -384,6 +385,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// In such cases, we'll return the last message in the chat history.
var lastMessage = await this.FunctionCallsProcessor.ProcessFunctionCallsAsync(
chatMessageContent,
chatExecutionSettings,
chatHistory,
requestIndex,
(FunctionCallContent content) => IsRequestableTool(chatOptions.Tools, content),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public FunctionCallsProcessor(ILogger? logger = null)
/// Processes AI function calls by iterating over the function calls, invoking them and adding the results to the chat history.
/// </summary>
/// <param name="chatMessageContent">The chat message content representing AI model response and containing function calls.</param>
/// <param name="executionSettings">The prompt execution settings.</param>
/// <param name="chatHistory">The chat history to add function invocation results to.</param>
/// <param name="requestIndex">AI model function(s) call request sequence index.</param>
/// <param name="checkIfFunctionAdvertised">Callback to check if a function was advertised to AI model or not.</param>
Expand All @@ -129,6 +130,7 @@ public FunctionCallsProcessor(ILogger? logger = null)
/// <returns>Last chat history message if function invocation filter requested processing termination, otherwise null.</returns>
public async Task<ChatMessageContent?> ProcessFunctionCallsAsync(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be public (is already a internal class and utility) ?

Suggested change
public async Task<ChatMessageContent?> ProcessFunctionCallsAsync(
internal async Task<ChatMessageContent?> ProcessFunctionCallsAsync(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That actually doesn't matter so much since the class is internal. It would make sense to revise the method access modifier if and when we decide to make the class public.

ChatMessageContent chatMessageContent,
PromptExecutionSettings? executionSettings,
ChatHistory chatHistory,
int requestIndex,
Func<FunctionCallContent, bool> checkIfFunctionAdvertised,
Expand Down Expand Up @@ -177,7 +179,8 @@ public FunctionCallsProcessor(ILogger? logger = null)
FunctionCount = functionCalls.Length,
CancellationToken = cancellationToken,
IsStreaming = isStreaming,
ToolCallId = functionCall.Id
ToolCallId = functionCall.Id,
ExecutionSettings = executionSettings
};

s_inflightAutoInvokes.Value++;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Microsoft.SemanticKernel.ChatCompletion;

Expand Down Expand Up @@ -79,6 +80,12 @@ public AutoFunctionInvocationContext(
/// </summary>
public ChatMessageContent ChatMessageContent { get; }

/// <summary>
/// The execution settings associated with the operation.
/// </summary>
[Experimental("SKEXP0001")]
public PromptExecutionSettings? ExecutionSettings { get; init; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a separate property in AutoFunctionInvocationContext for execution settings? Can we re-use AutoFunctionInvocationContext.Arguments for such scenario? Since KernelArguments type already contains a collection of execution settings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dmytrostruk Not fully sure. Arguments today is only filled if PromptExecutionSettings was literally passed as an Argument right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RogerBarreto Yes, but that's the only way how to pass PromptExecutionSettings if you use kernel.InvokePromptAsync instead of IChatCompletionService.

Copy link
Member

@RogerBarreto RogerBarreto Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dmytrostruk I see, any changes to Arguments in this manner would change the behavior, could be a good thing anyways, but we would need to reconsider how Arguments will be reflected, because from this perspective if we reuse arguments for actual selected settings, would be also important to add knowledge of which of the arguments in the execution settings dictionary is the one being used by the function.


/// <summary>
/// Gets the <see cref="Microsoft.SemanticKernel.ChatCompletion.ChatHistory"/> associated with automatic function invocation.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
using Microsoft.SemanticKernel.ChatCompletion;
#pragma warning disable IDE0005 // Using directive is unnecessary
using Microsoft.SemanticKernel.Connectors.FunctionCalling;
using Microsoft.SemanticKernel.Connectors.OpenAI;

#pragma warning restore IDE0005 // Using directive is unnecessary
using Moq;
using Xunit;
Expand All @@ -21,6 +23,7 @@ public class FunctionCallsProcessorTests
{
private readonly FunctionCallsProcessor _sut = new();
private readonly FunctionChoiceBehaviorOptions _functionChoiceBehaviorOptions = new();
private readonly OpenAIPromptExecutionSettings _openAIPromptExecutionSettings = new();

[Fact]
public void ItShouldReturnNoConfigurationIfNoBehaviorProvided()
Expand Down Expand Up @@ -94,6 +97,7 @@ async Task ProcessFunctionCallsRecursivelyToReachInflightLimitAsync()

await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: [],
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -123,6 +127,7 @@ public async Task ItShouldAddFunctionCallAssistantMessageToChatHistoryAsync()
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -151,6 +156,7 @@ public async Task ItShouldAddFunctionCallExceptionToChatHistoryAsync()
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -184,6 +190,7 @@ public async Task ItShouldAddFunctionInvocationExceptionToChatHistoryAsync()
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -212,6 +219,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionCallNotAdvertisedAsync(
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => false, // Return false to simulate that the function is not advertised
Expand Down Expand Up @@ -240,6 +248,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionIsNotRegisteredOnKernel
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -281,6 +290,7 @@ public async Task ItShouldInvokeFunctionsAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -347,6 +357,7 @@ public async Task ItShouldInvokeFiltersAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -436,6 +447,7 @@ public async Task ItShouldInvokeMultipleFiltersInOrderAsync(bool invokeConcurren
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -484,6 +496,7 @@ public async Task FilterCanOverrideArgumentsAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -536,6 +549,7 @@ public async Task FilterCanHandleExceptionAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -588,6 +602,7 @@ public async Task FiltersCanSkipFunctionExecutionAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -634,6 +649,7 @@ public async Task PreFilterCanTerminateOperationAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -678,6 +694,7 @@ public async Task PostFilterCanTerminateOperationAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -732,6 +749,7 @@ public async Task ItShouldHandleChatMessageContentAsFunctionResultAsync()
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -767,6 +785,7 @@ public async Task ItShouldSerializeFunctionResultOfUnknowTypeAsync()
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -837,6 +856,40 @@ public void ItShouldSerializeFunctionResultsWithStringProperties()
Assert.Equal("{\"Text\":\"テスト\"}", result);
}

[Fact]
public async Task ItShouldPassPromptExecutionSettingsToAutoFunctionInvocationFilterAsync()
{
// Arrange
var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);

AutoFunctionInvocationContext? actualContext = null;

Kernel kernel = CreateKernel(plugin, (context, next) =>
{
actualContext = context;
return Task.CompletedTask;
});

var chatMessageContent = new ChatMessageContent();
chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", arguments: new KernelArguments() { ["parameter"] = "function1-result" }));

// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: new ChatHistory(),
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
options: this._functionChoiceBehaviorOptions,
kernel: kernel!,
isStreaming: false,
cancellationToken: CancellationToken.None);

// Assert
Assert.NotNull(actualContext);
Assert.Same(this._openAIPromptExecutionSettings, actualContext!.ExecutionSettings);
}

private sealed class AutoFunctionInvocationFilter(
Func<AutoFunctionInvocationContext, Func<AutoFunctionInvocationContext, Task>, Task>? onAutoFunctionInvocation) : IAutoFunctionInvocationFilter
{
Expand Down
Loading