Skip to content

Commit

Permalink
.Net: Chat completion agent (#4911)
Browse files Browse the repository at this point in the history
### Motivation and Context
Today, SK already has agent that leverages OpenAI Assistants API. The
new agent will expand the options SK consumers can select from.

### Description
This PR introduces MVP of the chat completion agent that utilizes the SK
ChatCompletion API to communicate with LLM. The MVP will be iteratively
extended over the time to support agent abstraction, streaming,
multi-modality, etc.

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: Mark Wallace <[email protected]>
  • Loading branch information
SergeyMenshykh and markwallace-microsoft authored Feb 12, 2024
1 parent 12e4358 commit e20dfc4
Show file tree
Hide file tree
Showing 4 changed files with 330 additions and 0 deletions.
162 changes: 162 additions & 0 deletions dotnet/samples/KernelSyntaxExamples/Example79_ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Examples;
using Kusto.Cloud.Platform.Utils;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Experimental.Agents;
using Xunit;
using Xunit.Abstractions;

public class Example79_ChatCompletionAgent : BaseTest
{
/// <summary>
/// This example demonstrates a chat with the chat completion agent that utilizes the SK ChatCompletion API to communicate with LLM.
/// </summary>
[Fact]
public async Task ChatWithAgentAsync()
{
var kernel = Kernel.CreateBuilder()
.AddAzureOpenAIChatCompletion(
deploymentName: TestConfiguration.AzureOpenAI.ChatDeploymentName,
endpoint: TestConfiguration.AzureOpenAI.Endpoint,
apiKey: TestConfiguration.AzureOpenAI.ApiKey,
modelId: TestConfiguration.AzureOpenAI.ChatModelId)
.Build();

var agent = new ChatCompletionAgent(
kernel,
instructions: "You act as a professional financial adviser. However, clients may not know the terminology, so please provide a simple explanation.",
new OpenAIPromptExecutionSettings
{
MaxTokens = 500,
Temperature = 0.7,
TopP = 1.0,
PresencePenalty = 0.0,
FrequencyPenalty = 0.0,
}
);

var prompt = PrintPrompt("I need help with my investment portfolio. Please guide me.");
PrintConversation(await agent.InvokeAsync(new[] { new ChatMessageContent(AuthorRole.User, prompt) }));
}

/// <summary>
/// This example demonstrates a round-robin chat between two chat completion agents using the TurnBasedChat collaboration experience.
/// </summary>
[Fact]
public async Task TurnBasedAgentsChatAsync()
{
var kernel = Kernel.CreateBuilder()
.AddAzureOpenAIChatCompletion(
deploymentName: TestConfiguration.AzureOpenAI.ChatDeploymentName,
endpoint: TestConfiguration.AzureOpenAI.Endpoint,
apiKey: TestConfiguration.AzureOpenAI.ApiKey,
modelId: TestConfiguration.AzureOpenAI.ChatModelId)
.Build();

var settings = new OpenAIPromptExecutionSettings
{
MaxTokens = 1500,
Temperature = 0.7,
TopP = 1.0,
PresencePenalty = 0.0,
FrequencyPenalty = 0.0,
};

var fitnessTrainer = new ChatCompletionAgent(
kernel,
instructions: "As a fitness trainer, suggest workout routines, and exercises for beginners. " +
"You are not a stress management expert, so refrain from recommending stress management strategies. " +
"Collaborate with the stress management expert to create a holistic wellness plan." +
"Always incorporate stress reduction techniques provided by the stress management expert into the fitness plan." +
"Always include your role at the beginning of each response, such as 'As a fitness trainer.",
settings
);

var stressManagementExpert = new ChatCompletionAgent(
kernel,
instructions: "As a stress management expert, provide guidance on stress reduction strategies. " +
"Collaborate with the fitness trainer to create a simple and holistic wellness plan." +
"You are not a fitness expert; therefore, avoid recommending fitness exercises." +
"If the plan is not aligned with recommended stress reduction plan, ask the fitness trainer to rework it to incorporate recommended stress reduction techniques. " +
"Only you can stop the conversation by saying WELLNESS_PLAN_COMPLETE if suggested fitness plan is good." +
"Always include your role at the beginning of each response such as 'As a stress management expert.",
settings
);

var chat = new TurnBasedChat(new[] { fitnessTrainer, stressManagementExpert }, (chatHistory, replies, turn) =>
turn >= 10 || // Limit the number of turns to 10
replies.Any(
message => message.Role == AuthorRole.Assistant &&
message.Content!.Contains("WELLNESS_PLAN_COMPLETE", StringComparison.InvariantCulture))); // Exit when the message "WELLNESS_PLAN_COMPLETE" received from agent

var prompt = "I need help creating a simple wellness plan for a beginner. Please guide me.";
PrintConversation(await chat.SendMessageAsync(prompt));
}

private string PrintPrompt(string prompt)
{
this.WriteLine($"Prompt: {prompt}");

return prompt;
}

private void PrintConversation(IEnumerable<ChatMessageContent> messages)
{
foreach (var message in messages)
{
this.WriteLine($"------------------------------- {message.Role} ------------------------------");
this.WriteLine(message.Content);
this.WriteLine();
}

this.WriteLine();
}

private sealed class TurnBasedChat
{
public TurnBasedChat(IEnumerable<ChatCompletionAgent> agents, Func<ChatHistory, IEnumerable<ChatMessageContent>, int, bool> exitCondition)
{
this._agents = agents.ToArray();
this._exitCondition = exitCondition;
}

public async Task<IReadOnlyList<ChatMessageContent>> SendMessageAsync(string message, CancellationToken cancellationToken = default)
{
var chat = new ChatHistory();
chat.AddUserMessage(message);

IReadOnlyList<ChatMessageContent> result = new List<ChatMessageContent>();

var turn = 0;

do
{
var agent = this._agents[turn % this._agents.Length];

result = await agent.InvokeAsync(chat, cancellationToken);

chat.AddRange(result);

turn++;
}
while (!this._exitCondition(chat, result, turn));

return chat;
}

private readonly ChatCompletionAgent[] _agents;
private readonly Func<ChatHistory, IEnumerable<ChatMessageContent>, int, bool> _exitCondition;
}

public Example79_ChatCompletionAgent(ITestOutputHelper output) : base(output)
{
}
}
6 changes: 6 additions & 0 deletions dotnet/src/Experimental/Agents.UnitTests/.editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Suppressing errors for Test projects under dotnet folder
[*.cs]
dotnet_diagnostic.CA2007.severity = none # Do not directly await a Task
dotnet_diagnostic.VSTHRD111.severity = none # Use .ConfigureAwait(bool) is hidden by default, set to none to prevent IDE from changing on autosave
dotnet_diagnostic.CS1591.severity = none # Missing XML comment for publicly visible type or member
dotnet_diagnostic.IDE1006.severity = warning # Naming rule violations
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Experimental.Agents;
using Moq;
using Xunit;

namespace SemanticKernel.Experimental.Agents.UnitTests;
public class ChatCompletionAgentTests
{
private readonly IKernelBuilder _kernelBuilder;

public ChatCompletionAgentTests()
{
this._kernelBuilder = Kernel.CreateBuilder();
}

[Fact]
public async Task ItShouldResolveChatCompletionServiceFromKernelAsync()
{
// Arrange
var mockChatCompletionService = new Mock<IChatCompletionService>();

this._kernelBuilder.Services.AddSingleton<IChatCompletionService>(mockChatCompletionService.Object);

var agent = new ChatCompletionAgent(this._kernelBuilder.Build(), "fake-instructions");

// Act
var result = await agent.InvokeAsync(new List<ChatMessageContent>());

// Assert
mockChatCompletionService.Verify(x =>
x.GetChatMessageContentsAsync(
It.IsAny<ChatHistory>(),
It.IsAny<PromptExecutionSettings>(),
It.IsAny<Kernel>(),
It.IsAny<CancellationToken>()),
Times.Once);
}

[Fact]
public async Task ItShouldAddSystemInstructionsAndMessagesToChatHistoryAsync()
{
// Arrange
var mockChatCompletionService = new Mock<IChatCompletionService>();

this._kernelBuilder.Services.AddSingleton<IChatCompletionService>(mockChatCompletionService.Object);

var agent = new ChatCompletionAgent(this._kernelBuilder.Build(), "fake-instructions");

// Act
var result = await agent.InvokeAsync(new List<ChatMessageContent>() { new(AuthorRole.User, "fake-user-message") });

// Assert
mockChatCompletionService.Verify(
x => x.GetChatMessageContentsAsync(
It.Is<ChatHistory>(ch => ch.Count == 2 &&
ch.Any(m => m.Role == AuthorRole.System && m.Content == "fake-instructions") &&
ch.Any(m => m.Role == AuthorRole.User && m.Content == "fake-user-message")),
It.IsAny<PromptExecutionSettings>(),
It.IsAny<Kernel>(),
It.IsAny<CancellationToken>()),
Times.Once);
}

[Fact]
public async Task ItShouldReturnChatCompletionServiceMessagesAsync()
{
// Arrange
var mockChatCompletionService = new Mock<IChatCompletionService>();
mockChatCompletionService
.Setup(ccs => ccs.GetChatMessageContentsAsync(It.IsAny<ChatHistory>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new List<ChatMessageContent> {
new(AuthorRole.Assistant, "fake-assistant-message-1"),
new(AuthorRole.Assistant, "fake-assistant-message-2")
});

this._kernelBuilder.Services.AddSingleton<IChatCompletionService>(mockChatCompletionService.Object);

var agent = new ChatCompletionAgent(this._kernelBuilder.Build(), "fake-instructions");

// Act
var result = await agent.InvokeAsync(new List<ChatMessageContent>());

// Assert
Assert.Equal(2, result.Count);
Assert.Contains(result, m => m.Role == AuthorRole.Assistant && m.Content == "fake-assistant-message-1");
Assert.Contains(result, m => m.Role == AuthorRole.Assistant && m.Content == "fake-assistant-message-2");
}
}
66 changes: 66 additions & 0 deletions dotnet/src/Experimental/Agents/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Experimental.Agents;

/// <summary>
/// Represent an agent that is built around the SK ChatCompletion API and leverages the API's capabilities.
/// </summary>
public sealed class ChatCompletionAgent
{
private readonly Kernel _kernel;
private readonly string _instructions;
private readonly PromptExecutionSettings? _promptExecutionSettings;

/// <summary>
/// Initializes a new instance of the <see cref="ChatCompletionAgent"/> class.
/// </summary>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use by the agent.</param>
/// <param name="instructions">The instructions for the agent.</param>
/// <param name="executionSettings">The optional execution settings for the agent. If not provided, default settings will be used.</param>
public ChatCompletionAgent(Kernel kernel, string instructions, PromptExecutionSettings? executionSettings = null)
{
Verify.NotNull(kernel, nameof(kernel));
this._kernel = kernel;

Verify.NotNullOrWhiteSpace(instructions, nameof(instructions));
this._instructions = instructions;

this._promptExecutionSettings = executionSettings;
}

/// <summary>
/// Invokes the agent to process the given messages and generate a response.
/// </summary>
/// <param name="messages">A list of the messages for the agent to process.</param>
/// <param name="cancellationToken">An optional <see cref="CancellationToken"/> to cancel the operation.</param>
/// <returns>List of messages representing the agent's response.</returns>
public async Task<IReadOnlyList<ChatMessageContent>> InvokeAsync(IReadOnlyList<ChatMessageContent> messages, CancellationToken cancellationToken = default)
{
var chat = new ChatHistory(this._instructions);
chat.AddRange(messages);

var chatCompletionService = this.GetChatCompletionService();

var chatMessageContent = await chatCompletionService.GetChatMessageContentsAsync(
chat,
this._promptExecutionSettings,
this._kernel,
cancellationToken).ConfigureAwait(false);

return chatMessageContent;
}

/// <summary>
/// Resolves and returns the chat completion service.
/// </summary>
/// <returns>An instance of the chat completion service.</returns>
private IChatCompletionService GetChatCompletionService()
{
return this._kernel.GetRequiredService<IChatCompletionService>();
}
}

0 comments on commit e20dfc4

Please sign in to comment.