Skip to content

Commit

Permalink
SOLR-17632: empty and null checks
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandrobenedetti committed Jan 31, 2025
1 parent bdffe76 commit e983e04
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 121 deletions.
2 changes: 0 additions & 2 deletions solr/modules/llm/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ dependencies {
testImplementation project(':solr:test-framework')
testImplementation libs.junit.junit
testImplementation libs.commonsio.commonsio
testImplementation(libs.mockito.core)

}

// langchain4j has reflection issues, and requires the following permissions
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.solr.llm.texttovector.update.processor;

import org.apache.solr.common.SolrException;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.SolrInputField;
import org.apache.solr.llm.texttovector.model.SolrTextToVectorModel;
import org.apache.solr.llm.texttovector.store.rest.ManagedTextToVectorModelStore;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.update.AddUpdateCommand;
import org.apache.solr.update.processor.UpdateRequestProcessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.List;


class TextToVectorUpdateProcessor extends UpdateRequestProcessor {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private final String inputField;
private final String outputField;
private final String model;
private SolrTextToVectorModel textToVector;
private ManagedTextToVectorModelStore modelStore = null;

public TextToVectorUpdateProcessor(
String inputField,
String outputField,
String model,
SolrQueryRequest req,
UpdateRequestProcessor next) {
super(next);
this.inputField = inputField;
this.outputField = outputField;
this.model = model;
this.modelStore = ManagedTextToVectorModelStore.getManagedModelStore(req.getCore());
}

/**
* @param cmd the update command in input containing the Document to process
* @throws IOException If there is a low-level I/O error
*/
@Override
public void processAdd(AddUpdateCommand cmd) throws IOException {
this.textToVector = modelStore.getModel(model);
if (textToVector == null) {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"The model requested '"
+ model
+ "' can't be found in the store: "
+ ManagedTextToVectorModelStore.REST_END_POINT);
}

SolrInputDocument doc = cmd.getSolrInputDocument();
SolrInputField inputFieldContent = doc.get(inputField);
if (!isNullOrEmpty(inputFieldContent, doc, inputField)) {
String textToVectorise = inputFieldContent.getValue().toString();//add null checks and
float[] vector = textToVector.vectorise(textToVectorise);
List<Float> vectorAsList = new ArrayList<Float>(vector.length);
for (float f : vector) {
vectorAsList.add(f);
}
doc.addField(outputField, vectorAsList);
}
super.processAdd(cmd);
}

protected boolean isNullOrEmpty(SolrInputField inputFieldContent, SolrInputDocument doc, String fieldName) {
if (inputFieldContent == null || inputFieldContent.getValue() == null) {
log.warn("the input field: " + fieldName + " is missing for the document: " + doc.toString());
return true;
} else if (inputFieldContent.getValue().toString().isEmpty()) {
log.warn("the input field: " + fieldName + " is empty (string instance of zero length) for the document: " + doc.toString());
return true;
} else {
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@
import org.apache.solr.update.processor.UpdateRequestProcessorFactory;

/**
* This class implements an UpdateProcessorFactory for the Text Embedder Update Processor. It takes
* in input a series of parameter that will be necessary to instantiate and use the embedder
* This class implements an UpdateProcessorFactory for the Text To Vector Update Processor.
*/
public class TextToVectorUpdateProcessorFactory extends UpdateRequestProcessorFactory {
private static final String INPUT_FIELD_PARAM = "inputField";
private static final String OUTPUT_FIELD_PARAM = "outputField";
private static final String EMBEDDING_MODEl_NAME_PARAM = "model";
private static final String MODEL_NAME = "model";

String inputField;
String outputField;
String embeddingModelName;
String modelName;
SolrParams params;


Expand All @@ -53,8 +52,8 @@ public void init(final NamedList<?> args) {
outputField = params.get(OUTPUT_FIELD_PARAM);
checkNotNull(OUTPUT_FIELD_PARAM, outputField);

embeddingModelName = params.get(EMBEDDING_MODEl_NAME_PARAM);
checkNotNull(EMBEDDING_MODEl_NAME_PARAM, embeddingModelName);
modelName = params.get(MODEL_NAME);
checkNotNull(MODEL_NAME, modelName);
}
}

Expand All @@ -72,7 +71,7 @@ public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryRespons
final SchemaField outputFieldSchema = req.getCore().getLatestSchema().getField(outputField);
assertIsDenseVectorField(outputFieldSchema);

return new TexToVectorUpdateProcessor(inputField, outputField, embeddingModelName, req, next);
return new TextToVectorUpdateProcessor(inputField, outputField, modelName, req, next);
}

protected void assertIsDenseVectorField(SchemaField schemaField) {
Expand All @@ -92,7 +91,7 @@ public String getOutputField() {
return outputField;
}

public String getEmbeddingModelName() {
return embeddingModelName;
public String getModelName() {
return modelName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
import org.apache.solr.common.util.NamedList;
import org.apache.solr.llm.TestLlmBase;
import org.apache.solr.request.SolrQueryRequestBase;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;

Expand Down Expand Up @@ -56,7 +54,7 @@ public void init_fullArgs_shouldInitFullClassificationParams() {

assertEquals("_text_", factoryToTest.getInputField());
assertEquals("vector", factoryToTest.getOutputField());
assertEquals("model1", factoryToTest.getEmbeddingModelName());
assertEquals("model1", factoryToTest.getModelName());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,14 @@
*/
package org.apache.solr.llm.texttovector.update.processor;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.store.Directory;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.llm.TestLlmBase;
import org.apache.solr.llm.texttovector.store.rest.ManagedTextToVectorModelStore;
import org.apache.solr.update.AddUpdateCommand;
import org.junit.BeforeClass;
import org.junit.Test;

import static org.hamcrest.core.Is.is;

public class TextToVectorUpdateProcessorTest extends TestLlmBase {
/* field names are used in accordance with the solrconfig and schema supplied */
private static final String ID = "id";
private static final String TITLE = "title";
private static final String CONTENT = "content";
private static final String AUTHOR = "author";
private static final String TRAINING_CLASS = "cat";
private static final String PREDICTED_CLASS = "predicted";

protected Directory directory;
protected IndexReader reader;
protected IndexSearcher searcher;
private TexToVectorUpdateProcessor updateProcessorToTest;

@BeforeClass
public static void init() throws Exception {
Expand Down Expand Up @@ -85,5 +67,51 @@ public void processAdd_modelNotFound_shouldRaiseException() {
"/response/lst[@name='error']/int[@name='code']='400'");
}

@Test
public void processAdd_emptyInputField_shouldLogAndIndexWithNoVector() throws Exception {
loadModel("dummy-model.json");
assertU(adoc("id", "99", "_text_", ""));
assertU(adoc("id", "98", "_text_", "Vegeta is the saiyan prince."));
assertU(commit());

final String solrQuery = "*:*";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id,vector");

assertJQ(
"/query" + query.toQueryString(),
"/response/numFound==2]",
"/response/docs/[0]/id=='99'",
"!/response/docs/[0]/vector==", //no vector field for the document 99
"/response/docs/[1]/id=='98'",
"/response/docs/[1]/vector==[1.0, 2.0, 3.0, 4.0]");

restTestHarness.delete(ManagedTextToVectorModelStore.REST_END_POINT + "/dummy-1");
}

@Test
public void processAdd_nullInputField_shouldLogAndIndexWithNoVector() throws Exception {
loadModel("dummy-model.json");
assertU(adoc("id", "99", "_text_", "Vegeta is the saiyan prince."));
assertU(adoc("id", "98"));
assertU(commit());

final String solrQuery = "*:*";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id,vector");

assertJQ(
"/query" + query.toQueryString(),
"/response/numFound==2]",
"/response/docs/[0]/id=='99'",
"/response/docs/[0]/vector==[1.0, 2.0, 3.0, 4.0]",
"/response/docs/[1]/id=='98'",
"!/response/docs/[1]/vector==");//no vector field for the document 98

restTestHarness.delete(ManagedTextToVectorModelStore.REST_END_POINT + "/dummy-1");
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ public static void checkUpdateU(String update, String... tests) {
throw new RuntimeException("Exception during query", e2);
}
}


//String results = TestHarness.validateXPath(response, tests);
/**
* Validates a query matches some XPath test expressions
*
Expand Down

0 comments on commit e983e04

Please sign in to comment.