Skip to content

Commit e20dfc4

Browse files
.Net: Chat completion agent (#4911)
### Motivation and Context Today, SK already has agent that leverages OpenAI Assistants API. The new agent will expand the options SK consumers can select from. ### Description This PR introduces MVP of the chat completion agent that utilizes the SK ChatCompletion API to communicate with LLM. The MVP will be iteratively extended over the time to support agent abstraction, streaming, multi-modality, etc. <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄 --------- Co-authored-by: Mark Wallace <[email protected]>
1 parent 12e4358 commit e20dfc4

File tree

4 files changed

+330
-0
lines changed

4 files changed

+330
-0
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using Examples;
9+
using Kusto.Cloud.Platform.Utils;
10+
using Microsoft.SemanticKernel;
11+
using Microsoft.SemanticKernel.ChatCompletion;
12+
using Microsoft.SemanticKernel.Connectors.OpenAI;
13+
using Microsoft.SemanticKernel.Experimental.Agents;
14+
using Xunit;
15+
using Xunit.Abstractions;
16+
17+
public class Example79_ChatCompletionAgent : BaseTest
18+
{
19+
/// <summary>
20+
/// This example demonstrates a chat with the chat completion agent that utilizes the SK ChatCompletion API to communicate with LLM.
21+
/// </summary>
22+
[Fact]
23+
public async Task ChatWithAgentAsync()
24+
{
25+
var kernel = Kernel.CreateBuilder()
26+
.AddAzureOpenAIChatCompletion(
27+
deploymentName: TestConfiguration.AzureOpenAI.ChatDeploymentName,
28+
endpoint: TestConfiguration.AzureOpenAI.Endpoint,
29+
apiKey: TestConfiguration.AzureOpenAI.ApiKey,
30+
modelId: TestConfiguration.AzureOpenAI.ChatModelId)
31+
.Build();
32+
33+
var agent = new ChatCompletionAgent(
34+
kernel,
35+
instructions: "You act as a professional financial adviser. However, clients may not know the terminology, so please provide a simple explanation.",
36+
new OpenAIPromptExecutionSettings
37+
{
38+
MaxTokens = 500,
39+
Temperature = 0.7,
40+
TopP = 1.0,
41+
PresencePenalty = 0.0,
42+
FrequencyPenalty = 0.0,
43+
}
44+
);
45+
46+
var prompt = PrintPrompt("I need help with my investment portfolio. Please guide me.");
47+
PrintConversation(await agent.InvokeAsync(new[] { new ChatMessageContent(AuthorRole.User, prompt) }));
48+
}
49+
50+
/// <summary>
51+
/// This example demonstrates a round-robin chat between two chat completion agents using the TurnBasedChat collaboration experience.
52+
/// </summary>
53+
[Fact]
54+
public async Task TurnBasedAgentsChatAsync()
55+
{
56+
var kernel = Kernel.CreateBuilder()
57+
.AddAzureOpenAIChatCompletion(
58+
deploymentName: TestConfiguration.AzureOpenAI.ChatDeploymentName,
59+
endpoint: TestConfiguration.AzureOpenAI.Endpoint,
60+
apiKey: TestConfiguration.AzureOpenAI.ApiKey,
61+
modelId: TestConfiguration.AzureOpenAI.ChatModelId)
62+
.Build();
63+
64+
var settings = new OpenAIPromptExecutionSettings
65+
{
66+
MaxTokens = 1500,
67+
Temperature = 0.7,
68+
TopP = 1.0,
69+
PresencePenalty = 0.0,
70+
FrequencyPenalty = 0.0,
71+
};
72+
73+
var fitnessTrainer = new ChatCompletionAgent(
74+
kernel,
75+
instructions: "As a fitness trainer, suggest workout routines, and exercises for beginners. " +
76+
"You are not a stress management expert, so refrain from recommending stress management strategies. " +
77+
"Collaborate with the stress management expert to create a holistic wellness plan." +
78+
"Always incorporate stress reduction techniques provided by the stress management expert into the fitness plan." +
79+
"Always include your role at the beginning of each response, such as 'As a fitness trainer.",
80+
settings
81+
);
82+
83+
var stressManagementExpert = new ChatCompletionAgent(
84+
kernel,
85+
instructions: "As a stress management expert, provide guidance on stress reduction strategies. " +
86+
"Collaborate with the fitness trainer to create a simple and holistic wellness plan." +
87+
"You are not a fitness expert; therefore, avoid recommending fitness exercises." +
88+
"If the plan is not aligned with recommended stress reduction plan, ask the fitness trainer to rework it to incorporate recommended stress reduction techniques. " +
89+
"Only you can stop the conversation by saying WELLNESS_PLAN_COMPLETE if suggested fitness plan is good." +
90+
"Always include your role at the beginning of each response such as 'As a stress management expert.",
91+
settings
92+
);
93+
94+
var chat = new TurnBasedChat(new[] { fitnessTrainer, stressManagementExpert }, (chatHistory, replies, turn) =>
95+
turn >= 10 || // Limit the number of turns to 10
96+
replies.Any(
97+
message => message.Role == AuthorRole.Assistant &&
98+
message.Content!.Contains("WELLNESS_PLAN_COMPLETE", StringComparison.InvariantCulture))); // Exit when the message "WELLNESS_PLAN_COMPLETE" received from agent
99+
100+
var prompt = "I need help creating a simple wellness plan for a beginner. Please guide me.";
101+
PrintConversation(await chat.SendMessageAsync(prompt));
102+
}
103+
104+
private string PrintPrompt(string prompt)
105+
{
106+
this.WriteLine($"Prompt: {prompt}");
107+
108+
return prompt;
109+
}
110+
111+
private void PrintConversation(IEnumerable<ChatMessageContent> messages)
112+
{
113+
foreach (var message in messages)
114+
{
115+
this.WriteLine($"------------------------------- {message.Role} ------------------------------");
116+
this.WriteLine(message.Content);
117+
this.WriteLine();
118+
}
119+
120+
this.WriteLine();
121+
}
122+
123+
private sealed class TurnBasedChat
124+
{
125+
public TurnBasedChat(IEnumerable<ChatCompletionAgent> agents, Func<ChatHistory, IEnumerable<ChatMessageContent>, int, bool> exitCondition)
126+
{
127+
this._agents = agents.ToArray();
128+
this._exitCondition = exitCondition;
129+
}
130+
131+
public async Task<IReadOnlyList<ChatMessageContent>> SendMessageAsync(string message, CancellationToken cancellationToken = default)
132+
{
133+
var chat = new ChatHistory();
134+
chat.AddUserMessage(message);
135+
136+
IReadOnlyList<ChatMessageContent> result = new List<ChatMessageContent>();
137+
138+
var turn = 0;
139+
140+
do
141+
{
142+
var agent = this._agents[turn % this._agents.Length];
143+
144+
result = await agent.InvokeAsync(chat, cancellationToken);
145+
146+
chat.AddRange(result);
147+
148+
turn++;
149+
}
150+
while (!this._exitCondition(chat, result, turn));
151+
152+
return chat;
153+
}
154+
155+
private readonly ChatCompletionAgent[] _agents;
156+
private readonly Func<ChatHistory, IEnumerable<ChatMessageContent>, int, bool> _exitCondition;
157+
}
158+
159+
public Example79_ChatCompletionAgent(ITestOutputHelper output) : base(output)
160+
{
161+
}
162+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Suppressing errors for Test projects under dotnet folder
2+
[*.cs]
3+
dotnet_diagnostic.CA2007.severity = none # Do not directly await a Task
4+
dotnet_diagnostic.VSTHRD111.severity = none # Use .ConfigureAwait(bool) is hidden by default, set to none to prevent IDE from changing on autosave
5+
dotnet_diagnostic.CS1591.severity = none # Missing XML comment for publicly visible type or member
6+
dotnet_diagnostic.IDE1006.severity = warning # Naming rule violations
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Threading;
6+
using System.Threading.Tasks;
7+
using Microsoft.Extensions.DependencyInjection;
8+
using Microsoft.SemanticKernel;
9+
using Microsoft.SemanticKernel.ChatCompletion;
10+
using Microsoft.SemanticKernel.Experimental.Agents;
11+
using Moq;
12+
using Xunit;
13+
14+
namespace SemanticKernel.Experimental.Agents.UnitTests;
15+
public class ChatCompletionAgentTests
16+
{
17+
private readonly IKernelBuilder _kernelBuilder;
18+
19+
public ChatCompletionAgentTests()
20+
{
21+
this._kernelBuilder = Kernel.CreateBuilder();
22+
}
23+
24+
[Fact]
25+
public async Task ItShouldResolveChatCompletionServiceFromKernelAsync()
26+
{
27+
// Arrange
28+
var mockChatCompletionService = new Mock<IChatCompletionService>();
29+
30+
this._kernelBuilder.Services.AddSingleton<IChatCompletionService>(mockChatCompletionService.Object);
31+
32+
var agent = new ChatCompletionAgent(this._kernelBuilder.Build(), "fake-instructions");
33+
34+
// Act
35+
var result = await agent.InvokeAsync(new List<ChatMessageContent>());
36+
37+
// Assert
38+
mockChatCompletionService.Verify(x =>
39+
x.GetChatMessageContentsAsync(
40+
It.IsAny<ChatHistory>(),
41+
It.IsAny<PromptExecutionSettings>(),
42+
It.IsAny<Kernel>(),
43+
It.IsAny<CancellationToken>()),
44+
Times.Once);
45+
}
46+
47+
[Fact]
48+
public async Task ItShouldAddSystemInstructionsAndMessagesToChatHistoryAsync()
49+
{
50+
// Arrange
51+
var mockChatCompletionService = new Mock<IChatCompletionService>();
52+
53+
this._kernelBuilder.Services.AddSingleton<IChatCompletionService>(mockChatCompletionService.Object);
54+
55+
var agent = new ChatCompletionAgent(this._kernelBuilder.Build(), "fake-instructions");
56+
57+
// Act
58+
var result = await agent.InvokeAsync(new List<ChatMessageContent>() { new(AuthorRole.User, "fake-user-message") });
59+
60+
// Assert
61+
mockChatCompletionService.Verify(
62+
x => x.GetChatMessageContentsAsync(
63+
It.Is<ChatHistory>(ch => ch.Count == 2 &&
64+
ch.Any(m => m.Role == AuthorRole.System && m.Content == "fake-instructions") &&
65+
ch.Any(m => m.Role == AuthorRole.User && m.Content == "fake-user-message")),
66+
It.IsAny<PromptExecutionSettings>(),
67+
It.IsAny<Kernel>(),
68+
It.IsAny<CancellationToken>()),
69+
Times.Once);
70+
}
71+
72+
[Fact]
73+
public async Task ItShouldReturnChatCompletionServiceMessagesAsync()
74+
{
75+
// Arrange
76+
var mockChatCompletionService = new Mock<IChatCompletionService>();
77+
mockChatCompletionService
78+
.Setup(ccs => ccs.GetChatMessageContentsAsync(It.IsAny<ChatHistory>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()))
79+
.ReturnsAsync(new List<ChatMessageContent> {
80+
new(AuthorRole.Assistant, "fake-assistant-message-1"),
81+
new(AuthorRole.Assistant, "fake-assistant-message-2")
82+
});
83+
84+
this._kernelBuilder.Services.AddSingleton<IChatCompletionService>(mockChatCompletionService.Object);
85+
86+
var agent = new ChatCompletionAgent(this._kernelBuilder.Build(), "fake-instructions");
87+
88+
// Act
89+
var result = await agent.InvokeAsync(new List<ChatMessageContent>());
90+
91+
// Assert
92+
Assert.Equal(2, result.Count);
93+
Assert.Contains(result, m => m.Role == AuthorRole.Assistant && m.Content == "fake-assistant-message-1");
94+
Assert.Contains(result, m => m.Role == AuthorRole.Assistant && m.Content == "fake-assistant-message-2");
95+
}
96+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System.Collections.Generic;
4+
using System.Threading;
5+
using System.Threading.Tasks;
6+
using Microsoft.SemanticKernel.ChatCompletion;
7+
8+
namespace Microsoft.SemanticKernel.Experimental.Agents;
9+
10+
/// <summary>
11+
/// Represent an agent that is built around the SK ChatCompletion API and leverages the API's capabilities.
12+
/// </summary>
13+
public sealed class ChatCompletionAgent
14+
{
15+
private readonly Kernel _kernel;
16+
private readonly string _instructions;
17+
private readonly PromptExecutionSettings? _promptExecutionSettings;
18+
19+
/// <summary>
20+
/// Initializes a new instance of the <see cref="ChatCompletionAgent"/> class.
21+
/// </summary>
22+
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use by the agent.</param>
23+
/// <param name="instructions">The instructions for the agent.</param>
24+
/// <param name="executionSettings">The optional execution settings for the agent. If not provided, default settings will be used.</param>
25+
public ChatCompletionAgent(Kernel kernel, string instructions, PromptExecutionSettings? executionSettings = null)
26+
{
27+
Verify.NotNull(kernel, nameof(kernel));
28+
this._kernel = kernel;
29+
30+
Verify.NotNullOrWhiteSpace(instructions, nameof(instructions));
31+
this._instructions = instructions;
32+
33+
this._promptExecutionSettings = executionSettings;
34+
}
35+
36+
/// <summary>
37+
/// Invokes the agent to process the given messages and generate a response.
38+
/// </summary>
39+
/// <param name="messages">A list of the messages for the agent to process.</param>
40+
/// <param name="cancellationToken">An optional <see cref="CancellationToken"/> to cancel the operation.</param>
41+
/// <returns>List of messages representing the agent's response.</returns>
42+
public async Task<IReadOnlyList<ChatMessageContent>> InvokeAsync(IReadOnlyList<ChatMessageContent> messages, CancellationToken cancellationToken = default)
43+
{
44+
var chat = new ChatHistory(this._instructions);
45+
chat.AddRange(messages);
46+
47+
var chatCompletionService = this.GetChatCompletionService();
48+
49+
var chatMessageContent = await chatCompletionService.GetChatMessageContentsAsync(
50+
chat,
51+
this._promptExecutionSettings,
52+
this._kernel,
53+
cancellationToken).ConfigureAwait(false);
54+
55+
return chatMessageContent;
56+
}
57+
58+
/// <summary>
59+
/// Resolves and returns the chat completion service.
60+
/// </summary>
61+
/// <returns>An instance of the chat completion service.</returns>
62+
private IChatCompletionService GetChatCompletionService()
63+
{
64+
return this._kernel.GetRequiredService<IChatCompletionService>();
65+
}
66+
}

0 commit comments

Comments
 (0)