diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 9b1578a45..c988f9baa 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -21,6 +21,9 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; import java.util.Arrays; import java.util.Locale; @@ -34,9 +37,22 @@ import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; +/** + * Utility class for KNNVectorFieldMapper + */ @NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNVectorFieldMapperUtil { + private static ModelDao modelDao; + + /** + * Initializes static instance variables + * @param modelDao ModelDao object + */ + public static void initialize(final ModelDao modelDao) { + KNNVectorFieldMapperUtil.modelDao = modelDao; + } + /** * Validate the float vector value and throw exception if it is not a number or not in the finite range * or is not within the FP16 range of [-65504 to 65504]. @@ -171,4 +187,46 @@ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataTy return vectorDataType.getVectorFromBytesRef(storedVector); } + + /** + * Get the expected dimensions from a specified knn vector field type. + * + * If the field is model-based, get dimensions from model metadata. + * @param knnVectorFieldType knn vector field type + * @return expected dimensions + */ + public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) { + int expectedDimensions = knnVectorFieldType.getDimension(); + if (isModelBasedIndex(expectedDimensions)) { + ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); + expectedDimensions = modelMetadata.getDimension(); + } + return expectedDimensions; + } + + private static boolean isModelBasedIndex(int expectedDimensions) { + return expectedDimensions == -1; + } + + /** + * Returns the model metadata for a specified knn vector field + * + * @param knnVectorField knn vector field + * @return the model metadata from knnVectorField + */ + private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { + String modelId = knnVectorField.getModelId(); + + if (modelId == null) { + throw new IllegalArgumentException( + String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName()) + ); + } + + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (!ModelUtil.isModelCreated(modelMetadata)) { + throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); + } + return modelMetadata; + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index bc17e80e7..f898b622e 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -14,6 +14,7 @@ import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.KNNClusterUtil; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -33,7 +34,6 @@ import org.opensearch.knn.plugin.rest.RestTrainModelHandler; import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; -import org.opensearch.knn.plugin.script.KNNScoringSpaceUtil; import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.knn.plugin.transport.DeleteModelAction; import org.opensearch.knn.plugin.transport.DeleteModelTransportAction; @@ -205,7 +205,7 @@ public Collection createComponents( TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); - KNNScoringSpaceUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); + KNNVectorFieldMapperUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 3ba8bce63..8105539ba 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -8,6 +8,7 @@ import org.apache.lucene.search.IndexSearcher; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil; import org.opensearch.knn.index.query.KNNWeight; import org.apache.lucene.index.LeafReaderContext; import org.opensearch.index.mapper.MappedFieldType; @@ -61,7 +62,7 @@ public L2(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), + KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); @@ -97,7 +98,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), + KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); SpaceType.COSINESIMIL.validateVector(processedQuery); @@ -192,7 +193,7 @@ public L1(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), + KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); @@ -227,7 +228,7 @@ public LInf(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), + KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); @@ -264,7 +265,7 @@ public InnerProd(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), + KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 888184e54..889780d7a 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -8,9 +8,6 @@ import java.util.List; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; -import org.opensearch.knn.indices.ModelDao; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.MappedFieldType; @@ -22,14 +19,11 @@ import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; +/** + * Utility class for KNNScoringSpace + */ public class KNNScoringSpaceUtil { - private static ModelDao modelDao; - - public static void initialize(ModelDao modelDao) { - KNNScoringSpaceUtil.modelDao = modelDao; - } - /** * Check if the passed in fieldType is of type NumberFieldType with numericType being Long * @@ -146,43 +140,4 @@ public static float getVectorMagnitudeSquared(float[] inputVector) { } return normInputVector; } - - /** - * Get the expected dimensions from a specified knn vector field type. - * - * If the field is model-based, get dimensions from model metadata. - * @param knnVectorFieldType knn vector field type - * @return expected dimensions - */ - public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) { - int expectedDimensions = knnVectorFieldType.getDimension(); - // Value will be -1 when a model-based index is used. In this case, retrieve expected dimensions from model metadata. - if (expectedDimensions == -1) { - ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); - expectedDimensions = modelMetadata.getDimension(); - } - return expectedDimensions; - } - - /** - * Returns the model metadata for a specified knn vector field - * - * @param knnVectorField knn vector field - * @return the model metadata from knnVectorField - */ - private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { - String modelId = knnVectorField.getModelId(); - - if (modelId == null) { - throw new IllegalArgumentException( - String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName()) - ); - } - - ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (!ModelUtil.isModelCreated(modelMetadata)) { - throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); - } - return modelMetadata; - } } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index 3fa9f2363..ff47dcd69 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -13,12 +13,20 @@ import org.apache.lucene.document.StoredField; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; import java.io.ByteArrayInputStream; import java.util.Arrays; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class KNNVectorFieldMapperUtilTests extends KNNTestCase { private static final String TEST_FIELD_NAME = "test_field_name"; @@ -51,4 +59,59 @@ public void testStoredFields_whenVectorIsFloatType_thenSucceed() { assertTrue(vector instanceof float[]); assertArrayEquals(TEST_FLOAT_VECTOR, (float[]) vector, 0.001f); } + + public void testGetExpectedDimensionsSuccess() { + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldType.getDimension()).thenReturn(3); + + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); + String modelId = "test-model"; + when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); + + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getDimension()).thenReturn(4); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + + KNNVectorFieldMapperUtil.initialize(modelDao); + + assertEquals(3, KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldType)); + assertEquals(4, KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)); + } + + public void testGetExpectedDimensionsFailure() { + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); + String modelId = "test-model"; + when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); + + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + + KNNVectorFieldMapperUtil.initialize(modelDao); + + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased) + ); + assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage()); + + when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null); + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + MethodComponentContext methodComponentContext = mock(MethodComponentContext.class); + String fieldName = "test-field"; + when(methodComponentContext.getName()).thenReturn(fieldName); + when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); + when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext); + + e = expectThrows( + IllegalArgumentException.class, + () -> KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased) + ); + assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage()); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index 1497e3e17..b5bc4b95f 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -6,15 +6,10 @@ package org.opensearch.knn.plugin.script; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.KNNMethodContext; -import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; -import org.opensearch.knn.indices.ModelDao; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelState; import java.math.BigInteger; import java.util.ArrayList; @@ -80,44 +75,4 @@ public void testParseKNNVectorQuery() { String invalidObject = "invalidObject"; expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT)); } - - public void testGetExpectedDimensions() { - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); - - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); - when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); - String modelId = "test-model"; - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); - - ModelDao modelDao = mock(ModelDao.class); - ModelMetadata modelMetadata = mock(ModelMetadata.class); - when(modelMetadata.getState()).thenReturn(ModelState.CREATED); - when(modelMetadata.getDimension()).thenReturn(4); - when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); - - KNNScoringSpaceUtil.initialize(modelDao); - - assertEquals(3, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldType)); - assertEquals(4, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)); - - when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); - - IllegalArgumentException e = expectThrows( - IllegalArgumentException.class, - () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased) - ); - assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage()); - - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null); - KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - MethodComponentContext methodComponentContext = mock(MethodComponentContext.class); - String fieldName = "test-field"; - when(methodComponentContext.getName()).thenReturn(fieldName); - when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); - when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext); - - e = expectThrows(IllegalArgumentException.class, () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)); - assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage()); - } }