Skip to content

Commit

Permalink
Refactor into KNNVectorFieldMapperUtil and split test into two tests
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan committed Apr 23, 2024
1 parent 6806555 commit b4971f1
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

import java.util.Arrays;
import java.util.Locale;
Expand All @@ -34,9 +37,22 @@
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;

/**
* Utility class for KNNVectorFieldMapper
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNVectorFieldMapperUtil {

private static ModelDao modelDao;

/**
* Initializes static instance variables
* @param modelDao ModelDao object
*/
public static void initialize(final ModelDao modelDao) {
KNNVectorFieldMapperUtil.modelDao = modelDao;
}

/**
* Validate the float vector value and throw exception if it is not a number or not in the finite range
* or is not within the FP16 range of [-65504 to 65504].
Expand Down Expand Up @@ -171,4 +187,46 @@ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataTy

return vectorDataType.getVectorFromBytesRef(storedVector);
}

/**
* Get the expected dimensions from a specified knn vector field type.
*
* If the field is model-based, get dimensions from model metadata.
* @param knnVectorFieldType knn vector field type
* @return expected dimensions
*/
public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) {
int expectedDimensions = knnVectorFieldType.getDimension();
if (isModelBasedIndex(expectedDimensions)) {
ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType);
expectedDimensions = modelMetadata.getDimension();
}
return expectedDimensions;
}

private static boolean isModelBasedIndex(int expectedDimensions) {
return expectedDimensions == -1;
}

/**
* Returns the model metadata for a specified knn vector field
*
* @param knnVectorField knn vector field
* @return the model metadata from knnVectorField
*/
private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
String modelId = knnVectorField.getModelId();

if (modelId == null) {
throw new IllegalArgumentException(
String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName())
);
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
}
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
Expand All @@ -33,7 +34,6 @@
import org.opensearch.knn.plugin.rest.RestTrainModelHandler;
import org.opensearch.knn.plugin.rest.RestClearCacheHandler;
import org.opensearch.knn.plugin.script.KNNScoringScriptEngine;
import org.opensearch.knn.plugin.script.KNNScoringSpaceUtil;
import org.opensearch.knn.plugin.stats.KNNStats;
import org.opensearch.knn.plugin.transport.DeleteModelAction;
import org.opensearch.knn.plugin.transport.DeleteModelTransportAction;
Expand Down Expand Up @@ -205,7 +205,7 @@ public Collection<Object> createComponents(
TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNScoringSpaceUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNVectorFieldMapperUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.query.KNNWeight;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.index.mapper.MappedFieldType;
Expand Down Expand Up @@ -61,7 +62,7 @@ public L2(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v));
Expand Down Expand Up @@ -97,7 +98,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
SpaceType.COSINESIMIL.validateVector(processedQuery);
Expand Down Expand Up @@ -192,7 +193,7 @@ public L1(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v));
Expand Down Expand Up @@ -227,7 +228,7 @@ public LInf(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v));
Expand Down Expand Up @@ -264,7 +265,7 @@ public InnerProd(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import java.util.List;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.index.mapper.BinaryFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -22,14 +19,11 @@
import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;

/**
* Utility class for KNNScoringSpace
*/
public class KNNScoringSpaceUtil {

private static ModelDao modelDao;

public static void initialize(ModelDao modelDao) {
KNNScoringSpaceUtil.modelDao = modelDao;
}

/**
* Check if the passed in fieldType is of type NumberFieldType with numericType being Long
*
Expand Down Expand Up @@ -146,43 +140,4 @@ public static float getVectorMagnitudeSquared(float[] inputVector) {
}
return normInputVector;
}

/**
* Get the expected dimensions from a specified knn vector field type.
*
* If the field is model-based, get dimensions from model metadata.
* @param knnVectorFieldType knn vector field type
* @return expected dimensions
*/
public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) {
int expectedDimensions = knnVectorFieldType.getDimension();
// Value will be -1 when a model-based index is used. In this case, retrieve expected dimensions from model metadata.
if (expectedDimensions == -1) {
ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType);
expectedDimensions = modelMetadata.getDimension();
}
return expectedDimensions;
}

/**
* Returns the model metadata for a specified knn vector field
*
* @param knnVectorField knn vector field
* @return the model metadata from knnVectorField
*/
private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
String modelId = knnVectorField.getModelId();

if (modelId == null) {
throw new IllegalArgumentException(
String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName())
);
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@

import org.apache.lucene.document.StoredField;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;

import java.io.ByteArrayInputStream;
import java.util.Arrays;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class KNNVectorFieldMapperUtilTests extends KNNTestCase {

private static final String TEST_FIELD_NAME = "test_field_name";
Expand Down Expand Up @@ -51,4 +59,59 @@ public void testStoredFields_whenVectorIsFloatType_thenSucceed() {
assertTrue(vector instanceof float[]);
assertArrayEquals(TEST_FLOAT_VECTOR, (float[]) vector, 0.001f);
}

public void testGetExpectedDimensionsSuccess() {
KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldType.getDimension()).thenReturn(3);

KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1);
String modelId = "test-model";
when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelMetadata.getDimension()).thenReturn(4);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);

KNNVectorFieldMapperUtil.initialize(modelDao);

assertEquals(3, KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldType));
assertEquals(4, KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased));
}

public void testGetExpectedDimensionsFailure() {
KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1);
String modelId = "test-model";
when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.TRAINING);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);

KNNVectorFieldMapperUtil.initialize(modelDao);

IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)
);
assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage());

when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
MethodComponentContext methodComponentContext = mock(MethodComponentContext.class);
String fieldName = "test-field";
when(methodComponentContext.getName()).thenReturn(fieldName);
when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext);
when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext);

e = expectThrows(
IllegalArgumentException.class,
() -> KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)
);
assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,10 @@
package org.opensearch.knn.plugin.script;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.index.mapper.BinaryFieldMapper;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;

import java.math.BigInteger;
import java.util.ArrayList;
Expand Down Expand Up @@ -80,44 +75,4 @@ public void testParseKNNVectorQuery() {
String invalidObject = "invalidObject";
expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT));
}

public void testGetExpectedDimensions() {
KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldType.getDimension()).thenReturn(3);

KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1);
String modelId = "test-model";
when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelMetadata.getDimension()).thenReturn(4);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);

KNNScoringSpaceUtil.initialize(modelDao);

assertEquals(3, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldType));
assertEquals(4, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased));

when(modelMetadata.getState()).thenReturn(ModelState.TRAINING);

IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)
);
assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage());

when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
MethodComponentContext methodComponentContext = mock(MethodComponentContext.class);
String fieldName = "test-field";
when(methodComponentContext.getName()).thenReturn(fieldName);
when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext);
when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext);

e = expectThrows(IllegalArgumentException.class, () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased));
assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage());
}
}

0 comments on commit b4971f1

Please sign in to comment.