Skip to content

GH-3540: Allow user-provided embeddings in VectorStore #3541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.springframework.ai.util.JacksonUtils;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.converter.SimpleVectorStoreFilterExpressionConverter;
import org.springframework.ai.vectorstore.model.EmbeddedDocument;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.core.io.Resource;
Expand Down Expand Up @@ -102,17 +103,16 @@ public static SimpleVectorStoreBuilder builder(EmbeddingModel embeddingModel) {
}

@Override
public void doAdd(List<Document> documents) {
Objects.requireNonNull(documents, "Documents list cannot be null");
if (documents.isEmpty()) {
throw new IllegalArgumentException("Documents list cannot be empty");
public void doAdd(List<EmbeddedDocument> embeddedDocuments) {
if (embeddedDocuments.isEmpty()) {
throw new IllegalArgumentException("Embedded document list cannot be empty");
}

for (Document document : documents) {
for (EmbeddedDocument ed : embeddedDocuments) {
Document document = ed.document();
logger.info("Calling EmbeddingModel for document id = {}", document.getId());
float[] embedding = this.embeddingModel.embed(document);
SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent(document.getId(), document.getText(),
document.getMetadata(), embedding);
document.getMetadata(), ed.embedding());
this.store.put(document.getId(), storeContent);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.ai.document.DocumentWriter;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.model.EmbeddedDocument;
import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.util.Assert;
Expand All @@ -50,6 +51,18 @@ default String getName() {
*/
void add(List<Document> documents);

/**
* Adds list of {@link Document}s along with their corresponding embeddings to the vector store.
* @param embeddedDocuments the list of {@link EmbeddedDocument} instances to store. Throws an exception if the
* underlying provider checks for duplicate IDs.
* @throws IllegalArgumentException if there is:
* <ul>
* <li> A mismatch between documents and embeddings
* <li> Dimensional inconsistency between embeddings
* <li> Embeddings contain {@code NaN}, {@code Infinity}, or null/empty vectors.
*/
void addEmbedded(List<EmbeddedDocument> embeddedDocuments);

@Override
default void accept(List<Document> documents) {
add(documents);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.springframework.ai.vectorstore.model;

import org.springframework.ai.document.Document;
import org.springframework.util.Assert;

public record EmbeddedDocument(Document document, float[] embedding) {
public EmbeddedDocument {
Assert.notNull(document, "Document cannot be null.");
Assert.notNull(embedding, "Embedding cannot be null.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@
package org.springframework.ai.vectorstore.observation;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import io.micrometer.observation.ObservationRegistry;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.model.EmbeddedDocument;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
* Abstract base class for {@link VectorStore} implementations that provides observation
Expand Down Expand Up @@ -82,17 +87,57 @@ public void add(List<Document> documents) {
VectorStoreObservationDocumentation.AI_VECTOR_STORE
.observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> this.doAdd(documents));
.observe(() -> this.doAdd(this.toEmbeddedDocuments(documents, this.embeddingModel
.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy))));
}

/**
* Create a new {@link AbstractObservationVectorStore} instance.
* @param embeddedDocuments the list of {@link EmbeddedDocument} instances to add
*/
@Override
public void addEmbedded(List<EmbeddedDocument> embeddedDocuments) {
this.validateNonTextEmbeddedDocuments(embeddedDocuments);

VectorStoreObservationContext observationContext = this
.createObservationContextBuilder(VectorStoreObservationContext.Operation.ADD.value())
.build();

VectorStoreObservationDocumentation.AI_VECTOR_STORE
.observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
this.validateEmbeddings(embeddedDocuments);
this.doAdd(embeddedDocuments);
});
}

private void validateNonTextDocuments(List<Document> documents) {
if (documents == null)
return;
Assert.notNull(documents, "The document list should not be null.");

for (Document document : documents) {
if (document != null && !document.isText()) {
throw new IllegalArgumentException(
"Only text documents are supported for now. One of the documents contains non-text content.");
}
isNonTextDocument(document);
}
}

private void validateNonTextEmbeddedDocuments(List<EmbeddedDocument> embeddedDocuments) {
Assert.notNull(embeddedDocuments, "Embedded documents list should not be null.");

for (EmbeddedDocument embeddedDocument : embeddedDocuments) {
isNonTextDocument(embeddedDocument.document());
}
}

private List<EmbeddedDocument> toEmbeddedDocuments(List<Document> documents, List<float[]> embeddings) {
return IntStream.range(0, documents.size())
.mapToObj(i -> new EmbeddedDocument(documents.get(i), embeddings.get(i)))
.collect(Collectors.toList());
}

private void isNonTextDocument(Document document) {
if (document != null && !document.isText()) {
throw new IllegalArgumentException(
"Only text documents are supported for now. One of the documents contains non-text content.");
}
}

Expand Down Expand Up @@ -144,9 +189,9 @@ public List<Document> similaritySearch(SearchRequest request) {

/**
* Perform the actual add operation.
* @param documents the documents to add
* @param embeddedDocuments the list of {@link EmbeddedDocument} instances to add
*/
public abstract void doAdd(List<Document> documents);
public abstract void doAdd(List<EmbeddedDocument> embeddedDocuments);

/**
* Perform the actual delete operation.
Expand Down Expand Up @@ -180,4 +225,41 @@ protected void doDelete(Filter.Expression filterExpression) {
*/
public abstract VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName);

/**
* Validates a list of {@link EmbeddedDocument}s.
*
* @param embeddedDocuments the list of embedded documents to validate
* @throws IllegalArgumentException if validation fails for:
* <ul>
* <li> Dimensional inconsistency between embeddings
* <li> Embeddings contain {@code NaN}, {@code Infinity}, or empty vectors.
*/
protected void validateEmbeddings(List<EmbeddedDocument> embeddedDocuments) {
if (embeddedDocuments.isEmpty()) return;

int embSize = embeddedDocuments.size();
float[] first = embeddedDocuments.get(0).embedding();
final int expectedDim = first.length;

if (expectedDim == 0) {
throw new IllegalArgumentException("First embedding is empty.");
}

for (int i = 0; i < embSize; i++) {
float[] emb = embeddedDocuments.get(i).embedding();

if (emb.length != expectedDim) {
throw new IllegalArgumentException(String.format(
"Embedding at index %d has dimension %d, expected %d.", i, emb.length, expectedDim));
}

for (float val : emb) {
if (Float.isNaN(val) || Float.isInfinite(val)) {
throw new IllegalArgumentException(String.format(
"Embedding at index %d contains NaN or Infinite value.", i));
}
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;

import org.junit.jupiter.api.BeforeEach;
Expand All @@ -35,6 +30,7 @@
import org.springframework.ai.content.Media;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.model.EmbeddedDocument;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.Resource;
import org.springframework.util.MimeType;
Expand Down Expand Up @@ -62,7 +58,8 @@ void setUp() {
this.mockEmbeddingModel = mock(EmbeddingModel.class);
when(this.mockEmbeddingModel.dimensions()).thenReturn(3);
when(this.mockEmbeddingModel.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(this.mockEmbeddingModel.embed(any(Document.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(this.mockEmbeddingModel.embed(any(), any(), any()))
.thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }, new float[] { 0.1f, 0.2f, 0.3f }));
this.vectorStore = new SimpleVectorStore(SimpleVectorStore.builder(this.mockEmbeddingModel));
}

Expand Down Expand Up @@ -91,17 +88,58 @@ void shouldAddMultipleDocuments() {
assertThat(results).hasSize(2).extracting(Document::getId).containsExactlyInAnyOrder("1", "2");
}

@Test
void shouldAddMultipleDocsWithProvidedEmbeddings() {
List<EmbeddedDocument> embeddedDocs = Arrays.asList(
new EmbeddedDocument(Document.builder().id("1").text("first").build(), new float[] {0.1f, 0.2f, 0.3f}),
new EmbeddedDocument(Document.builder().id("2").text("second").build(), new float[] {0.4f, 0.5f, 0.6f})
);

this.vectorStore.addEmbedded(embeddedDocs);

List<Document> results = this.vectorStore.similaritySearch("first");
assertThat(results).hasSize(2).extracting(Document::getId).containsExactlyInAnyOrder("1", "2");
}

@Test
void shouldHandleInvalidEmbeddings() {
List<EmbeddedDocument> invalidEmbeddings = List.of(
new EmbeddedDocument(Document.builder().id("1").text("first").build(), new float[] {})
);

assertThatThrownBy(() -> this.vectorStore.addEmbedded(invalidEmbeddings))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("First embedding is empty.");

List<EmbeddedDocument> invalidEmbeddingsDimensions = List.of(
new EmbeddedDocument(Document.builder().id("1").text("first").build(), new float[] {0.1f, 0.2f, 0.3f}),
new EmbeddedDocument(Document.builder().id("2").text("second").build(), new float[] {0.1f, 0.2f})
);

assertThatThrownBy(() -> this.vectorStore.addEmbedded(invalidEmbeddingsDimensions))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Embedding at index 1 has dimension 2, expected 3.");

List<EmbeddedDocument> nanEmbeddings = List.of(
new EmbeddedDocument(Document.builder().id("1").text("first").build(), new float[]{Float.NaN})
);

assertThatThrownBy(() -> this.vectorStore.addEmbedded(nanEmbeddings))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Embedding at index 0 contains NaN or Infinite value.");
}

@Test
void shouldHandleEmptyDocumentList() {
assertThatThrownBy(() -> this.vectorStore.add(Collections.emptyList()))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Documents list cannot be empty");
.hasMessage("Embedded document list cannot be empty");
}

@Test
void shouldHandleNullDocumentList() {
assertThatThrownBy(() -> this.vectorStore.add(null)).isInstanceOf(NullPointerException.class)
.hasMessage("Documents list cannot be null");
assertThatThrownBy(() -> this.vectorStore.add(null)).isInstanceOf(IllegalArgumentException.class)
.hasMessage("The document list should not be null.");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ void setUp() {
this.mockEmbeddingModel = mock(EmbeddingModel.class);
when(this.mockEmbeddingModel.dimensions()).thenReturn(3);
when(this.mockEmbeddingModel.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(this.mockEmbeddingModel.embed(any(Document.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(this.mockEmbeddingModel.embed(any(), any(), any()))
.thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }, new float[] { 0.1f, 0.2f, 0.3f }));
this.vectorStore = SimpleVectorStore.builder(this.mockEmbeddingModel).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.vectorstore.model.EmbeddedDocument;
import reactor.core.publisher.Flux;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
Expand Down Expand Up @@ -226,15 +226,12 @@ private JsonNode mapCosmosDocument(Document document, float[] queryEmbedding) {
}

@Override
public void doAdd(List<Document> documents) {

// Batch the documents based on the batching strategy
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);
public void doAdd(List<EmbeddedDocument> embeddedDocuments) {

// Create a list to hold both the CosmosItemOperation and the corresponding
// document ID
List<ImmutablePair<String, CosmosItemOperation>> itemOperationsWithIds = documents.stream().map(doc -> {
List<ImmutablePair<String, CosmosItemOperation>> itemOperationsWithIds = embeddedDocuments.stream().map(ed -> {
Document doc = ed.document();
String partitionKeyValue;

if ("/id".equals(this.partitionKeyPath)) {
Expand All @@ -255,7 +252,7 @@ else if (this.partitionKeyPath.startsWith("/metadata/")) {
}

CosmosItemOperation operation = CosmosBulkOperations.getCreateItemOperation(
mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))),
mapCosmosDocument(doc, ed.embedding()),
new PartitionKey(partitionKeyValue)); // Pair the document ID
// with the operation
return new ImmutablePair<>(doc.getId(), operation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.model.EmbeddedDocument;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
Expand Down Expand Up @@ -151,20 +151,17 @@ public static Builder builder(SearchIndexClient searchIndexClient, EmbeddingMode
}

@Override
public void doAdd(List<Document> documents) {
public void doAdd(List<EmbeddedDocument> embeddedDocuments) {

Assert.notNull(documents, "The document list should not be null.");
if (CollectionUtils.isEmpty(documents)) {
if (CollectionUtils.isEmpty(embeddedDocuments)) {
return; // nothing to do;
}

List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);

final var searchDocuments = documents.stream().map(document -> {
final var searchDocuments = embeddedDocuments.stream().map(ed -> {
Document document = ed.document();
SearchDocument searchDocument = new SearchDocument();
searchDocument.put(ID_FIELD_NAME, document.getId());
searchDocument.put(EMBEDDING_FIELD_NAME, embeddings.get(documents.indexOf(document)));
searchDocument.put(EMBEDDING_FIELD_NAME, ed.embedding());
searchDocument.put(CONTENT_FIELD_NAME, document.getText());
searchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());

Expand Down
Loading