Skip to content

Commit bdffe76

Browse files
SOLR-17632: first draft implementation
1 parent 04d3ba0 commit bdffe76

File tree

10 files changed

+272
-136
lines changed

10 files changed

+272
-136
lines changed

solr/modules/llm/src/java/org/apache/solr/llm/texttovector/update/processor/TexToVectorUpdateProcessor.java

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,30 +32,16 @@
3232
import org.apache.solr.update.processor.UpdateRequestProcessor;
3333

3434
import java.io.IOException;
35+
import java.util.ArrayList;
36+
import java.util.List;
3537

36-
class TexToVectorUpdateProcessor extends UpdateRequestProcessor implements ResourceLoaderAware, ManagedResourceObserver {
38+
39+
class TexToVectorUpdateProcessor extends UpdateRequestProcessor{
3740
private final String inputField;
3841
private final String outputField;
3942
private final String model;
4043
private SolrTextToVectorModel textToVector;
4144
private ManagedTextToVectorModelStore modelStore = null;
42-
43-
@Override
44-
public void inform(ResourceLoader loader) {
45-
final SolrResourceLoader solrResourceLoader = (SolrResourceLoader) loader;
46-
ManagedTextToVectorModelStore.registerManagedTextToVectorModelStore(solrResourceLoader, this);
47-
}
48-
49-
@Override
50-
public void onManagedResourceInitialized(NamedList<?> args, ManagedResource res)
51-
throws SolrException {
52-
if (res instanceof ManagedTextToVectorModelStore) {
53-
modelStore = (ManagedTextToVectorModelStore) res;
54-
}
55-
if (modelStore != null) {
56-
modelStore.loadStoredModels();
57-
}
58-
}
5945

6046
public TexToVectorUpdateProcessor(
6147
String inputField,
@@ -88,8 +74,14 @@ public void processAdd(AddUpdateCommand cmd) throws IOException {
8874

8975
SolrInputDocument doc = cmd.getSolrInputDocument();
9076
String textToEmbed = doc.get(inputField).getValue().toString();//add null checks and
77+
9178
float[] vector = textToVector.vectorise(textToEmbed);
92-
doc.addField(outputField, vector);
79+
List<Float> vectorAsList = new ArrayList<Float>(vector.length);
80+
for (float f : vector) {
81+
vectorAsList.add(f);
82+
}
83+
84+
doc.addField(outputField, vectorAsList);
9385
super.processAdd(cmd);
9486
}
9587
}

solr/modules/llm/src/java/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorFactory.java

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,59 +31,56 @@
3131
/**
3232
* This class implements an UpdateProcessorFactory for the Text Embedder Update Processor. It takes
3333
* in input a series of parameter that will be necessary to instantiate and use the embedder
34-
*
3534
*/
3635
public class TextToVectorUpdateProcessorFactory extends UpdateRequestProcessorFactory {
3736
private static final String INPUT_FIELD_PARAM = "inputField";
38-
private static final String OUTPUT_FIELD_PARAM = "outputField";
39-
private static final String EMBEDDING_MODEl_NAME_PARAM = "model";
37+
private static final String OUTPUT_FIELD_PARAM = "outputField";
38+
private static final String EMBEDDING_MODEl_NAME_PARAM = "model";
39+
40+
String inputField;
41+
String outputField;
42+
String embeddingModelName;
43+
SolrParams params;
4044

41-
String inputField;
42-
String outputField;
43-
String embeddingModelName;
44-
SolrParams params;
4545

46-
47-
48-
@Override
49-
public void init(final NamedList<?> args) {
50-
if (args != null) {
51-
params = args.toSolrParams();
52-
inputField = params.get(INPUT_FIELD_PARAM);
53-
checkNotNull(INPUT_FIELD_PARAM, inputField);
46+
@Override
47+
public void init(final NamedList<?> args) {
48+
if (args != null) {
49+
params = args.toSolrParams();
50+
inputField = params.get(INPUT_FIELD_PARAM);
51+
checkNotNull(INPUT_FIELD_PARAM, inputField);
5452

55-
outputField = params.get(OUTPUT_FIELD_PARAM);
56-
checkNotNull(OUTPUT_FIELD_PARAM, outputField);
57-
58-
embeddingModelName = params.get(EMBEDDING_MODEl_NAME_PARAM);
59-
checkNotNull(EMBEDDING_MODEl_NAME_PARAM, embeddingModelName);
60-
}
61-
}
53+
outputField = params.get(OUTPUT_FIELD_PARAM);
54+
checkNotNull(OUTPUT_FIELD_PARAM, outputField);
6255

63-
private void checkNotNull(String paramName, Object param) {
64-
if (param == null) {
65-
throw new SolrException(
66-
SolrException.ErrorCode.SERVER_ERROR,
67-
"Text to Vector UpdateProcessor '" + paramName + "' can not be null");
56+
embeddingModelName = params.get(EMBEDDING_MODEl_NAME_PARAM);
57+
checkNotNull(EMBEDDING_MODEl_NAME_PARAM, embeddingModelName);
58+
}
6859
}
69-
}
7060

71-
@Override
72-
public UpdateRequestProcessor getInstance(
73-
SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
61+
private void checkNotNull(String paramName, Object param) {
62+
if (param == null) {
63+
throw new SolrException(
64+
SolrException.ErrorCode.SERVER_ERROR,
65+
"Text to Vector UpdateProcessor '" + paramName + "' can not be null");
66+
}
67+
}
7468

75-
final SchemaField outputFieldSchema = req.getCore().getLatestSchema().getField(outputField);
76-
assertIsDenseVectorField(outputFieldSchema);
69+
@Override
70+
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
71+
req.getCore().getLatestSchema().getField(inputField);
72+
final SchemaField outputFieldSchema = req.getCore().getLatestSchema().getField(outputField);
73+
assertIsDenseVectorField(outputFieldSchema);
7774

78-
return new TexToVectorUpdateProcessor(inputField, outputField, embeddingModelName, req, next);
79-
}
75+
return new TexToVectorUpdateProcessor(inputField, outputField, embeddingModelName, req, next);
76+
}
8077

8178
protected void assertIsDenseVectorField(SchemaField schemaField) {
8279
FieldType fieldType = schemaField.getType();
8380
if (!(fieldType instanceof DenseVectorField)) {
8481
throw new SolrException(
85-
SolrException.ErrorCode.BAD_REQUEST,
86-
"only DenseVectorField is compatible with Vector Query Parsers");
82+
SolrException.ErrorCode.SERVER_ERROR,
83+
"only DenseVectorField is compatible with Vector Query Parsers: " + schemaField.getName());
8784
}
8885
}
8986

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
<?xml version="1.0" ?>
2+
<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
3+
license agreements. See the NOTICE file distributed with this work for additional
4+
information regarding copyright ownership. The ASF licenses this file to
5+
You under the Apache License, Version 2.0 (the "License"); you may not use
6+
this file except in compliance with the License. You may obtain a copy of
7+
the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
8+
by applicable law or agreed to in writing, software distributed under the
9+
License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
10+
OF ANY KIND, either express or implied. See the License for the specific
11+
language governing permissions and limitations under the License. -->
12+
13+
<config>
14+
<luceneMatchVersion>${tests.luceneMatchVersion:LATEST}</luceneMatchVersion>
15+
<dataDir>${solr.data.dir:}</dataDir>
16+
<directoryFactory name="DirectoryFactory"
17+
class="${solr.directoryFactory:solr.MockDirectoryFactory}" />
18+
<schemaFactory class="ClassicIndexSchemaFactory" />
19+
20+
<requestDispatcher>
21+
<requestParsers />
22+
</requestDispatcher>
23+
24+
<!-- Query parser used to run neural queries-->
25+
<queryParser name="knn_text_to_vector"
26+
class="org.apache.solr.llm.texttovector.search.TextToVectorQParserPlugin" />
27+
28+
<query>
29+
<filterCache class="solr.CaffeineCache" size="4096"
30+
initialSize="2048" autowarmCount="0" />
31+
</query>
32+
<requestHandler name="/select" class="solr.SearchHandler" />
33+
34+
<updateHandler class="solr.DirectUpdateHandler2">
35+
<autoCommit>
36+
<maxTime>15000</maxTime>
37+
<openSearcher>false</openSearcher>
38+
</autoCommit>
39+
<autoSoftCommit>
40+
<maxTime>1000</maxTime>
41+
</autoSoftCommit>
42+
<updateLog>
43+
<str name="dir">${solr.data.dir:}</str>
44+
</updateLog>
45+
</updateHandler>
46+
47+
<!-- Query request handler managing models and features -->
48+
<requestHandler name="/query" class="solr.SearchHandler">
49+
<lst name="defaults">
50+
<str name="echoParams">explicit</str>
51+
<str name="wt">json</str>
52+
<str name="indent">true</str>
53+
<str name="df">id</str>
54+
</lst>
55+
</requestHandler>
56+
57+
<initParams path="/update/**">
58+
<lst name="defaults">
59+
<str name="update.chain">textToVector</str>
60+
</lst>
61+
</initParams>
62+
63+
<updateRequestProcessorChain name="textToVector">
64+
<processor class="solr.llm.texttovector.update.processor.TextToVectorUpdateProcessorFactory">
65+
<str name="inputField">_text_</str>
66+
<str name="outputField">vector</str>
67+
<str name="model">dummy-1</str>
68+
</processor>
69+
<processor class="solr.RunUpdateProcessorFactory"/>
70+
</updateRequestProcessorChain>
71+
72+
</config>

solr/modules/llm/src/test/org/apache/solr/llm/TestLlmBase.java

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,17 @@ public class TestLlmBase extends RestTestBase {
4545
protected static Path embeddingModelStoreFile = null;
4646

4747
protected static String IDField = "id";
48-
protected static String textField = "_text_";
4948
protected static String vectorField = "vector";
5049
protected static String vectorField2 = "vector2";
5150
protected static String vectorFieldByteEncoding = "vector_byte_encoding";
5251

5352
protected static void setupTest(
54-
String solrconfig, String schema, boolean buildIndexWithVectors, boolean buildIndexWithText, boolean persistModelStore)
53+
String solrconfig, String schema, boolean buildIndex, boolean persistModelStore)
5554
throws Exception {
5655
initFolders(persistModelStore);
5756
createJettyAndHarness(
5857
tmpSolrHome.toAbsolutePath().toString(), solrconfig, schema, "/solr", true, null);
59-
if (buildIndexWithVectors) prepareIndex(false);
60-
if (buildIndexWithText) prepareIndex(true);
58+
if (buildIndex) prepareIndex();
6159
}
6260

6361
protected static void initFolders(boolean isPersistent) throws Exception {
@@ -113,17 +111,12 @@ public static void loadModel(String fileName) throws Exception {
113111
ManagedTextToVectorModelStore.REST_END_POINT, multipleModels, "/responseHeader/status==0");
114112
}
115113

116-
protected static void prepareIndex(boolean textual) throws Exception {
117-
List<SolrInputDocument> docsToIndex;
118-
if(textual){
119-
docsToIndex = prepareTextualDocs();
120-
} else {
121-
docsToIndex = prepareTextualDocs();
122-
}
123-
114+
protected static void prepareIndex() throws Exception {
115+
List<SolrInputDocument> docsToIndex = prepareDocs();
124116
for (SolrInputDocument doc : docsToIndex) {
125117
assertU(adoc(doc));
126118
}
119+
127120
assertU(commit());
128121
}
129122

@@ -184,26 +177,4 @@ private static List<SolrInputDocument> prepareDocs() {
184177

185178
return docs;
186179
}
187-
188-
private static List<SolrInputDocument> prepareTextualDocs() {
189-
int docsCount = 13;
190-
List<SolrInputDocument> docs = new ArrayList<>(docsCount);
191-
for (int i = 1; i < docsCount + 1; i++) {
192-
SolrInputDocument doc = new SolrInputDocument();
193-
doc.addField(IDField, i);
194-
docs.add(doc);
195-
}
196-
197-
docs.get(0)
198-
.addField(textField, "Vegeta is the prince of all saiyans."); // cosine distance vector1= 1.0
199-
docs.get(1)
200-
.addField(
201-
textField, "Goku, also known as Kakaroth is a saiyan grown up on Earth."); // cosine distance vector1= 0.998
202-
docs.get(2)
203-
.addField(
204-
textField,
205-
Arrays.asList("Gohan is a Saiya-man hybrid.")); // cosine distance vector1= 0.992
206-
207-
return docs;
208-
}
209180
}

solr/modules/llm/src/test/org/apache/solr/llm/texttovector/search/TextToVectorQParserTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
public class TextToVectorQParserTest extends TestLlmBase {
2626
@BeforeClass
2727
public static void init() throws Exception {
28-
setupTest("solrconfig-llm.xml", "schema.xml", true, false, false);
28+
setupTest("solrconfig-llm.xml", "schema.xml", true, false);
2929
loadModel("dummy-model.json");
3030
}
3131

solr/modules/llm/src/test/org/apache/solr/llm/texttovector/store/rest/TestModelManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public class TestModelManager extends TestLlmBase {
3030

3131
@BeforeClass
3232
public static void init() throws Exception {
33-
setupTest("solrconfig-llm.xml", "schema.xml", false, false, false);
33+
setupTest("solrconfig-llm.xml", "schema.xml", false, false);
3434
}
3535

3636
@Test

solr/modules/llm/src/test/org/apache/solr/llm/texttovector/store/rest/TestModelManagerPersistence.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public class TestModelManagerPersistence extends TestLlmBase {
3030

3131
@Before
3232
public void init() throws Exception {
33-
setupTest("solrconfig-llm.xml", "schema.xml", false, false, true);
33+
setupTest("solrconfig-llm.xml", "schema.xml", false, true);
3434
}
3535

3636
@After

0 commit comments

Comments
 (0)