Skip to content

Commit

Permalink
SOLR-17632: first draft implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandrobenedetti committed Jan 31, 2025
1 parent 8f51adc commit 04d3ba0
Show file tree
Hide file tree
Showing 11 changed files with 353 additions and 110 deletions.
1 change: 1 addition & 0 deletions solr/modules/llm/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies {
testImplementation project(':solr:test-framework')
testImplementation libs.junit.junit
testImplementation libs.commonsio.commonsio
testImplementation(libs.mockito.core)

}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.solr.llm.texttovector.update.processor;

import org.apache.lucene.util.ResourceLoader;
import org.apache.lucene.util.ResourceLoaderAware;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.llm.texttovector.model.SolrTextToVectorModel;
import org.apache.solr.llm.texttovector.store.rest.ManagedTextToVectorModelStore;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceObserver;
import org.apache.solr.update.AddUpdateCommand;
import org.apache.solr.update.processor.UpdateRequestProcessor;

import java.io.IOException;

class TexToVectorUpdateProcessor extends UpdateRequestProcessor implements ResourceLoaderAware, ManagedResourceObserver {
private final String inputField;
private final String outputField;
private final String model;
private SolrTextToVectorModel textToVector;
private ManagedTextToVectorModelStore modelStore = null;

@Override
public void inform(ResourceLoader loader) {
final SolrResourceLoader solrResourceLoader = (SolrResourceLoader) loader;
ManagedTextToVectorModelStore.registerManagedTextToVectorModelStore(solrResourceLoader, this);
}

@Override
public void onManagedResourceInitialized(NamedList<?> args, ManagedResource res)
throws SolrException {
if (res instanceof ManagedTextToVectorModelStore) {
modelStore = (ManagedTextToVectorModelStore) res;
}
if (modelStore != null) {
modelStore.loadStoredModels();
}
}

public TexToVectorUpdateProcessor(
String inputField,
String outputField,
String model,
SolrQueryRequest req,
UpdateRequestProcessor next) {
super(next);
this.inputField = inputField;
this.outputField = outputField;
this.model = model;
this.modelStore = ManagedTextToVectorModelStore.getManagedModelStore(req.getCore());
}

/**
* @param cmd the update command in input containing the Document to classify
* @throws IOException If there is a low-level I/O error
*/
@Override
public void processAdd(AddUpdateCommand cmd) throws IOException {
this.textToVector = modelStore.getModel(model);
if (textToVector == null) {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"The model requested '"
+ model
+ "' can't be found in the store: "
+ ManagedTextToVectorModelStore.REST_END_POINT);
}

SolrInputDocument doc = cmd.getSolrInputDocument();
String textToEmbed = doc.get(inputField).getValue().toString();//add null checks and
float[] vector = textToVector.vectorise(textToEmbed);
doc.addField(outputField, vector);
super.processAdd(cmd);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,11 @@

package org.apache.solr.llm.texttovector.update.processor;

import org.apache.lucene.util.ResourceLoader;
import org.apache.lucene.util.ResourceLoaderAware;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.llm.texttovector.model.SolrTextToVectorModel;
import org.apache.solr.llm.texttovector.store.rest.ManagedTextToVectorModelStore;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceObserver;
import org.apache.solr.schema.DenseVectorField;
import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.SchemaField;
Expand All @@ -40,34 +33,17 @@
* in input a series of parameter that will be necessary to instantiate and use the embedder
*
*/
public class TextEmbedderUpdateProcessorFactory extends UpdateRequestProcessorFactory implements ResourceLoaderAware, ManagedResourceObserver {
public class TextToVectorUpdateProcessorFactory extends UpdateRequestProcessorFactory {
private static final String INPUT_FIELD_PARAM = "inputField";
private static final String OUTPUT_FIELD_PARAM = "outputField";
private static final String EMBEDDING_MODEl_NAME_PARAM = "model";

private String inputField;
private String outputField;
private String embeddingModelName;
private ManagedTextToVectorModelStore modelStore = null;
private SolrTextToVectorModel textToVector;
private SolrParams params;
String inputField;
String outputField;
String embeddingModelName;
SolrParams params;

@Override
public void inform(ResourceLoader loader) {
final SolrResourceLoader solrResourceLoader = (SolrResourceLoader) loader;
ManagedTextToVectorModelStore.registerManagedTextToVectorModelStore(solrResourceLoader, this);
}

@Override
public void onManagedResourceInitialized(NamedList<?> args, ManagedResource res)
throws SolrException {
if (res instanceof ManagedTextToVectorModelStore) {
modelStore = (ManagedTextToVectorModelStore) res;
}
if (modelStore != null) {
modelStore.loadStoredModels();
}
}


@Override
public void init(final NamedList<?> args) {
Expand All @@ -79,28 +55,16 @@ public void init(final NamedList<?> args) {
outputField = params.get(OUTPUT_FIELD_PARAM);
checkNotNull(OUTPUT_FIELD_PARAM, outputField);



embeddingModelName = params.get(EMBEDDING_MODEl_NAME_PARAM);
checkNotNull(EMBEDDING_MODEl_NAME_PARAM, embeddingModelName);

textToVector = modelStore.getModel(embeddingModelName);
if (textToVector == null) {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"The model requested '"
+ embeddingModelName
+ "' can't be found in the store: "
+ ManagedTextToVectorModelStore.REST_END_POINT);
}
}
}

private void checkNotNull(String paramName, Object param) {
if (param == null) {
throw new SolrException(
SolrException.ErrorCode.SERVER_ERROR,
"Text Embedder UpdateProcessor '" + paramName + "' can not be null");
"Text to Vector UpdateProcessor '" + paramName + "' can not be null");
}
}

Expand All @@ -111,7 +75,7 @@ public UpdateRequestProcessor getInstance(
final SchemaField outputFieldSchema = req.getCore().getLatestSchema().getField(outputField);
assertIsDenseVectorField(outputFieldSchema);

return new TextEmbedderUpdateProcessor(inputField, outputField, textToVector, next);
return new TexToVectorUpdateProcessor(inputField, outputField, embeddingModelName, req, next);
}

protected void assertIsDenseVectorField(SchemaField schemaField) {
Expand All @@ -131,7 +95,7 @@ public String getOutputField() {
return outputField;
}

public SolrTextToVectorModel getTextToVector() {
return textToVector;
public String getEmbeddingModelName() {
return embeddingModelName;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
<?xml version="1.0" ?>
<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
license agreements. See the NOTICE file distributed with this work for additional
information regarding copyright ownership. The ASF licenses this file to
You under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of
the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
by applicable law or agreed to in writing, software distributed under the
License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
OF ANY KIND, either express or implied. See the License for the specific
language governing permissions and limitations under the License. -->

<config>
<luceneMatchVersion>${tests.luceneMatchVersion:LATEST}</luceneMatchVersion>
<dataDir>${solr.data.dir:}</dataDir>
<directoryFactory name="DirectoryFactory"
class="${solr.directoryFactory:solr.MockDirectoryFactory}" />
<schemaFactory class="ClassicIndexSchemaFactory" />

<requestDispatcher>
<requestParsers />
</requestDispatcher>

<!-- Query parser used to run neural queries-->
<queryParser name="knn_text_to_vector"
class="org.apache.solr.llm.texttovector.search.TextToVectorQParserPlugin" />

<query>
<filterCache class="solr.CaffeineCache" size="4096"
initialSize="2048" autowarmCount="0" />
</query>
<requestHandler name="/select" class="solr.SearchHandler" />

<updateHandler class="solr.DirectUpdateHandler2">
<autoCommit>
<maxTime>15000</maxTime>
<openSearcher>false</openSearcher>
</autoCommit>
<autoSoftCommit>
<maxTime>1000</maxTime>
</autoSoftCommit>
<updateLog>
<str name="dir">${solr.data.dir:}</str>
</updateLog>
</updateHandler>

<!-- Query request handler managing models and features -->
<requestHandler name="/query" class="solr.SearchHandler">
<lst name="defaults">
<str name="echoParams">explicit</str>
<str name="wt">json</str>
<str name="indent">true</str>
<str name="df">id</str>
</lst>
</requestHandler>

<initParams path="/update/**">
<lst name="defaults">
<str name="update.chain">textToVector</str>
</lst>
</initParams>

<updateRequestProcessorChain name="textToVector">
<processor class="solr.llm.texttovector.update.processor.TextToVectorUpdateProcessorFactory">
<str name="inputField">_text_</str>
<str name="outputField">vector</str>
<str name="model">dummy-1</str>
</processor>
<processor class="solr.RunUpdateProcessorFactory"/>
</updateRequestProcessorChain>

</config>
Loading

0 comments on commit 04d3ba0

Please sign in to comment.