Skip to content

Commit

Permalink
Chroma memory store - C# implementation (#1634)
Browse files Browse the repository at this point in the history
### Motivation and Context
<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

This PR contains changes for Chroma C# memory store. 
For more information about Chroma, see:
https://docs.trychroma.com/getting-started.

To run integration tests, local Chroma server should be up and running.

To start local Chroma server for testing (this information is also
included in
[README](https://github.com/microsoft/semantic-kernel/blob/0fb77e0c0f347e3f22d79b3c99e7f84d35340462/dotnet/src/Connectors/Connectors.Memory.Chroma/README.md)):
1. Clone Chroma:

```bash
git clone https://github.com/chroma-core/chroma.git
cd chroma
```

2. Run local Chroma server with Docker within Chroma repository root:

```bash
docker-compose up -d --build
```

3. Use Semantic Kernel with Chroma, using server local endpoint
`http://localhost:8000`:
```csharp
const string endpoint = "http://localhost:8000";

ChromaMemoryStore memoryStore = new(endpoint);

IKernel kernel = Kernel.Builder
    .WithLogger(logger)
    .WithOpenAITextEmbeddingGenerationService("text-embedding-ada-002", "OPENAI_API_KEY")
    .WithMemoryStorage(memoryStore)
    //.WithChromaMemoryStore(endpoint) // This method offers an alternative approach to registering Chroma memory store.
    .Build();
```

### Description
<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

Main public classes:
1. `ChromaMemoryStore` - implementation of `IMemoryStore` interface.
Uses `ChromaClient` to communicate with Chroma API. Throws
`ChromaMemoryStoreException` in case of issues.
2. `ChromaClient` - responsible for communication with Chroma API.
Throws `ChromaClientException` in case of issues.

Chroma entity models:
1. `ChromaCollectionModel` - defines Chroma collection.
2. `ChromaEmbeddingsModel` - defines Chroma embeddings, constructed from
`MemoryRecord`.
3. `ChromaQueryResultModel` - defines Chroma search result set.

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->
- [x] The code builds clean without any errors or warnings
- [x] The PR follows SK Contribution Guidelines
(https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
- [x] The code follows the .NET coding conventions
(https://learn.microsoft.com/dotnet/csharp/fundamentals/coding-style/coding-conventions)
verified with `dotnet format`
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
dmytrostruk authored Jun 26, 2023
1 parent 527d378 commit 85d420f
Show file tree
Hide file tree
Showing 28 changed files with 2,047 additions and 0 deletions.
9 changes: 9 additions & 0 deletions dotnet/SK-dotnet.sln
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.Postgres"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.Redis", "src\Connectors\Connectors.Memory.Redis\Connectors.Memory.Redis.csproj", "{3720F5ED-FB4D-485E-8A93-CDE60DEF0805}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Memory.Chroma", "src\Connectors\Connectors.Memory.Chroma\Connectors.Memory.Chroma.csproj", "{185E0CE8-C2DA-4E4C-A491-E8EB40316315}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.AI.OpenAI", "src\Connectors\Connectors.AI.OpenAI\Connectors.AI.OpenAI.csproj", "{AFA81EB7-F869-467D-8A90-744305D80AAC}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SemanticKernel.Abstractions", "src\SemanticKernel.Abstractions\SemanticKernel.Abstractions.csproj", "{627742DB-1E52-468A-99BD-6FF1A542D25B}"
Expand Down Expand Up @@ -355,6 +357,12 @@ Global
{50FAE231-6F24-4779-9D02-12ABBC9A49E2}.Publish|Any CPU.Build.0 = Publish|Any CPU
{50FAE231-6F24-4779-9D02-12ABBC9A49E2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{50FAE231-6F24-4779-9D02-12ABBC9A49E2}.Release|Any CPU.Build.0 = Release|Any CPU
{185E0CE8-C2DA-4E4C-A491-E8EB40316315}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{185E0CE8-C2DA-4E4C-A491-E8EB40316315}.Debug|Any CPU.Build.0 = Debug|Any CPU
{185E0CE8-C2DA-4E4C-A491-E8EB40316315}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
{185E0CE8-C2DA-4E4C-A491-E8EB40316315}.Publish|Any CPU.Build.0 = Debug|Any CPU
{185E0CE8-C2DA-4E4C-A491-E8EB40316315}.Release|Any CPU.ActiveCfg = Release|Any CPU
{185E0CE8-C2DA-4E4C-A491-E8EB40316315}.Release|Any CPU.Build.0 = Release|Any CPU
{0D0C4DAD-E6BC-4504-AE3A-EEA4E35920C1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{0D0C4DAD-E6BC-4504-AE3A-EEA4E35920C1}.Debug|Any CPU.Build.0 = Debug|Any CPU
{0D0C4DAD-E6BC-4504-AE3A-EEA4E35920C1}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
Expand Down Expand Up @@ -418,6 +426,7 @@ Global
{B00AD427-0047-4850-BEF9-BA8237EA9D8B} = {958AD708-F048-4FAF-94ED-D2F2B92748B9}
{DB950192-30F1-48B1-88D7-F43FECCA1A1C} = {958AD708-F048-4FAF-94ED-D2F2B92748B9}
{1C19D805-3573-4477-BF07-40180FCDE1BD} = {958AD708-F048-4FAF-94ED-D2F2B92748B9}
{185E0CE8-C2DA-4E4C-A491-E8EB40316315} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C}
{0D0C4DAD-E6BC-4504-AE3A-EEA4E35920C1} = {9ECD1AA0-75B3-4E25-B0B5-9F0945B64974}
{E6EDAB8F-3406-4DBF-9AAB-DF40DC2CA0FA} = {FA3720F1-C99A-49B2-9577-A940257098BF}
EndGlobalSection
Expand Down
70 changes: 70 additions & 0 deletions dotnet/samples/KernelSyntaxExamples/Example50_Chroma.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.Memory.Chroma;
using Microsoft.SemanticKernel.Memory;
using RepoUtils;

// ReSharper disable once InconsistentNaming
public static class Example50_Chroma
{
private const string MemoryCollectionName = "chroma-test";

public static async Task RunAsync()
{
string endpoint = Env.Var("CHROMA_ENDPOINT");

var memoryStore = new ChromaMemoryStore(endpoint);

IKernel kernel = Kernel.Builder
.WithLogger(ConsoleLogger.Log)
.WithOpenAITextCompletionService("text-davinci-003", Env.Var("OPENAI_API_KEY"))
.WithOpenAITextEmbeddingGenerationService("text-embedding-ada-002", Env.Var("OPENAI_API_KEY"))
.WithMemoryStorage(memoryStore)
//.WithChromaMemoryStore(endpoint) // This method offers an alternative approach to registering Chroma memory store.
.Build();

Console.WriteLine("== Printing Collections in DB ==");
var collections = memoryStore.GetCollectionsAsync();
await foreach (var collection in collections)
{
Console.WriteLine(collection);
}

Console.WriteLine("== Adding Memories ==");

var key1 = await kernel.Memory.SaveInformationAsync(MemoryCollectionName, id: Guid.NewGuid().ToString(), text: "british short hair");
var key2 = await kernel.Memory.SaveInformationAsync(MemoryCollectionName, id: Guid.NewGuid().ToString(), text: "orange tabby");
var key3 = await kernel.Memory.SaveInformationAsync(MemoryCollectionName, id: Guid.NewGuid().ToString(), text: "norwegian forest cat");

Console.WriteLine("== Printing Collections in DB ==");
collections = memoryStore.GetCollectionsAsync();
await foreach (var collection in collections)
{
Console.WriteLine(collection);
}

Console.WriteLine("== Retrieving Memories Through the Kernel ==");
MemoryQueryResult? lookup = await kernel.Memory.GetAsync(MemoryCollectionName, key1);
Console.WriteLine(lookup != null ? lookup.Metadata.Text : "ERROR: memory not found");

Console.WriteLine("== Similarity Searching Memories: My favorite color is orange ==");
var searchResults = kernel.Memory.SearchAsync(MemoryCollectionName, "My favorite color is orange", limit: 3, minRelevanceScore: 0.6);

await foreach (var item in searchResults)
{
Console.WriteLine(item.Metadata.Text + " : " + item.Relevance);
}

Console.WriteLine("== Removing Collection {0} ==", MemoryCollectionName);
await memoryStore.DeleteCollectionAsync(MemoryCollectionName);

Console.WriteLine("== Printing Collections in DB ==");
await foreach (var collection in collections)
{
Console.WriteLine(collection);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
<ProjectReference Include="..\..\src\Connectors\Connectors.AI.OpenAI\Connectors.AI.OpenAI.csproj" />
<ProjectReference Include="..\..\src\Connectors\Connectors.AI.HuggingFace\Connectors.AI.HuggingFace.csproj" />
<ProjectReference Include="..\..\src\Connectors\Connectors.Memory.AzureCognitiveSearch\Connectors.Memory.AzureCognitiveSearch.csproj" />
<ProjectReference Include="..\..\src\Connectors\Connectors.Memory.Chroma\Connectors.Memory.Chroma.csproj" />
<ProjectReference Include="..\..\src\Connectors\Connectors.Memory.Postgres\Connectors.Memory.Postgres.csproj" />
<ProjectReference Include="..\..\src\Connectors\Connectors.Memory.Weaviate\Connectors.Memory.Weaviate.csproj" />
<ProjectReference Include="..\..\src\Connectors\Connectors.Memory.Redis\Connectors.Memory.Redis.csproj" />
Expand Down
3 changes: 3 additions & 0 deletions dotnet/samples/KernelSyntaxExamples/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,8 @@ public static async Task Main()

await Example49_LogitBias.RunAsync();
Console.WriteLine("== DONE ==");

await Example50_Chroma.RunAsync();
Console.WriteLine("== DONE ==");
}
}
207 changes: 207 additions & 0 deletions dotnet/src/Connectors/Connectors.Memory.Chroma/ChromaClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.Connectors.Memory.Chroma.Http.ApiSchema;
using Microsoft.SemanticKernel.Connectors.Memory.Chroma.Http.ApiSchema.Internal;

namespace Microsoft.SemanticKernel.Connectors.Memory.Chroma;

/// <summary>
/// An implementation of a client for the Chroma Vector DB. This class is used to
/// create, delete, and get embeddings data from Chroma Vector DB instance.
/// </summary>
#pragma warning disable CA1001 // Types that own disposable fields should be disposable. Explanation - In this case, there is no need to dispose because either the NonDisposableHttpClientHandler or a custom HTTP client is being used.
public class ChromaClient : IChromaClient
#pragma warning restore CA1001 // Types that own disposable fields should be disposable. Explanation - In this case, there is no need to dispose because either the NonDisposableHttpClientHandler or a custom HTTP client is being used.
{
/// <summary>
/// Initializes a new instance of the <see cref="ChromaClient"/> class.
/// </summary>
/// <param name="endpoint">Chroma server endpoint URL.</param>
/// <param name="logger">Optional logger instance.</param>
public ChromaClient(string endpoint, ILogger? logger = null)
{
this._httpClient = new HttpClient(NonDisposableHttpClientHandler.Instance, disposeHandler: false);
this._endpoint = endpoint;
this._logger = logger ?? NullLogger<ChromaClient>.Instance;
}

/// <summary>
/// Initializes a new instance of the <see cref="ChromaClient"/> class.
/// </summary>
/// <param name="httpClient">The <see cref="HttpClient"/> instance used for making HTTP requests.</param>
/// <param name="endpoint">Chroma server endpoint URL.</param>
/// <param name="logger">Optional logger instance.</param>
/// <exception cref="ChromaClientException">Occurs when <see cref="HttpClient"/> doesn't have base address and endpoint parameter is not provided.</exception>
public ChromaClient(HttpClient httpClient, string? endpoint = null, ILogger? logger = null)
{
if (string.IsNullOrEmpty(httpClient.BaseAddress?.AbsoluteUri) && string.IsNullOrEmpty(endpoint))
{
throw new ChromaClientException("The HttpClient BaseAddress and endpoint are both null or empty. Please ensure at least one is provided.");
}

this._httpClient = httpClient;
this._endpoint = endpoint;
this._logger = logger ?? NullLogger<ChromaClient>.Instance;
}

/// <inheritdoc />
public async Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
{
this._logger.LogDebug("Creating collection {0}", collectionName);

using var request = CreateCollectionRequest.Create(collectionName).Build();

await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public async Task<ChromaCollectionModel?> GetCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
{
this._logger.LogDebug("Getting collection {0}", collectionName);

using var request = GetCollectionRequest.Create(collectionName).Build();

(HttpResponseMessage response, string responseContent) = await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false);

var collection = JsonSerializer.Deserialize<ChromaCollectionModel>(responseContent);

return collection;
}

/// <inheritdoc />
public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
{
this._logger.LogDebug("Deleting collection {0}", collectionName);

using var request = DeleteCollectionRequest.Create(collectionName).Build();

await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public async IAsyncEnumerable<string> ListCollectionsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
this._logger.LogDebug("Listing collections");

using var request = ListCollectionsRequest.Create().Build();

(HttpResponseMessage response, string responseContent) = await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false);

var collections = JsonSerializer.Deserialize<List<ChromaCollectionModel>>(responseContent);

foreach (var collection in collections!)
{
yield return collection.Name;
}
}

/// <inheritdoc />
public async Task AddEmbeddingsAsync(string collectionId, string[] ids, float[][] embeddings, object[]? metadatas = null, CancellationToken cancellationToken = default)
{
this._logger.LogDebug("Adding embeddings to collection with id: {0}", collectionId);

using var request = AddEmbeddingsRequest.Create(collectionId, ids, embeddings, metadatas).Build();

await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public async Task<ChromaEmbeddingsModel> GetEmbeddingsAsync(string collectionId, string[] ids, string[]? include = null, CancellationToken cancellationToken = default)
{
this._logger.LogDebug("Getting embeddings from collection with id: {0}", collectionId);

using var request = GetEmbeddingsRequest.Create(collectionId, ids, include).Build();

(HttpResponseMessage response, string responseContent) = await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false);

var embeddings = JsonSerializer.Deserialize<ChromaEmbeddingsModel>(responseContent);

return embeddings ?? new ChromaEmbeddingsModel();
}

/// <inheritdoc />
public async Task DeleteEmbeddingsAsync(string collectionId, string[] ids, CancellationToken cancellationToken = default)
{
this._logger.LogDebug("Deleting embeddings from collection with id: {0}", collectionId);

using var request = DeleteEmbeddingsRequest.Create(collectionId, ids).Build();

await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public async Task<ChromaQueryResultModel> QueryEmbeddingsAsync(string collectionId, float[][] queryEmbeddings, int nResults, string[]? include = null, CancellationToken cancellationToken = default)
{
this._logger.LogDebug("Query embeddings in collection with id: {0}", collectionId);

using var request = QueryEmbeddingsRequest.Create(collectionId, queryEmbeddings, nResults, include).Build();

(HttpResponseMessage response, string responseContent) = await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false);

var queryResult = JsonSerializer.Deserialize<ChromaQueryResultModel>(responseContent);

return queryResult ?? new ChromaQueryResultModel();
}

#region private ================================================================================

private const string ApiRoute = "api/v1/";

private readonly ILogger _logger;
private readonly HttpClient _httpClient;
private readonly string? _endpoint = null;

private async Task<(HttpResponseMessage response, string responseContent)> ExecuteHttpRequestAsync(
HttpRequestMessage request,
CancellationToken cancellationToken = default)
{
string endpoint = this._endpoint ?? this._httpClient.BaseAddress.ToString();
endpoint = this.SanitizeEndpoint(endpoint);

string operationName = request.RequestUri.ToString();

request.RequestUri = new Uri(new Uri(endpoint), operationName);

HttpResponseMessage response = await this._httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false);

string responseContent = await response.Content.ReadAsStringAsync().ConfigureAwait(false);

try
{
response.EnsureSuccessStatusCode();
}
catch (HttpRequestException e)
{
this._logger.LogError(e, "{0} {1} operation failed: {2}, {3}", request.Method.Method, operationName, e.Message, responseContent);
throw new ChromaClientException($"{request.Method.Method} {operationName} operation failed: {e.Message}, {responseContent}", e);
}

return (response, responseContent);
}

private string SanitizeEndpoint(string endpoint)
{
StringBuilder builder = new(endpoint);

if (!endpoint.EndsWith("/", StringComparison.Ordinal))
{
builder.Append('/');
}

builder.Append(ApiRoute);

return builder.ToString();
}

#endregion
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Globalization;

namespace Microsoft.SemanticKernel.Connectors.Memory.Chroma;

/// <summary>
/// Exception to identify issues in <see cref="ChromaClient"/> class.
/// </summary>
public class ChromaClientException : Exception
{
private const string CollectionDoesNotExistErrorFormat = "Collection {0} does not exist";
private const string DeleteNonExistentCollectionErrorMessage = "list index out of range";

/// <summary>
/// Initializes a new instance of the <see cref="ChromaClientException"/> class.
/// </summary>
public ChromaClientException() : base()
{
}

/// <summary>
/// Initializes a new instance of the <see cref="ChromaClientException"/> class.
/// </summary>
/// <param name="message">The message that describes the error.</param>
public ChromaClientException(string message) : base(message)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="ChromaClientException"/> class.
/// </summary>
/// <param name="message">The message that describes the error.</param>
/// <param name="innerException">Instance of inner exception.</param>
public ChromaClientException(string message, Exception innerException) : base(message, innerException)
{
}

/// <summary>
/// Checks if Chroma API error means that collection does not exist.
/// </summary>
/// <param name="collectionName">Collection name.</param>
public bool CollectionDoesNotExistException(string collectionName) =>
this.Message.Contains(string.Format(CultureInfo.InvariantCulture, CollectionDoesNotExistErrorFormat, collectionName));

/// <summary>
/// Checks if Chroma API error means that there was an attempt to delete non-existent collection.
/// </summary>
public bool DeleteNonExistentCollectionException() =>
this.Message.Contains(DeleteNonExistentCollectionErrorMessage);
}
Loading

0 comments on commit 85d420f

Please sign in to comment.