Skip to content

Commit

Permalink
SOLR-17632: first draft implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandrobenedetti committed Jan 31, 2025
1 parent 04d3ba0 commit bdffe76
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,16 @@
import org.apache.solr.update.processor.UpdateRequestProcessor;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

class TexToVectorUpdateProcessor extends UpdateRequestProcessor implements ResourceLoaderAware, ManagedResourceObserver {

class TexToVectorUpdateProcessor extends UpdateRequestProcessor{
private final String inputField;
private final String outputField;
private final String model;
private SolrTextToVectorModel textToVector;
private ManagedTextToVectorModelStore modelStore = null;

@Override
public void inform(ResourceLoader loader) {
final SolrResourceLoader solrResourceLoader = (SolrResourceLoader) loader;
ManagedTextToVectorModelStore.registerManagedTextToVectorModelStore(solrResourceLoader, this);
}

@Override
public void onManagedResourceInitialized(NamedList<?> args, ManagedResource res)
throws SolrException {
if (res instanceof ManagedTextToVectorModelStore) {
modelStore = (ManagedTextToVectorModelStore) res;
}
if (modelStore != null) {
modelStore.loadStoredModels();
}
}

public TexToVectorUpdateProcessor(
String inputField,
Expand Down Expand Up @@ -88,8 +74,14 @@ public void processAdd(AddUpdateCommand cmd) throws IOException {

SolrInputDocument doc = cmd.getSolrInputDocument();
String textToEmbed = doc.get(inputField).getValue().toString();//add null checks and

float[] vector = textToVector.vectorise(textToEmbed);
doc.addField(outputField, vector);
List<Float> vectorAsList = new ArrayList<Float>(vector.length);
for (float f : vector) {
vectorAsList.add(f);
}

doc.addField(outputField, vectorAsList);
super.processAdd(cmd);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,59 +31,56 @@
/**
* This class implements an UpdateProcessorFactory for the Text Embedder Update Processor. It takes
* in input a series of parameter that will be necessary to instantiate and use the embedder
*
*/
public class TextToVectorUpdateProcessorFactory extends UpdateRequestProcessorFactory {
private static final String INPUT_FIELD_PARAM = "inputField";
private static final String OUTPUT_FIELD_PARAM = "outputField";
private static final String EMBEDDING_MODEl_NAME_PARAM = "model";
private static final String OUTPUT_FIELD_PARAM = "outputField";
private static final String EMBEDDING_MODEl_NAME_PARAM = "model";

String inputField;
String outputField;
String embeddingModelName;
SolrParams params;

String inputField;
String outputField;
String embeddingModelName;
SolrParams params;



@Override
public void init(final NamedList<?> args) {
if (args != null) {
params = args.toSolrParams();
inputField = params.get(INPUT_FIELD_PARAM);
checkNotNull(INPUT_FIELD_PARAM, inputField);
@Override
public void init(final NamedList<?> args) {
if (args != null) {
params = args.toSolrParams();
inputField = params.get(INPUT_FIELD_PARAM);
checkNotNull(INPUT_FIELD_PARAM, inputField);

outputField = params.get(OUTPUT_FIELD_PARAM);
checkNotNull(OUTPUT_FIELD_PARAM, outputField);

embeddingModelName = params.get(EMBEDDING_MODEl_NAME_PARAM);
checkNotNull(EMBEDDING_MODEl_NAME_PARAM, embeddingModelName);
}
}
outputField = params.get(OUTPUT_FIELD_PARAM);
checkNotNull(OUTPUT_FIELD_PARAM, outputField);

private void checkNotNull(String paramName, Object param) {
if (param == null) {
throw new SolrException(
SolrException.ErrorCode.SERVER_ERROR,
"Text to Vector UpdateProcessor '" + paramName + "' can not be null");
embeddingModelName = params.get(EMBEDDING_MODEl_NAME_PARAM);
checkNotNull(EMBEDDING_MODEl_NAME_PARAM, embeddingModelName);
}
}
}

@Override
public UpdateRequestProcessor getInstance(
SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
private void checkNotNull(String paramName, Object param) {
if (param == null) {
throw new SolrException(
SolrException.ErrorCode.SERVER_ERROR,
"Text to Vector UpdateProcessor '" + paramName + "' can not be null");
}
}

final SchemaField outputFieldSchema = req.getCore().getLatestSchema().getField(outputField);
assertIsDenseVectorField(outputFieldSchema);
@Override
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
req.getCore().getLatestSchema().getField(inputField);
final SchemaField outputFieldSchema = req.getCore().getLatestSchema().getField(outputField);
assertIsDenseVectorField(outputFieldSchema);

return new TexToVectorUpdateProcessor(inputField, outputField, embeddingModelName, req, next);
}
return new TexToVectorUpdateProcessor(inputField, outputField, embeddingModelName, req, next);
}

protected void assertIsDenseVectorField(SchemaField schemaField) {
FieldType fieldType = schemaField.getType();
if (!(fieldType instanceof DenseVectorField)) {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"only DenseVectorField is compatible with Vector Query Parsers");
SolrException.ErrorCode.SERVER_ERROR,
"only DenseVectorField is compatible with Vector Query Parsers: " + schemaField.getName());
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
<?xml version="1.0" ?>
<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
license agreements. See the NOTICE file distributed with this work for additional
information regarding copyright ownership. The ASF licenses this file to
You under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of
the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
by applicable law or agreed to in writing, software distributed under the
License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
OF ANY KIND, either express or implied. See the License for the specific
language governing permissions and limitations under the License. -->

<config>
<luceneMatchVersion>${tests.luceneMatchVersion:LATEST}</luceneMatchVersion>
<dataDir>${solr.data.dir:}</dataDir>
<directoryFactory name="DirectoryFactory"
class="${solr.directoryFactory:solr.MockDirectoryFactory}" />
<schemaFactory class="ClassicIndexSchemaFactory" />

<requestDispatcher>
<requestParsers />
</requestDispatcher>

<!-- Query parser used to run neural queries-->
<queryParser name="knn_text_to_vector"
class="org.apache.solr.llm.texttovector.search.TextToVectorQParserPlugin" />

<query>
<filterCache class="solr.CaffeineCache" size="4096"
initialSize="2048" autowarmCount="0" />
</query>
<requestHandler name="/select" class="solr.SearchHandler" />

<updateHandler class="solr.DirectUpdateHandler2">
<autoCommit>
<maxTime>15000</maxTime>
<openSearcher>false</openSearcher>
</autoCommit>
<autoSoftCommit>
<maxTime>1000</maxTime>
</autoSoftCommit>
<updateLog>
<str name="dir">${solr.data.dir:}</str>
</updateLog>
</updateHandler>

<!-- Query request handler managing models and features -->
<requestHandler name="/query" class="solr.SearchHandler">
<lst name="defaults">
<str name="echoParams">explicit</str>
<str name="wt">json</str>
<str name="indent">true</str>
<str name="df">id</str>
</lst>
</requestHandler>

<initParams path="/update/**">
<lst name="defaults">
<str name="update.chain">textToVector</str>
</lst>
</initParams>

<updateRequestProcessorChain name="textToVector">
<processor class="solr.llm.texttovector.update.processor.TextToVectorUpdateProcessorFactory">
<str name="inputField">_text_</str>
<str name="outputField">vector</str>
<str name="model">dummy-1</str>
</processor>
<processor class="solr.RunUpdateProcessorFactory"/>
</updateRequestProcessorChain>

</config>
39 changes: 5 additions & 34 deletions solr/modules/llm/src/test/org/apache/solr/llm/TestLlmBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,17 @@ public class TestLlmBase extends RestTestBase {
protected static Path embeddingModelStoreFile = null;

protected static String IDField = "id";
protected static String textField = "_text_";
protected static String vectorField = "vector";
protected static String vectorField2 = "vector2";
protected static String vectorFieldByteEncoding = "vector_byte_encoding";

protected static void setupTest(
String solrconfig, String schema, boolean buildIndexWithVectors, boolean buildIndexWithText, boolean persistModelStore)
String solrconfig, String schema, boolean buildIndex, boolean persistModelStore)
throws Exception {
initFolders(persistModelStore);
createJettyAndHarness(
tmpSolrHome.toAbsolutePath().toString(), solrconfig, schema, "/solr", true, null);
if (buildIndexWithVectors) prepareIndex(false);
if (buildIndexWithText) prepareIndex(true);
if (buildIndex) prepareIndex();
}

protected static void initFolders(boolean isPersistent) throws Exception {
Expand Down Expand Up @@ -113,17 +111,12 @@ public static void loadModel(String fileName) throws Exception {
ManagedTextToVectorModelStore.REST_END_POINT, multipleModels, "/responseHeader/status==0");
}

protected static void prepareIndex(boolean textual) throws Exception {
List<SolrInputDocument> docsToIndex;
if(textual){
docsToIndex = prepareTextualDocs();
} else {
docsToIndex = prepareTextualDocs();
}

protected static void prepareIndex() throws Exception {
List<SolrInputDocument> docsToIndex = prepareDocs();
for (SolrInputDocument doc : docsToIndex) {
assertU(adoc(doc));
}

assertU(commit());
}

Expand Down Expand Up @@ -184,26 +177,4 @@ private static List<SolrInputDocument> prepareDocs() {

return docs;
}

private static List<SolrInputDocument> prepareTextualDocs() {
int docsCount = 13;
List<SolrInputDocument> docs = new ArrayList<>(docsCount);
for (int i = 1; i < docsCount + 1; i++) {
SolrInputDocument doc = new SolrInputDocument();
doc.addField(IDField, i);
docs.add(doc);
}

docs.get(0)
.addField(textField, "Vegeta is the prince of all saiyans."); // cosine distance vector1= 1.0
docs.get(1)
.addField(
textField, "Goku, also known as Kakaroth is a saiyan grown up on Earth."); // cosine distance vector1= 0.998
docs.get(2)
.addField(
textField,
Arrays.asList("Gohan is a Saiya-man hybrid.")); // cosine distance vector1= 0.992

return docs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
public class TextToVectorQParserTest extends TestLlmBase {
@BeforeClass
public static void init() throws Exception {
setupTest("solrconfig-llm.xml", "schema.xml", true, false, false);
setupTest("solrconfig-llm.xml", "schema.xml", true, false);
loadModel("dummy-model.json");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class TestModelManager extends TestLlmBase {

@BeforeClass
public static void init() throws Exception {
setupTest("solrconfig-llm.xml", "schema.xml", false, false, false);
setupTest("solrconfig-llm.xml", "schema.xml", false, false);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class TestModelManagerPersistence extends TestLlmBase {

@Before
public void init() throws Exception {
setupTest("solrconfig-llm.xml", "schema.xml", false, false, true);
setupTest("solrconfig-llm.xml", "schema.xml", false, true);
}

@After
Expand Down
Loading

0 comments on commit bdffe76

Please sign in to comment.