Skip to content

Commit 55fc392

Browse files
authored
FIELDENG-589 EntraId support for Vectorizers (#522)
Entra ID support for Vectorizers
1 parent 40ffa00 commit 55fc392

File tree

5 files changed

+98
-7
lines changed

5 files changed

+98
-7
lines changed

src/Redis.OM.Vectorizers/AzureOpenAIVectorizer.cs

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11

2+
using System.Net.Http.Headers;
23
using System.Net.Http.Json;
34
using System.Text.Json;
5+
using Azure.Core;
6+
using Azure.Identity;
47
using Redis.OM.Contracts;
58
using Redis.OM.Modeling;
69

@@ -11,7 +14,8 @@ namespace Redis.OM.Vectorizers;
1114
/// </summary>
1215
public class AzureOpenAIVectorizer : IVectorizer<string>
1316
{
14-
private readonly string _apiKey;
17+
private readonly TokenCredential? _tokenCredential;
18+
private readonly string? _apiKey;
1519
private readonly string _resourceName;
1620
private readonly string _deploymentName;
1721

@@ -28,6 +32,22 @@ public AzureOpenAIVectorizer(string apiKey, string resourceName, string deployme
2832
_resourceName = resourceName;
2933
_deploymentName = deploymentName;
3034
Dim = dim;
35+
_tokenCredential = new DefaultAzureCredential();
36+
}
37+
38+
/// <summary>
39+
/// Initialize vectorizer
40+
/// </summary>
41+
/// <param name="resourceName">The Azure Resource's name</param>
42+
/// <param name="deploymentName">The Azure deployment name</param>
43+
/// <param name="dim">The dimensions of the model addressed by this resource/deployment.</param>
44+
public AzureOpenAIVectorizer(string resourceName, string deploymentName, int dim)
45+
{
46+
_apiKey = null;
47+
_resourceName = resourceName;
48+
_deploymentName = deploymentName;
49+
Dim = dim;
50+
_tokenCredential = new DefaultAzureCredential();
3151
}
3252

3353
/// <inheritdoc />
@@ -37,9 +57,9 @@ public AzureOpenAIVectorizer(string apiKey, string resourceName, string deployme
3757
public int Dim { get; }
3858

3959
/// <inheritdoc />
40-
public byte[] Vectorize(string str) => GetFloats(str, _resourceName, _deploymentName, _apiKey).SelectMany(BitConverter.GetBytes).ToArray();
60+
public byte[] Vectorize(string str) => GetFloats(str, _resourceName, _deploymentName, _apiKey, _tokenCredential).SelectMany(BitConverter.GetBytes).ToArray();
4161

42-
internal static float[] GetFloats(string s, string resourceName, string deploymentName, string apiKey)
62+
internal static float[] GetFloats(string s, string resourceName, string deploymentName, string? apiKey, TokenCredential? azureCredentials)
4363
{
4464
var client = Configuration.Instance.Client;
4565
var requestContent = JsonContent.Create(new { input = s });
@@ -49,9 +69,24 @@ internal static float[] GetFloats(string s, string resourceName, string deployme
4969
Method = HttpMethod.Post,
5070
RequestUri = new Uri(
5171
$"https://{resourceName}.openai.azure.com/openai/deployments/{deploymentName}/embeddings?api-version=2023-05-15"),
52-
Content = requestContent,
53-
Headers = { { "api-key", apiKey } }
72+
Content = requestContent
5473
};
74+
75+
if(string.IsNullOrEmpty(apiKey) && azureCredentials != null)
76+
{
77+
var scope = "https://cognitiveservices.azure.com/.default";
78+
AccessToken token = azureCredentials.GetToken(new TokenRequestContext(new []{scope}), CancellationToken.None);
79+
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token.Token);
80+
}
81+
else if(!string.IsNullOrEmpty(apiKey))
82+
{
83+
request.Headers.Add("api-key", apiKey);
84+
}
85+
else
86+
{
87+
throw new InvalidOperationException("Either apiKey or azureCredentials must be provided.");
88+
}
89+
5590

5691
var res = client.Send(request);
5792
if (!res.IsSuccessStatusCode)

src/Redis.OM.Vectorizers/AzureOpenAIVectorizerAttribute.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Azure.Core;
2+
using Azure.Identity;
13
using Redis.OM.Contracts;
24
using Redis.OM.Modeling;
35

@@ -6,13 +8,19 @@ namespace Redis.OM.Vectorizers;
68
/// <inheritdoc />
79
public class AzureOpenAIVectorizerAttribute : VectorizerAttribute<string>
810
{
11+
private readonly string? _apikey;
12+
private readonly TokenCredential? _tokenCredential;
13+
914
/// <inheritdoc />
1015
public AzureOpenAIVectorizerAttribute(string deploymentName, string resourceName, int dim)
1116
{
1217
DeploymentName = deploymentName;
1318
ResourceName = resourceName;
1419
Dim = dim;
15-
Vectorizer = new AzureOpenAIVectorizer(Configuration.Instance.AzureOpenAIApiKey, ResourceName, DeploymentName, Dim);
20+
_apikey = Configuration.Instance.AzureOpenAIApiKey;
21+
_tokenCredential = new DefaultAzureCredential();
22+
Vectorizer = string.IsNullOrEmpty(Configuration.Instance.AzureOpenAIApiKey) ? new AzureOpenAIVectorizer(resourceName, deploymentName, dim) : new AzureOpenAIVectorizer(Configuration.Instance.AzureOpenAIApiKey, ResourceName, DeploymentName, dim);
23+
1624
}
1725

1826
/// <summary>
@@ -42,7 +50,7 @@ public override byte[] Vectorize(object obj)
4250
throw new ArgumentException("Object must be a string to be embedded", nameof(obj));
4351
}
4452

45-
var floats = AzureOpenAIVectorizer.GetFloats(s, ResourceName, DeploymentName, Configuration.Instance.AzureOpenAIApiKey);
53+
var floats = AzureOpenAIVectorizer.GetFloats(s, ResourceName, DeploymentName, _apikey, _tokenCredential);
4654
return floats.SelectMany(BitConverter.GetBytes).ToArray();
4755
}
4856
}

src/Redis.OM.Vectorizers/Redis.OM.Vectorizers.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
<PackageReference Include="Microsoft.Extensions.Configuration" Version="7.0.0" />
2727
<PackageReference Include="Microsoft.Extensions.Configuration.Abstractions" Version="7.0.0" />
2828
<PackageReference Include="Microsoft.Extensions.Configuration.Binder" Version="7.0.4" />
29+
<PackageReference Include="Azure.Identity" Version="1.13.2" />
2930
</ItemGroup>
3031

3132
<ItemGroup>
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using Redis.OM.Modeling;
2+
using Redis.OM.Modeling.Vectors;
3+
using Redis.OM.Vectorizers;
4+
5+
namespace Redis.OM.Unit.Tests;
6+
7+
[Document(StorageType = StorageType.Json)]
8+
public class AzureOpenAIVectors
9+
{
10+
[RedisIdField]
11+
public string Id { get; set; }
12+
13+
[Indexed]
14+
[AzureOpenAIVectorizer("redisom-embedding", "redisom-openai", 1536)]
15+
public Vector<string> Sentence { get; set; }
16+
17+
[Indexed]
18+
public string Name { get; set; }
19+
20+
[Indexed]
21+
public int Age { get; set; }
22+
23+
public VectorScores VectorScore { get; set; }
24+
}

test/Redis.OM.Unit.Tests/RediSearchTests/VectorTests/VectorFunctionalTests.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,4 +388,27 @@ public void OpenAIQueryTest()
388388
result = collection.First(x=>x.TimeStamp > ts && x.Prompt.VectorRange(queryPrompt, .15));
389389
Assert.Equal("Paris", result.Response);
390390
}
391+
392+
[SkipIfMissingEnvVar("REDIS_OM_AZURE_OAI_TOKEN")]
393+
public void TestAzureOpenAIVectorizer()
394+
{
395+
_connection.DropIndexAndAssociatedRecords(typeof(AzureOpenAIVectors));
396+
_connection.CreateIndex(typeof(AzureOpenAIVectors));
397+
var collection = new RedisCollection<AzureOpenAIVectors>(_connection);
398+
var sentenceVector = Vector.Of("Hello World this is Hal.");
399+
var obj = new AzureOpenAIVectors
400+
{
401+
Age = 45,
402+
Sentence = sentenceVector,
403+
Name = "Hal"
404+
};
405+
406+
collection.Insert(obj);
407+
var queryVector = Vector.Of("Hello World this is Hal.");
408+
var res = collection.NearestNeighbors(x => x.Sentence, 2, queryVector).First();
409+
Assert.Equal(obj.Id, res.Id);
410+
Assert.True(res.VectorScore.NearestNeighborsScore < .01);
411+
Assert.Equal(obj.Sentence.Value, res.Sentence.Value);
412+
Assert.Equal(obj.Sentence.Embedding, res.Sentence.Embedding);
413+
}
391414
}

0 commit comments

Comments
 (0)