Skip to content

Commit fc6c2d4

Browse files
.Net: Support for configuring dimensions in Google AI embeddings generation (#10489)
### Motivation and Context This change addresses a limitation in the current implementation of the Google AI embeddings generation service in Semantic Kernel. Currently, users cannot configure the output dimensionality of the embeddings, even though the underlying Google AI API supports specifying the number of dimensions via the `output_dimensionality` parameter. **Why is this change required?** Allowing configuration of the dimensions provides greater flexibility for users to tailor the embeddings to their specific use cases—whether for optimizing memory usage, improving performance, or ensuring compatibility with downstream systems that expect a particular embedding size. **What problem does it solve?** It solves the issue of inflexibility by exposing the `dimensions` parameter in the service constructors, builder methods, and API request payloads. This ensures that developers can leverage the full capabilities of the Google API without being limited to the default embedding size. **What scenario does it contribute to?** This feature is particularly useful in scenarios where: - Users need to optimize storage or computational resources. - Downstream tasks or integrations require embeddings of a specific dimensionality. - Fine-tuning the model output is essential for performance or compatibility reasons. Relevant issue link: #10488 ### Description This PR introduces support for specifying the output dimensionality in the Google AI embeddings generation workflow. The main changes include: - **Service Constructor Update:** The `GoogleAITextEmbeddingGenerationService` constructor now accepts an optional `dimensions` parameter, which is then forwarded to the lower-level client implementations. - **Builder and Extension Methods:** Extension methods such as `AddGoogleAIEmbeddingGeneration` have been updated to accept a `dimensions` parameter. This allows developers to configure the embedding dimensions using the builder pattern. - **Request Payload Enhancement:** The `GoogleAIEmbeddingRequest` class now includes a new optional property `Dimensions` (serialized as `output_dimensionality`). When provided, this value is included in the JSON payload sent to the Google AI API. - **Metadata and Attributes Update:** The service’s metadata now reflects the provided dimensions, ensuring consistency in configuration tracking. - **Unit Testing:** New unit tests have been added to confirm that: - When a `dimensions` value is provided, it is correctly included in the JSON request. - When not provided, the default behavior remains unchanged. This enhancement maintains backward compatibility since the new parameter is optional. Existing implementations that do not specify a dimension will continue to work as before. ### Contribution Checklist - [x] The code builds clean without any errors or warnings. - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations. - [x] All unit tests pass, and I have added new tests where possible. - [x] I didn't break anyone 😄 --------- Co-authored-by: Roger Barreto <[email protected]>
1 parent b198d9f commit fc6c2d4

File tree

17 files changed

+534
-80
lines changed

17 files changed

+534
-80
lines changed

dotnet/Directory.Packages.props

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
<PackageVersion Include="Dapr.Actors.AspNetCore" Version="1.14.0" />
2525
<PackageVersion Include="Dapr.AspNetCore" Version="1.14.0" />
2626
<PackageVersion Include="FastBertTokenizer" Version="1.0.28" />
27+
<PackageVersion Include="Google.Apis.Auth" Version="1.69.0" />
2728
<PackageVersion Include="mcpdotnet" Version="1.0.1.3" />
2829
<PackageVersion Include="Microsoft.AspNetCore.Mvc.Testing" Version="8.0.13" />
2930
<PackageVersion Include="Microsoft.AspNetCore.OpenApi" Version="8.0.13" />
@@ -142,9 +143,9 @@
142143
<PackageVersion Include="Milvus.Client" Version="2.3.0-preview.1" />
143144
<PackageVersion Include="Testcontainers" Version="4.1.0" />
144145
<PackageVersion Include="Testcontainers.Milvus" Version="4.1.0" />
145-
<PackageVersion Include="Testcontainers.MongoDB" Version="4.1.0"/>
146-
<PackageVersion Include="Testcontainers.PostgreSql" Version="4.1.0"/>
147-
<PackageVersion Include="Testcontainers.Redis" Version="4.1.0"/>
146+
<PackageVersion Include="Testcontainers.MongoDB" Version="4.1.0" />
147+
<PackageVersion Include="Testcontainers.PostgreSql" Version="4.1.0" />
148+
<PackageVersion Include="Testcontainers.Redis" Version="4.1.0" />
148149
<PackageVersion Include="Microsoft.Data.SqlClient" Version="5.2.2" />
149150
<PackageVersion Include="Qdrant.Client" Version="1.12.0" />
150151
<!-- Symbols -->

dotnet/samples/Concepts/Concepts.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
<ItemGroup>
1717
<PackageReference Include="Microsoft.Net.Compilers.Toolset" />
1818
<PackageReference Include="Docker.DotNet" />
19+
<PackageReference Include="Google.Apis.Auth" />
1920
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" />
2021
<PackageReference Include="Microsoft.ML.Tokenizers.Data.Cl100kBase" />
2122
<PackageReference Include="Microsoft.NET.Test.Sdk" />
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using Google.Apis.Auth.OAuth2;
4+
using Microsoft.SemanticKernel;
5+
using Microsoft.SemanticKernel.Embeddings;
6+
using xRetry;
7+
8+
namespace Memory;
9+
10+
// The following example shows how to use Semantic Kernel with Google AI and Google's Vertex AI for embedding generation,
11+
// including the ability to specify custom dimensions.
12+
public class Google_EmbeddingGeneration(ITestOutputHelper output) : BaseTest(output)
13+
{
14+
/// <summary>
15+
/// This test demonstrates how to use the Google Vertex AI embedding generation service with default dimensions.
16+
/// </summary>
17+
/// <remarks>
18+
/// Currently custom dimensions are not supported for Vertex AI.
19+
/// </remarks>
20+
[RetryFact(typeof(HttpOperationException))]
21+
public async Task GenerateEmbeddingWithDefaultDimensionsUsingVertexAI()
22+
{
23+
string? bearerToken = null;
24+
25+
Assert.NotNull(TestConfiguration.VertexAI.EmbeddingModelId);
26+
Assert.NotNull(TestConfiguration.VertexAI.ClientId);
27+
Assert.NotNull(TestConfiguration.VertexAI.ClientSecret);
28+
Assert.NotNull(TestConfiguration.VertexAI.Location);
29+
Assert.NotNull(TestConfiguration.VertexAI.ProjectId);
30+
31+
IKernelBuilder kernelBuilder = Kernel.CreateBuilder();
32+
kernelBuilder.AddVertexAIEmbeddingGeneration(
33+
modelId: TestConfiguration.VertexAI.EmbeddingModelId!,
34+
bearerTokenProvider: GetBearerToken,
35+
location: TestConfiguration.VertexAI.Location,
36+
projectId: TestConfiguration.VertexAI.ProjectId);
37+
Kernel kernel = kernelBuilder.Build();
38+
39+
var embeddingGenerator = kernel.GetRequiredService<ITextEmbeddingGenerationService>();
40+
41+
// Generate embeddings with the default dimensions for the model
42+
var embeddings = await embeddingGenerator.GenerateEmbeddingsAsync(
43+
["Semantic Kernel is a lightweight, open-source development kit that lets you easily build AI agents and integrate the latest AI models into your codebase."]);
44+
45+
Console.WriteLine($"Generated '{embeddings.Count}' embedding(s) with '{embeddings[0].Length}' dimensions (default) for the provided text");
46+
47+
// Uses Google.Apis.Auth.OAuth2 to get the bearer token
48+
async ValueTask<string> GetBearerToken()
49+
{
50+
if (!string.IsNullOrEmpty(bearerToken))
51+
{
52+
return bearerToken;
53+
}
54+
55+
var credential = GoogleWebAuthorizationBroker.AuthorizeAsync(
56+
new ClientSecrets
57+
{
58+
ClientId = TestConfiguration.VertexAI.ClientId,
59+
ClientSecret = TestConfiguration.VertexAI.ClientSecret
60+
},
61+
["https://www.googleapis.com/auth/cloud-platform"],
62+
"user",
63+
CancellationToken.None);
64+
65+
var userCredential = await credential.WaitAsync(CancellationToken.None);
66+
bearerToken = userCredential.Token.AccessToken;
67+
68+
return bearerToken;
69+
}
70+
}
71+
72+
[RetryFact(typeof(HttpOperationException))]
73+
public async Task GenerateEmbeddingWithDefaultDimensionsUsingGoogleAI()
74+
{
75+
Assert.NotNull(TestConfiguration.GoogleAI.EmbeddingModelId);
76+
Assert.NotNull(TestConfiguration.GoogleAI.ApiKey);
77+
78+
IKernelBuilder kernelBuilder = Kernel.CreateBuilder();
79+
kernelBuilder.AddGoogleAIEmbeddingGeneration(
80+
modelId: TestConfiguration.GoogleAI.EmbeddingModelId!,
81+
apiKey: TestConfiguration.GoogleAI.ApiKey);
82+
Kernel kernel = kernelBuilder.Build();
83+
84+
var embeddingGenerator = kernel.GetRequiredService<ITextEmbeddingGenerationService>();
85+
86+
// Generate embeddings with the default dimensions for the model
87+
var embeddings = await embeddingGenerator.GenerateEmbeddingsAsync(
88+
["Semantic Kernel is a lightweight, open-source development kit that lets you easily build AI agents and integrate the latest AI models into your codebase."]);
89+
90+
Console.WriteLine($"Generated '{embeddings.Count}' embedding(s) with '{embeddings[0].Length}' dimensions (default) for the provided text");
91+
}
92+
93+
[RetryFact(typeof(HttpOperationException))]
94+
public async Task GenerateEmbeddingWithCustomDimensionsUsingGoogleAI()
95+
{
96+
Assert.NotNull(TestConfiguration.GoogleAI.EmbeddingModelId);
97+
Assert.NotNull(TestConfiguration.GoogleAI.ApiKey);
98+
99+
// Specify custom dimensions for the embeddings
100+
const int CustomDimensions = 512;
101+
102+
IKernelBuilder kernelBuilder = Kernel.CreateBuilder();
103+
kernelBuilder.AddGoogleAIEmbeddingGeneration(
104+
modelId: TestConfiguration.GoogleAI.EmbeddingModelId!,
105+
apiKey: TestConfiguration.GoogleAI.ApiKey,
106+
dimensions: CustomDimensions);
107+
Kernel kernel = kernelBuilder.Build();
108+
109+
var embeddingGenerator = kernel.GetRequiredService<ITextEmbeddingGenerationService>();
110+
111+
// Generate embeddings with the specified custom dimensions
112+
var embeddings = await embeddingGenerator.GenerateEmbeddingsAsync(
113+
["Semantic Kernel is a lightweight, open-source development kit that lets you easily build AI agents and integrate the latest AI models into your codebase."]);
114+
115+
Console.WriteLine($"Generated '{embeddings.Count}' embedding(s) with '{embeddings[0].Length}' dimensions (custom: '{CustomDimensions}') for the provided text");
116+
117+
// Verify that we received embeddings with our requested dimensions
118+
Assert.Equal(CustomDimensions, embeddings[0].Length);
119+
}
120+
}

dotnet/src/Connectors/Connectors.Google.UnitTests/Core/GoogleAI/GoogleAIClientEmbeddingsGenerationTests.cs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,61 @@ public async Task ItCreatesPostRequestWithSemanticKernelVersionHeaderAsync()
142142
Assert.Equal(expectedVersion, header);
143143
}
144144

145+
[Fact]
146+
public async Task ShouldIncludeDimensionsInAllRequestsAsync()
147+
{
148+
// Arrange
149+
const int Dimensions = 512;
150+
var client = this.CreateEmbeddingsClient(dimensions: Dimensions);
151+
var dataToEmbed = new List<string>()
152+
{
153+
"First text to embed",
154+
"Second text to embed",
155+
"Third text to embed"
156+
};
157+
158+
// Act
159+
await client.GenerateEmbeddingsAsync(dataToEmbed);
160+
161+
// Assert
162+
var request = JsonSerializer.Deserialize<GoogleAIEmbeddingRequest>(this._messageHandlerStub.RequestContent);
163+
Assert.NotNull(request);
164+
Assert.Equal(dataToEmbed.Count, request.Requests.Count);
165+
Assert.All(request.Requests, item => Assert.Equal(Dimensions, item.Dimensions));
166+
}
167+
168+
[Fact]
169+
public async Task ShouldNotIncludeDimensionsInAllRequestsWhenNotProvidedAsync()
170+
{
171+
// Arrange
172+
var client = this.CreateEmbeddingsClient();
173+
var dataToEmbed = new List<string>()
174+
{
175+
"First text to embed",
176+
"Second text to embed",
177+
"Third text to embed"
178+
};
179+
180+
// Act
181+
await client.GenerateEmbeddingsAsync(dataToEmbed);
182+
183+
// Assert
184+
var request = JsonSerializer.Deserialize<GoogleAIEmbeddingRequest>(this._messageHandlerStub.RequestContent);
185+
Assert.NotNull(request);
186+
Assert.Equal(dataToEmbed.Count, request.Requests.Count);
187+
Assert.All(request.Requests, item => Assert.Null(item.Dimensions));
188+
}
189+
145190
private GoogleAIEmbeddingClient CreateEmbeddingsClient(
146-
string modelId = "fake-model")
191+
string modelId = "fake-model",
192+
int? dimensions = null)
147193
{
148194
var client = new GoogleAIEmbeddingClient(
149195
httpClient: this._httpClient,
150196
modelId: modelId,
151197
apiVersion: GoogleAIVersion.V1,
152-
apiKey: "fake-key");
198+
apiKey: "fake-key",
199+
dimensions: dimensions);
153200
return client;
154201
}
155202

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,86 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System.Text.Json;
34
using Microsoft.SemanticKernel.Connectors.Google.Core;
45
using Xunit;
56

67
namespace SemanticKernel.Connectors.Google.UnitTests.Core.GoogleAI;
78

89
public sealed class GoogleAIEmbeddingRequestTests
910
{
11+
// Arrange
12+
private static readonly string[] s_data = ["text1", "text2"];
13+
private const string ModelId = "modelId";
14+
private const string DimensionalityJsonPropertyName = "\"outputDimensionality\"";
15+
private const int Dimensions = 512;
16+
1017
[Fact]
1118
public void FromDataReturnsValidRequestWithData()
1219
{
13-
// Arrange
14-
string[] data = ["text1", "text2"];
15-
var modelId = "modelId";
16-
1720
// Act
18-
var request = GoogleAIEmbeddingRequest.FromData(data, modelId);
21+
var request = GoogleAIEmbeddingRequest.FromData(s_data, ModelId);
1922

2023
// Assert
2124
Assert.Equal(2, request.Requests.Count);
22-
Assert.Equal(data[0], request.Requests[0].Content.Parts![0].Text);
23-
Assert.Equal(data[1], request.Requests[1].Content.Parts![0].Text);
25+
Assert.Equal(s_data[0], request.Requests[0].Content.Parts![0].Text);
26+
Assert.Equal(s_data[1], request.Requests[1].Content.Parts![0].Text);
2427
}
2528

2629
[Fact]
2730
public void FromDataReturnsValidRequestWithModelId()
2831
{
29-
// Arrange
30-
string[] data = ["text1", "text2"];
31-
var modelId = "modelId";
32+
// Act
33+
var request = GoogleAIEmbeddingRequest.FromData(s_data, ModelId);
34+
35+
// Assert
36+
Assert.Equal(2, request.Requests.Count);
37+
Assert.Equal($"models/{ModelId}", request.Requests[0].Model);
38+
Assert.Equal($"models/{ModelId}", request.Requests[1].Model);
39+
}
3240

41+
[Fact]
42+
public void FromDataSetsDimensionsToNullWhenNotProvided()
43+
{
3344
// Act
34-
var request = GoogleAIEmbeddingRequest.FromData(data, modelId);
45+
var request = GoogleAIEmbeddingRequest.FromData(s_data, ModelId);
3546

3647
// Assert
3748
Assert.Equal(2, request.Requests.Count);
38-
Assert.Equal($"models/{modelId}", request.Requests[0].Model);
39-
Assert.Equal($"models/{modelId}", request.Requests[1].Model);
49+
Assert.Null(request.Requests[0].Dimensions);
50+
Assert.Null(request.Requests[1].Dimensions);
51+
}
52+
53+
[Fact]
54+
public void FromDataJsonDoesNotIncludeDimensionsWhenNull()
55+
{
56+
// Act
57+
var request = GoogleAIEmbeddingRequest.FromData(s_data, ModelId);
58+
string json = JsonSerializer.Serialize(request);
59+
60+
// Assert
61+
Assert.DoesNotContain(DimensionalityJsonPropertyName, json);
62+
}
63+
64+
[Fact]
65+
public void FromDataSetsDimensionsWhenProvided()
66+
{
67+
// Act
68+
var request = GoogleAIEmbeddingRequest.FromData(s_data, ModelId, Dimensions);
69+
70+
// Assert
71+
Assert.Equal(2, request.Requests.Count);
72+
Assert.Equal(Dimensions, request.Requests[0].Dimensions);
73+
Assert.Equal(Dimensions, request.Requests[1].Dimensions);
74+
}
75+
76+
[Fact]
77+
public void FromDataJsonIncludesDimensionsWhenProvided()
78+
{
79+
// Act
80+
var request = GoogleAIEmbeddingRequest.FromData(s_data, ModelId, Dimensions);
81+
string json = JsonSerializer.Serialize(request);
82+
83+
// Assert
84+
Assert.Contains($"{DimensionalityJsonPropertyName}:{Dimensions}", json);
4085
}
4186
}

0 commit comments

Comments
 (0)