Skip to content

Commit

Permalink
.Net Agents - Streaming Bug Fix and Support Additional Assistant Opti…
Browse files Browse the repository at this point in the history
…on (#8852)

### Motivation and Context
<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Respond to two customer identified issues:
1. Add support for `AdditionalInstructions` for creating an assistant as
well as invocation override.
Fixes #8715

2. Fix issue with duplicated tool-call result when using
`ChatCompletionAgent` streaming
Fixes #8825

### Description
<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

1. `AdditionalInstructions` option wasn't included in the V2 migration
as oversight. This is a pure addition.
2. Unit-tests added for new `AdditionalInstructions` option.
4. Duplication of the terminated function result addressed within
`ChatCompletionAgent`
5. Streaming cases added to existing sample demonstrating use of
`IAutoFunctionInvocationFilter` :
`Concepts/Agents/ChatCompletion_FunctionTermination`

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
crickman authored Sep 17, 2024
1 parent ac4f394 commit 00f3a6b
Show file tree
Hide file tree
Showing 14 changed files with 177 additions and 30 deletions.
148 changes: 131 additions & 17 deletions dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,8 @@ public async Task UseAutoFunctionInvocationFilterWithAgentInvocationAsync()
await InvokeAgentAsync("What is the special drink?");
await InvokeAgentAsync("Thank you");

// Display the chat history.
Console.WriteLine("================================");
Console.WriteLine("CHAT HISTORY");
Console.WriteLine("================================");
foreach (ChatMessageContent message in chat)
{
this.WriteAgentChatMessage(message);
}
// Display the entire chat history.
WriteChatHistory(chat);

// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
Expand Down Expand Up @@ -91,15 +85,8 @@ public async Task UseAutoFunctionInvocationFilterWithAgentChatAsync()
await InvokeAgentAsync("What is the special drink?");
await InvokeAgentAsync("Thank you");

// Display the chat history.
Console.WriteLine("================================");
Console.WriteLine("CHAT HISTORY");
Console.WriteLine("================================");
ChatMessageContent[] history = await chat.GetChatMessagesAsync().ToArrayAsync();
for (int index = history.Length; index > 0; --index)
{
this.WriteAgentChatMessage(history[index - 1]);
}
// Display the entire chat history.
WriteChatHistory(await chat.GetChatMessagesAsync().ToArrayAsync());

// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
Expand All @@ -115,6 +102,133 @@ async Task InvokeAgentAsync(string input)
}
}

[Fact]
public async Task UseAutoFunctionInvocationFilterWithStreamingAgentInvocationAsync()
{
// Define the agent
ChatCompletionAgent agent =
new()
{
Instructions = "Answer questions about the menu.",
Kernel = CreateKernelWithFilter(),
Arguments = new KernelArguments(new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }),
};

KernelPlugin plugin = KernelPluginFactory.CreateFromType<MenuPlugin>();
agent.Kernel.Plugins.Add(plugin);

/// Create the chat history to capture the agent interaction.
ChatHistory chat = [];

// Respond to user input, invoking functions where appropriate.
await InvokeAgentAsync("Hello");
await InvokeAgentAsync("What is the special soup?");
await InvokeAgentAsync("What is the special drink?");
await InvokeAgentAsync("Thank you");

// Display the entire chat history.
WriteChatHistory(chat);

// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
{
ChatMessageContent message = new(AuthorRole.User, input);
chat.Add(message);
this.WriteAgentChatMessage(message);

int historyCount = chat.Count;

bool isFirst = false;
await foreach (StreamingChatMessageContent response in agent.InvokeStreamingAsync(chat))
{
if (string.IsNullOrEmpty(response.Content))
{
continue;
}

if (!isFirst)
{
Console.WriteLine($"\n# {response.Role} - {response.AuthorName ?? "*"}:");
isFirst = true;
}

Console.WriteLine($"\t > streamed: '{response.Content}'");
}

if (historyCount <= chat.Count)
{
for (int index = historyCount; index < chat.Count; index++)
{
this.WriteAgentChatMessage(chat[index]);
}
}
}
}

[Fact]
public async Task UseAutoFunctionInvocationFilterWithStreamingAgentChatAsync()
{
// Define the agent
ChatCompletionAgent agent =
new()
{
Instructions = "Answer questions about the menu.",
Kernel = CreateKernelWithFilter(),
Arguments = new KernelArguments(new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }),
};

KernelPlugin plugin = KernelPluginFactory.CreateFromType<MenuPlugin>();
agent.Kernel.Plugins.Add(plugin);

// Create a chat for agent interaction.
AgentGroupChat chat = new();

// Respond to user input, invoking functions where appropriate.
await InvokeAgentAsync("Hello");
await InvokeAgentAsync("What is the special soup?");
await InvokeAgentAsync("What is the special drink?");
await InvokeAgentAsync("Thank you");

// Display the entire chat history.
WriteChatHistory(await chat.GetChatMessagesAsync().ToArrayAsync());

// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
{
ChatMessageContent message = new(AuthorRole.User, input);
chat.AddChatMessage(message);
this.WriteAgentChatMessage(message);

bool isFirst = false;
await foreach (StreamingChatMessageContent response in chat.InvokeStreamingAsync(agent))
{
if (string.IsNullOrEmpty(response.Content))
{
continue;
}

if (!isFirst)
{
Console.WriteLine($"\n# {response.Role} - {response.AuthorName ?? "*"}:");
isFirst = true;
}

Console.WriteLine($"\t > streamed: '{response.Content}'");
}
}
}

private void WriteChatHistory(IEnumerable<ChatMessageContent> chat)
{
Console.WriteLine("================================");
Console.WriteLine("CHAT HISTORY");
Console.WriteLine("================================");
foreach (ChatMessageContent message in chat)
{
this.WriteAgentChatMessage(message);
}
}

private Kernel CreateKernelWithFilter()
{
IKernelBuilder builder = Kernel.CreateBuilder();
Expand Down
10 changes: 7 additions & 3 deletions dotnet/src/Agents/Core/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStream
StringBuilder builder = new();
await foreach (StreamingChatMessageContent message in messages.ConfigureAwait(false))
{
role ??= message.Role;
role = message.Role;
message.Role ??= AuthorRole.Assistant;
message.AuthorName = this.Name;

Expand All @@ -103,8 +103,6 @@ public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStream
yield return message;
}

chat.Add(new(role ?? AuthorRole.Assistant, builder.ToString()) { AuthorName = this.Name });

// Capture mutated messages related function calling / tools
for (int messageIndex = messageCount; messageIndex < chat.Count; messageIndex++)
{
Expand All @@ -114,6 +112,12 @@ public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStream

history.Add(message);
}

// Do not duplicate terminated function result to history
if (role != AuthorRole.Tool)
{
history.Add(new(role ?? AuthorRole.Assistant, builder.ToString()) { AuthorName = this.Name });
}
}

internal static (IChatCompletionService service, PromptExecutionSettings? executionSettings) GetChatCompletionService(Kernel kernel, KernelArguments? arguments)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public static RunCreationOptions GenerateOptions(OpenAIAssistantDefinition defin
RunCreationOptions options =
new()
{
AdditionalInstructions = invocationOptions?.AdditionalInstructions ?? definition.ExecutionOptions?.AdditionalInstructions,
MaxCompletionTokens = ResolveExecutionSetting(invocationOptions?.MaxCompletionTokens, definition.ExecutionOptions?.MaxCompletionTokens),
MaxPromptTokens = ResolveExecutionSetting(invocationOptions?.MaxPromptTokens, definition.ExecutionOptions?.MaxPromptTokens),
ModelOverride = invocationOptions?.ModelName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
using Azure;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;

using OpenAI;
using OpenAI.Assistants;

Expand Down
6 changes: 6 additions & 0 deletions dotnet/src/Agents/OpenAI/OpenAIAssistantExecutionOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ namespace Microsoft.SemanticKernel.Agents.OpenAI;
/// </remarks>
public sealed class OpenAIAssistantExecutionOptions
{
/// <summary>
/// Appends additional instructions.
/// </summary>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? AdditionalInstructions { get; init; }

/// <summary>
/// The maximum number of completion tokens that may be used over the course of the run.
/// </summary>
Expand Down
6 changes: 6 additions & 0 deletions dotnet/src/Agents/OpenAI/OpenAIAssistantInvocationOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ public sealed class OpenAIAssistantInvocationOptions
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? ModelName { get; init; }

/// <summary>
/// Appends additional instructions.
/// </summary>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? AdditionalInstructions { get; init; }

/// <summary>
/// Set if code_interpreter tool is enabled.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ public void AssistantRunOptionsFactoryExecutionOptionsNullTest()
new("gpt-anything")
{
Temperature = 0.5F,
ExecutionOptions =
new()
{
AdditionalInstructions = "test",
},
};

// Act
Expand All @@ -32,6 +37,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsNullTest()
Assert.NotNull(options);
Assert.Null(options.Temperature);
Assert.Null(options.NucleusSamplingFactor);
Assert.Equal("test", options.AdditionalInstructions);
Assert.Empty(options.Metadata);
}

Expand Down Expand Up @@ -77,13 +83,15 @@ public void AssistantRunOptionsFactoryExecutionOptionsOverrideTest()
ExecutionOptions =
new()
{
AdditionalInstructions = "test1",
TruncationMessageCount = 5,
},
};

OpenAIAssistantInvocationOptions invocationOptions =
new()
{
AdditionalInstructions = "test2",
Temperature = 0.9F,
TruncationMessageCount = 8,
EnableJsonResponse = true,
Expand All @@ -96,6 +104,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsOverrideTest()
Assert.NotNull(options);
Assert.Equal(0.9F, options.Temperature);
Assert.Equal(8, options.TruncationStrategy.LastMessages);
Assert.Equal("test2", options.AdditionalInstructions);
Assert.Equal(AssistantResponseFormat.JsonObject, options.ResponseFormat);
Assert.Null(options.NucleusSamplingFactor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public void VerifyOpenAIAssistantDefinitionAssignment()
ExecutionOptions =
new()
{
AdditionalInstructions = "test instructions",
MaxCompletionTokens = 1000,
MaxPromptTokens = 1000,
ParallelToolCallsEnabled = false,
Expand All @@ -83,6 +84,7 @@ public void VerifyOpenAIAssistantDefinitionAssignment()
Assert.Equal(2, definition.Temperature);
Assert.Equal(0, definition.TopP);
Assert.NotNull(definition.ExecutionOptions);
Assert.Equal("test instructions", definition.ExecutionOptions.AdditionalInstructions);
Assert.Equal(1000, definition.ExecutionOptions.MaxCompletionTokens);
Assert.Equal(1000, definition.ExecutionOptions.MaxPromptTokens);
Assert.Equal(12, definition.ExecutionOptions.TruncationMessageCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public void OpenAIAssistantInvocationOptionsInitialState()

// Assert
Assert.Null(options.ModelName);
Assert.Null(options.AdditionalInstructions);
Assert.Null(options.Metadata);
Assert.Null(options.Temperature);
Assert.Null(options.TopP);
Expand All @@ -48,6 +49,7 @@ public void OpenAIAssistantInvocationOptionsAssignment()
new()
{
ModelName = "testmodel",
AdditionalInstructions = "test instructions",
Metadata = new Dictionary<string, string>() { { "a", "1" } },
MaxCompletionTokens = 1000,
MaxPromptTokens = 1000,
Expand All @@ -62,6 +64,7 @@ public void OpenAIAssistantInvocationOptionsAssignment()

// Assert
Assert.Equal("testmodel", options.ModelName);
Assert.Equal("test instructions", options.AdditionalInstructions);
Assert.Equal(2, options.Temperature);
Assert.Equal(0, options.TopP);
Assert.Equal(1000, options.MaxCompletionTokens);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using SemanticKernel.IntegrationTests.TestSettings;
using xRetry;
using Xunit;

namespace SemanticKernel.IntegrationTests.Agents;
Expand All @@ -32,7 +33,7 @@ public sealed class ChatCompletionAgentTests()
/// Integration test for <see cref="ChatCompletionAgent"/> using function calling
/// and targeting Azure OpenAI services.
/// </summary>
[Theory]
[RetryTheory(typeof(HttpOperationException))]
[InlineData("What is the special soup?", "Clam Chowder", false)]
[InlineData("What is the special soup?", "Clam Chowder", true)]
public async Task AzureChatCompletionAgentAsync(string input, string expectedAnswerContains, bool useAutoFunctionTermination)
Expand Down Expand Up @@ -96,7 +97,7 @@ public async Task AzureChatCompletionAgentAsync(string input, string expectedAns
/// Integration test for <see cref="ChatCompletionAgent"/> using new function calling model
/// and targeting Azure OpenAI services.
/// </summary>
[Theory]
[RetryTheory(typeof(HttpOperationException))]
[InlineData("What is the special soup?", "Clam Chowder", false)]
[InlineData("What is the special soup?", "Clam Chowder", true)]
public async Task AzureChatCompletionAgentUsingNewFunctionCallingModelAsync(string input, string expectedAnswerContains, bool useAutoFunctionTermination)
Expand Down Expand Up @@ -160,7 +161,7 @@ public async Task AzureChatCompletionAgentUsingNewFunctionCallingModelAsync(stri
/// Integration test for <see cref="ChatCompletionAgent"/> using function calling
/// and targeting Azure OpenAI services.
/// </summary>
[Fact]
[RetryFact(typeof(HttpOperationException))]
public async Task AzureChatCompletionStreamingAsync()
{
// Arrange
Expand Down Expand Up @@ -206,7 +207,7 @@ public async Task AzureChatCompletionStreamingAsync()
/// Integration test for <see cref="ChatCompletionAgent"/> using new function calling model
/// and targeting Azure OpenAI services.
/// </summary>
[Fact]
[RetryFact(typeof(HttpOperationException))]
public async Task AzureChatCompletionStreamingUsingNewFunctionCallingModelAsync()
{
// Arrange
Expand Down
3 changes: 2 additions & 1 deletion dotnet/src/IntegrationTests/Agents/MixedAgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using SemanticKernel.IntegrationTests.TestSettings;
using xRetry;
using Xunit;

namespace SemanticKernel.IntegrationTests.Agents;
Expand Down Expand Up @@ -50,7 +51,7 @@ await this.VerifyAgentExecutionAsync(
/// Integration test for <see cref="OpenAIAssistantAgent"/> using function calling
/// and targeting Azure OpenAI services.
/// </summary>
[Theory]
[RetryTheory(typeof(HttpOperationException))]
[InlineData(false)]
[InlineData(true)]
public async Task AzureOpenAIMixedAgentAsync(bool useNewFunctionCallingModel)
Expand Down
Loading

0 comments on commit 00f3a6b

Please sign in to comment.