Skip to content

Commit

Permalink
.Net: Gemini added support of system messages and removed messages or…
Browse files Browse the repository at this point in the history
…der limitation (#7067)

### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Finaly Google Gemini API is supporting system messages.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

Added support one or multiple system messages.
Removed messages order limitation because google API not permitting this
anymore.
(so Agents framework and handlebars planner now should work with Gemini!
yeah!)

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

@RogerBarreto
  • Loading branch information
Krzysztof318 authored Jul 3, 2024
1 parent b071168 commit 1b4d348
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private async Task SimpleChatAsync(Kernel kernel)
{
Console.WriteLine("======== Simple Chat ========");

var chatHistory = new ChatHistory();
var chatHistory = new ChatHistory("You are an expert in the tool shop.");
var chat = kernel.GetRequiredService<IChatCompletionService>();

// First user message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private async Task StreamingChatAsync(Kernel kernel)
{
Console.WriteLine("======== Streaming Chat ========");

var chatHistory = new ChatHistory();
var chatHistory = new ChatHistory("You are an expert in the tool shop.");
var chat = kernel.GetRequiredService<IChatCompletionService>();

// First user message
Expand Down
8 changes: 4 additions & 4 deletions dotnet/samples/Concepts/ChatCompletion/Google_GeminiVision.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public async Task GoogleAIAsync()
Console.WriteLine("============= Google AI - Gemini Chat Completion with vision =============");

string geminiApiKey = TestConfiguration.GoogleAI.ApiKey;
string geminiModelId = "gemini-pro-vision";
string geminiModelId = TestConfiguration.GoogleAI.Gemini.ModelId;

if (geminiApiKey is null)
{
Expand All @@ -28,7 +28,7 @@ public async Task GoogleAIAsync()
apiKey: geminiApiKey)
.Build();

var chatHistory = new ChatHistory();
var chatHistory = new ChatHistory("Your job is describing images.");
var chatCompletionService = kernel.GetRequiredService<IChatCompletionService>();

// Load the image from the resources
Expand All @@ -55,7 +55,7 @@ public async Task VertexAIAsync()
Console.WriteLine("============= Vertex AI - Gemini Chat Completion with vision =============");

string geminiBearerKey = TestConfiguration.VertexAI.BearerKey;
string geminiModelId = "gemini-pro-vision";
string geminiModelId = TestConfiguration.VertexAI.Gemini.ModelId;
string geminiLocation = TestConfiguration.VertexAI.Location;
string geminiProject = TestConfiguration.VertexAI.ProjectId;

Expand Down Expand Up @@ -96,7 +96,7 @@ public async Task VertexAIAsync()
// location: TestConfiguration.VertexAI.Location,
// projectId: TestConfiguration.VertexAI.ProjectId);

var chatHistory = new ChatHistory();
var chatHistory = new ChatHistory("Your job is describing images.");
var chatCompletionService = kernel.GetRequiredService<IChatCompletionService>();

// Load the image from the resources
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,21 +259,7 @@ await Assert.ThrowsAsync<InvalidOperationException>(
}

[Fact]
public async Task ShouldThrowInvalidOperationExceptionIfChatHistoryContainsMoreThanOneSystemMessageAsync()
{
var client = this.CreateChatCompletionClient();
var chatHistory = new ChatHistory("System message");
chatHistory.AddSystemMessage("System message 2");
chatHistory.AddSystemMessage("System message 3");
chatHistory.AddUserMessage("hello");

// Act & Assert
await Assert.ThrowsAsync<InvalidOperationException>(
() => client.GenerateChatMessageAsync(chatHistory));
}

[Fact]
public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync()
public async Task ShouldPassSystemMessageToRequestAsync()
{
// Arrange
var client = this.CreateChatCompletionClient();
Expand All @@ -287,40 +273,35 @@ public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync()
// Assert
GeminiRequest? request = JsonSerializer.Deserialize<GeminiRequest>(this._messageHandlerStub.RequestContent);
Assert.NotNull(request);
var systemMessage = request.Contents[0].Parts![0].Text;
var messageRole = request.Contents[0].Role;
Assert.Equal(AuthorRole.User, messageRole);
Assert.NotNull(request.SystemInstruction);
var systemMessage = request.SystemInstruction.Parts![0].Text;
Assert.Null(request.SystemInstruction.Role);
Assert.Equal(message, systemMessage);
}

[Fact]
public async Task ShouldThrowNotSupportedIfChatHistoryHaveIncorrectOrderAsync()
public async Task ShouldPassMultipleSystemMessagesToRequestAsync()
{
// Arrange
string[] messages = ["System message 1", "System message 2", "System message 3"];
var client = this.CreateChatCompletionClient();
var chatHistory = new ChatHistory();
var chatHistory = new ChatHistory(messages[0]);
chatHistory.AddSystemMessage(messages[1]);
chatHistory.AddSystemMessage(messages[2]);
chatHistory.AddUserMessage("Hello");
chatHistory.AddAssistantMessage("Hi");
chatHistory.AddAssistantMessage("Hi me again");
chatHistory.AddUserMessage("How are you?");

// Act & Assert
await Assert.ThrowsAsync<NotSupportedException>(
() => client.GenerateChatMessageAsync(chatHistory));
}

[Fact]
public async Task ShouldThrowNotSupportedIfChatHistoryNotEndWithUserMessageAsync()
{
// Arrange
var client = this.CreateChatCompletionClient();
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("Hello");
chatHistory.AddAssistantMessage("Hi");
// Act
await client.GenerateChatMessageAsync(chatHistory);

// Act & Assert
await Assert.ThrowsAsync<NotSupportedException>(
() => client.GenerateChatMessageAsync(chatHistory));
// Assert
GeminiRequest? request = JsonSerializer.Deserialize<GeminiRequest>(this._messageHandlerStub.RequestContent);
Assert.NotNull(request);
Assert.NotNull(request.SystemInstruction);
Assert.Null(request.SystemInstruction.Role);
Assert.Collection(request.SystemInstruction.Parts!,
item => Assert.Equal(messages[0], item.Text),
item => Assert.Equal(messages[1], item.Text),
item => Assert.Equal(messages[2], item.Text));
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ public async Task ShouldUsePromptExecutionSettingsAsync()
}

[Fact]
public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync()
public async Task ShouldPassSystemMessageToRequestAsync()
{
// Arrange
var client = this.CreateChatCompletionClient();
Expand All @@ -262,12 +262,37 @@ public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync()
// Assert
GeminiRequest? request = JsonSerializer.Deserialize<GeminiRequest>(this._messageHandlerStub.RequestContent);
Assert.NotNull(request);
var systemMessage = request.Contents[0].Parts![0].Text;
var messageRole = request.Contents[0].Role;
Assert.Equal(AuthorRole.User, messageRole);
Assert.NotNull(request.SystemInstruction);
var systemMessage = request.SystemInstruction.Parts![0].Text;
Assert.Null(request.SystemInstruction.Role);
Assert.Equal(message, systemMessage);
}

[Fact]
public async Task ShouldPassMultipleSystemMessagesToRequestAsync()
{
// Arrange
string[] messages = ["System message 1", "System message 2", "System message 3"];
var client = this.CreateChatCompletionClient();
var chatHistory = new ChatHistory(messages[0]);
chatHistory.AddSystemMessage(messages[1]);
chatHistory.AddSystemMessage(messages[2]);
chatHistory.AddUserMessage("Hello");

// Act
await client.StreamGenerateChatMessageAsync(chatHistory).ToListAsync();

// Assert
GeminiRequest? request = JsonSerializer.Deserialize<GeminiRequest>(this._messageHandlerStub.RequestContent);
Assert.NotNull(request);
Assert.NotNull(request.SystemInstruction);
Assert.Null(request.SystemInstruction.Role);
Assert.Collection(request.SystemInstruction.Parts!,
item => Assert.Equal(messages[0], item.Text),
item => Assert.Equal(messages[1], item.Text),
item => Assert.Equal(messages[2], item.Text));
}

[Theory]
[InlineData(0)]
[InlineData(-15)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace SemanticKernel.Connectors.Google.UnitTests.Core.Gemini;
public sealed class GeminiRequestTests
{
[Fact]
public void FromPromptItReturnsGeminiRequestWithConfiguration()
public void FromPromptItReturnsWithConfiguration()
{
// Arrange
var prompt = "prompt-example";
Expand All @@ -37,7 +37,7 @@ public void FromPromptItReturnsGeminiRequestWithConfiguration()
}

[Fact]
public void FromPromptItReturnsGeminiRequestWithSafetySettings()
public void FromPromptItReturnsWithSafetySettings()
{
// Arrange
var prompt = "prompt-example";
Expand All @@ -59,7 +59,7 @@ public void FromPromptItReturnsGeminiRequestWithSafetySettings()
}

[Fact]
public void FromPromptItReturnsGeminiRequestWithPrompt()
public void FromPromptItReturnsWithPrompt()
{
// Arrange
var prompt = "prompt-example";
Expand All @@ -73,7 +73,7 @@ public void FromPromptItReturnsGeminiRequestWithPrompt()
}

[Fact]
public void FromChatHistoryItReturnsGeminiRequestWithConfiguration()
public void FromChatHistoryItReturnsWithConfiguration()
{
// Arrange
ChatHistory chatHistory = [];
Expand All @@ -98,7 +98,7 @@ public void FromChatHistoryItReturnsGeminiRequestWithConfiguration()
}

[Fact]
public void FromChatHistoryItReturnsGeminiRequestWithSafetySettings()
public void FromChatHistoryItReturnsWithSafetySettings()
{
// Arrange
ChatHistory chatHistory = [];
Expand All @@ -123,10 +123,11 @@ public void FromChatHistoryItReturnsGeminiRequestWithSafetySettings()
}

[Fact]
public void FromChatHistoryItReturnsGeminiRequestWithChatHistory()
public void FromChatHistoryItReturnsWithChatHistory()
{
// Arrange
ChatHistory chatHistory = [];
string systemMessage = "system-message";
var chatHistory = new ChatHistory(systemMessage);
chatHistory.AddUserMessage("user-message");
chatHistory.AddAssistantMessage("assist-message");
chatHistory.AddUserMessage("user-message2");
Expand All @@ -136,18 +137,41 @@ public void FromChatHistoryItReturnsGeminiRequestWithChatHistory()
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings);

// Assert
Assert.NotNull(request.SystemInstruction?.Parts);
Assert.Single(request.SystemInstruction.Parts);
Assert.Equal(request.SystemInstruction.Parts[0].Text, systemMessage);
Assert.Collection(request.Contents,
c => Assert.Equal(chatHistory[0].Content, c.Parts![0].Text),
c => Assert.Equal(chatHistory[1].Content, c.Parts![0].Text),
c => Assert.Equal(chatHistory[2].Content, c.Parts![0].Text));
c => Assert.Equal(chatHistory[2].Content, c.Parts![0].Text),
c => Assert.Equal(chatHistory[3].Content, c.Parts![0].Text));
Assert.Collection(request.Contents,
c => Assert.Equal(chatHistory[0].Role, c.Role),
c => Assert.Equal(chatHistory[1].Role, c.Role),
c => Assert.Equal(chatHistory[2].Role, c.Role));
c => Assert.Equal(chatHistory[2].Role, c.Role),
c => Assert.Equal(chatHistory[3].Role, c.Role));
}

[Fact]
public void FromChatHistoryMultipleSystemMessagesItReturnsWithSystemMessages()
{
// Arrange
string[] systemMessages = ["system-message", "system-message2", "system-message3", "system-message4"];
var chatHistory = new ChatHistory(systemMessages[0]);
chatHistory.AddUserMessage("user-message");
chatHistory.AddSystemMessage(systemMessages[1]);
chatHistory.AddMessage(AuthorRole.System,
[new TextContent(systemMessages[2]), new TextContent(systemMessages[3])]);
var executionSettings = new GeminiPromptExecutionSettings();

// Act
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings);

// Assert
Assert.NotNull(request.SystemInstruction?.Parts);
Assert.All(systemMessages, msg => Assert.Contains(request.SystemInstruction.Parts, p => p.Text == msg));
}

[Fact]
public void FromChatHistoryTextAsTextContentItReturnsGeminiRequestWithChatHistory()
public void FromChatHistoryTextAsTextContentItReturnsWithChatHistory()
{
// Arrange
ChatHistory chatHistory = [];
Expand All @@ -163,11 +187,11 @@ public void FromChatHistoryTextAsTextContentItReturnsGeminiRequestWithChatHistor
Assert.Collection(request.Contents,
c => Assert.Equal(chatHistory[0].Content, c.Parts![0].Text),
c => Assert.Equal(chatHistory[1].Content, c.Parts![0].Text),
c => Assert.Equal(chatHistory[2].Items!.Cast<TextContent>().Single().Text, c.Parts![0].Text));
c => Assert.Equal(chatHistory[2].Items.Cast<TextContent>().Single().Text, c.Parts![0].Text));
}

[Fact]
public void FromChatHistoryImageAsImageContentItReturnsGeminiRequestWithChatHistory()
public void FromChatHistoryImageAsImageContentItReturnsWithChatHistory()
{
// Arrange
ReadOnlyMemory<byte> imageAsBytes = new byte[] { 0x00, 0x01, 0x02, 0x03 };
Expand All @@ -187,7 +211,7 @@ public void FromChatHistoryImageAsImageContentItReturnsGeminiRequestWithChatHist
Assert.Collection(request.Contents,
c => Assert.Equal(chatHistory[0].Content, c.Parts![0].Text),
c => Assert.Equal(chatHistory[1].Content, c.Parts![0].Text),
c => Assert.Equal(chatHistory[2].Items!.Cast<ImageContent>().Single().Uri,
c => Assert.Equal(chatHistory[2].Items.Cast<ImageContent>().Single().Uri,
c.Parts![0].FileData!.FileUri),
c => Assert.True(imageAsBytes.ToArray()
.SequenceEqual(Convert.FromBase64String(c.Parts![0].InlineData!.InlineData))));
Expand Down Expand Up @@ -272,7 +296,7 @@ public void FromChatHistoryToolCallsNotNullAddsFunctionCalls()
}

[Fact]
public void AddFunctionItAddsFunctionToGeminiRequest()
public void AddFunctionToGeminiRequest()
{
// Arrange
var request = new GeminiRequest();
Expand All @@ -287,7 +311,7 @@ public void AddFunctionItAddsFunctionToGeminiRequest()
}

[Fact]
public void AddMultipleFunctionsItAddsFunctionsToGeminiRequest()
public void AddMultipleFunctionsToGeminiRequest()
{
// Arrange
var request = new GeminiRequest();
Expand All @@ -308,7 +332,7 @@ public void AddMultipleFunctionsItAddsFunctionsToGeminiRequest()
}

[Fact]
public void AddChatMessageToRequestItAddsChatMessageToGeminiRequest()
public void AddChatMessageToRequest()
{
// Arrange
ChatHistory chat = [];
Expand Down
Loading

0 comments on commit 1b4d348

Please sign in to comment.