diff --git a/solr/modules/llm/src/java/org/apache/solr/llm/texttovector/update/processor/TexToVectorUpdateProcessor.java b/solr/modules/llm/src/java/org/apache/solr/llm/texttovector/update/processor/TexToVectorUpdateProcessor.java index 3db9573d992..bd3db2b81e0 100644 --- a/solr/modules/llm/src/java/org/apache/solr/llm/texttovector/update/processor/TexToVectorUpdateProcessor.java +++ b/solr/modules/llm/src/java/org/apache/solr/llm/texttovector/update/processor/TexToVectorUpdateProcessor.java @@ -32,30 +32,16 @@ import org.apache.solr.update.processor.UpdateRequestProcessor; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; -class TexToVectorUpdateProcessor extends UpdateRequestProcessor implements ResourceLoaderAware, ManagedResourceObserver { + +class TexToVectorUpdateProcessor extends UpdateRequestProcessor{ private final String inputField; private final String outputField; private final String model; private SolrTextToVectorModel textToVector; private ManagedTextToVectorModelStore modelStore = null; - - @Override - public void inform(ResourceLoader loader) { - final SolrResourceLoader solrResourceLoader = (SolrResourceLoader) loader; - ManagedTextToVectorModelStore.registerManagedTextToVectorModelStore(solrResourceLoader, this); - } - - @Override - public void onManagedResourceInitialized(NamedList args, ManagedResource res) - throws SolrException { - if (res instanceof ManagedTextToVectorModelStore) { - modelStore = (ManagedTextToVectorModelStore) res; - } - if (modelStore != null) { - modelStore.loadStoredModels(); - } - } public TexToVectorUpdateProcessor( String inputField, @@ -88,8 +74,14 @@ public void processAdd(AddUpdateCommand cmd) throws IOException { SolrInputDocument doc = cmd.getSolrInputDocument(); String textToEmbed = doc.get(inputField).getValue().toString();//add null checks and + float[] vector = textToVector.vectorise(textToEmbed); - doc.addField(outputField, vector); + List vectorAsList = new ArrayList(vector.length); + for (float f : vector) { + vectorAsList.add(f); + } + + doc.addField(outputField, vectorAsList); super.processAdd(cmd); } } diff --git a/solr/modules/llm/src/java/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorFactory.java b/solr/modules/llm/src/java/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorFactory.java index cf92be54a0e..6d87ba68241 100644 --- a/solr/modules/llm/src/java/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorFactory.java +++ b/solr/modules/llm/src/java/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorFactory.java @@ -31,59 +31,56 @@ /** * This class implements an UpdateProcessorFactory for the Text Embedder Update Processor. It takes * in input a series of parameter that will be necessary to instantiate and use the embedder - * */ public class TextToVectorUpdateProcessorFactory extends UpdateRequestProcessorFactory { private static final String INPUT_FIELD_PARAM = "inputField"; - private static final String OUTPUT_FIELD_PARAM = "outputField"; - private static final String EMBEDDING_MODEl_NAME_PARAM = "model"; + private static final String OUTPUT_FIELD_PARAM = "outputField"; + private static final String EMBEDDING_MODEl_NAME_PARAM = "model"; + + String inputField; + String outputField; + String embeddingModelName; + SolrParams params; - String inputField; - String outputField; - String embeddingModelName; - SolrParams params; - - - @Override - public void init(final NamedList args) { - if (args != null) { - params = args.toSolrParams(); - inputField = params.get(INPUT_FIELD_PARAM); - checkNotNull(INPUT_FIELD_PARAM, inputField); + @Override + public void init(final NamedList args) { + if (args != null) { + params = args.toSolrParams(); + inputField = params.get(INPUT_FIELD_PARAM); + checkNotNull(INPUT_FIELD_PARAM, inputField); - outputField = params.get(OUTPUT_FIELD_PARAM); - checkNotNull(OUTPUT_FIELD_PARAM, outputField); - - embeddingModelName = params.get(EMBEDDING_MODEl_NAME_PARAM); - checkNotNull(EMBEDDING_MODEl_NAME_PARAM, embeddingModelName); - } - } + outputField = params.get(OUTPUT_FIELD_PARAM); + checkNotNull(OUTPUT_FIELD_PARAM, outputField); - private void checkNotNull(String paramName, Object param) { - if (param == null) { - throw new SolrException( - SolrException.ErrorCode.SERVER_ERROR, - "Text to Vector UpdateProcessor '" + paramName + "' can not be null"); + embeddingModelName = params.get(EMBEDDING_MODEl_NAME_PARAM); + checkNotNull(EMBEDDING_MODEl_NAME_PARAM, embeddingModelName); + } } - } - @Override - public UpdateRequestProcessor getInstance( - SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) { + private void checkNotNull(String paramName, Object param) { + if (param == null) { + throw new SolrException( + SolrException.ErrorCode.SERVER_ERROR, + "Text to Vector UpdateProcessor '" + paramName + "' can not be null"); + } + } - final SchemaField outputFieldSchema = req.getCore().getLatestSchema().getField(outputField); - assertIsDenseVectorField(outputFieldSchema); + @Override + public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) { + req.getCore().getLatestSchema().getField(inputField); + final SchemaField outputFieldSchema = req.getCore().getLatestSchema().getField(outputField); + assertIsDenseVectorField(outputFieldSchema); - return new TexToVectorUpdateProcessor(inputField, outputField, embeddingModelName, req, next); - } + return new TexToVectorUpdateProcessor(inputField, outputField, embeddingModelName, req, next); + } protected void assertIsDenseVectorField(SchemaField schemaField) { FieldType fieldType = schemaField.getType(); if (!(fieldType instanceof DenseVectorField)) { throw new SolrException( - SolrException.ErrorCode.BAD_REQUEST, - "only DenseVectorField is compatible with Vector Query Parsers"); + SolrException.ErrorCode.SERVER_ERROR, + "only DenseVectorField is compatible with Vector Query Parsers: " + schemaField.getName()); } } diff --git a/solr/modules/llm/src/test-files/solr/collection1/conf/solrconfig-llm-indexing-notDenseVectorField.xml b/solr/modules/llm/src/test-files/solr/collection1/conf/solrconfig-llm-indexing-notDenseVectorField.xml new file mode 100644 index 00000000000..11ff24ce627 --- /dev/null +++ b/solr/modules/llm/src/test-files/solr/collection1/conf/solrconfig-llm-indexing-notDenseVectorField.xml @@ -0,0 +1,72 @@ + + + + + ${tests.luceneMatchVersion:LATEST} + ${solr.data.dir:} + + + + + + + + + + + + + + + + + + 15000 + false + + + 1000 + + + ${solr.data.dir:} + + + + + + + explicit + json + true + id + + + + + + textToVector + + + + + + _text_ + vector + dummy-1 + + + + + diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/TestLlmBase.java b/solr/modules/llm/src/test/org/apache/solr/llm/TestLlmBase.java index c764b2c76e5..6fb391ad364 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/TestLlmBase.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/TestLlmBase.java @@ -45,19 +45,17 @@ public class TestLlmBase extends RestTestBase { protected static Path embeddingModelStoreFile = null; protected static String IDField = "id"; - protected static String textField = "_text_"; protected static String vectorField = "vector"; protected static String vectorField2 = "vector2"; protected static String vectorFieldByteEncoding = "vector_byte_encoding"; protected static void setupTest( - String solrconfig, String schema, boolean buildIndexWithVectors, boolean buildIndexWithText, boolean persistModelStore) + String solrconfig, String schema, boolean buildIndex, boolean persistModelStore) throws Exception { initFolders(persistModelStore); createJettyAndHarness( tmpSolrHome.toAbsolutePath().toString(), solrconfig, schema, "/solr", true, null); - if (buildIndexWithVectors) prepareIndex(false); - if (buildIndexWithText) prepareIndex(true); + if (buildIndex) prepareIndex(); } protected static void initFolders(boolean isPersistent) throws Exception { @@ -113,17 +111,12 @@ public static void loadModel(String fileName) throws Exception { ManagedTextToVectorModelStore.REST_END_POINT, multipleModels, "/responseHeader/status==0"); } - protected static void prepareIndex(boolean textual) throws Exception { - List docsToIndex; - if(textual){ - docsToIndex = prepareTextualDocs(); - } else { - docsToIndex = prepareTextualDocs(); - } - + protected static void prepareIndex() throws Exception { + List docsToIndex = prepareDocs(); for (SolrInputDocument doc : docsToIndex) { assertU(adoc(doc)); } + assertU(commit()); } @@ -184,26 +177,4 @@ private static List prepareDocs() { return docs; } - - private static List prepareTextualDocs() { - int docsCount = 13; - List docs = new ArrayList<>(docsCount); - for (int i = 1; i < docsCount + 1; i++) { - SolrInputDocument doc = new SolrInputDocument(); - doc.addField(IDField, i); - docs.add(doc); - } - - docs.get(0) - .addField(textField, "Vegeta is the prince of all saiyans."); // cosine distance vector1= 1.0 - docs.get(1) - .addField( - textField, "Goku, also known as Kakaroth is a saiyan grown up on Earth."); // cosine distance vector1= 0.998 - docs.get(2) - .addField( - textField, - Arrays.asList("Gohan is a Saiya-man hybrid.")); // cosine distance vector1= 0.992 - - return docs; - } } diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/search/TextToVectorQParserTest.java b/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/search/TextToVectorQParserTest.java index f9ad5121074..5c406f08217 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/search/TextToVectorQParserTest.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/search/TextToVectorQParserTest.java @@ -25,7 +25,7 @@ public class TextToVectorQParserTest extends TestLlmBase { @BeforeClass public static void init() throws Exception { - setupTest("solrconfig-llm.xml", "schema.xml", true, false, false); + setupTest("solrconfig-llm.xml", "schema.xml", true, false); loadModel("dummy-model.json"); } diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/store/rest/TestModelManager.java b/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/store/rest/TestModelManager.java index bc4cd753a8a..37d40b3f6c4 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/store/rest/TestModelManager.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/store/rest/TestModelManager.java @@ -30,7 +30,7 @@ public class TestModelManager extends TestLlmBase { @BeforeClass public static void init() throws Exception { - setupTest("solrconfig-llm.xml", "schema.xml", false, false, false); + setupTest("solrconfig-llm.xml", "schema.xml", false, false); } @Test diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/store/rest/TestModelManagerPersistence.java b/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/store/rest/TestModelManagerPersistence.java index 390e9fe097b..798e2f091b6 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/store/rest/TestModelManagerPersistence.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/store/rest/TestModelManagerPersistence.java @@ -30,7 +30,7 @@ public class TestModelManagerPersistence extends TestLlmBase { @Before public void init() throws Exception { - setupTest("solrconfig-llm.xml", "schema.xml", false, false, true); + setupTest("solrconfig-llm.xml", "schema.xml", false, true); } @After diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorFactoryTest.java b/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorFactoryTest.java index 18e9feebde7..7d048c8dcba 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorFactoryTest.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorFactoryTest.java @@ -16,51 +16,114 @@ */ package org.apache.solr.llm.texttovector.update.processor; -import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.common.SolrException; +import org.apache.solr.common.params.MultiMapSolrParams; +import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.NamedList; +import org.apache.solr.llm.TestLlmBase; +import org.apache.solr.request.SolrQueryRequestBase; +import org.junit.After; +import org.junit.AfterClass; import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; +import java.util.HashMap; +import java.util.Map; -public class TextToVectorUpdateProcessorFactoryTest extends SolrTestCaseJ4 { + +public class TextToVectorUpdateProcessorFactoryTest extends TestLlmBase { private TextToVectorUpdateProcessorFactory factoryToTest = new TextToVectorUpdateProcessorFactory(); private NamedList args = new NamedList<>(); + + @BeforeClass + public static void initArgs() throws Exception { + setupTest("solrconfig-llm.xml", "schema.xml", false, false); + } - @Before - public void initArgs() { - args.add("inputField", "inputField1"); - args.add("outputField", "outputField1"); - args.add("model", "model1"); + @AfterClass + public static void after() throws Exception { + afterTest(); } @Test public void init_fullArgs_shouldInitFullClassificationParams() { + args.add("inputField", "_text_"); + args.add("outputField", "vector"); + args.add("model", "model1"); factoryToTest.init(args); - assertEquals("inputField1", factoryToTest.getInputField()); - assertEquals("outputField1", factoryToTest.getOutputField()); + assertEquals("_text_", factoryToTest.getInputField()); + assertEquals("vector", factoryToTest.getOutputField()); assertEquals("model1", factoryToTest.getEmbeddingModelName()); } @Test - public void init_nullyInputFields_shouldThrowExceptionWithDetailedMessage() { - args.removeAll("inputField"); + public void init_nullInputField_shouldThrowExceptionWithDetailedMessage() { + args.add("outputField", "vector"); + args.add("model", "model1"); + SolrException e = assertThrows(SolrException.class, () -> factoryToTest.init(args)); assertEquals("Text to Vector UpdateProcessor 'inputField' can not be null", e.getMessage()); } + @Test + public void init_notExistentInputField_shouldThrowExceptionWithDetailedMessage() throws Exception { + args.add("inputField", "notExistentInput"); + args.add("outputField", "vector"); + args.add("model", "model1"); + + Map params = new HashMap<>(); + MultiMapSolrParams mmparams = new MultiMapSolrParams(params); + SolrQueryRequestBase req = new SolrQueryRequestBase(solrClientTestRule.getCoreContainer().getCore("collection1"), (SolrParams) mmparams) {}; + factoryToTest.init(args); + SolrException e = assertThrows(SolrException.class, () -> factoryToTest.getInstance(req,null,null)); + assertEquals("undefined field: \"notExistentInput\"", e.getMessage()); + } + @Test public void init_nullOutputField_shouldThrowExceptionWithDetailedMessage() { - args.removeAll("outputField"); + args.add("inputField", "_text_"); + args.add("model", "model1"); + SolrException e = assertThrows(SolrException.class, () -> factoryToTest.init(args)); assertEquals("Text to Vector UpdateProcessor 'outputField' can not be null", e.getMessage()); } + @Test + public void init_notExistentOutputField_shouldThrowExceptionWithDetailedMessage() throws Exception { + args.add("inputField", "_text_"); + args.add("outputField", "notExistentOutput"); + args.add("model", "model1"); + + Map params = new HashMap<>(); + MultiMapSolrParams mmparams = new MultiMapSolrParams(params); + SolrQueryRequestBase req = new SolrQueryRequestBase(solrClientTestRule.getCoreContainer().getCore("collection1"), (SolrParams) mmparams) {}; + factoryToTest.init(args); + SolrException e = assertThrows(SolrException.class, () -> factoryToTest.getInstance(req,null,null)); + assertEquals("undefined field: \"notExistentOutput\"", e.getMessage()); + } + + @Test + public void init_notDenseVectorOutputField_shouldThrowExceptionWithDetailedMessage() throws Exception { + args.add("inputField", "_text_"); + args.add("outputField", "_text_"); + args.add("model", "model1"); + + Map params = new HashMap<>(); + MultiMapSolrParams mmparams = new MultiMapSolrParams(params); + SolrQueryRequestBase req = new SolrQueryRequestBase(solrClientTestRule.getCoreContainer().getCore("collection1"), (SolrParams) mmparams) {}; + factoryToTest.init(args); + SolrException e = assertThrows(SolrException.class, () -> factoryToTest.getInstance(req,null,null)); + assertEquals("only DenseVectorField is compatible with Vector Query Parsers: _text_", e.getMessage()); + } + @Test public void init_nullModel_shouldThrowExceptionWithDetailedMessage() { - args.removeAll("model"); + args.add("inputField", "_text_"); + args.add("outputField", "vector"); + SolrException e = assertThrows(SolrException.class, () -> factoryToTest.init(args)); assertEquals("Text to Vector UpdateProcessor 'model' can not be null", e.getMessage()); } diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorTest.java b/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorTest.java index 751134fdbcc..e4ffd1d24e5 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorTest.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/texttovector/update/processor/TextToVectorUpdateProcessorTest.java @@ -22,49 +22,68 @@ import org.apache.solr.client.solrj.SolrQuery; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.llm.TestLlmBase; +import org.apache.solr.llm.texttovector.store.rest.ManagedTextToVectorModelStore; import org.apache.solr.update.AddUpdateCommand; import org.junit.BeforeClass; import org.junit.Test; import static org.hamcrest.core.Is.is; -public class TextToVectorUpdateProcessorTest extends TestLlmBase { - /* field names are used in accordance with the solrconfig and schema supplied */ - private static final String ID = "id"; - private static final String TITLE = "title"; - private static final String CONTENT = "content"; - private static final String AUTHOR = "author"; - private static final String TRAINING_CLASS = "cat"; - private static final String PREDICTED_CLASS = "predicted"; +public class TextToVectorUpdateProcessorTest extends TestLlmBase { + /* field names are used in accordance with the solrconfig and schema supplied */ + private static final String ID = "id"; + private static final String TITLE = "title"; + private static final String CONTENT = "content"; + private static final String AUTHOR = "author"; + private static final String TRAINING_CLASS = "cat"; + private static final String PREDICTED_CLASS = "predicted"; - protected Directory directory; - protected IndexReader reader; - protected IndexSearcher searcher; - private TexToVectorUpdateProcessor updateProcessorToTest; + protected Directory directory; + protected IndexReader reader; + protected IndexSearcher searcher; + private TexToVectorUpdateProcessor updateProcessorToTest; @BeforeClass public static void init() throws Exception { - setupTest("solrconfig-llm-indexing.xml", "schema.xml", false, false, false); + setupTest("solrconfig-llm-indexing.xml", "schema.xml", false, false); + + } + + @Test + public void processAdd_inputField_shouldVectoriseInputField() + throws Exception { loadModel("dummy-model.json"); + assertU(adoc("id", "99", "_text_", "Vegeta is the saiyan prince.")); + assertU(adoc("id", "98", "_text_", "Vegeta is the saiyan prince.")); + assertU(commit()); + + final String solrQuery = "*:*"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.add("fl", "id,vector"); + + assertJQ( + "/query" + query.toQueryString(), + "/response/numFound==2]", + "/response/docs/[0]/id=='99'", + "/response/docs/[0]/vector==[1.0, 2.0, 3.0, 4.0]", + "/response/docs/[1]/id=='98'", + "/response/docs/[1]/vector==[1.0, 2.0, 3.0, 4.0]"); + + restTestHarness.delete(ManagedTextToVectorModelStore.REST_END_POINT + "/dummy-1"); } - @Test - public void - classificationMonoClass_predictedClassFieldSet_shouldAssignClassInPredictedClassField() - throws Exception { - assertU(adoc("id", "99", "_text_", "Vegeta is the saiyan prince.")); - assertU(commit()); + /* + This test looks for the 'dummy-1' model, but such model is not loaded, the model store is empty, so the update fails + */ + @Test + public void processAdd_modelNotFound_shouldRaiseException() { + assertFailedU("This update should fail but actually succeeded", adoc("id", "99", "_text_", "Vegeta is the saiyan prince.")); - final String solrQuery = "*:*"; - final SolrQuery query = new SolrQuery(); - query.setQuery(solrQuery); - query.add("fl", "id"); + checkUpdateU(adoc("id", "99", "_text_", "Vegeta is the saiyan prince."), + "/response/lst[@name='error']/str[@name='msg']=\"The model requested 'dummy-1' can't be found in the store: /schema/text-to-vector-model-store\"", + "/response/lst[@name='error']/int[@name='code']='400'"); + } - assertJQ( - "/query" + query.toQueryString(), - "/response/numFound==1]", - "/response/docs/[0]/id=='99'"); - } - } diff --git a/solr/test-framework/src/java/org/apache/solr/util/RestTestBase.java b/solr/test-framework/src/java/org/apache/solr/util/RestTestBase.java index ce066c4e1eb..fdbb964892a 100644 --- a/solr/test-framework/src/java/org/apache/solr/util/RestTestBase.java +++ b/solr/test-framework/src/java/org/apache/solr/util/RestTestBase.java @@ -88,13 +88,35 @@ private static void checkUpdateU(String message, String update, boolean shouldSu if (response != null) fail(m + "update was not successful: " + response); } else { String response = restTestHarness.validateErrorUpdate(update); - if (response != null) fail(m + "update succeeded, but should have failed: " + response); + if (response == null) fail(m + "update succeeded, but should have failed: " + response); } } catch (SAXException e) { throw new RuntimeException("Invalid XML", e); } } + public static void checkUpdateU(String update, String... tests) { + try { + String response = restTestHarness.validateUpdate(update); + String results = TestHarness.validateXPath(response, tests); + if (null != results) { + log.error( + "REQUEST FAILED: xpath={}\n\txml response was: {}\n\trequest was:{}", + results, + response, + update); + fail(results); + } + } catch (XPathExpressionException e1) { + throw new RuntimeException("XPath is invalid", e1); + } catch (Exception e2) { + log.error("REQUEST FAILED: {}", update, e2); + throw new RuntimeException("Exception during query", e2); + } + } + + + //String results = TestHarness.validateXPath(response, tests); /** * Validates a query matches some XPath test expressions *